parent
eef1b5dbc8
commit
2d5a7f5e7d
|
@ -476,6 +476,9 @@ class Result(Dict):
|
|||
else:
|
||||
tbptt_reduce_fx = meta[k]['tbptt_reduce_fx']
|
||||
|
||||
if isinstance(value, list):
|
||||
value = torch.tensor(value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
# TODO: recursive reduce:
|
||||
_recursive_fx_apply(value, tbptt_reduce_fx)
|
||||
|
|
|
@ -21,6 +21,8 @@ def test__validation_step__log(tmpdir):
|
|||
acc = self.step(batch, batch_idx)
|
||||
acc = acc + batch_idx
|
||||
self.log('a', acc, on_step=True, on_epoch=True)
|
||||
self.log('a2', 2)
|
||||
|
||||
self.training_step_called = True
|
||||
return acc
|
||||
|
||||
|
@ -50,6 +52,7 @@ def test__validation_step__log(tmpdir):
|
|||
# make sure all the metrics are available for callbacks
|
||||
expected_logged_metrics = {
|
||||
'a',
|
||||
'a2',
|
||||
'a_step',
|
||||
'a_epoch',
|
||||
'b',
|
||||
|
@ -65,7 +68,7 @@ def test__validation_step__log(tmpdir):
|
|||
# on purpose DO NOT allow step_b... it's silly to monitor val step metrics
|
||||
callback_metrics = set(trainer.callback_metrics.keys())
|
||||
callback_metrics.remove('debug_epoch')
|
||||
expected_cb_metrics = {'a', 'b', 'a_epoch', 'b_epoch', 'a_step'}
|
||||
expected_cb_metrics = {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'}
|
||||
assert expected_cb_metrics == callback_metrics
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue