[bugfix] Apex never instantiated. (#7274)

* update

* update

* update apex

* update

* update

* update

* remove test.py

* update

* update

* update on comments

* update changelog

* update

* update

* typo
This commit is contained in:
thomas chaton 2021-04-30 18:16:28 +01:00 committed by GitHub
parent 44fd01734c
commit 16d6c9828d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 98 additions and 53 deletions

View File

@ -434,6 +434,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `LightningModule.save_hyperparameters()` when attempting to save an empty container ([#7268](https://github.com/PyTorchLightning/pytorch-lightning/pull/7268))
- Fixed `apex` not properly instantiated when running with `ddp` ([#7274](https://github.com/PyTorchLightning/pytorch-lightning/pull/7274))
## [1.2.7] - 2021-04-06

View File

@ -107,6 +107,11 @@ class Accelerator:
self.setup_optimizers(trainer)
self.precision_plugin.pre_dispatch()
def dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self.training_type_plugin.dispatch(trainer)
self.precision_plugin.dispatch(trainer)
def post_dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()

View File

@ -15,6 +15,8 @@ import contextlib
from abc import ABC
from typing import Generator
import pytorch_lightning as pl
class Plugin(ABC):
"""Basic class for all precision- and training type plugins."""
@ -22,6 +24,9 @@ class Plugin(ABC):
def pre_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something at trainer run_stage starts."""
def post_dispatch(self) -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""

View File

@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, ContextManager, List, Sequence, Tuple, Type
from typing import Any, Callable, ContextManager, Sequence
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from pytorch_lightning.core import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
from pytorch_lightning.utilities.types import _PARAMETERS
@ -34,24 +34,19 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
super().__init__()
self.backend = AMPType.APEX
self.amp_level = amp_level
self._connected = False
def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
return amp.master_params(optimizer)
def connect(
self,
model: Module,
optimizers: List[Optimizer],
lr_schedulers: List[Any],
) -> Tuple[Module, List[Optimizer], List[Any]]:
"""Connects the precision plugin to the training process,
configures apex and reinits the schedulers
"""
if model.device.type != "cuda":
return model, optimizers, lr_schedulers
model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level)
self.reinit_scheduler_properties(optimizers, lr_schedulers)
return model, optimizers, lr_schedulers
def dispatch(self, trainer: "pl.Trainer") -> None:
if not self._connected:
accelerator = trainer.accelerator
_, accelerator.optimizers = amp.initialize(
trainer.lightning_module, accelerator.optimizers, opt_level=self.amp_level
)
self._connected = True
return super().dispatch(trainer)
def backward(
self,
@ -99,40 +94,6 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
closure_loss = closure_loss.detach()
return closure_loss
def configure_apex(
self,
amp: Type,
model: Module,
optimizers: List[Optimizer],
amp_level: str,
) -> Tuple[Module, List[Optimizer]]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
Args:
amp: pointer to amp library object.
model: pointer to current :class:`torch.nn.Module`.
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
amp_level: AMP mode chosen ('O1', 'O2', etc...)
Return:
Apex wrapped model and optimizers
Examples:
.. code-block:: python
# Default implementation used by Trainer.
def configure_apex(self, amp, model, optimizers, amp_level):
model, optimizers = amp.initialize(
model, optimizers, opt_level=amp_level,
)
return model, optimizers
"""
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers
@staticmethod
def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Sequence[Any]) -> None:
"""Reinitializes schedulers with correct properties"""

View File

@ -516,7 +516,8 @@ class Trainer(
else:
self.accelerator.start_training(self)
def run_stage(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
def run_stage(self):
self.accelerator.dispatch(self)
self.profile_connector.setup()
if self.evaluating:

View File

@ -1,3 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
@ -37,7 +51,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin):
pytest.param('native', False, NativeMixedPrecisionPlugin, marks=RunIf(amp_native=True)),
pytest.param('native', True, MyNativeAMP, marks=RunIf(amp_native=True)),
pytest.param('apex', False, ApexMixedPrecisionPlugin, marks=RunIf(amp_apex=True)),
pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True))
pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True)),
]
)
def test_amp_apex_ddp(
@ -83,3 +97,47 @@ def test_amp_gradient_unscale(tmpdir, accum: int):
accumulate_grad_batches=accum,
)
trainer.fit(model)
@RunIf(min_gpus=2, amp_apex=True, special=True)
@pytest.mark.parametrize("amp_level", ['O2'])
def test_amp_apex_ddp_fit(amp_level, tmpdir):
class CustomBoringModel(BoringModel):
def training_step(self, batch, batch_idx):
assert self.layer.weight.dtype == torch.float16
assert self.trainer.precision_plugin._connected
return super().training_step(batch, batch_idx)
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
precision=16,
amp_backend="apex",
gpus=2,
accelerator='ddp',
plugins=ApexMixedPrecisionPlugin(amp_level=amp_level),
)
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
model = CustomBoringModel()
trainer.fit(model)
trainer.test(model)
@RunIf(min_gpus=2, amp_apex=True)
@pytest.mark.parametrize("amp_level", ['O2'])
def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
precision=16,
amp_backend="apex",
gpus=2,
accelerator='ddp_spawn',
plugins=ApexMixedPrecisionPlugin(amp_level=amp_level),
)
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
model = BoringModel()
trainer.fit(model)

View File

@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock