auto state-dict and remove the way the model is loaded during hpc
This commit is contained in:
parent
ff1ed9db7e
commit
e2c7fa44b7
|
@ -9,17 +9,19 @@ class ModelIO(object):
|
|||
def load_model_specific(self, checkpoint):
|
||||
"""
|
||||
Do something with the checkpoint
|
||||
Gives model a chance to load something before state_dict is restored
|
||||
:param checkpoint:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
||||
def get_save_dict(self):
|
||||
"""
|
||||
Return specific things for the model
|
||||
Called before trainer requests the state_dict
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
||||
# -------------------------
|
||||
# OPTIONAL HOOKS
|
||||
|
@ -80,6 +82,7 @@ class TrainerIO(object):
|
|||
checkpoint_dict = model.get_save_dict()
|
||||
|
||||
# merge trainer and model saving items
|
||||
checkpoint['state_dict'] = checkpoint_dict
|
||||
checkpoint.update(checkpoint_dict)
|
||||
return checkpoint
|
||||
|
||||
|
@ -167,13 +170,18 @@ class TrainerIO(object):
|
|||
else:
|
||||
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
|
||||
|
||||
# load training state
|
||||
# load training state (affects trainer only)
|
||||
self.restore_training_state(checkpoint)
|
||||
|
||||
# load model state
|
||||
model = self.__get_model()
|
||||
|
||||
# give model a chance to load something
|
||||
model.load_model_specific(checkpoint)
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# call model hook
|
||||
model.on_hpc_load()
|
||||
|
||||
|
|
|
@ -110,9 +110,12 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
model = cls(hparams)
|
||||
|
||||
# allow model to load
|
||||
# give model a chance to load something
|
||||
model.load_model_specific(checkpoint)
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
return model
|
||||
|
||||
def summarize(self):
|
||||
|
|
Loading…
Reference in New Issue