testing hpc save load

This commit is contained in:
William Falcon 2019-07-24 18:01:33 -04:00
parent 2408aa886d
commit 423bc5c6c9
1 changed files with 3 additions and 2 deletions

View File

@ -144,7 +144,8 @@ class TrainerIO(object):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number) filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
# give model a chance to do something on hpc_save # give model a chance to do something on hpc_save
self.model.on_hpc_save() model = self.model.module if type(self.model) is LightningDataParallel else self.model
model.on_hpc_save()
# request what to save from the model # request what to save from the model
checkpoint_dict = self.dump_checkpoint() checkpoint_dict = self.dump_checkpoint()
@ -168,7 +169,7 @@ class TrainerIO(object):
model.load_model_specific(checkpoint) model.load_model_specific(checkpoint)
# call model hook # call model hook
self.model.on_hpc_load() model.on_hpc_load()
def max_ckpt_in_folder(self, path): def max_ckpt_in_folder(self, path):
files = os.listdir(path) files = os.listdir(path)