Sharded Plugin 2/n: Allow ddp plugin to modify optimizer state saving (#4675)
* Allow ddp plugin to modify optimizer state saving * Rely on the accelerator for optimizer states * Ensure we init the accelerator for the saving function * Better comment for optim state dump * Revert "Ensure we init the accelerator for the saving function" This reverts commitaf65effa
* Added accelerator check to initialize tuner before saving model checkpoint * Simplify comment * Revert "Added accelerator check to initialize tuner before saving model checkpoint" This reverts commitf9929c0c
* Return single optimizer state to reduce duplication * Fixed docstring * Fixed typing * Fixed comment * Added CHANGELOG.md Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
8283680aa0
commit
e7134a9135
|
@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
[#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
|
||||
|
||||
|
||||
- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, List
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
@ -202,6 +202,17 @@ class Accelerator(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def optimizer_state(self, optimizer: Optimizer) -> dict:
|
||||
"""
|
||||
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
|
||||
plugins.
|
||||
Return:
|
||||
Optimizer state dict
|
||||
"""
|
||||
if self.ddp_plugin:
|
||||
return self.ddp_plugin.optimizer_state(optimizer)
|
||||
return optimizer.state_dict()
|
||||
|
||||
def __getstate__(self):
|
||||
return {
|
||||
'trainer': self.trainer,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from typing import List, Dict, Any
|
||||
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
|
||||
|
@ -80,3 +82,6 @@ class DDPPlugin(object):
|
|||
Returns: args moved to correct device if needed.
|
||||
"""
|
||||
return args
|
||||
|
||||
def optimizer_state(self, optimizer: Optimizer) -> dict:
|
||||
return optimizer.state_dict()
|
||||
|
|
|
@ -298,10 +298,12 @@ class CheckpointConnector:
|
|||
callback_states = self.trainer.on_save_checkpoint()
|
||||
checkpoint['callbacks'] = callback_states
|
||||
|
||||
# dump optimizers
|
||||
optimizer_states = []
|
||||
for i, optimizer in enumerate(self.trainer.optimizers):
|
||||
optimizer_states.append(optimizer.state_dict())
|
||||
# Rely on accelerator to dump optimizer state
|
||||
optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer)
|
||||
optimizer_states.append(optimizer_state)
|
||||
|
||||
checkpoint['optimizer_states'] = optimizer_states
|
||||
|
||||
# dump lr schedulers
|
||||
|
|
Loading…
Reference in New Issue