diff --git a/README.md b/README.md index 306b214151..4bab2881a6 100644 --- a/README.md +++ b/README.md @@ -28,31 +28,16 @@ This is a research tool I built for myself internally while doing my PhD. The AP ## What is it? Keras is too abstract for researchers. Lightning abstracts the full training loop but gives you control in the critical points. + +## Why do I want to use lightning? +Because you want to use best practices and get gpu training, multi-node training, checkpointing, mixed-precision, etc... for free. + To use lightning do 2 things: 1. Define a model with the lightning interface. 2. Feed this model to the lightning trainer. *Example model definition* ```python -from pytorch_lightning import RootModule - -class MyModel(RootModule): - - def init(self): # define model - def training_step(self, data_batch, batch_nb): # what to do with a training batch - def validation_step(self, data_batch, batch_nb): # what to do with a val/test batch - def validation_end(self, data_batch, batch_nb): # collate all val batch outputs - def get_save_dict(self): # return what to save in a checkpoint - def load_model_specific(self, checkpoint): # use the checkpoint to reset your model state - def configure_optimizers(self): # return a list of optimizers - def tng_dataloader(self): # return a pytorch dataloader for each split - def val_dataloader(self): - def test_dataloader(self): - def add_model_specific_args(parent_parser): # add args for this model to your argparse -``` - -*Example trainer* -```python from pytorch_lightning import Trainer from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint @@ -63,16 +48,10 @@ trainer = Trainer( early_stop_callback=EarlyStopping(...), gpus=[0,1] ) + +trainer.fit(model) ``` -Pytorch -<-- Lightning -Your model. - -## Why do I want to use lightning? -Because you want to use best practices and get gpu training, multi-node training, checkpointing, mixed-precision, etc... for free. - - ## What are some key lightning features? - Automatic training loop @@ -86,11 +65,6 @@ def training_step(self, data_batch, batch_nb): # define what happens for validation here def validation_step(self, data_batch, batch_nb): ``` -- Automatic early stopping -```python -callback = EarlyStopping(...) -Trainer(early_stopping=callback) -``` - Automatic early stopping ```python @@ -105,7 +79,9 @@ Trainer(lr_scheduler_milestones=[100, 200]) ``` - 16 bit precision training -```--use_amp``` +```python +Trainer(use_amp=True, amp_level='O2') +``` - multi-gpu training ```python