testing hpc save load

This commit is contained in:
William Falcon 2019-07-24 18:01:33 -04:00
parent 2408aa886d
commit 423bc5c6c9
1 changed files with 3 additions and 2 deletions

View File

@ -144,7 +144,8 @@ class TrainerIO(object):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
# give model a chance to do something on hpc_save
self.model.on_hpc_save()
model = self.model.module if type(self.model) is LightningDataParallel else self.model
model.on_hpc_save()
# request what to save from the model
checkpoint_dict = self.dump_checkpoint()
@ -168,7 +169,7 @@ class TrainerIO(object):
model.load_model_specific(checkpoint)
# call model hook
self.model.on_hpc_load()
model.on_hpc_load()
def max_ckpt_in_folder(self, path):
files = os.listdir(path)