From e0c64f0ef639a8b9f46e0d8e32e5e0a6b7532cff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 3 May 2021 00:27:17 +0200 Subject: [PATCH] Fix Adagrad optimizer not working with DDP/GPU (#7277) Co-authored-by: ananthsub Co-authored-by: thomas chaton --- CHANGELOG.md | 3 +++ pytorch_lightning/accelerators/accelerator.py | 14 +++++++++- tests/trainer/optimization/test_optimizers.py | 26 +++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a015cad4f5..b629e9e72a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -440,6 +440,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `apex` not properly instantiated when running with `ddp` ([#7274](https://github.com/PyTorchLightning/pytorch-lightning/pull/7274)) +- Fixed optimizer `state` not moved to `GPU` ([#7277](https://github.com/PyTorchLightning/pytorch-lightning/pull/7277)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f26d0d9d51..ab846a562a 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from collections import defaultdict from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union import torch @@ -25,7 +26,7 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, Native from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn -from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -102,11 +103,22 @@ class Accelerator: def pre_dispatch(self, trainer: 'pl.Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" + self._move_optimizer_state() + self.training_type_plugin.pre_dispatch() if self.training_type_plugin.setup_optimizers_in_pre_dispatch: self.setup_optimizers(trainer) + self.precision_plugin.pre_dispatch() + def _move_optimizer_state(self) -> None: + """ Moves the state of the optimizers to the GPU if needed. """ + for opt in self.optimizers: + state = defaultdict(dict) + for p, v in opt.state.items(): + state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) + opt.state = state + def dispatch(self, trainer: 'pl.Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.dispatch(trainer) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 22ecdfe801..71ef6e4938 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -19,6 +19,7 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf def test_optimizer_with_scheduling(tmpdir): @@ -498,3 +499,28 @@ def test_warn_invalid_scheduler_key_in_manual_optimization(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) with pytest.warns(RuntimeWarning, match='the keys will be ignored'): trainer.fit(model) + + +class TestModel(BoringModel): + + def configure_optimizers(self): + # Adagrad creates state tensors immediately, model is not yet on GPU. + return torch.optim.Adagrad(self.parameters()) + + def on_train_start(self, *args, **kwargs): + opt = self.optimizers() + _, state = next(iter(opt.state.items())) + assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device + + +@RunIf(min_gpus=2, special=True) +def test_optimizer_state_on_device(tmpdir): + """ Test that optimizers that create state initially at instantiation still end up with the state on the GPU. """ + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + gpus=2, + accelerator="ddp", + fast_dev_run=True, + ) + trainer.fit(model)