Sharded Plugin 2/n: Allow ddp plugin to modify optimizer state saving (#4675)

* Allow ddp plugin to modify optimizer state saving

* Rely on the accelerator for optimizer states

* Ensure we init the accelerator for the saving function

* Better comment for optim state dump

* Revert "Ensure we init the accelerator for the saving function"

This reverts commit af65effa

* Added accelerator check to initialize tuner before saving model checkpoint

* Simplify comment

* Revert "Added accelerator check to initialize tuner before saving model checkpoint"

This reverts commit f9929c0c

* Return single optimizer state to reduce duplication

* Fixed docstring

* Fixed typing

* Fixed comment

* Added CHANGELOG.md

Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
Sean Naren 2020-11-18 16:38:35 +00:00 committed by GitHub
parent 8283680aa0
commit e7134a9135
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 3 deletions

View File

@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) [#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675))
### Changed ### Changed
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from enum import Enum from enum import Enum
from typing import Any, Optional, Union from typing import Any, Optional, Union, List
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
@ -202,6 +202,17 @@ class Accelerator(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def optimizer_state(self, optimizer: Optimizer) -> dict:
"""
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
plugins.
Return:
Optimizer state dict
"""
if self.ddp_plugin:
return self.ddp_plugin.optimizer_state(optimizer)
return optimizer.state_dict()
def __getstate__(self): def __getstate__(self):
return { return {
'trainer': self.trainer, 'trainer': self.trainer,

View File

@ -1,5 +1,7 @@
from typing import List, Dict, Any from typing import List, Dict, Any
from torch.optim import Optimizer
from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
@ -80,3 +82,6 @@ class DDPPlugin(object):
Returns: args moved to correct device if needed. Returns: args moved to correct device if needed.
""" """
return args return args
def optimizer_state(self, optimizer: Optimizer) -> dict:
return optimizer.state_dict()

View File

@ -298,10 +298,12 @@ class CheckpointConnector:
callback_states = self.trainer.on_save_checkpoint() callback_states = self.trainer.on_save_checkpoint()
checkpoint['callbacks'] = callback_states checkpoint['callbacks'] = callback_states
# dump optimizers
optimizer_states = [] optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers): for i, optimizer in enumerate(self.trainer.optimizers):
optimizer_states.append(optimizer.state_dict()) # Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)
checkpoint['optimizer_states'] = optimizer_states checkpoint['optimizer_states'] = optimizer_states
# dump lr schedulers # dump lr schedulers