Update training_tricks.py (#3151)

* Update training_tricks.py

* pep

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
LiJiezhi 2020-08-26 15:57:34 +08:00 committed by GitHub
parent cb0c60bf7a
commit 0112355055
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 0 deletions

View File

@ -213,6 +213,7 @@ class TrainerTrainingTricksMixin(ABC):
def __scale_batch_dump_params(self): def __scale_batch_dump_params(self):
# Prevent going into infinite loop # Prevent going into infinite loop
self.__dumped_params = { self.__dumped_params = {
'auto_lr_find': self.auto_lr_find,
'max_steps': self.max_steps, 'max_steps': self.max_steps,
'weights_summary': self.weights_summary, 'weights_summary': self.weights_summary,
'logger': self.logger, 'logger': self.logger,
@ -226,6 +227,7 @@ class TrainerTrainingTricksMixin(ABC):
def __scale_batch_reset_params(self, model, steps_per_trial): def __scale_batch_reset_params(self, model, steps_per_trial):
self.auto_scale_batch_size = None # prevent recursion self.auto_scale_batch_size = None # prevent recursion
self.auto_lr_find = False # avoid lr find being called multiple times
self.max_steps = steps_per_trial # take few steps self.max_steps = steps_per_trial # take few steps
self.weights_summary = None # not needed before full run self.weights_summary = None # not needed before full run
self.logger = DummyLogger() self.logger = DummyLogger()
@ -237,6 +239,7 @@ class TrainerTrainingTricksMixin(ABC):
self.model = model # required for saving self.model = model # required for saving
def __scale_batch_restore_params(self): def __scale_batch_restore_params(self):
self.auto_lr_find = self.__dumped_params['auto_lr_find']
self.max_steps = self.__dumped_params['max_steps'] self.max_steps = self.__dumped_params['max_steps']
self.weights_summary = self.__dumped_params['weights_summary'] self.weights_summary = self.__dumped_params['weights_summary']
self.logger = self.__dumped_params['logger'] self.logger = self.__dumped_params['logger']