removed bad hook call
This commit is contained in:
parent
a931ded310
commit
d5fd16a478
|
@ -244,9 +244,6 @@ class Trainer(TrainerIO):
|
||||||
'''
|
'''
|
||||||
raise ModuleNotFoundError(msg)
|
raise ModuleNotFoundError(msg)
|
||||||
|
|
||||||
# restore training and model
|
|
||||||
self.restore_state_if_existing_checkpoint()
|
|
||||||
|
|
||||||
def restore_state_if_existing_checkpoint(self):
|
def restore_state_if_existing_checkpoint(self):
|
||||||
# restore trainer state and model if there is a weight for this experiment
|
# restore trainer state and model if there is a weight for this experiment
|
||||||
last_epoch = -1
|
last_epoch = -1
|
||||||
|
@ -624,6 +621,9 @@ class Trainer(TrainerIO):
|
||||||
ref_model.trainer = self
|
ref_model.trainer = self
|
||||||
ref_model.experiment = self.experiment
|
ref_model.experiment = self.experiment
|
||||||
|
|
||||||
|
# restore training and model
|
||||||
|
self.restore_state_if_existing_checkpoint()
|
||||||
|
|
||||||
# run tiny validation to make sure program won't crash during val
|
# run tiny validation to make sure program won't crash during val
|
||||||
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
|
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue