auto state-dict and remove the way the model is loaded during hpc

This commit is contained in:
William Falcon 2019-07-26 21:37:06 -04:00
parent ff1ed9db7e
commit e2c7fa44b7
2 changed files with 16 additions and 5 deletions

View File

@ -9,17 +9,19 @@ class ModelIO(object):
def load_model_specific(self, checkpoint):
"""
Do something with the checkpoint
Gives model a chance to load something before state_dict is restored
:param checkpoint:
:return:
"""
raise NotImplementedError
pass
def get_save_dict(self):
"""
Return specific things for the model
Called before trainer requests the state_dict
:return:
"""
raise NotImplementedError
pass
# -------------------------
# OPTIONAL HOOKS
@ -80,6 +82,7 @@ class TrainerIO(object):
checkpoint_dict = model.get_save_dict()
# merge trainer and model saving items
checkpoint['state_dict'] = checkpoint_dict
checkpoint.update(checkpoint_dict)
return checkpoint
@ -167,13 +170,18 @@ class TrainerIO(object):
else:
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
# load training state
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
# load model state
model = self.__get_model()
# give model a chance to load something
model.load_model_specific(checkpoint)
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
# call model hook
model.on_hpc_load()

View File

@ -110,9 +110,12 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
model = cls(hparams)
# allow model to load
# give model a chance to load something
model.load_model_specific(checkpoint)
model.load_state_dict(checkpoint['state_dict'], strict=False)
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
return model
def summarize(self):