diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 8172152029..f69b6230aa 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -136,7 +136,10 @@ class TrainerTrainingTricksMixin(ABC): """ if not hasattr(model, batch_arg_name): - raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`') + if not hasattr(model.hparams, batch_arg_name): + raise MisconfigurationException( + 'Neither of `model.batch_size` and `model.hparams.batch_size` found.' + ) if hasattr(model.train_dataloader, 'patch_loader_code'): raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders' @@ -245,9 +248,15 @@ def _adjust_batch_size(trainer, """ model = trainer.get_model() - batch_size = getattr(model, batch_arg_name) + if hasattr(model, batch_arg_name): + batch_size = getattr(model, batch_arg_name) + else: + batch_size = getattr(model.hparams, batch_arg_name) if value: - setattr(model, batch_arg_name, value) + if hasattr(model, batch_arg_name): + setattr(model, batch_arg_name, value) + else: + setattr(model.hparams, batch_arg_name, value) new_size = value if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') @@ -255,7 +264,7 @@ def _adjust_batch_size(trainer, new_size = int(batch_size * factor) if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') - setattr(model, batch_arg_name, new_size) + setattr(model.hparams, batch_arg_name, new_size) return new_size