updated output of test models
This commit is contained in:
parent
d372f9a2e2
commit
6bb3c0306a
|
@ -96,12 +96,15 @@ class LightningTestModel(LightningModule):
|
|||
if self.trainer.use_dp:
|
||||
loss_val = loss_val.unsqueeze(0)
|
||||
|
||||
# alternate possible outputs to test
|
||||
if self.trainer.batch_nb % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'loss': loss_val
|
||||
'loss': loss_val,
|
||||
'prog': {'some_val': loss_val * loss_val}
|
||||
})
|
||||
|
||||
# can also return just a scalar instead of a dict (return loss_val)
|
||||
return output
|
||||
if self.trainer.batch_nb % 2 == 0:
|
||||
return loss_val
|
||||
|
||||
def validation_step(self, data_batch, batch_i):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue