parent
97e6f35b34
commit
28f79d9f7a
|
@ -347,6 +347,23 @@ class Result(Dict):
|
|||
if 'hiddens' in self:
|
||||
del self['hiddens']
|
||||
|
||||
def rename_keys(self, map_dict: dict):
|
||||
"""
|
||||
Maps key values to the target values. Useful when renaming variables in mass.
|
||||
|
||||
Args:
|
||||
map_dict:
|
||||
"""
|
||||
meta = self.meta
|
||||
for source, dest in map_dict.items():
|
||||
# map the main keys
|
||||
self[dest] = self[source]
|
||||
del self[source]
|
||||
|
||||
# map meta
|
||||
meta[dest] = meta[source]
|
||||
del meta[source]
|
||||
|
||||
|
||||
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
|
||||
for out in outputs:
|
||||
|
|
|
@ -532,3 +532,14 @@ def test_full_train_loop_with_results_obj_dp(tmpdir):
|
|||
assert 'train_step_metric' in seen_keys
|
||||
assert 'train_step_end_metric' in seen_keys
|
||||
assert 'epoch_train_epoch_end_metric' in seen_keys
|
||||
|
||||
|
||||
def test_result_map(tmpdir):
|
||||
result = TrainResult()
|
||||
result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)})
|
||||
result.rename_keys({'x1': 'y1', 'x2': 'y2'})
|
||||
|
||||
assert 'x1' not in result
|
||||
assert 'x2' not in result
|
||||
assert 'y1' in result
|
||||
assert 'y2' in result
|
||||
|
|
Loading…
Reference in New Issue