Retrieve last logged val from result by key (#3049)
* return last logged value * Update test_results.py * Update step_result.py * Update step_result.py * pep8 * pep8
This commit is contained in:
parent
89a5d8fee9
commit
7358d456f3
|
@ -45,6 +45,12 @@ class Result(Dict):
|
|||
}
|
||||
}
|
||||
|
||||
def __getitem__(self, key: Union[str, Any]) -> Any:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError:
|
||||
return super().__getitem__(f'step_{key}')
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
try:
|
||||
if key == 'callback_metrics':
|
||||
|
|
|
@ -236,3 +236,11 @@ def test_result_gather_mixed_types():
|
|||
expected = [1.2, ["bar", None], torch.tensor(1)]
|
||||
assert isinstance(result["foo"], list)
|
||||
assert result["foo"] == expected
|
||||
|
||||
|
||||
def test_result_retrieve_last_logged_item():
|
||||
result = Result()
|
||||
result.log('a', 5., on_step=True, on_epoch=True)
|
||||
assert result['epoch_a'] == 5.
|
||||
assert result['step_a'] == 5.
|
||||
assert result['a'] == 5.
|
||||
|
|
Loading…
Reference in New Issue