added arg docs

This commit is contained in:
William Falcon 2019-07-18 12:11:59 -04:00
parent f01cb63234
commit d7409afed9
1 changed files with 9 additions and 3 deletions

View File

@ -86,7 +86,9 @@ class LightningTemplateModel(LightningModule):
output = OrderedDict({
'loss': loss_val
})
return loss_val
# can also return just a scalar instead of a dict (return loss_val)
return output
def validation_step(self, data_batch, batch_i):
"""
@ -108,7 +110,9 @@ class LightningTemplateModel(LightningModule):
'val_loss': loss_val,
'val_acc': torch.tensor(val_acc).cuda(loss_val.device.index),
})
return loss_val
# can also return just a scalar instead of a dict (return loss_val)
return output
def validation_end(self, outputs):
"""
@ -116,7 +120,9 @@ class LightningTemplateModel(LightningModule):
:param outputs: list of individual outputs of each validation step
:return:
"""
return torch.stack(outputs).mean()
# if returned a scalar from validation_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
# return torch.stack(outputs).mean()
val_loss_mean = 0
val_acc_mean = 0