From 39af973bd4df064c8d51b966dbc20b4d959618ff Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 27 Jun 2019 11:03:53 -0400 Subject: [PATCH] added trainer docs --- docs/Pytorch-Lightning/Trainer.md | 1 - docs/Trainer/index.md | 95 +++++++++++++++++++ docs/source/examples/basic_trainer.py | 2 +- .../source/examples/fully_featured_trainer.py | 2 +- pytorch_lightning/__init__.py | 1 + .../{utils => callbacks}/pt_callbacks.py | 0 pytorch_lightning/models/__init__.py | 1 + pytorch_lightning/root_module/root_module.py | 1 + pytorch_lightning/trainer_main.py | 2 +- 9 files changed, 101 insertions(+), 4 deletions(-) delete mode 100644 docs/Pytorch-Lightning/Trainer.md create mode 100644 docs/Trainer/index.md rename pytorch_lightning/{utils => callbacks}/pt_callbacks.py (100%) diff --git a/docs/Pytorch-Lightning/Trainer.md b/docs/Pytorch-Lightning/Trainer.md deleted file mode 100644 index 47d6904184..0000000000 --- a/docs/Pytorch-Lightning/Trainer.md +++ /dev/null @@ -1 +0,0 @@ -# Trainer \ No newline at end of file diff --git a/docs/Trainer/index.md b/docs/Trainer/index.md new file mode 100644 index 0000000000..4e51b13c9e --- /dev/null +++ b/docs/Trainer/index.md @@ -0,0 +1,95 @@ +# Trainer +[[Github Code](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/models/trainer.py)] + +The lightning trainer abstracts best practices for running a training, val, test routine. It calls parts of your model when it wants to hand over full control and otherwise makes training assumptions which are now standard practice in AI research. + +This is the basic use of the trainer: + +``` {.python} +from pytorch_lightning import Trainer + +model = LightningTemplate() + +trainer = Trainer() +trainer.fit(model) +``` + +But of course the fun is in all the advanced things it can do: + +``` {.python} +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from test_tube import Experiment, SlurmCluster + +trainer = Trainer( + experiment=Experiment, + checkpoint_callback=ModelCheckpoint, + early_stop_callback=EarlyStopping, + cluster=SlurmCluster, + process_position=0, + current_gpu_name=0, + gpus=None, + enable_tqdm=True, + overfit_pct=0.0, + track_grad_norm=-1, + check_val_every_n_epoch=1, + fast_dev_run=False, + accumulate_grad_batches=1, + enable_early_stop=True, max_nb_epochs=5, min_nb_epochs=1, + train_percent_check=1.0, + val_percent_check=1.0, + test_percent_check=1.0, + val_check_interval=0.95, + log_save_interval=1, add_log_row_interval=1, + lr_scheduler_milestones=None, + use_amp=False, + check_grad_nans=False, + amp_level='O2', + nb_sanity_val_steps=5): +) +``` + + +Things you can do with the trainer module: + +**Training loop** + +- Accumulate gradients +- Check GPU usage +- Check which gradients are nan +- Check validation every n epochs +- Display metrics in progress bar +- Force training for min or max epochs +- Inspect gradient norms +- Learning rate annealing +- Make model overfit on subset of data +- Multiple optimizers (like GANs) +- Set how much of the training set to check (1-100%) +- Show progress bar +- training_step function + +**Validation loop** + +- Display metrics in progress bar +- Set how much of the validation set to check (1-100%) +- Set validation check frequency within 1 training epoch (1-100%) +- validation_step function +- Why does validation run first for 5 steps? + +**Distributed training** + +- Single-gpu +- Multi-gpu +- Multi-node +- 16-bit mixed precision + +**Checkpointing** + +- Model saving +- Model loading + +**Computing cluster (SLURM)** + +- Automatic checkpointing +- Automatic saving, loading +- Running grid search on a cluster +- Walltime auto-resubmit \ No newline at end of file diff --git a/docs/source/examples/basic_trainer.py b/docs/source/examples/basic_trainer.py index 308f24eb8e..a1f6d85151 100644 --- a/docs/source/examples/basic_trainer.py +++ b/docs/source/examples/basic_trainer.py @@ -4,7 +4,7 @@ import sys from test_tube import HyperOptArgumentParser, Experiment from pytorch_lightning.models.trainer import Trainer from pytorch_lightning.utils.arg_parse import add_default_args -from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks.pt_callbacks import EarlyStopping, ModelCheckpoint from docs.source.examples.example_model import ExampleModel diff --git a/docs/source/examples/fully_featured_trainer.py b/docs/source/examples/fully_featured_trainer.py index 0ef013df13..bbf05c5f45 100644 --- a/docs/source/examples/fully_featured_trainer.py +++ b/docs/source/examples/fully_featured_trainer.py @@ -8,7 +8,7 @@ from test_tube import HyperOptArgumentParser, Experiment, SlurmCluster from pytorch_lightning.models.trainer import Trainer from pytorch_lightning.utils.arg_parse import add_default_args -from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint SEED = 2334 torch.manual_seed(SEED) diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index e69de29bb2..ce4b32be27 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -0,0 +1 @@ +from .models import Trainer \ No newline at end of file diff --git a/pytorch_lightning/utils/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py similarity index 100% rename from pytorch_lightning/utils/pt_callbacks.py rename to pytorch_lightning/callbacks/pt_callbacks.py diff --git a/pytorch_lightning/models/__init__.py b/pytorch_lightning/models/__init__.py index e69de29bb2..9ec9ed0e3f 100644 --- a/pytorch_lightning/models/__init__.py +++ b/pytorch_lightning/models/__init__.py @@ -0,0 +1 @@ +from .trainer import Trainer \ No newline at end of file diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 68eeb54e50..1d1f8039aa 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -9,6 +9,7 @@ from pytorch_lightning.root_module.optimization import OptimizerConfig from pytorch_lightning.root_module.hooks import ModelHooks + class LightningModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): def __init__(self, hparams): diff --git a/pytorch_lightning/trainer_main.py b/pytorch_lightning/trainer_main.py index 09e470210c..d10e148809 100644 --- a/pytorch_lightning/trainer_main.py +++ b/pytorch_lightning/trainer_main.py @@ -8,7 +8,7 @@ from pytorch_lightning.models.trainer import Trainer from pytorch_lightning.utils.arg_parse import add_default_args from time import sleep -from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks.pt_callbacks import EarlyStopping, ModelCheckpoint SEED = 2334 torch.manual_seed(SEED) np.random.seed(SEED)