Fix amp autocast (#6080)
* precision fixes * add amp test model * fix test * revert * move assert to training step * fix test * fix test * remove unrelated changes * add changelog * remove unused import
This commit is contained in:
parent
0b271474e5
commit
4b7c0fae00
|
@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
|
||||
|
||||
|
||||
## [1.2.0] - 2021-02-18
|
||||
|
||||
|
|
|
@ -91,4 +91,5 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
@contextmanager
|
||||
def train_step_context(self) -> Generator[autocast, None, None]:
|
||||
"""Enable autocast context"""
|
||||
yield torch.cuda.amp.autocast()
|
||||
with torch.cuda.amp.autocast():
|
||||
yield
|
||||
|
|
|
@ -27,6 +27,16 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from tests.helpers import BoringModel
|
||||
|
||||
|
||||
class AMPTestModel(BoringModel):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
assert torch.is_autocast_enabled()
|
||||
output = self(batch)
|
||||
assert output.dtype == torch.float16
|
||||
loss = self.loss(batch, output)
|
||||
return {"loss": loss}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='dp + amp not supported currently') # TODO
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_amp_single_gpu_dp(tmpdir):
|
||||
|
@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir):
|
|||
precision=16,
|
||||
)
|
||||
|
||||
model = BoringModel()
|
||||
model = AMPTestModel()
|
||||
# tutils.run_model_test(trainer_options, model)
|
||||
trainer.fit(model)
|
||||
|
||||
|
@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
|
|||
precision=16,
|
||||
)
|
||||
|
||||
model = BoringModel()
|
||||
model = AMPTestModel()
|
||||
# tutils.run_model_test(trainer_options, model)
|
||||
trainer.fit(model)
|
||||
|
||||
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
||||
|
||||
|
||||
|
@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir):
|
|||
precision=16,
|
||||
)
|
||||
|
||||
model = BoringModel()
|
||||
model = AMPTestModel()
|
||||
# tutils.run_model_test(trainer_options, model)
|
||||
trainer.fit(model)
|
||||
|
||||
|
@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
|
|||
precision=16,
|
||||
)
|
||||
|
||||
model = BoringModel()
|
||||
model = AMPTestModel()
|
||||
# tutils.run_model_test(trainer_options, model)
|
||||
trainer.fit(model)
|
||||
|
||||
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
||||
|
||||
|
||||
|
@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
|
|||
# simulate setting slurm flags
|
||||
tutils.set_random_master_port()
|
||||
|
||||
model = BoringModel()
|
||||
model = AMPTestModel()
|
||||
|
||||
# exp file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
|
Loading…
Reference in New Issue