fix auto scale batch size not working with precision=16 (#3045)

* add test

* test

* test

* add fix

* changelog

* check batch size changed
This commit is contained in:
Adrian Wälchli 2020-08-19 22:41:33 +02:00 committed by GitHub
parent 9031dc3b81
commit 89a5d8fee9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 1 deletions

View File

@ -148,6 +148,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
## [0.8.5] - 2020-07-09
### Added

View File

@ -354,7 +354,7 @@ class TrainerIOMixin(ABC):
checkpoint['lr_schedulers'] = lr_schedulers
# save native amp scaling
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
if self.amp_backend == AMPType.NATIVE and not self.use_tpu and self.scaler is not None:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
elif self.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()

View File

@ -5,6 +5,7 @@ from torch.utils.data import RandomSampler, SequentialSampler, DataLoader
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
@ -257,3 +258,22 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
with pytest.raises(MisconfigurationException):
trainer.fit(model, **fit_options)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="test requires native AMP.")
def test_auto_scale_batch_size_with_amp(tmpdir):
model = EvalModelTemplate()
batch_size_before = model.batch_size
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
auto_scale_batch_size=True,
gpus=1,
precision=16
)
trainer.fit(model)
batch_size_after = model.batch_size
assert trainer.amp_backend == AMPType.NATIVE
assert trainer.scaler is not None
assert batch_size_after != batch_size_before