Refactor `Strategy._move_optimizer_states` as utility functions (#11758)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
ananthsub 2022-02-18 00:36:07 -08:00 committed by GitHub
parent d61371922b
commit cf64f34434
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 78 additions and 25 deletions

View File

@ -117,6 +117,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Accelerator.is_available` to check device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797))
- Added utility functions for moving optimizers to devices ([#11758](https://github.com/PyTorchLightning/pytorch-lightning/pull/11758))
### Changed
- Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332))

View File

@ -288,6 +288,7 @@ Utilities API
finite_checks
memory
model_summary
optimizer
parsing
rank_zero
seed

View File

@ -39,6 +39,7 @@ from pytorch_lightning.utilities.enums import AMPType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT
@ -349,7 +350,7 @@ class DeepSpeedStrategy(DDPStrategy):
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
optimizers_to_device(self.optimizers, self.root_device)
self.init_deepspeed()
self.barrier()

View File

@ -25,6 +25,7 @@ from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.types import STEP_OUTPUT
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
@ -136,7 +137,7 @@ class DDPFullyShardedStrategy(DDPStrategy):
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
optimizers_to_device(self.optimizers, self.root_device)
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

View File

@ -30,9 +30,10 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.base import _Launcher
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT
TBroadcast = TypeVar("TBroadcast")
@ -138,7 +139,7 @@ class Strategy(ABC):
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
optimizers_to_device(self.optimizers, self.root_device)
def setup_precision_plugin(self) -> None:
"""Attaches the precision plugin to the accelerator."""
@ -149,14 +150,6 @@ class Strategy(ABC):
self.optimizers = optimizers
self.lr_scheduler_configs = lr_scheduler_configs
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
"""Moves the state of the optimizers to the appropriate device if needed."""
for opt in self.optimizers:
for p, v in opt.state.items():
# `self.root_device` would raise error if called outside the spawn process
# while training on 8 and more cores.
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device)
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""Returns state of an optimizer.
@ -330,6 +323,7 @@ class Strategy(ABC):
optimizer_states = checkpoint["optimizer_states"]
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
optimizer_to_device(optimizer, self.root_device)
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
"""The actual training step.
@ -445,7 +439,7 @@ class Strategy(ABC):
It is the right place to release memory and free other resources.
"""
self._move_optimizer_state(torch.device("cpu"))
optimizers_to_device(self.optimizers, torch.device("cpu"))
self.precision_plugin.teardown()
@classmethod

View File

@ -31,6 +31,7 @@ from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
@ -126,7 +127,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
self.accelerator.setup(trainer)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self._move_optimizer_state()
optimizers_to_device(self.optimizers, self.root_device)
if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)

View File

@ -296,17 +296,6 @@ class CheckpointConnector:
# restore the optimizers
self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint)
for optimizer in self.trainer.optimizers:
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
for param, state in optimizer.state.items():
if isinstance(state, dict):
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.trainer.root_gpu)
elif isinstance(state, torch.Tensor):
optimizer.state[param] = state.cuda(self.trainer.root_gpu)
def restore_lr_schedulers(self) -> None:
"""Restores the learning rate scheduler states from the pre-loaded checkpoint."""

View File

@ -0,0 +1,33 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable
import torch
from torch.optim import Optimizer
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.types import _DEVICE
def optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> None:
"""Moves optimizer states for a sequence of optimizers to the device."""
for opt in optimizers:
optimizer_to_device(opt, device)
def optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None:
"""Moves the state of a single optimizer to the device."""
for p, v in optimizer.state.items():
optimizer.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device)

View File

@ -0,0 +1,30 @@
import collections
import torch
from pytorch_lightning.utilities.optimizer import optimizer_to_device
def test_optimizer_to_device():
class TestOptimizer(torch.optim.SGD):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state["dummy"] = torch.tensor(0)
layer = torch.nn.Linear(32, 2)
opt = TestOptimizer(layer.parameters(), lr=0.1)
optimizer_to_device(opt, "cpu")
if torch.cuda.is_available():
optimizer_to_device(opt, "cuda")
assert_opt_parameters_on_device(opt, "cuda")
def assert_opt_parameters_on_device(opt, device: str):
for param in opt.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, torch.Tensor):
assert param.data.device.type == device
elif isinstance(param, collections.Mapping):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
assert param.data.device.type == device