Misleading exception raised during batch scaling (#1973)
* Misleading exception raised during batch scaling Use batch_size from `model.hparams.batch_size` instead of `model.batch_size` * Improvements considering #1896 * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
e1f238a097
commit
f8103f9c7d
|
@ -136,7 +136,10 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
|
||||
"""
|
||||
if not hasattr(model, batch_arg_name):
|
||||
raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`')
|
||||
if not hasattr(model.hparams, batch_arg_name):
|
||||
raise MisconfigurationException(
|
||||
'Neither of `model.batch_size` and `model.hparams.batch_size` found.'
|
||||
)
|
||||
|
||||
if hasattr(model.train_dataloader, 'patch_loader_code'):
|
||||
raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'
|
||||
|
@ -245,9 +248,15 @@ def _adjust_batch_size(trainer,
|
|||
|
||||
"""
|
||||
model = trainer.get_model()
|
||||
if hasattr(model, batch_arg_name):
|
||||
batch_size = getattr(model, batch_arg_name)
|
||||
else:
|
||||
batch_size = getattr(model.hparams, batch_arg_name)
|
||||
if value:
|
||||
if hasattr(model, batch_arg_name):
|
||||
setattr(model, batch_arg_name, value)
|
||||
else:
|
||||
setattr(model.hparams, batch_arg_name, value)
|
||||
new_size = value
|
||||
if desc:
|
||||
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
|
||||
|
@ -255,7 +264,7 @@ def _adjust_batch_size(trainer,
|
|||
new_size = int(batch_size * factor)
|
||||
if desc:
|
||||
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
|
||||
setattr(model, batch_arg_name, new_size)
|
||||
setattr(model.hparams, batch_arg_name, new_size)
|
||||
return new_size
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue