From 16d6c9828d8203337a9323c5d7ad61e1e102ccf1 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 30 Apr 2021 18:16:28 +0100 Subject: [PATCH] [bugfix] Apex never instantiated. (#7274) * update * update * update apex * update * update * update * remove test.py * update * update * update on comments * update changelog * update * update * typo --- CHANGELOG.md | 2 + pytorch_lightning/accelerators/accelerator.py | 5 ++ pytorch_lightning/plugins/base_plugin.py | 5 ++ .../plugins/precision/apex_amp.py | 63 ++++--------------- pytorch_lightning/trainer/trainer.py | 3 +- tests/plugins/test_amp_plugins.py | 60 +++++++++++++++++- tests/plugins/test_cluster_integration.py | 13 ++++ 7 files changed, 98 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b47c8a8644..b141e5ed16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -434,6 +434,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `LightningModule.save_hyperparameters()` when attempting to save an empty container ([#7268](https://github.com/PyTorchLightning/pytorch-lightning/pull/7268)) +- Fixed `apex` not properly instantiated when running with `ddp` ([#7274](https://github.com/PyTorchLightning/pytorch-lightning/pull/7274)) + ## [1.2.7] - 2021-04-06 diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 9d99da26b9..f26d0d9d51 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -107,6 +107,11 @@ class Accelerator: self.setup_optimizers(trainer) self.precision_plugin.pre_dispatch() + def dispatch(self, trainer: 'pl.Trainer') -> None: + """Hook to do something before the training/evaluation/prediction starts.""" + self.training_type_plugin.dispatch(trainer) + self.precision_plugin.dispatch(trainer) + def post_dispatch(self, trainer: 'pl.Trainer') -> None: """Hook to do something after the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index ed45d0bc68..515fc29d0e 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -15,6 +15,8 @@ import contextlib from abc import ABC from typing import Generator +import pytorch_lightning as pl + class Plugin(ABC): """Basic class for all precision- and training type plugins.""" @@ -22,6 +24,9 @@ class Plugin(ABC): def pre_dispatch(self) -> None: """Hook to do something before the training/evaluation/prediction starts.""" + def dispatch(self, trainer: "pl.Trainer") -> None: + """Hook to do something at trainer run_stage starts.""" + def post_dispatch(self) -> None: """Hook to do something after the training/evaluation/prediction finishes.""" diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 762095e10e..71c2119e73 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, ContextManager, List, Sequence, Tuple, Type +from typing import Any, Callable, ContextManager, Sequence import torch from torch import Tensor -from torch.nn import Module from torch.optim import Optimizer -from pytorch_lightning.core import LightningModule +import pytorch_lightning as pl +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType from pytorch_lightning.utilities.types import _PARAMETERS @@ -34,24 +34,19 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): super().__init__() self.backend = AMPType.APEX self.amp_level = amp_level + self._connected = False def master_params(self, optimizer: Optimizer) -> _PARAMETERS: return amp.master_params(optimizer) - def connect( - self, - model: Module, - optimizers: List[Optimizer], - lr_schedulers: List[Any], - ) -> Tuple[Module, List[Optimizer], List[Any]]: - """Connects the precision plugin to the training process, - configures apex and reinits the schedulers - """ - if model.device.type != "cuda": - return model, optimizers, lr_schedulers - model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level) - self.reinit_scheduler_properties(optimizers, lr_schedulers) - return model, optimizers, lr_schedulers + def dispatch(self, trainer: "pl.Trainer") -> None: + if not self._connected: + accelerator = trainer.accelerator + _, accelerator.optimizers = amp.initialize( + trainer.lightning_module, accelerator.optimizers, opt_level=self.amp_level + ) + self._connected = True + return super().dispatch(trainer) def backward( self, @@ -99,40 +94,6 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): closure_loss = closure_loss.detach() return closure_loss - def configure_apex( - self, - amp: Type, - model: Module, - optimizers: List[Optimizer], - amp_level: str, - ) -> Tuple[Module, List[Optimizer]]: - r""" - Override to init AMP your own way. - Must return a model and list of optimizers. - - Args: - amp: pointer to amp library object. - model: pointer to current :class:`torch.nn.Module`. - optimizers: list of optimizers passed in :meth:`configure_optimizers`. - amp_level: AMP mode chosen ('O1', 'O2', etc...) - - Return: - Apex wrapped model and optimizers - - Examples: - .. code-block:: python - - # Default implementation used by Trainer. - def configure_apex(self, amp, model, optimizers, amp_level): - model, optimizers = amp.initialize( - model, optimizers, opt_level=amp_level, - ) - - return model, optimizers - """ - model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) - return model, optimizers - @staticmethod def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Sequence[Any]) -> None: """Reinitializes schedulers with correct properties""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b98c1c0c55..35397dc35d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -516,7 +516,8 @@ class Trainer( else: self.accelerator.start_training(self) - def run_stage(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + def run_stage(self): + self.accelerator.dispatch(self) self.profile_connector.setup() if self.evaluating: diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 328cb0a59f..6d0dbed2cf 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -1,3 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from unittest import mock @@ -37,7 +51,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin): pytest.param('native', False, NativeMixedPrecisionPlugin, marks=RunIf(amp_native=True)), pytest.param('native', True, MyNativeAMP, marks=RunIf(amp_native=True)), pytest.param('apex', False, ApexMixedPrecisionPlugin, marks=RunIf(amp_apex=True)), - pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True)) + pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True)), ] ) def test_amp_apex_ddp( @@ -83,3 +97,47 @@ def test_amp_gradient_unscale(tmpdir, accum: int): accumulate_grad_batches=accum, ) trainer.fit(model) + + +@RunIf(min_gpus=2, amp_apex=True, special=True) +@pytest.mark.parametrize("amp_level", ['O2']) +def test_amp_apex_ddp_fit(amp_level, tmpdir): + + class CustomBoringModel(BoringModel): + + def training_step(self, batch, batch_idx): + assert self.layer.weight.dtype == torch.float16 + assert self.trainer.precision_plugin._connected + return super().training_step(batch, batch_idx) + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + precision=16, + amp_backend="apex", + gpus=2, + accelerator='ddp', + plugins=ApexMixedPrecisionPlugin(amp_level=amp_level), + ) + assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin) + model = CustomBoringModel() + trainer.fit(model) + trainer.test(model) + + +@RunIf(min_gpus=2, amp_apex=True) +@pytest.mark.parametrize("amp_level", ['O2']) +def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir): + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + precision=16, + amp_backend="apex", + gpus=2, + accelerator='ddp_spawn', + plugins=ApexMixedPrecisionPlugin(amp_level=amp_level), + ) + assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin) + model = BoringModel() + trainer.fit(model) diff --git a/tests/plugins/test_cluster_integration.py b/tests/plugins/test_cluster_integration.py index 032276dd67..fda6b67ad5 100644 --- a/tests/plugins/test_cluster_integration.py +++ b/tests/plugins/test_cluster_integration.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os from unittest import mock