cleaned readme
This commit is contained in:
parent
b989358c9b
commit
74817c2fb1
61
README.md
61
README.md
|
@ -38,8 +38,65 @@ gpu training, etc... every time you start a project. Let lightning handle all of
|
||||||
data and what happens in the training, testing and validation loop and lightning will do the rest.
|
data and what happens in the training, testing and validation loop and lightning will do the rest.
|
||||||
|
|
||||||
To use lightning do 2 things:
|
To use lightning do 2 things:
|
||||||
1. [Define a Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/trainer_cpu_template.py).
|
1. [Define a LightningModel](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/)
|
||||||
2. [Define a LightningModel](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/lightning_module_template.py).
|
```python
|
||||||
|
from pytorch_lightning import LightningModule
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class CoolModel(LightningModule):
|
||||||
|
|
||||||
|
def __init(self):
|
||||||
|
self.l1 = torch.nn.Linear(28*28, 10)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.l1(x)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_nb):
|
||||||
|
x, y = batch
|
||||||
|
y_hat = self.forward(x)
|
||||||
|
return {'tng_loss': some_loss(y_hat, y)}
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_nb):
|
||||||
|
x, y = batch
|
||||||
|
y_hat = self.forward(x)
|
||||||
|
return {'val_loss': some_loss(y_hat, y)}
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return [optim.Adam(self.parameters(), lr=0.02)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tng_dataloader(self):
|
||||||
|
mnist = MNIST('path/to/save', train=True)
|
||||||
|
return DataLoader(mnist, batch_size=32)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def val_dataloader(self):
|
||||||
|
mnist = MNIST('path/to/save', train=False)
|
||||||
|
return DataLoader(mnist, batch_size=32)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_dataloader(self):
|
||||||
|
mnist = MNIST('sam/as/val/for/simplicity', train=False)
|
||||||
|
return DataLoader(mnist, batch_size=32)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
|
||||||
|
```python
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
from test_tube import Experiment
|
||||||
|
|
||||||
|
model = CoolModel()
|
||||||
|
|
||||||
|
# fit on 32 gpus across 4 nodes
|
||||||
|
exp = Experiment(save_dir='some/dir')
|
||||||
|
trainer = Trainer(experiment=exp, nb_gpu_nodes=4, gpus=[0,1,2,3,4,5,6,7])
|
||||||
|
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
# see all experiment metrics here
|
||||||
|
# tensorboard --log_dir some/dir
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## What does lightning control for me?
|
## What does lightning control for me?
|
||||||
Everything!
|
Everything!
|
||||||
|
|
Loading…
Reference in New Issue