Fix Adagrad optimizer not working with DDP/GPU (#7277)

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Adrian Wälchli 2021-05-03 00:27:17 +02:00 committed by GitHub
parent 29357ba94e
commit e0c64f0ef6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 1 deletions

View File

@ -440,6 +440,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `apex` not properly instantiated when running with `ddp` ([#7274](https://github.com/PyTorchLightning/pytorch-lightning/pull/7274))
- Fixed optimizer `state` not moved to `GPU` ([#7277](https://github.com/PyTorchLightning/pytorch-lightning/pull/7277))
## [1.2.7] - 2021-04-06
### Fixed

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections import defaultdict
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
import torch
@ -25,7 +26,7 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, Native
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
@ -102,11 +103,22 @@ class Accelerator:
def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self._move_optimizer_state()
self.training_type_plugin.pre_dispatch()
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)
self.precision_plugin.pre_dispatch()
def _move_optimizer_state(self) -> None:
""" Moves the state of the optimizers to the GPU if needed. """
for opt in self.optimizers:
state = defaultdict(dict)
for p, v in opt.state.items():
state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
opt.state = state
def dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.dispatch(trainer)

View File

@ -19,6 +19,7 @@ from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
def test_optimizer_with_scheduling(tmpdir):
@ -498,3 +499,28 @@ def test_warn_invalid_scheduler_key_in_manual_optimization(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.warns(RuntimeWarning, match='the keys will be ignored'):
trainer.fit(model)
class TestModel(BoringModel):
def configure_optimizers(self):
# Adagrad creates state tensors immediately, model is not yet on GPU.
return torch.optim.Adagrad(self.parameters())
def on_train_start(self, *args, **kwargs):
opt = self.optimizers()
_, state = next(iter(opt.state.items()))
assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device
@RunIf(min_gpus=2, special=True)
def test_optimizer_state_on_device(tmpdir):
""" Test that optimizers that create state initially at instantiation still end up with the state on the GPU. """
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
gpus=2,
accelerator="ddp",
fast_dev_run=True,
)
trainer.fit(model)