cleaned readme

This commit is contained in:
William Falcon 2019-07-25 10:11:51 -04:00
parent b989358c9b
commit 74817c2fb1
1 changed files with 59 additions and 2 deletions

View File

@ -38,8 +38,65 @@ gpu training, etc... every time you start a project. Let lightning handle all of
data and what happens in the training, testing and validation loop and lightning will do the rest.
To use lightning do 2 things:
1. [Define a Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/trainer_cpu_template.py).
2. [Define a LightningModel](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/lightning_module_template.py).
1. [Define a LightningModel](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/)
```python
from pytorch_lightning import LightningModule
import torch
class CoolModel(LightningModule):
def __init(self):
self.l1 = torch.nn.Linear(28*28, 10)
def forward(self, x):
return self.l1(x)
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'tng_loss': some_loss(y_hat, y)}
def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': some_loss(y_hat, y)}
def configure_optimizers(self):
return [optim.Adam(self.parameters(), lr=0.02)]
@property
def tng_dataloader(self):
mnist = MNIST('path/to/save', train=True)
return DataLoader(mnist, batch_size=32)
@property
def val_dataloader(self):
mnist = MNIST('path/to/save', train=False)
return DataLoader(mnist, batch_size=32)
@property
def test_dataloader(self):
mnist = MNIST('sam/as/val/for/simplicity', train=False)
return DataLoader(mnist, batch_size=32)
```
2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
```python
from pytorch_lightning import Trainer
from test_tube import Experiment
model = CoolModel()
# fit on 32 gpus across 4 nodes
exp = Experiment(save_dir='some/dir')
trainer = Trainer(experiment=exp, nb_gpu_nodes=4, gpus=[0,1,2,3,4,5,6,7])
trainer.fit(model)
# see all experiment metrics here
# tensorboard --log_dir some/dir
```
## What does lightning control for me?
Everything!