From f8103f9c7dfc35b4198e951a1789cae534c8b1db Mon Sep 17 00:00:00 2001 From: Tejasvi S Tomar <45873379+tejasvi@users.noreply.github.com> Date: Wed, 17 Jun 2020 17:31:04 +0530 Subject: [PATCH] Misleading exception raised during batch scaling (#1973) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Misleading exception raised during batch scaling Use batch_size from `model.hparams.batch_size` instead of `model.batch_size` * Improvements considering #1896 * Apply suggestions from code review Co-authored-by: Adrian Wälchli Co-authored-by: Rohit Gupta Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/training_tricks.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 8172152029..f69b6230aa 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -136,7 +136,10 @@ class TrainerTrainingTricksMixin(ABC): """ if not hasattr(model, batch_arg_name): - raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`') + if not hasattr(model.hparams, batch_arg_name): + raise MisconfigurationException( + 'Neither of `model.batch_size` and `model.hparams.batch_size` found.' + ) if hasattr(model.train_dataloader, 'patch_loader_code'): raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders' @@ -245,9 +248,15 @@ def _adjust_batch_size(trainer, """ model = trainer.get_model() - batch_size = getattr(model, batch_arg_name) + if hasattr(model, batch_arg_name): + batch_size = getattr(model, batch_arg_name) + else: + batch_size = getattr(model.hparams, batch_arg_name) if value: - setattr(model, batch_arg_name, value) + if hasattr(model, batch_arg_name): + setattr(model, batch_arg_name, value) + else: + setattr(model.hparams, batch_arg_name, value) new_size = value if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') @@ -255,7 +264,7 @@ def _adjust_batch_size(trainer, new_size = int(batch_size * factor) if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') - setattr(model, batch_arg_name, new_size) + setattr(model.hparams, batch_arg_name, new_size) return new_size