[docs] Added description of saving using ddp (#4660)
* Added description of saving using ddp * Added code block example to explain DDP saving logic * Fixed underline * Added verbose explanation * Apply suggestions from code review * Added caveat when using custom saving functions * flake8 Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
parent
62903717a4
commit
ed5bda3eda
|
@ -14,6 +14,7 @@ Lightning automates saving and loading checkpoints. Checkpoints capture the exac
|
|||
|
||||
Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.
|
||||
|
||||
|
||||
*****************
|
||||
Checkpoint saving
|
||||
*****************
|
||||
|
@ -138,6 +139,23 @@ You can manually save checkpoints and restore your model from the checkpointed s
|
|||
trainer.save_checkpoint("example.ckpt")
|
||||
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
|
||||
|
||||
Manual saving with accelerators
|
||||
===============================
|
||||
|
||||
Lightning also handles accelerators where multiple processes are running, such as DDP. For example, when using the DDP accelerator our training script is running across multiple devices at the same time.
|
||||
Lightning automatically ensures that the model is saved only on the main process, whilst other processes do not interfere with saving checkpoints. This requires no code changes as seen below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(accelerator="ddp")
|
||||
model = MyLightningModule(hparams)
|
||||
trainer.fit(model)
|
||||
# Saves only on the main process
|
||||
trainer.save_checkpoint("example.ckpt")
|
||||
|
||||
Not using `trainer.save_checkpoint` can lead to unexpected behaviour and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the trainer's save functionality.
|
||||
If using custom saving functions cannot be avoided, we recommend using :func:`~pytorch_lightning.loggers.base.rank_zero_only` to ensure saving occurs only on the main process.
|
||||
|
||||
******************
|
||||
Checkpoint loading
|
||||
******************
|
||||
|
|
|
@ -72,12 +72,12 @@ class ImageNetLightningModel(LightningModule):
|
|||
def training_step(self, batch, batch_idx):
|
||||
images, target = batch
|
||||
output = self(images)
|
||||
loss_train= F.cross_entropy(output, target)
|
||||
loss_train = F.cross_entropy(output, target)
|
||||
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
|
||||
self.log('train_loss', loss_train, on_step=True, on_epoch=True, logger=True)
|
||||
self.log('train_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True, logger=True)
|
||||
self.log('train_acc5', acc5, on_step=True, on_epoch=True, logger=True)
|
||||
return loss_val
|
||||
return loss_train
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
images, target = batch
|
||||
|
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.metrics.functional.classification import auc
|
||||
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
|
Loading…
Reference in New Issue