diff --git a/README.md b/README.md index 85b27c204d..864ed27c43 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,10 @@
- The Keras for ML researchers using PyTorch. More control. Less boilerplate. + The PyTorch Keras for ML researchers. More control. Less boilerplate.
@@ -31,30 +31,35 @@ Lightning defers training and validation loop logic to you. It guarantees correc ## Why do I want to use lightning? -When starting a new project the last thing you want to do is recode a training loop, model loading/saving, distributed training, when to validate, etc... You're likely to spend a long time ironing out all the bugs without even getting to the core of your research. +When starting a new project the last thing you want to do is recode a training loop, multi-cluster training, 16-bit precision, early-stopping, model loading/saving, when to validate, etc... You're likely to spend a long time ironing out all the bugs without even getting to the core of your research. -With lightning, you guarantee those parts of your code work so you can focus on what the meat of the research: Data and training, validation loop logic. Don't worry about multiple gpus or speeding up your code, lightning will do that for you! +With lightning, you guarantee those parts of your code work so you can focus on what the meat of the research: The data and the training/validation loop logic. + +Don't worry about training on multiple gpus or speeding up your code, lightning will do that for you! ## How do I do use it? To use lightning do 2 things: 1. [Define a LightningModel](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/) ```python -import pytorch_lightning as ptl +import os import torch from torch.nn import functional as F from torch.utils.data import DataLoader from torchvision.datasets import MNIST +import torchvision.transforms as transforms + +import pytorch_lightning as ptl class CoolModel(ptl.LightningModule): - def __init(self): + def __init__(self): super(CoolModel, self).__init__() # not the best model... self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): - return torch.relu(self.l1(x)) + return torch.relu(self.l1(x.view(x.size(0), -1))) def my_loss(self, y_hat, y): return F.cross_entropy(y_hat, y) @@ -62,7 +67,7 @@ class CoolModel(ptl.LightningModule): def training_step(self, batch, batch_nb): x, y = batch y_hat = self.forward(x) - return {'tng_loss': self.my_loss(y_hat, y)} + return {'loss': self.my_loss(y_hat, y)} def validation_step(self, batch, batch_nb): x, y = batch @@ -70,23 +75,23 @@ class CoolModel(ptl.LightningModule): return {'val_loss': self.my_loss(y_hat, y)} def validation_end(self, outputs): - avg_loss = torch.stack([x for x in outputs['val_loss']]).mean() - return avg_loss + avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + return {'avg_val_loss': avg_loss} def configure_optimizers(self): return [torch.optim.Adam(self.parameters(), lr=0.02)] @ptl.data_loader def tng_dataloader(self): - return DataLoader(MNIST('path/to/save', train=True), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @ptl.data_loader def val_dataloader(self): - return DataLoader(MNIST('path/to/save', train=False), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @ptl.data_loader def test_dataloader(self): - return DataLoader(MNIST('path/to/save', train=False), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) ``` 2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/) @@ -95,15 +100,23 @@ from pytorch_lightning import Trainer from test_tube import Experiment model = CoolModel() +exp = Experiment(save_dir=os.getcwd()) -# fit on 32 gpus across 4 nodes -exp = Experiment(save_dir='some/dir') -trainer = Trainer(experiment=exp, nb_gpu_nodes=4, gpus=[0,1,2,3,4,5,6,7]) +# train on cpu using only 10% of the data (for demo purposes) +trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1) +# train on 4 gpus +# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3]) + +# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job) +# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3, 4, 5, 6, 7], nb_gpu_nodes=4) + +# train (1 epoch only here for demo) trainer.fit(model) -# see all experiment metrics here -# tensorboard --log_dir some/dir +# view tensorflow logs +print(f'View tensorboard logs by running\ntensorboard --logdir {os.getcwd()}') +print('and going to http://localhost:6006 on your browser') ``` @@ -222,6 +235,7 @@ tensorboard --logdir /some/path ## Lightning automates all of the following ([each is also configurable](https://williamfalcon.github.io/pytorch-lightning/Trainer/)): + ###### Checkpointing - [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving) @@ -254,9 +268,9 @@ tensorboard --logdir /some/path ###### Experiment Logging - [Display metrics in progress bar](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#display-metrics-in-progress-bar) -- Log arbitrary metrics - [Log metric row every k batches](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#log-metric-row-every-k-batches) - [Process position](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#process-position) +- [Tensorboard support](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#tensorboard-support) - [Save a snapshot of all hyperparameters](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#save-a-snapshot-of-all-hyperparameters) - [Snapshot code for a training run](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#snapshot-code-for-a-training-run) - [Write logs file to csv every k batches](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#write-logs-file-to-csv-every-k-batches) @@ -264,22 +278,25 @@ tensorboard --logdir /some/path ###### Training loop - [Accumulate gradients](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#accumulated-gradients) -- [Anneal Learning rate](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#anneal-learning-rate) - [Force training for min or max epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-training-for-min-or-max-epochs) - [Force disable early stop](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-disable-early-stop) - [Gradient Clipping](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#gradient-clipping) -- [Use multiple optimizers (like GANs)](https://williamfalcon.github.io/pytorch-lightning/Pytorch-Lightning/LightningModule/#configure_optimizers) +- [Hooks](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/) +- [Learning rate scheduling](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#configure_optimizers) +- [Use multiple optimizers (like GANs)](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#configure_optimizers) - [Set how much of the training set to check (1-100%)](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#set-how-much-of-the-training-set-to-check) ###### Validation loop - [Check validation every n epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#check-validation-every-n-epochs) +- [Hooks](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/) - [Set how much of the validation set to check](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-how-much-of-the-validation-set-to-check) - [Set how much of the test set to check](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-how-much-of-the-test-set-to-check) - [Set validation check frequency within 1 training epoch](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-validation-check-frequency-within-1-training-epoch) - [Set the number of validation sanity steps](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-the-number-of-validation-sanity-steps) + ## Demo ```bash # install lightning @@ -301,8 +318,21 @@ python single_gpu_node_template.py --gpus "0,1" python multi_node_cluster_template.py --nb_gpu_nodes 4 --gpus '0,1,2,3,4,5,6,7' ``` +## Contributing +Welcome to the PTL community! We're building the most advanced research platform on the planet to implement the latest, best practices that the amazing PyTorch team rolls out! + +#### Bug fixes: +1. Submit a github issue. +2. Fix it. +3. Submit a PR! + +#### New Features: +1. Submit a github issue. +2. We'll agree on the feature scope. +3. Submit a PR! (with updated docs and tests 🙃). + ## Bleeding edge If you can't wait for the next release, install the most up to date code with: ```bash pip install git+https://github.com/williamFalcon/pytorch-lightning.git@master --upgrade -``` \ No newline at end of file +``` diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index b9bb1d36fc..abc015caf0 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -3,7 +3,7 @@ A lightning module is a strict superclass of nn.Module, it provides a standard interface for the trainer to interact with the model. -The easiest thing to do is copy [this template](../../pytorch_lightning/examples/new_project_templates/lightning_module_template.py) and modify accordingly. +The easiest thing to do is copy the [minimal example](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#minimal-example) below and modify accordingly. Otherwise, to Define a Lightning Module, implement the following methods: @@ -27,23 +27,26 @@ Otherwise, to Define a Lightning Module, implement the following methods: - [add_model_specific_args](RequiredTrainerInterface.md#add_model_specific_args) --- -**Minimal example** +### Minimal example ```python -import pytorch_lightning as ptl +import os import torch from torch.nn import functional as F from torch.utils.data import DataLoader from torchvision.datasets import MNIST +import torchvision.transforms as transforms + +import pytorch_lightning as ptl class CoolModel(ptl.LightningModule): - def __init(self): + def __init__(self): super(CoolModel, self).__init__() # not the best model... self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): - return torch.relu(self.l1(x)) + return torch.relu(self.l1(x.view(x.size(0), -1))) def my_loss(self, y_hat, y): return F.cross_entropy(y_hat, y) @@ -51,7 +54,7 @@ class CoolModel(ptl.LightningModule): def training_step(self, batch, batch_nb): x, y = batch y_hat = self.forward(x) - return {'tng_loss': self.my_loss(y_hat, y)} + return {'loss': self.my_loss(y_hat, y)} def validation_step(self, batch, batch_nb): x, y = batch @@ -59,23 +62,23 @@ class CoolModel(ptl.LightningModule): return {'val_loss': self.my_loss(y_hat, y)} def validation_end(self, outputs): - avg_loss = torch.stack([x for x in outputs['val_loss']]).mean() - return avg_loss + avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + return {'avg_val_loss': avg_loss} def configure_optimizers(self): return [torch.optim.Adam(self.parameters(), lr=0.02)] @ptl.data_loader def tng_dataloader(self): - return DataLoader(MNIST('path/to/save', train=True), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @ptl.data_loader def val_dataloader(self): - return DataLoader(MNIST('path/to/save', train=False), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @ptl.data_loader def test_dataloader(self): - return DataLoader(MNIST('path/to/save', train=False), batch_size=32) + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) ``` --- @@ -222,26 +225,27 @@ def validation_end(self, outputs): def configure_optimizers(self) ``` -Set up as many optimizers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple. -Lightning will call .backward() and .step() on each one. If you use 16 bit precision it will also handle that. +Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple. +Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that. ##### Return -List - List of optimizers +List or Tuple - List of optimizers with an optional second list of learning-rate schedulers **Example** ``` {.python} # most cases def configure_optimizers(self): - opt = Adam(lr=0.01) + opt = Adam(self.parameters(), lr=0.01) return [opt] -# gan example +# gan example, with scheduler for discriminator def configure_optimizers(self): - generator_opt = Adam(lr=0.01) - disriminator_opt = Adam(lr=0.02) - return [generator_opt, disriminator_opt] + generator_opt = Adam(self.model_gen.parameters(), lr=0.01) + disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02) + discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10) + return [generator_opt, disriminator_opt], [discriminator_sched] ``` --- @@ -296,7 +300,7 @@ def tng_dataloader(self) Called by lightning during training loop. Make sure to use the @ptl.data_loader decorator, this ensures not calling this function until the data are needed. ##### Return -Pytorch DataLoader +PyTorch DataLoader **Example** @@ -323,7 +327,7 @@ def tng_dataloader(self) Called by lightning during validation loop. Make sure to use the @ptl.data_loader decorator, this ensures not calling this function until the data are needed. ##### Return -Pytorch DataLoader +PyTorch DataLoader **Example** @@ -351,7 +355,7 @@ def test_dataloader(self) Called by lightning during test loop. Make sure to use the @ptl.data_loader decorator, this ensures not calling this function until the data are needed. ##### Return -Pytorch DataLoader +PyTorch DataLoader **Example** @@ -428,4 +432,4 @@ def add_model_specific_args(parent_parser, root_dir): parser.opt_list('--batch_size', default=256, type=int, options=[32, 64, 128, 256], tunable=False) parser.opt_list('--optimizer_name', default='adam', type=str, options=['adam'], tunable=False) return parser -``` \ No newline at end of file +``` diff --git a/docs/LightningModule/methods.md b/docs/LightningModule/methods.md index d57c695034..cb96ea7a1d 100644 --- a/docs/LightningModule/methods.md +++ b/docs/LightningModule/methods.md @@ -31,7 +31,7 @@ y_hat = pretrained_model(x) | Param | description | |---|---| -| weights_path | Path to a pytorch checkpoint | +| weights_path | Path to a PyTorch checkpoint | | tags_csv | Path to meta_tags.csv file generated by the test-tube Experiment | | on_gpu | if True, puts model on GPU. Make sure to use transforms option if model devices have changed | | map_location | A dictionary mapping saved weight GPU devices to new GPU devices | diff --git a/docs/Trainer/Distributed training.md b/docs/Trainer/Distributed training.md index a7d487b525..aedbd20ed1 100644 --- a/docs/Trainer/Distributed training.md +++ b/docs/Trainer/Distributed training.md @@ -23,6 +23,16 @@ have configuration issues depending on your cluster. For a deeper understanding of what lightning is doing, feel free to read [this guide](https://medium.com/@_willfalcon/9-tips-for-training-lightning-fast-neural-networks-in-pytorch-8e63a502f565). +--- +#### CUDA flags +CUDA flags make certain GPUs visible to your script. +Lightning sets these for you automatically, there's NO NEED to do this yourself. +```python +# lightning will set according to what you give the trainer +# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" +``` + --- #### 16-bit mixed precision 16 bit precision can cut your memory footprint by half. If using volta architecture GPUs it can give a dramatic training speed-up as well. @@ -43,10 +53,6 @@ trainer = Trainer(amp_level='O2', use_amp=False) #### Single-gpu Make sure you're on a GPU machine. ```python -# set these flags -os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["CUDA_VISIBLE_DEVICES"] = "0" - # DEFAULT trainer = Trainer(gpus=[0]) ``` @@ -56,13 +62,6 @@ trainer = Trainer(gpus=[0]) Make sure you're on a GPU machine. You can set as many GPUs as you want. In this setting, the model will run on all 8 GPUs at once using DataParallel under the hood. ```python -# set these flags -# lightning sets these flags for you automatically -# no need to set yourself -# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" - - # to use DataParallel (default) trainer = Trainer(gpus=[0,1,2,3,4,5,6,7], distributed_backend='dp') diff --git a/docs/Trainer/Logging.md b/docs/Trainer/Logging.md index 1596251f8b..2edbab16f1 100644 --- a/docs/Trainer/Logging.md +++ b/docs/Trainer/Logging.md @@ -50,6 +50,33 @@ exp = Experiment(create_git_tag=True) Trainer(experiment=exp) ``` +--- +### Tensorboard support +The experiment object is a strict subclass of PyTorch SummaryWriter. However, this class +also snapshots every detail about the experiment (data folder paths, code, hyperparams), +and allows you to visualize it using tensorboard. +``` {.python} +from test_tube import Experiment, HyperOptArgumentParser + +# exp hyperparams +args = HyperOptArgumentParser() +hparams = args.parse_args() + +# this is a summaryWriter with nicer logging structure +exp = Experiment(save_dir='/some/path', create_git_tag=True) + +# track experiment details (must be ArgumentParser or HyperOptArgumentParser). +# each option in the parser is tracked +exp.argparse(hparams) +exp.tag({'description': 'running demo'}) + +# trainer uses the exp object to log exp data +trainer = Trainer(experiment=exp) +trainer.fit(model) + +# view logs at: +# tensorboard --logdir /some/path +``` --- #### Write logs file to csv every k batches diff --git a/docs/Trainer/Training Loop.md b/docs/Trainer/Training Loop.md index 2be8da9edd..c2b6dc35a0 100644 --- a/docs/Trainer/Training Loop.md +++ b/docs/Trainer/Training Loop.md @@ -1,4 +1,4 @@ -The lightning training loop handles everything except the actual computations of your model. To decide what will happen in your training loop, define the [training_step function](../../Pytorch-lightning/LightningModule/#training_step). +The lightning training loop handles everything except the actual computations of your model. To decide what will happen in your training loop, define the [training_step function](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#training_step). Below are all the things lightning automates for you in the training loop. @@ -11,17 +11,6 @@ Accumulated gradients runs K small batches of size N before doing a backwards pa trainer = Trainer(accumulate_grad_batches=1) ``` ---- -#### Anneal Learning rate -Cut the learning rate by 10 at every epoch listed in this list. -``` {.python} -# DEFAULT (don't anneal) -trainer = Trainer(lr_scheduler_milestones=None) - -# cut LR by 10 at 100, 200, and 300 epochs -trainer = Trainer(lr_scheduler_milestones='100, 200, 300') -``` - --- #### Force training for min or max epochs It can be useful to force training for a minimum number of epochs or limit to a max number diff --git a/docs/Trainer/Validation loop.md b/docs/Trainer/Validation loop.md index 693df88904..8d6c5bac46 100644 --- a/docs/Trainer/Validation loop.md +++ b/docs/Trainer/Validation loop.md @@ -1,4 +1,4 @@ -The lightning validation loop handles everything except the actual computations of your model. To decide what will happen in your validation loop, define the [validation_step function](../../Pytorch-lightning/LightningModule/#validation_step). +The lightning validation loop handles everything except the actual computations of your model. To decide what will happen in your validation loop, define the [validation_step function](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#validation_step). Below are all the things lightning automates for you in the validation loop. **Note** diff --git a/docs/Trainer/hooks.md b/docs/Trainer/hooks.md index e69de29bb2..dd08b30b45 100644 --- a/docs/Trainer/hooks.md +++ b/docs/Trainer/hooks.md @@ -0,0 +1,86 @@ +# Hooks +[[Github Code](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/hooks.py)] + +There are cases when you might want to do something different at different parts of the training/validation loop. +To enable a hook, simply override the method in your LightningModule and the trainer will call it at the correct time. + +**Contributing** If there's a hook you'd like to add, simply: +1. Fork PyTorchLightning. +2. Add the hook [here](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/hooks.py). +3. Add the correct place in the [Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/models/trainer.py) where it should be called. + +--- +#### on_epoch_start +Called in the training loop at the very beginning of the epoch. +```python +def on_epoch_start(self): + # do something when the epoch starts +``` + +--- +#### on_batch_end +Called in the training loop at the very end of the epoch. +```python +def on_epoch_end(self): + # do something when the epoch ends +``` + +--- +#### on_batch_start +Called in the training loop before anything happens for that batch. +```python +def on_batch_start(self): + # do something when the batch starts +``` + +--- +#### on_pre_performance_check +Called at the very beginning of the validation loop. +```python +def on_pre_performance_check(self): + # do something before validation starts +``` + +--- +#### on_post_performance_check +Called at the very end of the validation loop. +```python +def on_post_performance_check(self): + # do something before validation end +``` + +--- +#### on_tng_metrics +Called in the training loop, right before metrics are logged. +Although you can log at any time by using self.experiment, you can use +this callback to modify what will be logged. +```python +def on_tng_metrics(self, metrics): + # do something before validation end +``` + +--- +#### on_before_zero_grad +Called in the training loop after taking an optimizer step and before zeroing grads. +Good place to inspect weight information with weights updated. + +Called once per optimizer +```python +def on_before_zero_grad(self, optimizer): + # do something with the optimizer or inspect it. +``` + +--- +#### on_after_backward +Called in the training loop after model.backward() +This is the ideal place to inspect or log gradient information +```python +def on_after_backward(self): + # example to inspect gradient information in tensorboard + if self.trainer.global_step % 25 == 0: # don't make the tf file huge + params = self.state_dict() + for k, v in params.items(): + grads = v + name = k + self.experiment.add_histogram(tag=name, values=grads, global_step=self.trainer.global_step) +``` diff --git a/docs/Trainer/index.md b/docs/Trainer/index.md index 1b30da1966..3983daae88 100644 --- a/docs/Trainer/index.md +++ b/docs/Trainer/index.md @@ -49,25 +49,28 @@ But of course the fun is in all the advanced things it can do: **Experiment Logging** - [Display metrics in progress bar](Logging/#display-metrics-in-progress-bar) -- Log arbitrary metrics - [Log metric row every k batches](Logging/#log-metric-row-every-k-batches) - [Process position](Logging/#process-position) +- [Tensorboard support](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#tensorboard-support) - [Save a snapshot of all hyperparameters](Logging/#save-a-snapshot-of-all-hyperparameters) - [Snapshot code for a training run](Logging/#snapshot-code-for-a-training-run) - [Write logs file to csv every k batches](Logging/#write-logs-file-to-csv-every-k-batches) **Training loop** -- [Accumulate gradients](Training%20Loop/#accumulated-gradients) -- [Anneal Learning rate](Training%20Loop/#anneal-learning-rate) -- [Force training for min or max epochs](Training%20Loop/#force-training-for-min-or-max-epochs) -- [Force disable early stop](Training%20Loop/#force-disable-early-stop) -- [Use multiple optimizers (like GANs)](../Pytorch-lightning/LightningModule/#configure_optimizers) -- [Set how much of the training set to check (1-100%)](Training%20Loop/#set-how-much-of-the-training-set-to-check) +- [Accumulate gradients](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#accumulated-gradients) +- [Force training for min or max epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-training-for-min-or-max-epochs) +- [Force disable early stop](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-disable-early-stop) +- [Gradient Clipping](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#gradient-clipping) +- [Hooks](hooks) +- [Learning rate scheduling](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#configure_optimizers) +- [Use multiple optimizers (like GANs)](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#configure_optimizers) +- [Set how much of the training set to check (1-100%)](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#set-how-much-of-the-training-set-to-check) **Validation loop** - [Check validation every n epochs](Validation%20Loop/#check-validation-every-n-epochs) +- [Hooks](hooks) - [Set how much of the validation set to check](Validation%20Loop/#set-how-much-of-the-validation-set-to-check) - [Set how much of the test set to check](Validation%20Loop/#set-how-much-of-the-test-set-to-check) - [Set validation check frequency within 1 training epoch](Validation%20Loop/#set-validation-check-frequency-within-1-training-epoch) diff --git a/docs/index.md b/docs/index.md index 0e25fa79d5..ae08df4bf8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,8 +1,14 @@ ###### New project Quick Start -To start a new project define these two files. +To start a new project you define two files, a LightningModule and a Trainer file. -1. [Define a LightningModule](/pytorch-lightning/LightningModule/RequiredTrainerInterface/) -2. [Define a trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/) +A separate trainer file allows to run many LightningModules. Each LightningModule has the core +logic to a particular research project. + +For example, one lightningModule could be an image classifier, the other +one could be a seq-2-seq model, both (optionally) ran by the same trainer file. + +1. [MNIST LightningModule](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#minimal-example) +2. [Trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/) - [Basic CPU Trainer Template](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/examples/new_project_templates/single_cpu_template.py) - [Multi-GPU Trainer Template](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/examples/new_project_templates/single_gpu_node_template.py) - [GPU cluster Trainer Template](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/examples/new_project_templates/multi_node_cluster_template.py) @@ -50,9 +56,9 @@ To start a new project define these two files. ###### Experiment Logging - [Display metrics in progress bar](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#display-metrics-in-progress-bar) -- Log arbitrary metrics - [Log metric row every k batches](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#log-metric-row-every-k-batches) - [Process position](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#process-position) +- [Tensorboard support](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#tensorboard-support) - [Save a snapshot of all hyperparameters](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#save-a-snapshot-of-all-hyperparameters) - [Snapshot code for a training run](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#snapshot-code-for-a-training-run) - [Write logs file to csv every k batches](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#write-logs-file-to-csv-every-k-batches) @@ -60,16 +66,18 @@ To start a new project define these two files. ###### Training loop - [Accumulate gradients](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#accumulated-gradients) -- [Anneal Learning rate](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#anneal-learning-rate) - [Force training for min or max epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-training-for-min-or-max-epochs) - [Force disable early stop](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-disable-early-stop) - [Gradient Clipping](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#gradient-clipping) -- [Use multiple optimizers (like GANs)](https://williamfalcon.github.io/pytorch-lightning/Pytorch-Lightning/LightningModule/#configure_optimizers) +- [Hooks](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/) +- [Learning rate scheduling](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#configure_optimizers) +- [Use multiple optimizers (like GANs)](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#configure_optimizers) - [Set how much of the training set to check (1-100%)](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#set-how-much-of-the-training-set-to-check) -######Validation loop +###### Validation loop - [Check validation every n epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#check-validation-every-n-epochs) +- [Hooks](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/) - [Set how much of the validation set to check](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-how-much-of-the-validation-set-to-check) - [Set how much of the test set to check](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-how-much-of-the-test-set-to-check) - [Set validation check frequency within 1 training epoch](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-validation-check-frequency-within-1-training-epoch) diff --git a/mkdocs.yml b/mkdocs.yml index 5675a64e01..539714d4b5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,10 +1,10 @@ -site_name: Pytorch lightning Documentation +site_name: PyTorch lightning Documentation theme: name: 'material' docs_dir: docs repo_url: https://github.com/williamFalcon/pytorch-lightning site_dir: 'site' -site_description: 'Documentation for Pytorch LightningModule, the researcher version of keras.' +site_description: 'Documentation for PyTorch LightningModule, the researcher version of keras.' dev_addr: '0.0.0.0:8000' #google_analytics: ['UA-aasd', 'sitename'] diff --git a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py index 608e534e0f..ca8f604b23 100644 --- a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py +++ b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py @@ -163,7 +163,8 @@ class LightningTemplateModel(LightningModule): :return: list of optimizers """ optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - return [optimizer] + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) + return [optimizer], [scheduler] def __dataloader(self, train): # init data generators @@ -220,7 +221,6 @@ class LightningTemplateModel(LightningModule): # parser.set_defaults(gradient_clip=5.0) # network params - parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) parser.add_argument('--in_features', default=28*28, type=int) parser.add_argument('--out_features', default=10, type=int) parser.add_argument('--hidden_dim', default=50000, type=int) # use 500 for CPU, 50000 for GPU to see speed difference diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index d686e39f59..aba6219fd6 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -10,7 +10,6 @@ import re import torch from torch.utils.data.distributed import DistributedSampler -from torch.optim.lr_scheduler import MultiStepLR import torch.multiprocessing as mp import torch.distributed as dist import numpy as np @@ -71,7 +70,6 @@ class Trainer(TrainerIO): train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95, log_save_interval=100, add_log_row_interval=10, - lr_scheduler_milestones=None, distributed_backend='dp', use_amp=False, print_nan_grads=False, @@ -104,7 +102,6 @@ class Trainer(TrainerIO): :param val_check_interval: :param log_save_interval: :param add_log_row_interval: - :param lr_scheduler_milestones: :param distributed_backend: 'np' to use DistributedParallel, 'ddp' to use DistributedDataParallel :param use_amp: :param print_nan_grads: @@ -141,7 +138,6 @@ class Trainer(TrainerIO): self.early_stop_callback = early_stop_callback self.min_nb_epochs = min_nb_epochs self.nb_sanity_val_steps = nb_sanity_val_steps - self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')] self.lr_schedulers = [] self.amp_level = amp_level self.print_nan_grads = print_nan_grads @@ -442,8 +438,10 @@ class Trainer(TrainerIO): raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # CHOOSE OPTIMIZER - # filter out the weights that were done on gpu so we can load on good old cpus + # allow for lr schedulers as well self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers self.__run_pretrain_routine(model) @@ -454,8 +452,10 @@ class Trainer(TrainerIO): def __dp_train(self, model): # CHOOSE OPTIMIZER - # filter out the weights that were done on gpu so we can load on good old cpus + # allow for lr schedulers as well self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers model.cuda(self.data_parallel_device_ids[0]) @@ -508,8 +508,10 @@ class Trainer(TrainerIO): self.__init_tcp_connection() # CHOOSE OPTIMIZER - # filter out the weights that were done on gpu so we can load on good old cpus + # allow for lr schedulers as well self.optimizers = model.configure_optimizers() + if len(self.optimizers) == 2: + self.optimizers, self.lr_schedulers = self.optimizers # MODEL # copy model to each gpu @@ -589,12 +591,6 @@ class Trainer(TrainerIO): # init training constants self.__layout_bookeeping() - # add lr schedulers - if self.lr_scheduler_milestones is not None: - for optimizer in self.optimizers: - scheduler = MultiStepLR(optimizer, self.lr_scheduler_milestones) - self.lr_schedulers.append(scheduler) - # print model summary if self.proc_rank == 0 and self.print_weights_summary: ref_model.summarize() @@ -628,8 +624,9 @@ class Trainer(TrainerIO): # run all epochs for epoch_nb in range(self.current_epoch, self.max_nb_epochs): # update the lr scheduler - for lr_scheduler in self.lr_schedulers: - lr_scheduler.step() + if self.lr_schedulers is not None: + for lr_scheduler in self.lr_schedulers: + lr_scheduler.step() model = self.__get_model() model.current_epoch = epoch_nb @@ -775,7 +772,7 @@ class Trainer(TrainerIO): output = self.model.training_step(data_batch, batch_nb) try: - model_specific_tqdm_metrics_dic = output['tqdm_metrics'] + model_specific_tqdm_metrics_dic = output['prog'] except Exception as e: model_specific_tqdm_metrics_dic = {} diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index c5831317e5..39c5ae7b70 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -71,11 +71,19 @@ class TrainerIO(object): checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience + # save optimizers optimizer_states = [] for i, optimizer in enumerate(self.optimizers): optimizer_states.append(optimizer.state_dict()) checkpoint['optimizer_states'] = optimizer_states + + # save lr schedulers + lr_schedulers = [] + for i, scheduler in enumerate(self.lr_schedulers): + lr_schedulers.append(scheduler.state_dict()) + + checkpoint['lr_schedulers'] = lr_schedulers # add the state_dict from the model model = self.__get_model() @@ -94,13 +102,16 @@ class TrainerIO(object): return # allow test tube to handle model check pointing automatically - self.cluster.set_checkpoint_save_function( - self.hpc_save, - kwargs={ - 'folderpath': self.checkpoint_callback.filepath, - 'experiment': self.experiment - } - ) + # only if proc 0 so we don't trigger world_size resubmits + if self.proc_rank == 0: + self.cluster.set_checkpoint_save_function( + self.hpc_save, + kwargs={ + 'folderpath': self.checkpoint_callback.filepath, + 'experiment': self.experiment + } + ) + self.cluster.set_checkpoint_load_function( self.hpc_load, kwargs={ @@ -130,6 +141,11 @@ class TrainerIO(object): optimizer_states = checkpoint['optimizer_states'] for optimizer, opt_state in zip(self.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) + + # restore the lr schedulers + lr_schedulers = checkpoint['lr_schedulers'] + for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): + scheduler.load_state_dict(lrs_state) # ---------------------------------- # PRIVATE OPS diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index b49dd3af90..c860685eea 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -58,7 +58,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): def configure_optimizers(self): """ - Return array of optimizers + Return a list of optimizers and a list of schedulers (could be empty) :return: """ raise NotImplementedError diff --git a/pytorch_lightning/testing_models/lm_test_module.py b/pytorch_lightning/testing_models/lm_test_module.py index 39fcbb4c3d..8e0718b972 100644 --- a/pytorch_lightning/testing_models/lm_test_module.py +++ b/pytorch_lightning/testing_models/lm_test_module.py @@ -96,12 +96,15 @@ class LightningTestModel(LightningModule): if self.trainer.use_dp: loss_val = loss_val.unsqueeze(0) - output = OrderedDict({ - 'loss': loss_val - }) - - # can also return just a scalar instead of a dict (return loss_val) - return output + # alternate possible outputs to test + if self.trainer.batch_nb % 1 == 0: + output = OrderedDict({ + 'loss': loss_val, + 'prog': {'some_val': loss_val * loss_val} + }) + return output + if self.trainer.batch_nb % 2 == 0: + return loss_val def validation_step(self, data_batch, batch_i): """ @@ -179,7 +182,10 @@ class LightningTestModel(LightningModule): return whatever optimizers we want here :return: list of optimizers """ + # try no scheduler for this model (testing purposes) optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + # test returning only 1 list instead of 2 return [optimizer] def __dataloader(self, train): diff --git a/pytorch_lightning/trainer_main.py b/pytorch_lightning/trainer_main.py deleted file mode 100644 index 0c30a41968..0000000000 --- a/pytorch_lightning/trainer_main.py +++ /dev/null @@ -1,210 +0,0 @@ -import os -import sys - -import torch -import numpy as np -from test_tube import HyperOptArgumentParser, Experiment, SlurmCluster -from pytorch_lightning.models.trainer import Trainer -from pytorch_lightning.utils.arg_parse import add_default_args -from time import sleep - -from pytorch_lightning.callbacks.pt_callbacks import EarlyStopping, ModelCheckpoint -SEED = 2334 -torch.manual_seed(SEED) -np.random.seed(SEED) - -# --------------------- -# DEFINE MODEL HERE -# --------------------- -from pytorch_lightning.models.sample_model_template.model_template import ExampleModel1 -# --------------------- - -AVAILABLE_MODELS = { - 'model_1': ExampleModel1 -} - - -""" -Allows training by using command line arguments - -Run by: -# TYPE YOUR RUN COMMAND HERE -""" - - -def main_local(hparams): - main(hparams, None, None) - - -def main(hparams, cluster, results_dict): - """ - Main training routine specific for this project - :param hparams: - :return: - """ - on_gpu = torch.cuda.is_available() - if hparams.disable_cuda: - on_gpu = False - - device = 'cuda' if on_gpu else 'cpu' - hparams.__setattr__('device', device) - hparams.__setattr__('on_gpu', on_gpu) - hparams.__setattr__('nb_gpus', torch.cuda.device_count()) - hparams.__setattr__('inference_mode', hparams.model_load_weights_path is not None) - - # init experiment - exp = Experiment( - name=hparams.tt_name, - debug=hparams.debug, - save_dir=hparams.tt_save_path, - version=hparams.hpc_exp_number, - autosave=False, - description=hparams.tt_description - ) - - exp.argparse(hparams) - exp.save() - - # build model - print('loading model...') - model = TRAINING_MODEL(hparams) - print('model built') - - # callbacks - early_stop = EarlyStopping( - monitor=hparams.early_stop_metric, - patience=hparams.early_stop_patience, - verbose=True, - mode=hparams.early_stop_mode - ) - - model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version) - checkpoint = ModelCheckpoint( - filepath=model_save_path, - save_function=None, - save_best_only=True, - verbose=True, - monitor=hparams.model_save_monitor_value, - mode=hparams.model_save_monitor_mode - ) - - # configure trainer - trainer = Trainer( - experiment=exp, - on_gpu=on_gpu, - cluster=cluster, - progress_bar=hparams.enable_tqdm, - overfit_pct=hparams.overfit, - track_grad_norm=hparams.track_grad_norm, - fast_dev_run=hparams.fast_dev_run, - check_val_every_n_epoch=hparams.check_val_every_n_epoch, - accumulate_grad_batches=hparams.accumulate_grad_batches, - process_position=process_position, - current_gpu_name=current_gpu, - checkpoint_callback=checkpoint, - early_stop_callback=early_stop, - enable_early_stop=hparams.enable_early_stop, - max_nb_epochs=hparams.max_nb_epochs, - min_nb_epochs=hparams.min_nb_epochs, - train_percent_check=hparams.train_percent_check, - val_percent_check=hparams.val_percent_check, - test_percent_check=hparams.test_percent_check, - val_check_interval=hparams.val_check_interval, - log_save_interval=hparams.log_save_interval, - add_log_row_interval=hparams.add_log_row_interval, - lr_scheduler_milestones=hparams.lr_scheduler_milestones - ) - - # train model - trainer.fit(model) - - -def get_default_parser(strategy, root_dir): - - possible_model_names = list(AVAILABLE_MODELS.keys()) - parser = HyperOptArgumentParser(strategy=strategy, add_help=False) - add_default_args(parser, root_dir, possible_model_names, SEED) - return parser - - -def get_model_name(args): - for i, arg in enumerate(args): - if 'model_name' in arg: - return args[i+1] - - -def optimize_on_cluster(hyperparams): - # enable cluster training - cluster = SlurmCluster( - hyperparam_optimizer=hyperparams, - log_path=hyperparams.tt_save_path, - test_tube_exp_name=hyperparams.tt_name - ) - - # email for cluster coms - cluster.notify_job_status(email='add_email_here', on_done=True, on_fail=True) - - # configure cluster - cluster.per_experiment_nb_gpus = hyperparams.per_experiment_nb_gpus - cluster.job_time = '48:00:00' - cluster.gpu_type = '1080ti' - cluster.memory_mb_per_node = 48000 - - # any modules for code to run in env - cluster.add_command('source activate pytorch_lightning') - - # name of exp - job_display_name = hyperparams.tt_name.split('_')[0] - job_display_name = job_display_name[0:3] - - # run hopt - print('submitting jobs...') - cluster.optimize_parallel_cluster_gpu( - main, - nb_trials=hyperparams.nb_hopt_trials, - job_name=job_display_name - ) - - -if __name__ == '__main__': - - model_name = get_model_name(sys.argv) - - # use default args - root_dir = os.path.split(os.path.dirname(sys.modules['__main__'].__file__))[0] - parent_parser = get_default_parser(strategy='random_search', root_dir=root_dir) - - # allow model to overwrite or extend args - TRAINING_MODEL = AVAILABLE_MODELS[model_name] - parser = TRAINING_MODEL.add_model_specific_args(parent_parser) - parser.json_config('-c', '--config', default=root_dir + '/run_configs/local.json') - hyperparams = parser.parse_args() - - # format GPU layout - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - gpu_ids = hyperparams.gpus.split(';') - - # RUN TRAINING - if hyperparams.on_cluster: - print('RUNNING ON SLURM CLUSTER') - os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_ids) - optimize_on_cluster(hyperparams) - - elif hyperparams.single_run_gpu: - print(f'RUNNING 1 TRIAL ON GPU. gpu: {gpu_ids[0]}') - os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[0] - main(hyperparams, None, None) - - elif hyperparams.local or hyperparams.single_run: - os.environ["CUDA_VISIBLE_DEVICES"] = '0' - print('RUNNING LOCALLY') - main(hyperparams, None, None) - - else: - print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}') - hyperparams.optimize_parallel_gpu( - main_local, - gpu_ids=gpu_ids, - nb_trials=hyperparams.nb_hopt_trials, - nb_workers=len(gpu_ids) - ) diff --git a/setup.py b/setup.py index 50a050570a..769d607a60 100755 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages # http://blog.ionelmc.ro/2014/05/25/python-packaging/ setup( name="pytorch-lightning", - version='0.3.6.4', + version='0.3.6.8', description="The Keras for ML researchers using PyTorch", author="William Falcon", author_email="waf2107@columbia.edu", diff --git a/tests/README.md b/tests/README.md index 20f783bf7e..f7a85b51a0 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,4 +1,4 @@ -# Pytorch-Lightning Tests +# PyTorch-Lightning Tests ## Running tests The automatic travis tests ONLY run CPU-based tests. Although these cover most of the use cases, diff --git a/tests/test_models.py b/tests/test_models.py index e8fa339b54..9d40c0da93 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,6 +24,33 @@ np.random.seed(SEED) # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ +def test_amp_gpu_ddp(): + """ + Make sure DDP + AMP work + :return: + """ + if not torch.cuda.is_available(): + warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test') + return + if not torch.cuda.device_count() > 1: + warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test') + return + + os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0]) + + hparams = get_hparams() + model = LightningTestModel(hparams) + + trainer_options = dict( + progress_bar=True, + max_nb_epochs=1, + gpus=[0, 1], + distributed_backend='ddp', + use_amp=True + ) + + run_gpu_model_test(trainer_options, model, hparams) + def test_cpu_slurm_save_load(): """ @@ -280,7 +307,7 @@ def test_amp_gpu_ddp_slurm_managed(): if trainer.use_ddp: # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model - trainer.optimizers = pretrained_model.configure_optimizers() + trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() # test HPC loading / saving trainer.hpc_save(save_dir, exp) @@ -477,33 +504,6 @@ def test_multi_gpu_model_ddp(): run_gpu_model_test(trainer_options, model, hparams) -def test_amp_gpu_ddp(): - """ - Make sure DDP + AMP work - :return: - """ - if not torch.cuda.is_available(): - warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a GPU node to run this test') - return - if not torch.cuda.device_count() > 1: - warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test') - return - - os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0]) - - hparams = get_hparams() - model = LightningTestModel(hparams) - - trainer_options = dict( - progress_bar=True, - max_nb_epochs=1, - gpus=[0, 1], - distributed_backend='ddp', - use_amp=True - ) - - run_gpu_model_test(trainer_options, model, hparams) - def test_ddp_sampler_error(): """ @@ -574,7 +574,7 @@ def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True): if trainer.use_ddp: # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model - trainer.optimizers = pretrained_model.configure_optimizers() + trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() # test HPC loading / saving trainer.hpc_save(save_dir, exp)