fix auto scale batch size not working with precision=16 (#3045)
* add test * test * test * add fix * changelog * check batch size changed
This commit is contained in:
parent
9031dc3b81
commit
89a5d8fee9
|
@ -148,6 +148,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
|
||||
|
||||
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
|
||||
|
||||
## [0.8.5] - 2020-07-09
|
||||
|
||||
### Added
|
||||
|
|
|
@ -354,7 +354,7 @@ class TrainerIOMixin(ABC):
|
|||
checkpoint['lr_schedulers'] = lr_schedulers
|
||||
|
||||
# save native amp scaling
|
||||
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
|
||||
if self.amp_backend == AMPType.NATIVE and not self.use_tpu and self.scaler is not None:
|
||||
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
|
||||
elif self.amp_backend == AMPType.APEX:
|
||||
checkpoint['amp_scaling_state'] = amp.state_dict()
|
||||
|
|
|
@ -5,6 +5,7 @@ from torch.utils.data import RandomSampler, SequentialSampler, DataLoader
|
|||
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
@ -257,3 +258,22 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
|
|||
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer.fit(model, **fit_options)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="test requires native AMP.")
|
||||
def test_auto_scale_batch_size_with_amp(tmpdir):
|
||||
model = EvalModelTemplate()
|
||||
batch_size_before = model.batch_size
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=1,
|
||||
auto_scale_batch_size=True,
|
||||
gpus=1,
|
||||
precision=16
|
||||
)
|
||||
trainer.fit(model)
|
||||
batch_size_after = model.batch_size
|
||||
assert trainer.amp_backend == AMPType.NATIVE
|
||||
assert trainer.scaler is not None
|
||||
assert batch_size_after != batch_size_before
|
||||
|
|
Loading…
Reference in New Issue