Fix Adagrad optimizer not working with DDP/GPU (#7277)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
29357ba94e
commit
e0c64f0ef6
|
@ -440,6 +440,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `apex` not properly instantiated when running with `ddp` ([#7274](https://github.com/PyTorchLightning/pytorch-lightning/pull/7274))
|
||||
|
||||
|
||||
- Fixed optimizer `state` not moved to `GPU` ([#7277](https://github.com/PyTorchLightning/pytorch-lightning/pull/7277))
|
||||
|
||||
|
||||
## [1.2.7] - 2021-04-06
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -25,7 +26,7 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, Native
|
|||
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
|
||||
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
||||
|
||||
|
@ -102,11 +103,22 @@ class Accelerator:
|
|||
|
||||
def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
self._move_optimizer_state()
|
||||
|
||||
self.training_type_plugin.pre_dispatch()
|
||||
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
|
||||
self.setup_optimizers(trainer)
|
||||
|
||||
self.precision_plugin.pre_dispatch()
|
||||
|
||||
def _move_optimizer_state(self) -> None:
|
||||
""" Moves the state of the optimizers to the GPU if needed. """
|
||||
for opt in self.optimizers:
|
||||
state = defaultdict(dict)
|
||||
for p, v in opt.state.items():
|
||||
state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
|
||||
opt.state = state
|
||||
|
||||
def dispatch(self, trainer: 'pl.Trainer') -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
self.training_type_plugin.dispatch(trainer)
|
||||
|
|
|
@ -19,6 +19,7 @@ from pytorch_lightning.trainer.states import TrainerState
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
def test_optimizer_with_scheduling(tmpdir):
|
||||
|
@ -498,3 +499,28 @@ def test_warn_invalid_scheduler_key_in_manual_optimization(tmpdir):
|
|||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
with pytest.warns(RuntimeWarning, match='the keys will be ignored'):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def configure_optimizers(self):
|
||||
# Adagrad creates state tensors immediately, model is not yet on GPU.
|
||||
return torch.optim.Adagrad(self.parameters())
|
||||
|
||||
def on_train_start(self, *args, **kwargs):
|
||||
opt = self.optimizers()
|
||||
_, state = next(iter(opt.state.items()))
|
||||
assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, special=True)
|
||||
def test_optimizer_state_on_device(tmpdir):
|
||||
""" Test that optimizers that create state initially at instantiation still end up with the state on the GPU. """
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
gpus=2,
|
||||
accelerator="ddp",
|
||||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue