fix setting batch_size attribute in batch_size finder (finishing PR #2523) (#3043)

* lightning attr fix

* revert refactor

* create test

* separate test

* changelog update

* tests

* revert

* Update pytorch_lightning/trainer/training_tricks.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Adrian Wälchli 2020-08-20 01:01:55 +02:00 committed by GitHub
parent ee4eae884b
commit 7b917de946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 27 deletions

View File

@ -148,8 +148,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed gathering of results with tensors of varying shape ([#3020](https://github.com/PyTorchLightning/pytorch-lightning/pull/3020))
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
## [0.8.5] - 2020-07-09
### Added

View File

@ -24,9 +24,10 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_getattr, lightning_setattr
try:
from apex import amp
@ -158,11 +159,15 @@ class TrainerTrainingTricksMixin(ABC):
algorithm is terminated
"""
if not hasattr(model, batch_arg_name):
if not hasattr(model.hparams, batch_arg_name):
raise MisconfigurationException(
'Neither of `model.batch_size` and `model.hparams.batch_size` found.'
)
if not lightning_hasattr(model, batch_arg_name):
raise MisconfigurationException(
f'Field {batch_arg_name} not found in both `model` and `model.hparams`')
if hasattr(model, batch_arg_name) and hasattr(model, "hparams") and batch_arg_name in model.hparams:
rank_zero_warn(
f'Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!'
f' `model.{batch_arg_name}` will be used as the initial batch size for scaling.'
f' If this is not the intended behavior, please remove either one.'
)
if hasattr(model.train_dataloader, 'patch_loader_code'):
raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'
@ -268,15 +273,9 @@ def _adjust_batch_size(trainer,
"""
model = trainer.get_model()
if hasattr(model, batch_arg_name):
batch_size = getattr(model, batch_arg_name)
else:
batch_size = getattr(model.hparams, batch_arg_name)
batch_size = lightning_getattr(model, batch_arg_name)
if value:
if hasattr(model, batch_arg_name):
setattr(model, batch_arg_name, value)
else:
setattr(model.hparams, batch_arg_name, value)
lightning_setattr(model, batch_arg_name, value)
new_size = value
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
@ -284,7 +283,7 @@ def _adjust_batch_size(trainer,
new_size = int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
setattr(model.hparams, batch_arg_name, new_size)
lightning_setattr(model, batch_arg_name, new_size)
return new_size

View File

@ -196,28 +196,59 @@ def test_trainer_reset_correctly(tmpdir):
f'Attribute {key} was not reset correctly after learning rate finder'
@pytest.mark.parametrize('scale_arg', ['power', 'binsearch'])
def test_trainer_arg(tmpdir, scale_arg):
""" Check that trainer arg works with bool input. """
@pytest.mark.parametrize('scale_arg', ['power', 'binsearch', True])
def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
""" Test possible values for 'batch size auto scaling' Trainer argument. """
tutils.reset_seed()
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
before_batch_size = hparams.get('batch_size')
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
auto_scale_batch_size=scale_arg,
)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=scale_arg)
trainer.fit(model)
after_batch_size = model.batch_size
assert before_batch_size != after_batch_size, \
'Batch size was not altered after running auto scaling of batch size'
@pytest.mark.parametrize('use_hparams', [True, False])
def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
""" Test that new batch size gets written to the correct hyperparameter attribute. """
tutils.reset_seed()
hparams = EvalModelTemplate.get_default_hparams()
before_batch_size = hparams.get('batch_size')
class HparamsEvalModelTemplate(EvalModelTemplate):
def dataloader(self, *args, **kwargs):
# artificially set batch_size so we can get a dataloader
# remove it immediately after, because we want only self.hparams.batch_size
setattr(self, "batch_size", before_batch_size)
dataloader = super().dataloader(*args, **kwargs)
del self.batch_size
return dataloader
model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate
model = model_class(**hparams)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.fit(model)
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert before_batch_size != after_batch_size
def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
""" Test for a warning when model.batch_size and model.hparams.batch_size both present. """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
model.hparams = hparams
# now we have model.batch_size and model.hparams.batch_size
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, auto_scale_batch_size=True)
expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!"
with pytest.warns(UserWarning, match=expected_message):
trainer.fit(model)
@pytest.mark.parametrize('scale_method', ['power', 'binsearch'])
def test_call_to_trainer_method(tmpdir, scale_method):
""" Test that calling the trainer method itself works. """