From 28f79d9f7ad7a249c7cbde2b98d6cb38fb4d755a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 9 Aug 2020 18:50:39 -0400 Subject: [PATCH] Mapkeys (#2900) * added a map dict * added a map dict --- pytorch_lightning/core/step_result.py | 17 +++++++++++++++++ .../trainer/test_trainer_steps_result_return.py | 11 +++++++++++ 2 files changed, 28 insertions(+) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 5174f8aa44..8b482f0436 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -347,6 +347,23 @@ class Result(Dict): if 'hiddens' in self: del self['hiddens'] + def rename_keys(self, map_dict: dict): + """ + Maps key values to the target values. Useful when renaming variables in mass. + + Args: + map_dict: + """ + meta = self.meta + for source, dest in map_dict.items(): + # map the main keys + self[dest] = self[source] + del self[source] + + # map meta + meta[dest] = meta[source] + del meta[source] + def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]: for out in outputs: diff --git a/tests/trainer/test_trainer_steps_result_return.py b/tests/trainer/test_trainer_steps_result_return.py index 62b0b6e483..3e291df73a 100644 --- a/tests/trainer/test_trainer_steps_result_return.py +++ b/tests/trainer/test_trainer_steps_result_return.py @@ -532,3 +532,14 @@ def test_full_train_loop_with_results_obj_dp(tmpdir): assert 'train_step_metric' in seen_keys assert 'train_step_end_metric' in seen_keys assert 'epoch_train_epoch_end_metric' in seen_keys + + +def test_result_map(tmpdir): + result = TrainResult() + result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)}) + result.rename_keys({'x1': 'y1', 'x2': 'y2'}) + + assert 'x1' not in result + assert 'x2' not in result + assert 'y1' in result + assert 'y2' in result