diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 2655809c4a..c9c4d8504e 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -244,6 +244,30 @@ class Trainer(TrainerIO): ''' raise ModuleNotFoundError(msg) + # restore training and model + self.restore_state_if_existing_checkpoint() + + def restore_state_if_existing_checkpoint(self): + # restore trainer state and model if there is a weight for this experiment + last_epoch = -1 + last_ckpt_name = None + + # find last epoch + checkpoints = os.listdir(self.checkpoint_callback.filepath) + for name in checkpoints: + if '.ckpt' in name: + epoch = name.split('epoch_')[1] + epoch = re.sub('[^0-9]', '' ,epoch) + + if epoch > last_epoch: + last_epoch = epoch + last_ckpt_name = name + + # restore last checkpoint + last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name) + self.restore(last_ckpt_path, self.on_gpu) + print(f'model and trainer restored from checkpoint: {last_ckpt_path}') + @property def data_parallel(self): return self.use_dp or self.use_ddp diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index 0bde0943f4..d9c2e3a5c7 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -58,6 +58,25 @@ class TrainerIO(object): # do the actual save torch.save(checkpoint, filepath) + def restore(self, checkpoint_path, on_gpu): + + if on_gpu: + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + + # load training state (affects trainer only) + self.restore_training_state(checkpoint) + + # load model state + model = self.__get_model() + + # load the state_dict on the model automatically + model.load_state_dict(checkpoint['state_dict']) + + # call model hook + model.on_hpc_load(checkpoint) + def dump_checkpoint(self): checkpoint = { @@ -198,15 +217,15 @@ class TrainerIO(object): # call model hook model.on_hpc_load(checkpoint) - def max_ckpt_in_folder(self, path): + def max_ckpt_in_folder(self, path, name_key='ckpt_'): files = os.listdir(path) - files = [x for x in files if 'ckpt_' in x] + files = [x for x in files if name_key in x] if len(files) == 0: return 0 ckpt_vs = [] for name in files: - name = name.split('ckpt_')[-1] + name = name.split(name_key)[-1] name = re.sub('[^0-9]', '', name) ckpt_vs.append(int(name))