[docs] Added description of saving using ddp (#4660)

* Added description of saving using ddp

* Added code block example to explain DDP saving logic

* Fixed underline

* Added verbose explanation

* Apply suggestions from code review

* Added caveat when using custom saving functions

* flake8

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
Sean Naren 2020-12-04 16:59:38 +00:00 committed by GitHub
parent 62903717a4
commit ed5bda3eda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 3 deletions

View File

@ -14,6 +14,7 @@ Lightning automates saving and loading checkpoints. Checkpoints capture the exac
Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.
*****************
Checkpoint saving
*****************
@ -138,6 +139,23 @@ You can manually save checkpoints and restore your model from the checkpointed s
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
Manual saving with accelerators
===============================
Lightning also handles accelerators where multiple processes are running, such as DDP. For example, when using the DDP accelerator our training script is running across multiple devices at the same time.
Lightning automatically ensures that the model is saved only on the main process, whilst other processes do not interfere with saving checkpoints. This requires no code changes as seen below.
.. code-block:: python
trainer = Trainer(accelerator="ddp")
model = MyLightningModule(hparams)
trainer.fit(model)
# Saves only on the main process
trainer.save_checkpoint("example.ckpt")
Not using `trainer.save_checkpoint` can lead to unexpected behaviour and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the trainer's save functionality.
If using custom saving functions cannot be avoided, we recommend using :func:`~pytorch_lightning.loggers.base.rank_zero_only` to ensure saving occurs only on the main process.
******************
Checkpoint loading
******************

View File

@ -72,12 +72,12 @@ class ImageNetLightningModel(LightningModule):
def training_step(self, batch, batch_idx):
images, target = batch
output = self(images)
loss_train= F.cross_entropy(output, target)
loss_train = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
self.log('train_loss', loss_train, on_step=True, on_epoch=True, logger=True)
self.log('train_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True, logger=True)
self.log('train_acc5', acc5, on_step=True, on_epoch=True, logger=True)
return loss_val
return loss_train
def validation_step(self, batch, batch_idx):
images, target = batch

View File

@ -7,7 +7,7 @@ import pytest
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.metrics.functional.classification import auc
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler
from pytorch_lightning.utilities.exceptions import MisconfigurationException