Sharded Plugin 2/n: Allow ddp plugin to modify optimizer state saving (#4675)
* Allow ddp plugin to modify optimizer state saving * Rely on the accelerator for optimizer states * Ensure we init the accelerator for the saving function * Better comment for optim state dump * Revert "Ensure we init the accelerator for the saving function" This reverts commitaf65effa
* Added accelerator check to initialize tuner before saving model checkpoint * Simplify comment * Revert "Added accelerator check to initialize tuner before saving model checkpoint" This reverts commitf9929c0c
* Return single optimizer state to reduce duplication * Fixed docstring * Fixed typing * Fixed comment * Added CHANGELOG.md Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
8283680aa0
commit
e7134a9135
|
@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
[#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
|
[#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
|
||||||
|
|
||||||
|
|
||||||
|
- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675))
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
|
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -202,6 +202,17 @@ class Accelerator(object):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def optimizer_state(self, optimizer: Optimizer) -> dict:
|
||||||
|
"""
|
||||||
|
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
|
||||||
|
plugins.
|
||||||
|
Return:
|
||||||
|
Optimizer state dict
|
||||||
|
"""
|
||||||
|
if self.ddp_plugin:
|
||||||
|
return self.ddp_plugin.optimizer_state(optimizer)
|
||||||
|
return optimizer.state_dict()
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
return {
|
return {
|
||||||
'trainer': self.trainer,
|
'trainer': self.trainer,
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
from pytorch_lightning.core.lightning import LightningModule
|
||||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||||
|
|
||||||
|
@ -80,3 +82,6 @@ class DDPPlugin(object):
|
||||||
Returns: args moved to correct device if needed.
|
Returns: args moved to correct device if needed.
|
||||||
"""
|
"""
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def optimizer_state(self, optimizer: Optimizer) -> dict:
|
||||||
|
return optimizer.state_dict()
|
||||||
|
|
|
@ -298,10 +298,12 @@ class CheckpointConnector:
|
||||||
callback_states = self.trainer.on_save_checkpoint()
|
callback_states = self.trainer.on_save_checkpoint()
|
||||||
checkpoint['callbacks'] = callback_states
|
checkpoint['callbacks'] = callback_states
|
||||||
|
|
||||||
# dump optimizers
|
|
||||||
optimizer_states = []
|
optimizer_states = []
|
||||||
for i, optimizer in enumerate(self.trainer.optimizers):
|
for i, optimizer in enumerate(self.trainer.optimizers):
|
||||||
optimizer_states.append(optimizer.state_dict())
|
# Rely on accelerator to dump optimizer state
|
||||||
|
optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer)
|
||||||
|
optimizer_states.append(optimizer_state)
|
||||||
|
|
||||||
checkpoint['optimizer_states'] = optimizer_states
|
checkpoint['optimizer_states'] = optimizer_states
|
||||||
|
|
||||||
# dump lr schedulers
|
# dump lr schedulers
|
||||||
|
|
Loading…
Reference in New Issue