Merge branch 'master' of https://github.com/williamFalcon/pytorch-lightning
This commit is contained in:
commit
a093d11c40
41
README.md
41
README.md
|
@ -40,21 +40,24 @@ With lightning, you guarantee those parts of your code work so you can focus on
|
|||
To use lightning do 2 things:
|
||||
1. [Define a LightningModel](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/)
|
||||
```python
|
||||
import pytorch_lightning as ptl
|
||||
import os
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import pytorch_lightning as ptl
|
||||
|
||||
class CoolModel(ptl.LightningModule):
|
||||
|
||||
def __init(self):
|
||||
def __init__(self):
|
||||
super(CoolModel, self).__init__()
|
||||
# not the best model...
|
||||
self.l1 = torch.nn.Linear(28 * 28, 10)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.relu(self.l1(x))
|
||||
return torch.relu(self.l1(x.view(x.size(0), -1)))
|
||||
|
||||
def my_loss(self, y_hat, y):
|
||||
return F.cross_entropy(y_hat, y)
|
||||
|
@ -62,7 +65,7 @@ class CoolModel(ptl.LightningModule):
|
|||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self.forward(x)
|
||||
return {'tng_loss': self.my_loss(y_hat, y)}
|
||||
return {'loss': self.my_loss(y_hat, y)}
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
|
@ -70,23 +73,23 @@ class CoolModel(ptl.LightningModule):
|
|||
return {'val_loss': self.my_loss(y_hat, y)}
|
||||
|
||||
def validation_end(self, outputs):
|
||||
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
|
||||
return avg_loss
|
||||
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
|
||||
return {'avg_val_loss': avg_loss}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
@ptl.data_loader
|
||||
def tng_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)
|
||||
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
@ptl.data_loader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
@ptl.data_loader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||
```
|
||||
|
||||
2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
|
||||
|
@ -95,15 +98,23 @@ from pytorch_lightning import Trainer
|
|||
from test_tube import Experiment
|
||||
|
||||
model = CoolModel()
|
||||
exp = Experiment(save_dir=os.getcwd())
|
||||
|
||||
# fit on 32 gpus across 4 nodes
|
||||
exp = Experiment(save_dir='some/dir')
|
||||
trainer = Trainer(experiment=exp, nb_gpu_nodes=4, gpus=[0,1,2,3,4,5,6,7])
|
||||
# train on cpu using only 10% of the data (for demo purposes)
|
||||
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1)
|
||||
|
||||
# train on 4 gpus
|
||||
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3])
|
||||
|
||||
# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)
|
||||
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3, 4, 5, 6, 7], nb_gpu_nodes=4)
|
||||
|
||||
# train (1 epoch only here for demo)
|
||||
trainer.fit(model)
|
||||
|
||||
# see all experiment metrics here
|
||||
# tensorboard --log_dir some/dir
|
||||
# view tensorflow logs
|
||||
print(f'View tensorboard logs by running\ntensorboard --logdir {os.getcwd()}')
|
||||
print('and going to http://localhost:6006 on your browser')
|
||||
```
|
||||
|
||||
|
||||
|
@ -305,4 +316,4 @@ python multi_node_cluster_template.py --nb_gpu_nodes 4 --gpus '0,1,2,3,4,5,6,7'
|
|||
If you can't wait for the next release, install the most up to date code with:
|
||||
```bash
|
||||
pip install git+https://github.com/williamFalcon/pytorch-lightning.git@master --upgrade
|
||||
```
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue