From ed5bda3eda0664b9d5b525211617bd87a085bad7 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Fri, 4 Dec 2020 16:59:38 +0000 Subject: [PATCH] [docs] Added description of saving using ddp (#4660) * Added description of saving using ddp * Added code block example to explain DDP saving logic * Fixed underline * Added verbose explanation * Apply suggestions from code review * Added caveat when using custom saving functions * flake8 Co-authored-by: Rohit Gupta Co-authored-by: Jirka Borovec --- docs/source/weights_loading.rst | 18 ++++++++++++++++++ pl_examples/domain_templates/imagenet.py | 4 ++-- tests/test_deprecated.py | 2 +- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/docs/source/weights_loading.rst b/docs/source/weights_loading.rst index 627ede828d..5dc80b51e5 100644 --- a/docs/source/weights_loading.rst +++ b/docs/source/weights_loading.rst @@ -14,6 +14,7 @@ Lightning automates saving and loading checkpoints. Checkpoints capture the exac Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model. + ***************** Checkpoint saving ***************** @@ -138,6 +139,23 @@ You can manually save checkpoints and restore your model from the checkpointed s trainer.save_checkpoint("example.ckpt") new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt") +Manual saving with accelerators +=============================== + +Lightning also handles accelerators where multiple processes are running, such as DDP. For example, when using the DDP accelerator our training script is running across multiple devices at the same time. +Lightning automatically ensures that the model is saved only on the main process, whilst other processes do not interfere with saving checkpoints. This requires no code changes as seen below. + +.. code-block:: python + + trainer = Trainer(accelerator="ddp") + model = MyLightningModule(hparams) + trainer.fit(model) + # Saves only on the main process + trainer.save_checkpoint("example.ckpt") + +Not using `trainer.save_checkpoint` can lead to unexpected behaviour and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the trainer's save functionality. +If using custom saving functions cannot be avoided, we recommend using :func:`~pytorch_lightning.loggers.base.rank_zero_only` to ensure saving occurs only on the main process. + ****************** Checkpoint loading ****************** diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index e4ac3c490d..ba4b292fb4 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -72,12 +72,12 @@ class ImageNetLightningModel(LightningModule): def training_step(self, batch, batch_idx): images, target = batch output = self(images) - loss_train= F.cross_entropy(output, target) + loss_train = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) self.log('train_loss', loss_train, on_step=True, on_epoch=True, logger=True) self.log('train_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True, logger=True) self.log('train_acc5', acc5, on_step=True, on_epoch=True, logger=True) - return loss_val + return loss_train def validation_step(self, batch, batch_idx): images, target = batch diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index d103c13b7e..2c89e19880 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -7,7 +7,7 @@ import pytest import torch from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.metrics.functional.classification import auc from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler from pytorch_lightning.utilities.exceptions import MisconfigurationException