[bugfix] Apex never instantiated. (#7274)
* update * update * update apex * update * update * update * remove test.py * update * update * update on comments * update changelog * update * update * typo
This commit is contained in:
parent
44fd01734c
commit
16d6c9828d
|
@ -434,6 +434,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `LightningModule.save_hyperparameters()` when attempting to save an empty container ([#7268](https://github.com/PyTorchLightning/pytorch-lightning/pull/7268))
|
||||
|
||||
|
||||
- Fixed `apex` not properly instantiated when running with `ddp` ([#7274](https://github.com/PyTorchLightning/pytorch-lightning/pull/7274))
|
||||
|
||||
|
||||
## [1.2.7] - 2021-04-06
|
||||
|
||||
|
|
|
@ -107,6 +107,11 @@ class Accelerator:
|
|||
self.setup_optimizers(trainer)
|
||||
self.precision_plugin.pre_dispatch()
|
||||
|
||||
def dispatch(self, trainer: 'pl.Trainer') -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
self.training_type_plugin.dispatch(trainer)
|
||||
self.precision_plugin.dispatch(trainer)
|
||||
|
||||
def post_dispatch(self, trainer: 'pl.Trainer') -> None:
|
||||
"""Hook to do something after the training/evaluation/prediction starts."""
|
||||
self.training_type_plugin.post_dispatch()
|
||||
|
|
|
@ -15,6 +15,8 @@ import contextlib
|
|||
from abc import ABC
|
||||
from typing import Generator
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
class Plugin(ABC):
|
||||
"""Basic class for all precision- and training type plugins."""
|
||||
|
@ -22,6 +24,9 @@ class Plugin(ABC):
|
|||
def pre_dispatch(self) -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
|
||||
def dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
"""Hook to do something at trainer run_stage starts."""
|
||||
|
||||
def post_dispatch(self) -> None:
|
||||
"""Hook to do something after the training/evaluation/prediction finishes."""
|
||||
|
||||
|
|
|
@ -11,14 +11,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, ContextManager, List, Sequence, Tuple, Type
|
||||
from typing import Any, Callable, ContextManager, Sequence
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.core import LightningModule
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
|
||||
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.types import _PARAMETERS
|
||||
|
@ -34,24 +34,19 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
super().__init__()
|
||||
self.backend = AMPType.APEX
|
||||
self.amp_level = amp_level
|
||||
self._connected = False
|
||||
|
||||
def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
|
||||
return amp.master_params(optimizer)
|
||||
|
||||
def connect(
|
||||
self,
|
||||
model: Module,
|
||||
optimizers: List[Optimizer],
|
||||
lr_schedulers: List[Any],
|
||||
) -> Tuple[Module, List[Optimizer], List[Any]]:
|
||||
"""Connects the precision plugin to the training process,
|
||||
configures apex and reinits the schedulers
|
||||
"""
|
||||
if model.device.type != "cuda":
|
||||
return model, optimizers, lr_schedulers
|
||||
model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level)
|
||||
self.reinit_scheduler_properties(optimizers, lr_schedulers)
|
||||
return model, optimizers, lr_schedulers
|
||||
def dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
if not self._connected:
|
||||
accelerator = trainer.accelerator
|
||||
_, accelerator.optimizers = amp.initialize(
|
||||
trainer.lightning_module, accelerator.optimizers, opt_level=self.amp_level
|
||||
)
|
||||
self._connected = True
|
||||
return super().dispatch(trainer)
|
||||
|
||||
def backward(
|
||||
self,
|
||||
|
@ -99,40 +94,6 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
closure_loss = closure_loss.detach()
|
||||
return closure_loss
|
||||
|
||||
def configure_apex(
|
||||
self,
|
||||
amp: Type,
|
||||
model: Module,
|
||||
optimizers: List[Optimizer],
|
||||
amp_level: str,
|
||||
) -> Tuple[Module, List[Optimizer]]:
|
||||
r"""
|
||||
Override to init AMP your own way.
|
||||
Must return a model and list of optimizers.
|
||||
|
||||
Args:
|
||||
amp: pointer to amp library object.
|
||||
model: pointer to current :class:`torch.nn.Module`.
|
||||
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
|
||||
amp_level: AMP mode chosen ('O1', 'O2', etc...)
|
||||
|
||||
Return:
|
||||
Apex wrapped model and optimizers
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# Default implementation used by Trainer.
|
||||
def configure_apex(self, amp, model, optimizers, amp_level):
|
||||
model, optimizers = amp.initialize(
|
||||
model, optimizers, opt_level=amp_level,
|
||||
)
|
||||
|
||||
return model, optimizers
|
||||
"""
|
||||
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
|
||||
return model, optimizers
|
||||
|
||||
@staticmethod
|
||||
def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Sequence[Any]) -> None:
|
||||
"""Reinitializes schedulers with correct properties"""
|
||||
|
|
|
@ -516,7 +516,8 @@ class Trainer(
|
|||
else:
|
||||
self.accelerator.start_training(self)
|
||||
|
||||
def run_stage(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
|
||||
def run_stage(self):
|
||||
self.accelerator.dispatch(self)
|
||||
self.profile_connector.setup()
|
||||
|
||||
if self.evaluating:
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
|
@ -37,7 +51,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin):
|
|||
pytest.param('native', False, NativeMixedPrecisionPlugin, marks=RunIf(amp_native=True)),
|
||||
pytest.param('native', True, MyNativeAMP, marks=RunIf(amp_native=True)),
|
||||
pytest.param('apex', False, ApexMixedPrecisionPlugin, marks=RunIf(amp_apex=True)),
|
||||
pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True))
|
||||
pytest.param('apex', True, MyApexPlugin, marks=RunIf(amp_apex=True)),
|
||||
]
|
||||
)
|
||||
def test_amp_apex_ddp(
|
||||
|
@ -83,3 +97,47 @@ def test_amp_gradient_unscale(tmpdir, accum: int):
|
|||
accumulate_grad_batches=accum,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, amp_apex=True, special=True)
|
||||
@pytest.mark.parametrize("amp_level", ['O2'])
|
||||
def test_amp_apex_ddp_fit(amp_level, tmpdir):
|
||||
|
||||
class CustomBoringModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
assert self.layer.weight.dtype == torch.float16
|
||||
assert self.trainer.precision_plugin._connected
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
precision=16,
|
||||
amp_backend="apex",
|
||||
gpus=2,
|
||||
accelerator='ddp',
|
||||
plugins=ApexMixedPrecisionPlugin(amp_level=amp_level),
|
||||
)
|
||||
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
|
||||
model = CustomBoringModel()
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, amp_apex=True)
|
||||
@pytest.mark.parametrize("amp_level", ['O2'])
|
||||
def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
precision=16,
|
||||
amp_backend="apex",
|
||||
gpus=2,
|
||||
accelerator='ddp_spawn',
|
||||
plugins=ApexMixedPrecisionPlugin(amp_level=amp_level),
|
||||
)
|
||||
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
|
||||
model = BoringModel()
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -1,3 +1,16 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
|
|
Loading…
Reference in New Issue