commit
3ffeba4caa
|
@ -772,7 +772,7 @@ class Trainer(TrainerIO):
|
|||
output = self.model.training_step(data_batch, batch_nb)
|
||||
|
||||
try:
|
||||
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
|
||||
model_specific_tqdm_metrics_dic = output['prog']
|
||||
except Exception as e:
|
||||
model_specific_tqdm_metrics_dic = {}
|
||||
|
||||
|
|
|
@ -96,12 +96,15 @@ class LightningTestModel(LightningModule):
|
|||
if self.trainer.use_dp:
|
||||
loss_val = loss_val.unsqueeze(0)
|
||||
|
||||
output = OrderedDict({
|
||||
'loss': loss_val
|
||||
})
|
||||
|
||||
# can also return just a scalar instead of a dict (return loss_val)
|
||||
return output
|
||||
# alternate possible outputs to test
|
||||
if self.trainer.batch_nb % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'loss': loss_val,
|
||||
'prog': {'some_val': loss_val * loss_val}
|
||||
})
|
||||
return output
|
||||
if self.trainer.batch_nb % 2 == 0:
|
||||
return loss_val
|
||||
|
||||
def validation_step(self, data_batch, batch_i):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue