Fix amp autocast (#6080)

* precision fixes

* add amp test model

* fix test

* revert

* move assert to training step

* fix test

* fix test

* remove unrelated changes

* add changelog

* remove unused import
This commit is contained in:
Adrian Wälchli 2021-02-19 18:00:27 +01:00 committed by GitHub
parent 0b271474e5
commit 4b7c0fae00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 8 deletions

View File

@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
## [1.2.0] - 2021-02-18

View File

@ -91,4 +91,5 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
@contextmanager
def train_step_context(self) -> Generator[autocast, None, None]:
"""Enable autocast context"""
yield torch.cuda.amp.autocast()
with torch.cuda.amp.autocast():
yield

View File

@ -27,6 +27,16 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
class AMPTestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert torch.is_autocast_enabled()
output = self(batch)
assert output.dtype == torch.float16
loss = self.loss(batch, output)
return {"loss": loss}
@pytest.mark.skip(reason='dp + amp not supported currently') # TODO
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_amp_single_gpu_dp(tmpdir):
@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir):
precision=16,
)
model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
precision=16,
)
model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir):
precision=16,
)
model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
precision=16,
)
model = BoringModel()
model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
# simulate setting slurm flags
tutils.set_random_master_port()
model = BoringModel()
model = AMPTestModel()
# exp file to get meta
logger = tutils.get_default_logger(tmpdir)