diff --git a/CHANGELOG.md b/CHANGELOG.md index 72aa5cfcbd..797569f711 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) +- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675)) + + ### Changed - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e765c2ab62..f2751cc2da 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -13,7 +13,7 @@ # limitations under the License. import os from enum import Enum -from typing import Any, Optional, Union +from typing import Any, Optional, Union, List import torch from torch.optim import Optimizer @@ -202,6 +202,17 @@ class Accelerator(object): """ raise NotImplementedError() + def optimizer_state(self, optimizer: Optimizer) -> dict: + """ + Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom + plugins. + Return: + Optimizer state dict + """ + if self.ddp_plugin: + return self.ddp_plugin.optimizer_state(optimizer) + return optimizer.state_dict() + def __getstate__(self): return { 'trainer': self.trainer, diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 4d73d4bdde..8afb80cd67 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -1,5 +1,7 @@ from typing import List, Dict, Any +from torch.optim import Optimizer + from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel @@ -80,3 +82,6 @@ class DDPPlugin(object): Returns: args moved to correct device if needed. """ return args + + def optimizer_state(self, optimizer: Optimizer) -> dict: + return optimizer.state_dict() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3b44ce96c0..d98a3137be 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -298,10 +298,12 @@ class CheckpointConnector: callback_states = self.trainer.on_save_checkpoint() checkpoint['callbacks'] = callback_states - # dump optimizers optimizer_states = [] for i, optimizer in enumerate(self.trainer.optimizers): - optimizer_states.append(optimizer.state_dict()) + # Rely on accelerator to dump optimizer state + optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer) + optimizer_states.append(optimizer_state) + checkpoint['optimizer_states'] = optimizer_states # dump lr schedulers