This commit is contained in:
William Falcon 2020-10-13 06:42:11 -04:00 committed by GitHub
parent eef1b5dbc8
commit 2d5a7f5e7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View File

@ -476,6 +476,9 @@ class Result(Dict):
else:
tbptt_reduce_fx = meta[k]['tbptt_reduce_fx']
if isinstance(value, list):
value = torch.tensor(value)
if isinstance(value, dict):
# TODO: recursive reduce:
_recursive_fx_apply(value, tbptt_reduce_fx)

View File

@ -21,6 +21,8 @@ def test__validation_step__log(tmpdir):
acc = self.step(batch, batch_idx)
acc = acc + batch_idx
self.log('a', acc, on_step=True, on_epoch=True)
self.log('a2', 2)
self.training_step_called = True
return acc
@ -50,6 +52,7 @@ def test__validation_step__log(tmpdir):
# make sure all the metrics are available for callbacks
expected_logged_metrics = {
'a',
'a2',
'a_step',
'a_epoch',
'b',
@ -65,7 +68,7 @@ def test__validation_step__log(tmpdir):
# on purpose DO NOT allow step_b... it's silly to monitor val step metrics
callback_metrics = set(trainer.callback_metrics.keys())
callback_metrics.remove('debug_epoch')
expected_cb_metrics = {'a', 'b', 'a_epoch', 'b_epoch', 'a_step'}
expected_cb_metrics = {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'}
assert expected_cb_metrics == callback_metrics