fixed correct module on hpc save
This commit is contained in:
parent
423bc5c6c9
commit
a63f74281a
|
@ -41,6 +41,11 @@ class ModelIO(object):
|
|||
|
||||
class TrainerIO(object):
|
||||
|
||||
def __get_model(self):
|
||||
is_dp_module = type(self.model) is LightningDistributedDataParallel or type(self.model) is LightningDataParallel
|
||||
model = self.model.module if is_dp_module else self.model
|
||||
return model
|
||||
|
||||
# --------------------
|
||||
# MODEL SAVE CHECKPOINT
|
||||
# --------------------
|
||||
|
@ -71,8 +76,7 @@ class TrainerIO(object):
|
|||
checkpoint['optimizer_states'] = optimizer_states
|
||||
|
||||
# request what to save from the model
|
||||
is_dp_module = type(self.model) is LightningDistributedDataParallel or type(self.model) is LightningDataParallel
|
||||
model = self.model.module if is_dp_module else self.model
|
||||
model = self.__get_model()
|
||||
checkpoint_dict = model.get_save_dict()
|
||||
|
||||
# merge trainer and model saving items
|
||||
|
@ -144,7 +148,7 @@ class TrainerIO(object):
|
|||
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
|
||||
|
||||
# give model a chance to do something on hpc_save
|
||||
model = self.model.module if type(self.model) is LightningDataParallel else self.model
|
||||
model = self.__get_model()
|
||||
model.on_hpc_save()
|
||||
|
||||
# request what to save from the model
|
||||
|
@ -165,7 +169,7 @@ class TrainerIO(object):
|
|||
self.restore_training_state(checkpoint)
|
||||
|
||||
# load model state
|
||||
model = self.model.module if type(self.model) is LightningDataParallel else self.model
|
||||
model = self.__get_model()
|
||||
model.load_model_specific(checkpoint)
|
||||
|
||||
# call model hook
|
||||
|
|
Loading…
Reference in New Issue