Refactor `Strategy._move_optimizer_states` as utility functions (#11758)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
d61371922b
commit
cf64f34434
|
@ -117,6 +117,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `Accelerator.is_available` to check device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797))
|
||||
|
||||
|
||||
- Added utility functions for moving optimizers to devices ([#11758](https://github.com/PyTorchLightning/pytorch-lightning/pull/11758))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332))
|
||||
|
|
|
@ -288,6 +288,7 @@ Utilities API
|
|||
finite_checks
|
||||
memory
|
||||
model_summary
|
||||
optimizer
|
||||
parsing
|
||||
rank_zero
|
||||
seed
|
||||
|
|
|
@ -39,6 +39,7 @@ from pytorch_lightning.utilities.enums import AMPType, PrecisionType
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT
|
||||
|
@ -349,7 +350,7 @@ class DeepSpeedStrategy(DDPStrategy):
|
|||
self.accelerator.setup(trainer)
|
||||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin()
|
||||
self._move_optimizer_state()
|
||||
optimizers_to_device(self.optimizers, self.root_device)
|
||||
self.init_deepspeed()
|
||||
self.barrier()
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ from pytorch_lightning.strategies.ddp import DDPStrategy
|
|||
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
|
||||
|
@ -136,7 +137,7 @@ class DDPFullyShardedStrategy(DDPStrategy):
|
|||
self.accelerator.setup(trainer)
|
||||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin()
|
||||
self._move_optimizer_state()
|
||||
optimizers_to_device(self.optimizers, self.root_device)
|
||||
|
||||
if self.sync_batchnorm:
|
||||
self.model = self.configure_sync_batchnorm(self.model)
|
||||
|
|
|
@ -30,9 +30,10 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin
|
|||
from pytorch_lightning.strategies.launchers.base import _Launcher
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device
|
||||
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT
|
||||
|
||||
TBroadcast = TypeVar("TBroadcast")
|
||||
|
@ -138,7 +139,7 @@ class Strategy(ABC):
|
|||
self.accelerator.setup(trainer)
|
||||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin()
|
||||
self._move_optimizer_state()
|
||||
optimizers_to_device(self.optimizers, self.root_device)
|
||||
|
||||
def setup_precision_plugin(self) -> None:
|
||||
"""Attaches the precision plugin to the accelerator."""
|
||||
|
@ -149,14 +150,6 @@ class Strategy(ABC):
|
|||
self.optimizers = optimizers
|
||||
self.lr_scheduler_configs = lr_scheduler_configs
|
||||
|
||||
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
|
||||
"""Moves the state of the optimizers to the appropriate device if needed."""
|
||||
for opt in self.optimizers:
|
||||
for p, v in opt.state.items():
|
||||
# `self.root_device` would raise error if called outside the spawn process
|
||||
# while training on 8 and more cores.
|
||||
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device)
|
||||
|
||||
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
|
||||
"""Returns state of an optimizer.
|
||||
|
||||
|
@ -330,6 +323,7 @@ class Strategy(ABC):
|
|||
optimizer_states = checkpoint["optimizer_states"]
|
||||
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
|
||||
optimizer.load_state_dict(opt_state)
|
||||
optimizer_to_device(optimizer, self.root_device)
|
||||
|
||||
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
|
||||
"""The actual training step.
|
||||
|
@ -445,7 +439,7 @@ class Strategy(ABC):
|
|||
|
||||
It is the right place to release memory and free other resources.
|
||||
"""
|
||||
self._move_optimizer_state(torch.device("cpu"))
|
||||
optimizers_to_device(self.optimizers, torch.device("cpu"))
|
||||
self.precision_plugin.teardown()
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -31,6 +31,7 @@ from pytorch_lightning.utilities.data import has_len
|
|||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
|
||||
|
@ -126,7 +127,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
|
|||
self.accelerator.setup(trainer)
|
||||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin()
|
||||
self._move_optimizer_state()
|
||||
optimizers_to_device(self.optimizers, self.root_device)
|
||||
|
||||
if self.debug:
|
||||
os.environ["PT_XLA_DEBUG"] = str(1)
|
||||
|
|
|
@ -296,17 +296,6 @@ class CheckpointConnector:
|
|||
|
||||
# restore the optimizers
|
||||
self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)
|
||||
for optimizer in self.trainer.optimizers:
|
||||
# move optimizer to GPU 1 weight at a time
|
||||
# avoids OOM
|
||||
if self.trainer.root_gpu is not None:
|
||||
for param, state in optimizer.state.items():
|
||||
if isinstance(state, dict):
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.cuda(self.trainer.root_gpu)
|
||||
elif isinstance(state, torch.Tensor):
|
||||
optimizer.state[param] = state.cuda(self.trainer.root_gpu)
|
||||
|
||||
def restore_lr_schedulers(self) -> None:
|
||||
"""Restores the learning rate scheduler states from the pre-loaded checkpoint."""
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
|
||||
def optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> None:
|
||||
"""Moves optimizer states for a sequence of optimizers to the device."""
|
||||
for opt in optimizers:
|
||||
optimizer_to_device(opt, device)
|
||||
|
||||
|
||||
def optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
|
||||
"""Moves the state of a single optimizer to the device."""
|
||||
for p, v in optimizer.state.items():
|
||||
optimizer.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device)
|
|
@ -0,0 +1,30 @@
|
|||
import collections
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities.optimizer import optimizer_to_device
|
||||
|
||||
|
||||
def test_optimizer_to_device():
|
||||
class TestOptimizer(torch.optim.SGD):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.state["dummy"] = torch.tensor(0)
|
||||
|
||||
layer = torch.nn.Linear(32, 2)
|
||||
opt = TestOptimizer(layer.parameters(), lr=0.1)
|
||||
optimizer_to_device(opt, "cpu")
|
||||
if torch.cuda.is_available():
|
||||
optimizer_to_device(opt, "cuda")
|
||||
assert_opt_parameters_on_device(opt, "cuda")
|
||||
|
||||
|
||||
def assert_opt_parameters_on_device(opt, device: str):
|
||||
for param in opt.state.values():
|
||||
# Not sure there are any global tensors in the state dict
|
||||
if isinstance(param, torch.Tensor):
|
||||
assert param.data.device.type == device
|
||||
elif isinstance(param, collections.Mapping):
|
||||
for subparam in param.values():
|
||||
if isinstance(subparam, torch.Tensor):
|
||||
assert param.data.device.type == device
|
Loading…
Reference in New Issue