diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index e9224e2ba9..03cb5864a8 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -7,7 +7,7 @@ from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistr class ModelIO(object): - def load_model_specific(self, checkpoint): + def on_load_checkpoint(self, checkpoint): """ Do something with the checkpoint Gives model a chance to load something before state_dict is restored @@ -16,25 +16,24 @@ class ModelIO(object): """ pass - def get_save_dict(self): + def on_save_checkpoint(self, checkpoint): """ - Return specific things for the model - Called before trainer requests the state_dict - :return: + Give the model a chance to add something to the checkpoint. + state_dict is already there """ pass # ------------------------- # OPTIONAL HOOKS # ------------------------- - def on_hpc_save(self): + def on_hpc_save(self, checkpoint): """ Hook to do whatever you need right before Slurm manager saves the model :return: """ pass - def on_hpc_load(self): + def on_hpc_load(self, checkpoint): """ Hook to do whatever you need right before Slurm manager loads the model :return: @@ -78,13 +77,13 @@ class TrainerIO(object): checkpoint['optimizer_states'] = optimizer_states - # request what to save from the model - model = self.__get_model() - checkpoint_dict = model.get_save_dict() - checkpoint.update(checkpoint_dict) - # add the state_dict from the model - checkpoint['state_dict'] = checkpoint_dict + model = self.__get_model() + checkpoint['state_dict'] = model.get_state_dict + + # give the model a chance to add a few things + model.on_save_checkpoint(checkpoint) + return checkpoint # -------------------- @@ -153,13 +152,12 @@ class TrainerIO(object): # give model a chance to do something on hpc_save model = self.__get_model() - model.on_hpc_save() + checkpoint = self.dump_checkpoint() - # request what to save from the model - checkpoint_dict = self.dump_checkpoint() + model.on_hpc_save(checkpoint) # do the actual save - torch.save(checkpoint_dict, filepath) + torch.save(checkpoint, filepath) return filepath @@ -177,14 +175,11 @@ class TrainerIO(object): # load model state model = self.__get_model() - # give model a chance to load something - model.load_model_specific(checkpoint) - # load the state_dict on the model automatically model.load_state_dict(checkpoint['state_dict']) # call model hook - model.on_hpc_load() + model.on_hpc_load(checkpoint) def max_ckpt_in_folder(self, path):