updated doc indexes
This commit is contained in:
parent
cdb4de3606
commit
e89975d19e
|
@ -29,21 +29,24 @@ Otherwise, to Define a Lightning Module, implement the following methods:
|
||||||
---
|
---
|
||||||
**Minimal example**
|
**Minimal example**
|
||||||
```python
|
```python
|
||||||
import pytorch_lightning as ptl
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
import pytorch_lightning as ptl
|
||||||
|
|
||||||
class CoolModel(ptl.LightningModule):
|
class CoolModel(ptl.LightningModule):
|
||||||
|
|
||||||
def __init(self):
|
def __init__(self):
|
||||||
super(CoolModel, self).__init__()
|
super(CoolModel, self).__init__()
|
||||||
# not the best model...
|
# not the best model...
|
||||||
self.l1 = torch.nn.Linear(28 * 28, 10)
|
self.l1 = torch.nn.Linear(28 * 28, 10)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.relu(self.l1(x))
|
return torch.relu(self.l1(x.view(x.size(0), -1)))
|
||||||
|
|
||||||
def my_loss(self, y_hat, y):
|
def my_loss(self, y_hat, y):
|
||||||
return F.cross_entropy(y_hat, y)
|
return F.cross_entropy(y_hat, y)
|
||||||
|
@ -51,7 +54,7 @@ class CoolModel(ptl.LightningModule):
|
||||||
def training_step(self, batch, batch_nb):
|
def training_step(self, batch, batch_nb):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_hat = self.forward(x)
|
y_hat = self.forward(x)
|
||||||
return {'tng_loss': self.my_loss(y_hat, y)}
|
return {'loss': self.my_loss(y_hat, y)}
|
||||||
|
|
||||||
def validation_step(self, batch, batch_nb):
|
def validation_step(self, batch, batch_nb):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
|
@ -59,23 +62,23 @@ class CoolModel(ptl.LightningModule):
|
||||||
return {'val_loss': self.my_loss(y_hat, y)}
|
return {'val_loss': self.my_loss(y_hat, y)}
|
||||||
|
|
||||||
def validation_end(self, outputs):
|
def validation_end(self, outputs):
|
||||||
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
|
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
|
||||||
return avg_loss
|
return {'avg_val_loss': avg_loss}
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||||
|
|
||||||
@ptl.data_loader
|
@ptl.data_loader
|
||||||
def tng_dataloader(self):
|
def tng_dataloader(self):
|
||||||
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)
|
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||||
|
|
||||||
@ptl.data_loader
|
@ptl.data_loader
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||||
|
|
||||||
@ptl.data_loader
|
@ptl.data_loader
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
Loading…
Reference in New Issue