updated docs
This commit is contained in:
parent
600c755460
commit
d272f29c88
16
README.md
16
README.md
|
@ -42,27 +42,34 @@ To use lightning do 2 things:
|
|||
```python
|
||||
import pytorch_lightning as ptl
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
class CoolModel(ptl.LightningModule):
|
||||
|
||||
def __init(self):
|
||||
# not the best model...
|
||||
self.l1 = torch.nn.Linear(28*28, 10)
|
||||
|
||||
def forward(self, x):
|
||||
return self.l1(x)
|
||||
return torch.relu(self.l1(x))
|
||||
|
||||
def my_loss(self, y_hat, y):
|
||||
return F.cross_entropy(y_hat, y)
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self.forward(x)
|
||||
return {'tng_loss': some_loss(y_hat, y)}
|
||||
return {'tng_loss': self.my_loss(y_hat, y)}
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self.forward(x)
|
||||
return {'val_loss': some_loss(y_hat, y)}
|
||||
return {'val_loss': self.my_loss(y_hat, y)}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return [optim.Adam(self.parameters(), lr=0.02)]
|
||||
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
@ptl.data_loader
|
||||
def tng_dataloader(self):
|
||||
|
@ -75,7 +82,6 @@ class CoolModel(ptl.LightningModule):
|
|||
@ptl.data_loader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||
|
||||
```
|
||||
|
||||
2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
|
||||
|
|
|
@ -26,6 +26,53 @@ Otherwise, to Define a Lightning Module, implement the following methods:
|
|||
- [update_tng_log_metrics](RequiredTrainerInterface.md#update_tng_log_metrics)
|
||||
- [add_model_specific_args](RequiredTrainerInterface.md#add_model_specific_args)
|
||||
|
||||
---
|
||||
**Minimal example**
|
||||
```python
|
||||
import pytorch_lightning as ptl
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
class CoolModel(ptl.LightningModule):
|
||||
|
||||
def __init(self):
|
||||
# not the best model...
|
||||
self.l1 = torch.nn.Linear(28*28, 10)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.relu(self.l1(x))
|
||||
|
||||
def my_loss(self, y_hat, y):
|
||||
return F.cross_entropy(y_hat, y)
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self.forward(x)
|
||||
return {'tng_loss': self.my_loss(y_hat, y)}
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self.forward(x)
|
||||
return {'val_loss': self.my_loss(y_hat, y)}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
@ptl.data_loader
|
||||
def tng_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)
|
||||
|
||||
@ptl.data_loader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||
|
||||
@ptl.data_loader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### training_step
|
||||
|
|
|
@ -64,14 +64,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def loss(self, *args, **kwargs):
|
||||
"""
|
||||
Expand model_out into your components
|
||||
:param model_out:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def summarize(self):
|
||||
model_summary = ModelSummary(self)
|
||||
print(model_summary)
|
||||
|
|
Loading…
Reference in New Issue