diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index 06a8445948..06c57dd5ac 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -262,6 +262,11 @@ def validation_step(self, data_batch, batch_nb): out = self.forward(x) loss = self.loss(out, x) + # log 6 example images + sample_imgs = x[:6] + grid = torchvision.utils.make_grid(sample_imgs) + self.experiment.add_image('example_images', grid, 0) + # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)