cleaned readme

This commit is contained in:
William Falcon 2019-07-25 10:11:51 -04:00
parent b989358c9b
commit 74817c2fb1
1 changed files with 59 additions and 2 deletions

View File

@ -38,8 +38,65 @@ gpu training, etc... every time you start a project. Let lightning handle all of
data and what happens in the training, testing and validation loop and lightning will do the rest. data and what happens in the training, testing and validation loop and lightning will do the rest.
To use lightning do 2 things: To use lightning do 2 things:
1. [Define a Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/trainer_cpu_template.py). 1. [Define a LightningModel](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/)
2. [Define a LightningModel](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/lightning_module_template.py). ```python
from pytorch_lightning import LightningModule
import torch
class CoolModel(LightningModule):
def __init(self):
self.l1 = torch.nn.Linear(28*28, 10)
def forward(self, x):
return self.l1(x)
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'tng_loss': some_loss(y_hat, y)}
def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': some_loss(y_hat, y)}
def configure_optimizers(self):
return [optim.Adam(self.parameters(), lr=0.02)]
@property
def tng_dataloader(self):
mnist = MNIST('path/to/save', train=True)
return DataLoader(mnist, batch_size=32)
@property
def val_dataloader(self):
mnist = MNIST('path/to/save', train=False)
return DataLoader(mnist, batch_size=32)
@property
def test_dataloader(self):
mnist = MNIST('sam/as/val/for/simplicity', train=False)
return DataLoader(mnist, batch_size=32)
```
2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
```python
from pytorch_lightning import Trainer
from test_tube import Experiment
model = CoolModel()
# fit on 32 gpus across 4 nodes
exp = Experiment(save_dir='some/dir')
trainer = Trainer(experiment=exp, nb_gpu_nodes=4, gpus=[0,1,2,3,4,5,6,7])
trainer.fit(model)
# see all experiment metrics here
# tensorboard --log_dir some/dir
```
## What does lightning control for me? ## What does lightning control for me?
Everything! Everything!