Merge pull request #24 from williamFalcon/keys

Keys
This commit is contained in:
William Falcon 2019-07-28 12:11:49 -04:00 committed by GitHub
commit 3ffeba4caa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 7 deletions

View File

@ -772,7 +772,7 @@ class Trainer(TrainerIO):
output = self.model.training_step(data_batch, batch_nb)
try:
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
model_specific_tqdm_metrics_dic = output['prog']
except Exception as e:
model_specific_tqdm_metrics_dic = {}

View File

@ -96,12 +96,15 @@ class LightningTestModel(LightningModule):
if self.trainer.use_dp:
loss_val = loss_val.unsqueeze(0)
output = OrderedDict({
'loss': loss_val
})
# can also return just a scalar instead of a dict (return loss_val)
return output
# alternate possible outputs to test
if self.trainer.batch_nb % 1 == 0:
output = OrderedDict({
'loss': loss_val,
'prog': {'some_val': loss_val * loss_val}
})
return output
if self.trainer.batch_nb % 2 == 0:
return loss_val
def validation_step(self, data_batch, batch_i):
"""