fixed hpc save, load. cleaned apu

This commit is contained in:
William Falcon 2019-07-26 22:04:27 -04:00
parent 4148c36abd
commit 265411572f
1 changed files with 16 additions and 21 deletions

View File

@ -7,7 +7,7 @@ from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistr
class ModelIO(object):
def load_model_specific(self, checkpoint):
def on_load_checkpoint(self, checkpoint):
"""
Do something with the checkpoint
Gives model a chance to load something before state_dict is restored
@ -16,25 +16,24 @@ class ModelIO(object):
"""
pass
def get_save_dict(self):
def on_save_checkpoint(self, checkpoint):
"""
Return specific things for the model
Called before trainer requests the state_dict
:return:
Give the model a chance to add something to the checkpoint.
state_dict is already there
"""
pass
# -------------------------
# OPTIONAL HOOKS
# -------------------------
def on_hpc_save(self):
def on_hpc_save(self, checkpoint):
"""
Hook to do whatever you need right before Slurm manager saves the model
:return:
"""
pass
def on_hpc_load(self):
def on_hpc_load(self, checkpoint):
"""
Hook to do whatever you need right before Slurm manager loads the model
:return:
@ -78,13 +77,13 @@ class TrainerIO(object):
checkpoint['optimizer_states'] = optimizer_states
# request what to save from the model
model = self.__get_model()
checkpoint_dict = model.get_save_dict()
checkpoint.update(checkpoint_dict)
# add the state_dict from the model
checkpoint['state_dict'] = checkpoint_dict
model = self.__get_model()
checkpoint['state_dict'] = model.get_state_dict
# give the model a chance to add a few things
model.on_save_checkpoint(checkpoint)
return checkpoint
# --------------------
@ -153,13 +152,12 @@ class TrainerIO(object):
# give model a chance to do something on hpc_save
model = self.__get_model()
model.on_hpc_save()
checkpoint = self.dump_checkpoint()
# request what to save from the model
checkpoint_dict = self.dump_checkpoint()
model.on_hpc_save(checkpoint)
# do the actual save
torch.save(checkpoint_dict, filepath)
torch.save(checkpoint, filepath)
return filepath
@ -177,14 +175,11 @@ class TrainerIO(object):
# load model state
model = self.__get_model()
# give model a chance to load something
model.load_model_specific(checkpoint)
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
# call model hook
model.on_hpc_load()
model.on_hpc_load(checkpoint)
def max_ckpt_in_folder(self, path):