From d272f29c881d1f422aaf4ffa2b3fec6222b67d59 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 25 Jul 2019 11:52:54 -0400 Subject: [PATCH] updated docs --- README.md | 18 ++++--- .../RequiredTrainerInterface.md | 47 +++++++++++++++++++ pytorch_lightning/root_module/root_module.py | 8 ---- 3 files changed, 59 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 3f62526eb6..8dc89cd4f7 100644 --- a/README.md +++ b/README.md @@ -42,27 +42,34 @@ To use lightning do 2 things: ```python import pytorch_lightning as ptl import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import MNIST class CoolModel(ptl.LightningModule): def __init(self): + # not the best model... self.l1 = torch.nn.Linear(28*28, 10) def forward(self, x): - return self.l1(x) + return torch.relu(self.l1(x)) + + def my_loss(self, y_hat, y): + return F.cross_entropy(y_hat, y) def training_step(self, batch, batch_nb): x, y = batch y_hat = self.forward(x) - return {'tng_loss': some_loss(y_hat, y)} + return {'tng_loss': self.my_loss(y_hat, y)} def validation_step(self, batch, batch_nb): x, y = batch y_hat = self.forward(x) - return {'val_loss': some_loss(y_hat, y)} + return {'val_loss': self.my_loss(y_hat, y)} def configure_optimizers(self): - return [optim.Adam(self.parameters(), lr=0.02)] + return [torch.optim.Adam(self.parameters(), lr=0.02)] @ptl.data_loader def tng_dataloader(self): @@ -74,8 +81,7 @@ class CoolModel(ptl.LightningModule): @ptl.data_loader def test_dataloader(self): - return DataLoader(MNIST('path/to/save', train=False), batch_size=32) - + return DataLoader(MNIST('path/to/save', train=False), batch_size=32) ``` 2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/) diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index 7625be2494..5faa67a701 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -26,6 +26,53 @@ Otherwise, to Define a Lightning Module, implement the following methods: - [update_tng_log_metrics](RequiredTrainerInterface.md#update_tng_log_metrics) - [add_model_specific_args](RequiredTrainerInterface.md#add_model_specific_args) +--- +**Minimal example** +```python +import pytorch_lightning as ptl +import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import MNIST + +class CoolModel(ptl.LightningModule): + + def __init(self): + # not the best model... + self.l1 = torch.nn.Linear(28*28, 10) + + def forward(self, x): + return torch.relu(self.l1(x)) + + def my_loss(self, y_hat, y): + return F.cross_entropy(y_hat, y) + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self.forward(x) + return {'tng_loss': self.my_loss(y_hat, y)} + + def validation_step(self, batch, batch_nb): + x, y = batch + y_hat = self.forward(x) + return {'val_loss': self.my_loss(y_hat, y)} + + def configure_optimizers(self): + return [torch.optim.Adam(self.parameters(), lr=0.02)] + + @ptl.data_loader + def tng_dataloader(self): + return DataLoader(MNIST('path/to/save', train=True), batch_size=32) + + @ptl.data_loader + def val_dataloader(self): + return DataLoader(MNIST('path/to/save', train=False), batch_size=32) + + @ptl.data_loader + def test_dataloader(self): + return DataLoader(MNIST('path/to/save', train=False), batch_size=32) +``` + --- ### training_step diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index d6e0ec9666..584c5aa4b4 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -64,14 +64,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): """ raise NotImplementedError - def loss(self, *args, **kwargs): - """ - Expand model_out into your components - :param model_out: - :return: - """ - raise NotImplementedError - def summarize(self): model_summary = ModelSummary(self) print(model_summary)