diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index a0c65af546..7f5459e161 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -86,7 +86,9 @@ class LightningTemplateModel(LightningModule): output = OrderedDict({ 'loss': loss_val }) - return loss_val + + # can also return just a scalar instead of a dict (return loss_val) + return output def validation_step(self, data_batch, batch_i): """ @@ -108,7 +110,9 @@ class LightningTemplateModel(LightningModule): 'val_loss': loss_val, 'val_acc': torch.tensor(val_acc).cuda(loss_val.device.index), }) - return loss_val + + # can also return just a scalar instead of a dict (return loss_val) + return output def validation_end(self, outputs): """ @@ -116,7 +120,9 @@ class LightningTemplateModel(LightningModule): :param outputs: list of individual outputs of each validation step :return: """ - return torch.stack(outputs).mean() + # if returned a scalar from validation_step, outputs is a list of tensor scalars + # we return just the average in this case (if we want) + # return torch.stack(outputs).mean() val_loss_mean = 0 val_acc_mean = 0