diff --git a/CHANGELOG.md b/CHANGELOG.md index 3552838f54..267895c407 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -117,6 +117,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Accelerator.is_available` to check device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797)) +- Added utility functions for moving optimizers to devices ([#11758](https://github.com/PyTorchLightning/pytorch-lightning/pull/11758)) + + ### Changed - Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332)) diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index 67208a127f..e11136118e 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -288,6 +288,7 @@ Utilities API finite_checks memory model_summary + optimizer parsing rank_zero seed diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index cbf66b7040..f8266189c3 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -39,6 +39,7 @@ from pytorch_lightning.utilities.enums import AMPType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, LRSchedulerTypeUnion, STEP_OUTPUT @@ -349,7 +350,7 @@ class DeepSpeedStrategy(DDPStrategy): self.accelerator.setup(trainer) self.setup_optimizers(trainer) self.setup_precision_plugin() - self._move_optimizer_state() + optimizers_to_device(self.optimizers, self.root_device) self.init_deepspeed() self.barrier() diff --git a/pytorch_lightning/strategies/fully_sharded.py b/pytorch_lightning/strategies/fully_sharded.py index af2d6d74bf..41d036bdcb 100644 --- a/pytorch_lightning/strategies/fully_sharded.py +++ b/pytorch_lightning/strategies/fully_sharded.py @@ -25,6 +25,7 @@ from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.types import STEP_OUTPUT if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: @@ -136,7 +137,7 @@ class DDPFullyShardedStrategy(DDPStrategy): self.accelerator.setup(trainer) self.setup_optimizers(trainer) self.setup_precision_plugin() - self._move_optimizer_state() + optimizers_to_device(self.optimizers, self.root_device) if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index e4d8827e9a..3ff8f10c65 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -30,9 +30,10 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT TBroadcast = TypeVar("TBroadcast") @@ -138,7 +139,7 @@ class Strategy(ABC): self.accelerator.setup(trainer) self.setup_optimizers(trainer) self.setup_precision_plugin() - self._move_optimizer_state() + optimizers_to_device(self.optimizers, self.root_device) def setup_precision_plugin(self) -> None: """Attaches the precision plugin to the accelerator.""" @@ -149,14 +150,6 @@ class Strategy(ABC): self.optimizers = optimizers self.lr_scheduler_configs = lr_scheduler_configs - def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: - """Moves the state of the optimizers to the appropriate device if needed.""" - for opt in self.optimizers: - for p, v in opt.state.items(): - # `self.root_device` would raise error if called outside the spawn process - # while training on 8 and more cores. - opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device) - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """Returns state of an optimizer. @@ -330,6 +323,7 @@ class Strategy(ABC): optimizer_states = checkpoint["optimizer_states"] for optimizer, opt_state in zip(self.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) + optimizer_to_device(optimizer, self.root_device) def training_step(self, *args, **kwargs) -> STEP_OUTPUT: """The actual training step. @@ -445,7 +439,7 @@ class Strategy(ABC): It is the right place to release memory and free other resources. """ - self._move_optimizer_state(torch.device("cpu")) + optimizers_to_device(self.optimizers, torch.device("cpu")) self.precision_plugin.teardown() @classmethod diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index d97797f92d..1606f3bb44 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -31,6 +31,7 @@ from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT @@ -126,7 +127,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy): self.accelerator.setup(trainer) self.setup_optimizers(trainer) self.setup_precision_plugin() - self._move_optimizer_state() + optimizers_to_device(self.optimizers, self.root_device) if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 8533fc6ee3..dbe0539e4d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -296,17 +296,6 @@ class CheckpointConnector: # restore the optimizers self.trainer.strategy.load_optimizer_state_dict(self._loaded_checkpoint) - for optimizer in self.trainer.optimizers: - # move optimizer to GPU 1 weight at a time - # avoids OOM - if self.trainer.root_gpu is not None: - for param, state in optimizer.state.items(): - if isinstance(state, dict): - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state[k] = v.cuda(self.trainer.root_gpu) - elif isinstance(state, torch.Tensor): - optimizer.state[param] = state.cuda(self.trainer.root_gpu) def restore_lr_schedulers(self) -> None: """Restores the learning rate scheduler states from the pre-loaded checkpoint.""" diff --git a/pytorch_lightning/utilities/optimizer.py b/pytorch_lightning/utilities/optimizer.py new file mode 100644 index 0000000000..75efa57fc0 --- /dev/null +++ b/pytorch_lightning/utilities/optimizer.py @@ -0,0 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable + +import torch +from torch.optim import Optimizer + +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device +from pytorch_lightning.utilities.types import _DEVICE + + +def optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> None: + """Moves optimizer states for a sequence of optimizers to the device.""" + for opt in optimizers: + optimizer_to_device(opt, device) + + +def optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None: + """Moves the state of a single optimizer to the device.""" + for p, v in optimizer.state.items(): + optimizer.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) diff --git a/tests/utilities/test_optimizer.py b/tests/utilities/test_optimizer.py new file mode 100644 index 0000000000..6d4c0ec54e --- /dev/null +++ b/tests/utilities/test_optimizer.py @@ -0,0 +1,30 @@ +import collections + +import torch + +from pytorch_lightning.utilities.optimizer import optimizer_to_device + + +def test_optimizer_to_device(): + class TestOptimizer(torch.optim.SGD): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.state["dummy"] = torch.tensor(0) + + layer = torch.nn.Linear(32, 2) + opt = TestOptimizer(layer.parameters(), lr=0.1) + optimizer_to_device(opt, "cpu") + if torch.cuda.is_available(): + optimizer_to_device(opt, "cuda") + assert_opt_parameters_on_device(opt, "cuda") + + +def assert_opt_parameters_on_device(opt, device: str): + for param in opt.state.values(): + # Not sure there are any global tensors in the state dict + if isinstance(param, torch.Tensor): + assert param.data.device.type == device + elif isinstance(param, collections.Mapping): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + assert param.data.device.type == device