From 89a5d8fee94fe9ac4989d3bc9cc9c7ca58a781d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 19 Aug 2020 22:41:33 +0200 Subject: [PATCH] fix auto scale batch size not working with precision=16 (#3045) * add test * test * test * add fix * changelog * check batch size changed --- CHANGELOG.md | 2 ++ pytorch_lightning/trainer/training_io.py | 2 +- tests/trainer/test_trainer_tricks.py | 20 ++++++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 131f8adbce..abd870c1f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -148,6 +148,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020)) +- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 3b03d14dbd..d103ee18fb 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -354,7 +354,7 @@ class TrainerIOMixin(ABC): checkpoint['lr_schedulers'] = lr_schedulers # save native amp scaling - if self.amp_backend == AMPType.NATIVE and not self.use_tpu: + if self.amp_backend == AMPType.NATIVE and not self.use_tpu and self.scaler is not None: checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() elif self.amp_backend == AMPType.APEX: checkpoint['amp_scaling_state'] = amp.state_dict() diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 7e6665d2be..aca8a458ac 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -5,6 +5,7 @@ from torch.utils.data import RandomSampler, SequentialSampler, DataLoader import tests.base.develop_utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -257,3 +258,22 @@ def test_error_on_dataloader_passed_to_fit(tmpdir): with pytest.raises(MisconfigurationException): trainer.fit(model, **fit_options) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="test requires native AMP.") +def test_auto_scale_batch_size_with_amp(tmpdir): + model = EvalModelTemplate() + batch_size_before = model.batch_size + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + auto_scale_batch_size=True, + gpus=1, + precision=16 + ) + trainer.fit(model) + batch_size_after = model.batch_size + assert trainer.amp_backend == AMPType.NATIVE + assert trainer.scaler is not None + assert batch_size_after != batch_size_before