fixed hpc save, load. cleaned apu
This commit is contained in:
parent
4148c36abd
commit
265411572f
|
@ -7,7 +7,7 @@ from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistr
|
|||
|
||||
class ModelIO(object):
|
||||
|
||||
def load_model_specific(self, checkpoint):
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
"""
|
||||
Do something with the checkpoint
|
||||
Gives model a chance to load something before state_dict is restored
|
||||
|
@ -16,25 +16,24 @@ class ModelIO(object):
|
|||
"""
|
||||
pass
|
||||
|
||||
def get_save_dict(self):
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
"""
|
||||
Return specific things for the model
|
||||
Called before trainer requests the state_dict
|
||||
:return:
|
||||
Give the model a chance to add something to the checkpoint.
|
||||
state_dict is already there
|
||||
"""
|
||||
pass
|
||||
|
||||
# -------------------------
|
||||
# OPTIONAL HOOKS
|
||||
# -------------------------
|
||||
def on_hpc_save(self):
|
||||
def on_hpc_save(self, checkpoint):
|
||||
"""
|
||||
Hook to do whatever you need right before Slurm manager saves the model
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_hpc_load(self):
|
||||
def on_hpc_load(self, checkpoint):
|
||||
"""
|
||||
Hook to do whatever you need right before Slurm manager loads the model
|
||||
:return:
|
||||
|
@ -78,13 +77,13 @@ class TrainerIO(object):
|
|||
|
||||
checkpoint['optimizer_states'] = optimizer_states
|
||||
|
||||
# request what to save from the model
|
||||
model = self.__get_model()
|
||||
checkpoint_dict = model.get_save_dict()
|
||||
checkpoint.update(checkpoint_dict)
|
||||
|
||||
# add the state_dict from the model
|
||||
checkpoint['state_dict'] = checkpoint_dict
|
||||
model = self.__get_model()
|
||||
checkpoint['state_dict'] = model.get_state_dict
|
||||
|
||||
# give the model a chance to add a few things
|
||||
model.on_save_checkpoint(checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
# --------------------
|
||||
|
@ -153,13 +152,12 @@ class TrainerIO(object):
|
|||
|
||||
# give model a chance to do something on hpc_save
|
||||
model = self.__get_model()
|
||||
model.on_hpc_save()
|
||||
checkpoint = self.dump_checkpoint()
|
||||
|
||||
# request what to save from the model
|
||||
checkpoint_dict = self.dump_checkpoint()
|
||||
model.on_hpc_save(checkpoint)
|
||||
|
||||
# do the actual save
|
||||
torch.save(checkpoint_dict, filepath)
|
||||
torch.save(checkpoint, filepath)
|
||||
|
||||
return filepath
|
||||
|
||||
|
@ -177,14 +175,11 @@ class TrainerIO(object):
|
|||
# load model state
|
||||
model = self.__get_model()
|
||||
|
||||
# give model a chance to load something
|
||||
model.load_model_specific(checkpoint)
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# call model hook
|
||||
model.on_hpc_load()
|
||||
model.on_hpc_load(checkpoint)
|
||||
|
||||
|
||||
def max_ckpt_in_folder(self, path):
|
||||
|
|
Loading…
Reference in New Issue