Sharded Plugin 2/n: Allow ddp plugin to modify optimizer state saving (#4675)

* Allow ddp plugin to modify optimizer state saving

* Rely on the accelerator for optimizer states

* Ensure we init the accelerator for the saving function

* Better comment for optim state dump

* Revert "Ensure we init the accelerator for the saving function"

This reverts commit af65effa

* Added accelerator check to initialize tuner before saving model checkpoint

* Simplify comment

* Revert "Added accelerator check to initialize tuner before saving model checkpoint"

This reverts commit f9929c0c

* Return single optimizer state to reduce duplication

* Fixed docstring

* Fixed typing

* Fixed comment

* Added CHANGELOG.md

Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
Sean Naren 2020-11-18 16:38:35 +00:00 committed by GitHub
parent 8283680aa0
commit e7134a9135
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 3 deletions

View File

@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675))
### Changed
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))

View File

@ -13,7 +13,7 @@
# limitations under the License.
import os
from enum import Enum
from typing import Any, Optional, Union
from typing import Any, Optional, Union, List
import torch
from torch.optim import Optimizer
@ -202,6 +202,17 @@ class Accelerator(object):
"""
raise NotImplementedError()
def optimizer_state(self, optimizer: Optimizer) -> dict:
"""
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
plugins.
Return:
Optimizer state dict
"""
if self.ddp_plugin:
return self.ddp_plugin.optimizer_state(optimizer)
return optimizer.state_dict()
def __getstate__(self):
return {
'trainer': self.trainer,

View File

@ -1,5 +1,7 @@
from typing import List, Dict, Any
from torch.optim import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
@ -80,3 +82,6 @@ class DDPPlugin(object):
Returns: args moved to correct device if needed.
"""
return args
def optimizer_state(self, optimizer: Optimizer) -> dict:
return optimizer.state_dict()

View File

@ -298,10 +298,12 @@ class CheckpointConnector:
callback_states = self.trainer.on_save_checkpoint()
checkpoint['callbacks'] = callback_states
# dump optimizers
optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
optimizer_states.append(optimizer.state_dict())
# Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)
checkpoint['optimizer_states'] = optimizer_states
# dump lr schedulers