From e89975d19eba4d418260205dede7a352176c18dc Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 28 Jul 2019 08:14:50 -0400 Subject: [PATCH] updated doc indexes --- .../RequiredTrainerInterface.md | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index b9bb1d36fc..ff4adf57a6 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -29,21 +29,24 @@ Otherwise, to Define a Lightning Module, implement the following methods: --- **Minimal example** ```python -import pytorch_lightning as ptl +import os import torch from torch.nn import functional as F from torch.utils.data import DataLoader from torchvision.datasets import MNIST +import torchvision.transforms as transforms + +import pytorch_lightning as ptl class CoolModel(ptl.LightningModule): - def __init(self): + def __init__(self): super(CoolModel, self).__init__() # not the best model... self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): - return torch.relu(self.l1(x)) + return torch.relu(self.l1(x.view(x.size(0), -1))) def my_loss(self, y_hat, y): return F.cross_entropy(y_hat, y) @@ -51,7 +54,7 @@ class CoolModel(ptl.LightningModule): def training_step(self, batch, batch_nb): x, y = batch y_hat = self.forward(x) - return {'tng_loss': self.my_loss(y_hat, y)} + return {'loss': self.my_loss(y_hat, y)} def validation_step(self, batch, batch_nb): x, y = batch @@ -59,23 +62,23 @@ class CoolModel(ptl.LightningModule): return {'val_loss': self.my_loss(y_hat, y)} def validation_end(self, outputs): - avg_loss = torch.stack([x for x in outputs['val_loss']]).mean() - return avg_loss + avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + return {'avg_val_loss': avg_loss} def configure_optimizers(self): return [torch.optim.Adam(self.parameters(), lr=0.02)] @ptl.data_loader def tng_dataloader(self): - return DataLoader(MNIST('path/to/save', train=True), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @ptl.data_loader def val_dataloader(self): - return DataLoader(MNIST('path/to/save', train=False), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @ptl.data_loader def test_dataloader(self): - return DataLoader(MNIST('path/to/save', train=False), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) ``` ---