diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index 2537785175..ec55d5f62f 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -9,17 +9,19 @@ class ModelIO(object): def load_model_specific(self, checkpoint): """ Do something with the checkpoint + Gives model a chance to load something before state_dict is restored :param checkpoint: :return: """ - raise NotImplementedError + pass def get_save_dict(self): """ Return specific things for the model + Called before trainer requests the state_dict :return: """ - raise NotImplementedError + pass # ------------------------- # OPTIONAL HOOKS @@ -80,6 +82,7 @@ class TrainerIO(object): checkpoint_dict = model.get_save_dict() # merge trainer and model saving items + checkpoint['state_dict'] = checkpoint_dict checkpoint.update(checkpoint_dict) return checkpoint @@ -167,13 +170,18 @@ class TrainerIO(object): else: checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage) - # load training state + # load training state (affects trainer only) self.restore_training_state(checkpoint) # load model state model = self.__get_model() + + # give model a chance to load something model.load_model_specific(checkpoint) + # load the state_dict on the model automatically + model.load_state_dict(checkpoint['state_dict']) + # call model hook model.on_hpc_load() diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index d6d740f039..a25f7f6f85 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -110,9 +110,12 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): model = cls(hparams) - # allow model to load + # give model a chance to load something model.load_model_specific(checkpoint) - model.load_state_dict(checkpoint['state_dict'], strict=False) + + # load the state_dict on the model automatically + model.load_state_dict(checkpoint['state_dict']) + return model def summarize(self):