From 011235505544f95dacf42cffd545bb53dd4b6b15 Mon Sep 17 00:00:00 2001 From: LiJiezhi Date: Wed, 26 Aug 2020 15:57:34 +0800 Subject: [PATCH] Update training_tricks.py (#3151) * Update training_tricks.py * pep Co-authored-by: Rohit Gupta Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- pytorch_lightning/trainer/training_tricks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 04d77916d7..32d0c59434 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -213,6 +213,7 @@ class TrainerTrainingTricksMixin(ABC): def __scale_batch_dump_params(self): # Prevent going into infinite loop self.__dumped_params = { + 'auto_lr_find': self.auto_lr_find, 'max_steps': self.max_steps, 'weights_summary': self.weights_summary, 'logger': self.logger, @@ -226,6 +227,7 @@ class TrainerTrainingTricksMixin(ABC): def __scale_batch_reset_params(self, model, steps_per_trial): self.auto_scale_batch_size = None # prevent recursion + self.auto_lr_find = False # avoid lr find being called multiple times self.max_steps = steps_per_trial # take few steps self.weights_summary = None # not needed before full run self.logger = DummyLogger() @@ -237,6 +239,7 @@ class TrainerTrainingTricksMixin(ABC): self.model = model # required for saving def __scale_batch_restore_params(self): + self.auto_lr_find = self.__dumped_params['auto_lr_find'] self.max_steps = self.__dumped_params['max_steps'] self.weights_summary = self.__dumped_params['weights_summary'] self.logger = self.__dumped_params['logger']