diff --git a/CHANGELOG.md b/CHANGELOG.md index 131f8adbce..abd870c1f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -148,6 +148,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020)) +- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 3b03d14dbd..d103ee18fb 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -354,7 +354,7 @@ class TrainerIOMixin(ABC): checkpoint['lr_schedulers'] = lr_schedulers # save native amp scaling - if self.amp_backend == AMPType.NATIVE and not self.use_tpu: + if self.amp_backend == AMPType.NATIVE and not self.use_tpu and self.scaler is not None: checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() elif self.amp_backend == AMPType.APEX: checkpoint['amp_scaling_state'] = amp.state_dict() diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 7e6665d2be..aca8a458ac 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -5,6 +5,7 @@ from torch.utils.data import RandomSampler, SequentialSampler, DataLoader import tests.base.develop_utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -257,3 +258,22 @@ def test_error_on_dataloader_passed_to_fit(tmpdir): with pytest.raises(MisconfigurationException): trainer.fit(model, **fit_options) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="test requires native AMP.") +def test_auto_scale_batch_size_with_amp(tmpdir): + model = EvalModelTemplate() + batch_size_before = model.batch_size + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + auto_scale_batch_size=True, + gpus=1, + precision=16 + ) + trainer.fit(model) + batch_size_after = model.batch_size + assert trainer.amp_backend == AMPType.NATIVE + assert trainer.scaler is not None + assert batch_size_after != batch_size_before