* added a map dict

* added a map dict
This commit is contained in:
William Falcon 2020-08-09 18:50:39 -04:00 committed by GitHub
parent 97e6f35b34
commit 28f79d9f7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 0 deletions

View File

@ -347,6 +347,23 @@ class Result(Dict):
if 'hiddens' in self:
del self['hiddens']
def rename_keys(self, map_dict: dict):
"""
Maps key values to the target values. Useful when renaming variables in mass.
Args:
map_dict:
"""
meta = self.meta
for source, dest in map_dict.items():
# map the main keys
self[dest] = self[source]
del self[source]
# map meta
meta[dest] = meta[source]
del meta[source]
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
for out in outputs:

View File

@ -532,3 +532,14 @@ def test_full_train_loop_with_results_obj_dp(tmpdir):
assert 'train_step_metric' in seen_keys
assert 'train_step_end_metric' in seen_keys
assert 'epoch_train_epoch_end_metric' in seen_keys
def test_result_map(tmpdir):
result = TrainResult()
result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)})
result.rename_keys({'x1': 'y1', 'x2': 'y2'})
assert 'x1' not in result
assert 'x2' not in result
assert 'y1' in result
assert 'y2' in result