added trainer docs

This commit is contained in:
William Falcon 2019-06-27 11:03:53 -04:00
parent fa12098c5f
commit 39af973bd4
9 changed files with 101 additions and 4 deletions

View File

@ -1 +0,0 @@
# Trainer

95
docs/Trainer/index.md Normal file
View File

@ -0,0 +1,95 @@
# Trainer
[[Github Code](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/models/trainer.py)]
The lightning trainer abstracts best practices for running a training, val, test routine. It calls parts of your model when it wants to hand over full control and otherwise makes training assumptions which are now standard practice in AI research.
This is the basic use of the trainer:
``` {.python}
from pytorch_lightning import Trainer
model = LightningTemplate()
trainer = Trainer()
trainer.fit(model)
```
But of course the fun is in all the advanced things it can do:
``` {.python}
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from test_tube import Experiment, SlurmCluster
trainer = Trainer(
experiment=Experiment,
checkpoint_callback=ModelCheckpoint,
early_stop_callback=EarlyStopping,
cluster=SlurmCluster,
process_position=0,
current_gpu_name=0,
gpus=None,
enable_tqdm=True,
overfit_pct=0.0,
track_grad_norm=-1,
check_val_every_n_epoch=1,
fast_dev_run=False,
accumulate_grad_batches=1,
enable_early_stop=True, max_nb_epochs=5, min_nb_epochs=1,
train_percent_check=1.0,
val_percent_check=1.0,
test_percent_check=1.0,
val_check_interval=0.95,
log_save_interval=1, add_log_row_interval=1,
lr_scheduler_milestones=None,
use_amp=False,
check_grad_nans=False,
amp_level='O2',
nb_sanity_val_steps=5):
)
```
Things you can do with the trainer module:
**Training loop**
- Accumulate gradients
- Check GPU usage
- Check which gradients are nan
- Check validation every n epochs
- Display metrics in progress bar
- Force training for min or max epochs
- Inspect gradient norms
- Learning rate annealing
- Make model overfit on subset of data
- Multiple optimizers (like GANs)
- Set how much of the training set to check (1-100%)
- Show progress bar
- training_step function
**Validation loop**
- Display metrics in progress bar
- Set how much of the validation set to check (1-100%)
- Set validation check frequency within 1 training epoch (1-100%)
- validation_step function
- Why does validation run first for 5 steps?
**Distributed training**
- Single-gpu
- Multi-gpu
- Multi-node
- 16-bit mixed precision
**Checkpointing**
- Model saving
- Model loading
**Computing cluster (SLURM)**
- Automatic checkpointing
- Automatic saving, loading
- Running grid search on a cluster
- Walltime auto-resubmit

View File

@ -4,7 +4,7 @@ import sys
from test_tube import HyperOptArgumentParser, Experiment
from pytorch_lightning.models.trainer import Trainer
from pytorch_lightning.utils.arg_parse import add_default_args
from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks.pt_callbacks import EarlyStopping, ModelCheckpoint
from docs.source.examples.example_model import ExampleModel

View File

@ -8,7 +8,7 @@ from test_tube import HyperOptArgumentParser, Experiment, SlurmCluster
from pytorch_lightning.models.trainer import Trainer
from pytorch_lightning.utils.arg_parse import add_default_args
from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
SEED = 2334
torch.manual_seed(SEED)

View File

@ -0,0 +1 @@
from .models import Trainer

View File

@ -0,0 +1 @@
from .trainer import Trainer

View File

@ -9,6 +9,7 @@ from pytorch_lightning.root_module.optimization import OptimizerConfig
from pytorch_lightning.root_module.hooks import ModelHooks
class LightningModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
def __init__(self, hparams):

View File

@ -8,7 +8,7 @@ from pytorch_lightning.models.trainer import Trainer
from pytorch_lightning.utils.arg_parse import add_default_args
from time import sleep
from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks.pt_callbacks import EarlyStopping, ModelCheckpoint
SEED = 2334
torch.manual_seed(SEED)
np.random.seed(SEED)