diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 04d77916d7..32d0c59434 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -213,6 +213,7 @@ class TrainerTrainingTricksMixin(ABC): def __scale_batch_dump_params(self): # Prevent going into infinite loop self.__dumped_params = { + 'auto_lr_find': self.auto_lr_find, 'max_steps': self.max_steps, 'weights_summary': self.weights_summary, 'logger': self.logger, @@ -226,6 +227,7 @@ class TrainerTrainingTricksMixin(ABC): def __scale_batch_reset_params(self, model, steps_per_trial): self.auto_scale_batch_size = None # prevent recursion + self.auto_lr_find = False # avoid lr find being called multiple times self.max_steps = steps_per_trial # take few steps self.weights_summary = None # not needed before full run self.logger = DummyLogger() @@ -237,6 +239,7 @@ class TrainerTrainingTricksMixin(ABC): self.model = model # required for saving def __scale_batch_restore_params(self): + self.auto_lr_find = self.__dumped_params['auto_lr_find'] self.max_steps = self.__dumped_params['max_steps'] self.weights_summary = self.__dumped_params['weights_summary'] self.logger = self.__dumped_params['logger']