diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 1e4023af07..b4502e1b70 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -45,6 +45,12 @@ class Result(Dict): } } + def __getitem__(self, key: Union[str, Any]) -> Any: + try: + return super().__getitem__(key) + except KeyError: + return super().__getitem__(f'step_{key}') + def __getattr__(self, key: str) -> Any: try: if key == 'callback_metrics': diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 187423c5ce..bbae6e6114 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -236,3 +236,11 @@ def test_result_gather_mixed_types(): expected = [1.2, ["bar", None], torch.tensor(1)] assert isinstance(result["foo"], list) assert result["foo"] == expected + + +def test_result_retrieve_last_logged_item(): + result = Result() + result.log('a', 5., on_step=True, on_epoch=True) + assert result['epoch_a'] == 5. + assert result['step_a'] == 5. + assert result['a'] == 5.