added docs page
This commit is contained in:
parent
56f7ebf07e
commit
92f5e026bb
42
README.md
42
README.md
|
@ -28,31 +28,16 @@ This is a research tool I built for myself internally while doing my PhD. The AP
|
|||
## What is it?
|
||||
Keras is too abstract for researchers. Lightning abstracts the full training loop but gives you control in the critical points.
|
||||
|
||||
|
||||
## Why do I want to use lightning?
|
||||
Because you want to use best practices and get gpu training, multi-node training, checkpointing, mixed-precision, etc... for free.
|
||||
|
||||
To use lightning do 2 things:
|
||||
1. Define a model with the lightning interface.
|
||||
2. Feed this model to the lightning trainer.
|
||||
|
||||
*Example model definition*
|
||||
```python
|
||||
from pytorch_lightning import RootModule
|
||||
|
||||
class MyModel(RootModule):
|
||||
|
||||
def init(self): # define model
|
||||
def training_step(self, data_batch, batch_nb): # what to do with a training batch
|
||||
def validation_step(self, data_batch, batch_nb): # what to do with a val/test batch
|
||||
def validation_end(self, data_batch, batch_nb): # collate all val batch outputs
|
||||
def get_save_dict(self): # return what to save in a checkpoint
|
||||
def load_model_specific(self, checkpoint): # use the checkpoint to reset your model state
|
||||
def configure_optimizers(self): # return a list of optimizers
|
||||
def tng_dataloader(self): # return a pytorch dataloader for each split
|
||||
def val_dataloader(self):
|
||||
def test_dataloader(self):
|
||||
def add_model_specific_args(parent_parser): # add args for this model to your argparse
|
||||
```
|
||||
|
||||
*Example trainer*
|
||||
```python
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
|
@ -63,16 +48,10 @@ trainer = Trainer(
|
|||
early_stop_callback=EarlyStopping(...),
|
||||
gpus=[0,1]
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
```
|
||||
|
||||
Pytorch
|
||||
<-- Lightning
|
||||
Your model.
|
||||
|
||||
## Why do I want to use lightning?
|
||||
Because you want to use best practices and get gpu training, multi-node training, checkpointing, mixed-precision, etc... for free.
|
||||
|
||||
|
||||
## What are some key lightning features?
|
||||
|
||||
- Automatic training loop
|
||||
|
@ -86,11 +65,6 @@ def training_step(self, data_batch, batch_nb):
|
|||
# define what happens for validation here
|
||||
def validation_step(self, data_batch, batch_nb):
|
||||
```
|
||||
- Automatic early stopping
|
||||
```python
|
||||
callback = EarlyStopping(...)
|
||||
Trainer(early_stopping=callback)
|
||||
```
|
||||
|
||||
- Automatic early stopping
|
||||
```python
|
||||
|
@ -105,7 +79,9 @@ Trainer(lr_scheduler_milestones=[100, 200])
|
|||
```
|
||||
|
||||
- 16 bit precision training
|
||||
```--use_amp```
|
||||
```python
|
||||
Trainer(use_amp=True, amp_level='O2')
|
||||
```
|
||||
|
||||
- multi-gpu training
|
||||
```python
|
||||
|
|
Loading…
Reference in New Issue