updated output of test models

This commit is contained in:
William Falcon 2019-07-28 11:51:32 -04:00
parent d372f9a2e2
commit 6bb3c0306a
1 changed files with 9 additions and 6 deletions

View File

@ -96,12 +96,15 @@ class LightningTestModel(LightningModule):
if self.trainer.use_dp:
loss_val = loss_val.unsqueeze(0)
output = OrderedDict({
'loss': loss_val
})
# can also return just a scalar instead of a dict (return loss_val)
return output
# alternate possible outputs to test
if self.trainer.batch_nb % 1 == 0:
output = OrderedDict({
'loss': loss_val,
'prog': {'some_val': loss_val * loss_val}
})
return output
if self.trainer.batch_nb % 2 == 0:
return loss_val
def validation_step(self, data_batch, batch_i):
"""