Misleading exception raised during batch scaling (#1973)

* Misleading exception raised during batch scaling

Use batch_size from `model.hparams.batch_size` instead of `model.batch_size`

* Improvements considering #1896

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Tejasvi S Tomar 2020-06-17 17:31:04 +05:30 committed by GitHub
parent e1f238a097
commit f8103f9c7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 4 deletions

View File

@ -136,7 +136,10 @@ class TrainerTrainingTricksMixin(ABC):
"""
if not hasattr(model, batch_arg_name):
raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`')
if not hasattr(model.hparams, batch_arg_name):
raise MisconfigurationException(
'Neither of `model.batch_size` and `model.hparams.batch_size` found.'
)
if hasattr(model.train_dataloader, 'patch_loader_code'):
raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'
@ -245,9 +248,15 @@ def _adjust_batch_size(trainer,
"""
model = trainer.get_model()
if hasattr(model, batch_arg_name):
batch_size = getattr(model, batch_arg_name)
else:
batch_size = getattr(model.hparams, batch_arg_name)
if value:
if hasattr(model, batch_arg_name):
setattr(model, batch_arg_name, value)
else:
setattr(model.hparams, batch_arg_name, value)
new_size = value
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
@ -255,7 +264,7 @@ def _adjust_batch_size(trainer,
new_size = int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
setattr(model, batch_arg_name, new_size)
setattr(model.hparams, batch_arg_name, new_size)
return new_size