From 4b7c0fae00084b72dffe37fdd0ea7d2e9b60d103 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Feb 2021 18:00:27 +0100 Subject: [PATCH] Fix amp autocast (#6080) * precision fixes * add amp test model * fix test * revert * move assert to training step * fix test * fix test * remove unrelated changes * add changelog * remove unused import --- CHANGELOG.md | 2 ++ .../plugins/precision/native_amp.py | 3 ++- tests/models/test_amp.py | 22 +++++++++++++------ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ad54381a0..7dad863d41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) + ## [1.2.0] - 2021-02-18 diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 60c0f5f846..94e6cf376b 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -91,4 +91,5 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): @contextmanager def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" - yield torch.cuda.amp.autocast() + with torch.cuda.amp.autocast(): + yield diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 2dd6c9d997..53ec32764f 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -27,6 +27,16 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel +class AMPTestModel(BoringModel): + + def training_step(self, batch, batch_idx): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + loss = self.loss(batch, output) + return {"loss": loss} + + @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_amp_single_gpu_dp(tmpdir): @@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) @@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) @@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): # simulate setting slurm flags tutils.set_random_master_port() - model = BoringModel() + model = AMPTestModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir)