added arg docs
This commit is contained in:
parent
f01cb63234
commit
d7409afed9
|
@ -86,7 +86,9 @@ class LightningTemplateModel(LightningModule):
|
|||
output = OrderedDict({
|
||||
'loss': loss_val
|
||||
})
|
||||
return loss_val
|
||||
|
||||
# can also return just a scalar instead of a dict (return loss_val)
|
||||
return output
|
||||
|
||||
def validation_step(self, data_batch, batch_i):
|
||||
"""
|
||||
|
@ -108,7 +110,9 @@ class LightningTemplateModel(LightningModule):
|
|||
'val_loss': loss_val,
|
||||
'val_acc': torch.tensor(val_acc).cuda(loss_val.device.index),
|
||||
})
|
||||
return loss_val
|
||||
|
||||
# can also return just a scalar instead of a dict (return loss_val)
|
||||
return output
|
||||
|
||||
def validation_end(self, outputs):
|
||||
"""
|
||||
|
@ -116,7 +120,9 @@ class LightningTemplateModel(LightningModule):
|
|||
:param outputs: list of individual outputs of each validation step
|
||||
:return:
|
||||
"""
|
||||
return torch.stack(outputs).mean()
|
||||
# if returned a scalar from validation_step, outputs is a list of tensor scalars
|
||||
# we return just the average in this case (if we want)
|
||||
# return torch.stack(outputs).mean()
|
||||
|
||||
val_loss_mean = 0
|
||||
val_acc_mean = 0
|
||||
|
|
Loading…
Reference in New Issue