From e7134a91358c0fda0d6adaf0171512c5c2008b53 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 18 Nov 2020 16:38:35 +0000 Subject: [PATCH] Sharded Plugin 2/n: Allow ddp plugin to modify optimizer state saving (#4675) * Allow ddp plugin to modify optimizer state saving * Rely on the accelerator for optimizer states * Ensure we init the accelerator for the saving function * Better comment for optim state dump * Revert "Ensure we init the accelerator for the saving function" This reverts commit af65effa * Added accelerator check to initialize tuner before saving model checkpoint * Simplify comment * Revert "Added accelerator check to initialize tuner before saving model checkpoint" This reverts commit f9929c0c * Return single optimizer state to reduce duplication * Fixed docstring * Fixed typing * Fixed comment * Added CHANGELOG.md Co-authored-by: chaton --- CHANGELOG.md | 3 +++ pytorch_lightning/accelerators/accelerator.py | 13 ++++++++++++- pytorch_lightning/plugins/ddp_plugin.py | 5 +++++ .../trainer/connectors/checkpoint_connector.py | 6 ++++-- 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72aa5cfcbd..797569f711 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) +- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675)) + + ### Changed - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e765c2ab62..f2751cc2da 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -13,7 +13,7 @@ # limitations under the License. import os from enum import Enum -from typing import Any, Optional, Union +from typing import Any, Optional, Union, List import torch from torch.optim import Optimizer @@ -202,6 +202,17 @@ class Accelerator(object): """ raise NotImplementedError() + def optimizer_state(self, optimizer: Optimizer) -> dict: + """ + Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom + plugins. + Return: + Optimizer state dict + """ + if self.ddp_plugin: + return self.ddp_plugin.optimizer_state(optimizer) + return optimizer.state_dict() + def __getstate__(self): return { 'trainer': self.trainer, diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 4d73d4bdde..8afb80cd67 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -1,5 +1,7 @@ from typing import List, Dict, Any +from torch.optim import Optimizer + from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel @@ -80,3 +82,6 @@ class DDPPlugin(object): Returns: args moved to correct device if needed. """ return args + + def optimizer_state(self, optimizer: Optimizer) -> dict: + return optimizer.state_dict() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3b44ce96c0..d98a3137be 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -298,10 +298,12 @@ class CheckpointConnector: callback_states = self.trainer.on_save_checkpoint() checkpoint['callbacks'] = callback_states - # dump optimizers optimizer_states = [] for i, optimizer in enumerate(self.trainer.optimizers): - optimizer_states.append(optimizer.state_dict()) + # Rely on accelerator to dump optimizer state + optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer) + optimizer_states.append(optimizer_state) + checkpoint['optimizer_states'] = optimizer_states # dump lr schedulers