From 74817c2fb138570f27c5b006b54aacc5bd2bb32e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 25 Jul 2019 10:11:51 -0400 Subject: [PATCH] cleaned readme --- README.md | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1c3a0949ae..24e0c731c8 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,65 @@ gpu training, etc... every time you start a project. Let lightning handle all of data and what happens in the training, testing and validation loop and lightning will do the rest. To use lightning do 2 things: -1. [Define a Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/trainer_cpu_template.py). -2. [Define a LightningModel](https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/lightning_module_template.py). +1. [Define a LightningModel](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/) +```python +from pytorch_lightning import LightningModule +import torch + +class CoolModel(LightningModule): + + def __init(self): + self.l1 = torch.nn.Linear(28*28, 10) + + def forward(self, x): + return self.l1(x) + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self.forward(x) + return {'tng_loss': some_loss(y_hat, y)} + + def validation_step(self, batch, batch_nb): + x, y = batch + y_hat = self.forward(x) + return {'val_loss': some_loss(y_hat, y)} + + def configure_optimizers(self): + return [optim.Adam(self.parameters(), lr=0.02)] + + @property + def tng_dataloader(self): + mnist = MNIST('path/to/save', train=True) + return DataLoader(mnist, batch_size=32) + + @property + def val_dataloader(self): + mnist = MNIST('path/to/save', train=False) + return DataLoader(mnist, batch_size=32) + + @property + def test_dataloader(self): + mnist = MNIST('sam/as/val/for/simplicity', train=False) + return DataLoader(mnist, batch_size=32) +``` + +2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/) +```python +from pytorch_lightning import Trainer +from test_tube import Experiment + +model = CoolModel() + +# fit on 32 gpus across 4 nodes +exp = Experiment(save_dir='some/dir') +trainer = Trainer(experiment=exp, nb_gpu_nodes=4, gpus=[0,1,2,3,4,5,6,7]) + +trainer.fit(model) + +# see all experiment metrics here +# tensorboard --log_dir some/dir +``` + ## What does lightning control for me? Everything!