testing hpc save load
This commit is contained in:
parent
2408aa886d
commit
423bc5c6c9
|
@ -144,7 +144,8 @@ class TrainerIO(object):
|
||||||
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
|
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
|
||||||
|
|
||||||
# give model a chance to do something on hpc_save
|
# give model a chance to do something on hpc_save
|
||||||
self.model.on_hpc_save()
|
model = self.model.module if type(self.model) is LightningDataParallel else self.model
|
||||||
|
model.on_hpc_save()
|
||||||
|
|
||||||
# request what to save from the model
|
# request what to save from the model
|
||||||
checkpoint_dict = self.dump_checkpoint()
|
checkpoint_dict = self.dump_checkpoint()
|
||||||
|
@ -168,7 +169,7 @@ class TrainerIO(object):
|
||||||
model.load_model_specific(checkpoint)
|
model.load_model_specific(checkpoint)
|
||||||
|
|
||||||
# call model hook
|
# call model hook
|
||||||
self.model.on_hpc_load()
|
model.on_hpc_load()
|
||||||
|
|
||||||
def max_ckpt_in_folder(self, path):
|
def max_ckpt_in_folder(self, path):
|
||||||
files = os.listdir(path)
|
files = os.listdir(path)
|
||||||
|
|
Loading…
Reference in New Issue