updated docs
This commit is contained in:
parent
600c755460
commit
d272f29c88
18
README.md
18
README.md
|
@ -42,27 +42,34 @@ To use lightning do 2 things:
|
||||||
```python
|
```python
|
||||||
import pytorch_lightning as ptl
|
import pytorch_lightning as ptl
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
class CoolModel(ptl.LightningModule):
|
class CoolModel(ptl.LightningModule):
|
||||||
|
|
||||||
def __init(self):
|
def __init(self):
|
||||||
|
# not the best model...
|
||||||
self.l1 = torch.nn.Linear(28*28, 10)
|
self.l1 = torch.nn.Linear(28*28, 10)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.l1(x)
|
return torch.relu(self.l1(x))
|
||||||
|
|
||||||
|
def my_loss(self, y_hat, y):
|
||||||
|
return F.cross_entropy(y_hat, y)
|
||||||
|
|
||||||
def training_step(self, batch, batch_nb):
|
def training_step(self, batch, batch_nb):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_hat = self.forward(x)
|
y_hat = self.forward(x)
|
||||||
return {'tng_loss': some_loss(y_hat, y)}
|
return {'tng_loss': self.my_loss(y_hat, y)}
|
||||||
|
|
||||||
def validation_step(self, batch, batch_nb):
|
def validation_step(self, batch, batch_nb):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_hat = self.forward(x)
|
y_hat = self.forward(x)
|
||||||
return {'val_loss': some_loss(y_hat, y)}
|
return {'val_loss': self.my_loss(y_hat, y)}
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return [optim.Adam(self.parameters(), lr=0.02)]
|
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||||
|
|
||||||
@ptl.data_loader
|
@ptl.data_loader
|
||||||
def tng_dataloader(self):
|
def tng_dataloader(self):
|
||||||
|
@ -74,8 +81,7 @@ class CoolModel(ptl.LightningModule):
|
||||||
|
|
||||||
@ptl.data_loader
|
@ptl.data_loader
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
|
2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
|
||||||
|
|
|
@ -26,6 +26,53 @@ Otherwise, to Define a Lightning Module, implement the following methods:
|
||||||
- [update_tng_log_metrics](RequiredTrainerInterface.md#update_tng_log_metrics)
|
- [update_tng_log_metrics](RequiredTrainerInterface.md#update_tng_log_metrics)
|
||||||
- [add_model_specific_args](RequiredTrainerInterface.md#add_model_specific_args)
|
- [add_model_specific_args](RequiredTrainerInterface.md#add_model_specific_args)
|
||||||
|
|
||||||
|
---
|
||||||
|
**Minimal example**
|
||||||
|
```python
|
||||||
|
import pytorch_lightning as ptl
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
|
class CoolModel(ptl.LightningModule):
|
||||||
|
|
||||||
|
def __init(self):
|
||||||
|
# not the best model...
|
||||||
|
self.l1 = torch.nn.Linear(28*28, 10)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.relu(self.l1(x))
|
||||||
|
|
||||||
|
def my_loss(self, y_hat, y):
|
||||||
|
return F.cross_entropy(y_hat, y)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_nb):
|
||||||
|
x, y = batch
|
||||||
|
y_hat = self.forward(x)
|
||||||
|
return {'tng_loss': self.my_loss(y_hat, y)}
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_nb):
|
||||||
|
x, y = batch
|
||||||
|
y_hat = self.forward(x)
|
||||||
|
return {'val_loss': self.my_loss(y_hat, y)}
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||||
|
|
||||||
|
@ptl.data_loader
|
||||||
|
def tng_dataloader(self):
|
||||||
|
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)
|
||||||
|
|
||||||
|
@ptl.data_loader
|
||||||
|
def val_dataloader(self):
|
||||||
|
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||||
|
|
||||||
|
@ptl.data_loader
|
||||||
|
def test_dataloader(self):
|
||||||
|
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### training_step
|
### training_step
|
||||||
|
|
|
@ -64,14 +64,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def loss(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Expand model_out into your components
|
|
||||||
:param model_out:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def summarize(self):
|
def summarize(self):
|
||||||
model_summary = ModelSummary(self)
|
model_summary = ModelSummary(self)
|
||||||
print(model_summary)
|
print(model_summary)
|
||||||
|
|
Loading…
Reference in New Issue