Update training_tricks.py (#3151)
* Update training_tricks.py * pep Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
cb0c60bf7a
commit
0112355055
|
@ -213,6 +213,7 @@ class TrainerTrainingTricksMixin(ABC):
|
||||||
def __scale_batch_dump_params(self):
|
def __scale_batch_dump_params(self):
|
||||||
# Prevent going into infinite loop
|
# Prevent going into infinite loop
|
||||||
self.__dumped_params = {
|
self.__dumped_params = {
|
||||||
|
'auto_lr_find': self.auto_lr_find,
|
||||||
'max_steps': self.max_steps,
|
'max_steps': self.max_steps,
|
||||||
'weights_summary': self.weights_summary,
|
'weights_summary': self.weights_summary,
|
||||||
'logger': self.logger,
|
'logger': self.logger,
|
||||||
|
@ -226,6 +227,7 @@ class TrainerTrainingTricksMixin(ABC):
|
||||||
|
|
||||||
def __scale_batch_reset_params(self, model, steps_per_trial):
|
def __scale_batch_reset_params(self, model, steps_per_trial):
|
||||||
self.auto_scale_batch_size = None # prevent recursion
|
self.auto_scale_batch_size = None # prevent recursion
|
||||||
|
self.auto_lr_find = False # avoid lr find being called multiple times
|
||||||
self.max_steps = steps_per_trial # take few steps
|
self.max_steps = steps_per_trial # take few steps
|
||||||
self.weights_summary = None # not needed before full run
|
self.weights_summary = None # not needed before full run
|
||||||
self.logger = DummyLogger()
|
self.logger = DummyLogger()
|
||||||
|
@ -237,6 +239,7 @@ class TrainerTrainingTricksMixin(ABC):
|
||||||
self.model = model # required for saving
|
self.model = model # required for saving
|
||||||
|
|
||||||
def __scale_batch_restore_params(self):
|
def __scale_batch_restore_params(self):
|
||||||
|
self.auto_lr_find = self.__dumped_params['auto_lr_find']
|
||||||
self.max_steps = self.__dumped_params['max_steps']
|
self.max_steps = self.__dumped_params['max_steps']
|
||||||
self.weights_summary = self.__dumped_params['weights_summary']
|
self.weights_summary = self.__dumped_params['weights_summary']
|
||||||
self.logger = self.__dumped_params['logger']
|
self.logger = self.__dumped_params['logger']
|
||||||
|
|
Loading…
Reference in New Issue