parent
97e6f35b34
commit
28f79d9f7a
|
@ -347,6 +347,23 @@ class Result(Dict):
|
||||||
if 'hiddens' in self:
|
if 'hiddens' in self:
|
||||||
del self['hiddens']
|
del self['hiddens']
|
||||||
|
|
||||||
|
def rename_keys(self, map_dict: dict):
|
||||||
|
"""
|
||||||
|
Maps key values to the target values. Useful when renaming variables in mass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
map_dict:
|
||||||
|
"""
|
||||||
|
meta = self.meta
|
||||||
|
for source, dest in map_dict.items():
|
||||||
|
# map the main keys
|
||||||
|
self[dest] = self[source]
|
||||||
|
del self[source]
|
||||||
|
|
||||||
|
# map meta
|
||||||
|
meta[dest] = meta[source]
|
||||||
|
del meta[source]
|
||||||
|
|
||||||
|
|
||||||
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
|
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
|
||||||
for out in outputs:
|
for out in outputs:
|
||||||
|
|
|
@ -532,3 +532,14 @@ def test_full_train_loop_with_results_obj_dp(tmpdir):
|
||||||
assert 'train_step_metric' in seen_keys
|
assert 'train_step_metric' in seen_keys
|
||||||
assert 'train_step_end_metric' in seen_keys
|
assert 'train_step_end_metric' in seen_keys
|
||||||
assert 'epoch_train_epoch_end_metric' in seen_keys
|
assert 'epoch_train_epoch_end_metric' in seen_keys
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_map(tmpdir):
|
||||||
|
result = TrainResult()
|
||||||
|
result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)})
|
||||||
|
result.rename_keys({'x1': 'y1', 'x2': 'y2'})
|
||||||
|
|
||||||
|
assert 'x1' not in result
|
||||||
|
assert 'x2' not in result
|
||||||
|
assert 'y1' in result
|
||||||
|
assert 'y2' in result
|
||||||
|
|
Loading…
Reference in New Issue