fixed correct module on hpc save

This commit is contained in:
William Falcon 2019-07-24 18:03:19 -04:00
parent 423bc5c6c9
commit a63f74281a
1 changed files with 8 additions and 4 deletions

View File

@ -41,6 +41,11 @@ class ModelIO(object):
class TrainerIO(object):
def __get_model(self):
is_dp_module = type(self.model) is LightningDistributedDataParallel or type(self.model) is LightningDataParallel
model = self.model.module if is_dp_module else self.model
return model
# --------------------
# MODEL SAVE CHECKPOINT
# --------------------
@ -71,8 +76,7 @@ class TrainerIO(object):
checkpoint['optimizer_states'] = optimizer_states
# request what to save from the model
is_dp_module = type(self.model) is LightningDistributedDataParallel or type(self.model) is LightningDataParallel
model = self.model.module if is_dp_module else self.model
model = self.__get_model()
checkpoint_dict = model.get_save_dict()
# merge trainer and model saving items
@ -144,7 +148,7 @@ class TrainerIO(object):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
# give model a chance to do something on hpc_save
model = self.model.module if type(self.model) is LightningDataParallel else self.model
model = self.__get_model()
model.on_hpc_save()
# request what to save from the model
@ -165,7 +169,7 @@ class TrainerIO(object):
self.restore_training_state(checkpoint)
# load model state
model = self.model.module if type(self.model) is LightningDataParallel else self.model
model = self.__get_model()
model.load_model_specific(checkpoint)
# call model hook