""" Lightning can automate saving and loading checkpoints ===================================================== Checkpointing is enabled by default to the current working directory. To change the checkpoint path pass in:: Trainer(default_save_path='/your/path/to/save/checkpoints') To modify the behavior of checkpointing pass in your own callback. .. code-block:: python from pytorch_lightning.callbacks import ModelCheckpoint # DEFAULTS used by the Trainer checkpoint_callback = ModelCheckpoint( filepath=os.getcwd(), save_top_k=1, verbose=True, monitor='val_loss', mode='min', prefix='' ) trainer = Trainer(checkpoint_callback=checkpoint_callback) Restoring training session -------------------------- You might want to not only load a model but also continue training it. Use this method to restore the trainer state as well. This will continue from the epoch and global step you last left off. However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter). Lightning will restore the session if you pass a logger with the same version and there's a saved checkpoint. .. code-block:: python from pytorch_lightning import Trainer trainer = Trainer( resume_from_checkpoint=PATH ) # this fit call loads model weights and trainer state # the trainer continues seamlessly from where you left off # without having to do anything else. trainer.fit(model) The trainer restores: - global_step - current_epoch - All optimizers - All lr_schedulers - Model weights You can even change the logic of your model as long as the weights and "architecture" of the system isn't different. If you add a layer, for instance, it might not work. At a rough level, here's what happens inside Trainer :py:mod:`pytorch_lightning.base_module.model_saving.py`: .. code-block:: python self.global_step = checkpoint['global_step'] self.current_epoch = checkpoint['epoch'] # restore the optimizers optimizer_states = checkpoint['optimizer_states'] for optimizer, opt_state in zip(self.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) # restore the lr schedulers lr_schedulers = checkpoint['lr_schedulers'] for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): scheduler['scheduler'].load_state_dict(lrs_state) # uses the model you passed into trainer model.load_state_dict(checkpoint['state_dict']) """ import os import re import signal from abc import ABC from argparse import Namespace from subprocess import call from typing import Union import torch import torch.distributed as torch_distrib from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, LightningDataParallel, ) from pytorch_lightning.utilities import rank_zero_warn try: import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp except ImportError: XLA_AVAILABLE = False else: XLA_AVAILABLE = True class TrainerIOMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class model: LightningModule on_gpu: bool root_gpu: ... resume_from_checkpoint: ... use_ddp: bool use_ddp2: bool checkpoint_callback: ... proc_rank: int weights_save_path: str logger: Union[LightningLoggerBase, bool] early_stop_callback: ... lr_schedulers: ... optimizers: ... on_tpu: bool num_training_batches: int accumulate_grad_batches: int def get_model(self): is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, LightningDataParallel)) model = self.model.module if is_dp_module else self.model return model # -------------------- # CHECK-POINTING # -------------------- def restore_weights(self, model: LightningModule): """ We attempt to restore weights in this order: 1. HPC weights. 2. if no HPC weights restore checkpoint_path weights 3. otherwise don't restore weights """ # clear cache before restore if self.on_gpu: torch.cuda.empty_cache() # if script called from hpc resubmit, load weights did_restore_hpc_weights = self.restore_hpc_weights_if_needed(model) # clear cache after restore if self.on_gpu: torch.cuda.empty_cache() if not did_restore_hpc_weights: if self.resume_from_checkpoint is not None: self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu) # wait for all models to restore weights if self.use_ddp or self.use_ddp2: # wait for all processes to catch up torch_distrib.barrier() # wait for all models to restore weights if self.on_tpu and XLA_AVAILABLE: # wait for all processes to catch up torch_xla.core.xla_model.rendezvous("pl.TrainerIOMixin.restore_weights") # clear cache after restore if self.on_gpu: torch.cuda.empty_cache() # -------------------- # HPC SIGNAL HANDLING # -------------------- def register_slurm_signal_handlers(self): # see if we're using slurm (not interactive) on_slurm = False try: job_name = os.environ['SLURM_JOB_NAME'] if job_name != 'bash': on_slurm = True except Exception as e: pass if on_slurm: log.info('Set SLURM handle signals.') signal.signal(signal.SIGUSR1, self.sig_handler) signal.signal(signal.SIGTERM, self.term_handler) def sig_handler(self, signum, frame): # pragma: no-cover if self.proc_rank == 0: # save weights log.info('handling SIGUSR1') self.hpc_save(self.weights_save_path, self.logger) # find job id job_id = os.environ['SLURM_JOB_ID'] cmd = 'scontrol requeue {}'.format(job_id) # requeue job log.info(f'requeing job {job_id}...') result = call(cmd, shell=True) # print result text if result == 0: log.info(f'requeued exp {job_id}') else: log.info('requeue failed...') # close experiment to avoid issues self.logger.close() def term_handler(self, signum, frame): # save log.info("bypassing sigterm") # -------------------- # MODEL SAVE CHECKPOINT # -------------------- def _atomic_save(self, checkpoint, filepath: str): """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once saving is finished. Args: checkpoint: The object to save. Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` accepts. filepath: The path to which the checkpoint will be saved. This points to the file that the checkpoint will be stored in. """ tmp_path = str(filepath) + ".part" torch.save(checkpoint, tmp_path) os.replace(tmp_path, filepath) def save_checkpoint(self, filepath): checkpoint = self.dump_checkpoint() if self.proc_rank == 0: # do the actual save try: self._atomic_save(checkpoint, filepath) except AttributeError: if 'hparams' in checkpoint: del checkpoint['hparams'] self._atomic_save(checkpoint, filepath) def restore(self, checkpoint_path: str, on_gpu: bool): """ Restore training state from checkpoint. Also restores all training state like: - epoch - callbacks - schedulers - optimizer """ # if on_gpu: # checkpoint = torch.load(checkpoint_path) # else: # load on CPU first checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) # load model state model = self.get_model() # load the state_dict on the model automatically model.load_state_dict(checkpoint['state_dict']) if on_gpu: model.cuda(self.root_gpu) # load training state (affects trainer only) self.restore_training_state(checkpoint) def dump_checkpoint(self): checkpoint = { 'epoch': self.current_epoch + 1, 'global_step': self.global_step + 1, } if self.checkpoint_callback is not None and self.checkpoint_callback is not False: checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best if self.early_stop_callback is not None and self.checkpoint_callback is not False: checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience # save optimizers optimizer_states = [] for i, optimizer in enumerate(self.optimizers): optimizer_states.append(optimizer.state_dict()) checkpoint['optimizer_states'] = optimizer_states # save lr schedulers lr_schedulers = [] for scheduler in self.lr_schedulers: lr_schedulers.append(scheduler['scheduler'].state_dict()) checkpoint['lr_schedulers'] = lr_schedulers # add the hparams and state_dict from the model model = self.get_model() checkpoint['state_dict'] = model.state_dict() if hasattr(model, "hparams"): is_namespace = isinstance(model.hparams, Namespace) checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams checkpoint['hparams_type'] = 'namespace' if is_namespace else 'dict' else: rank_zero_warn( "Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters." ) # give the model a chance to add a few things model.on_save_checkpoint(checkpoint) return checkpoint # -------------------- # HPC IO # -------------------- def restore_hpc_weights_if_needed(self, model: LightningModule): """If there is a set of hpc weights, use as signal to restore model.""" did_restore = False # look for hpc weights folderpath = self.weights_save_path if os.path.exists(folderpath): files = os.listdir(folderpath) hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x] # if hpc weights exist restore model if len(hpc_weight_paths) > 0: self.hpc_load(folderpath, self.on_gpu) did_restore = True return did_restore def restore_training_state(self, checkpoint): """ Restore trainer state. Model will get its change to update :param checkpoint: :return: """ if self.checkpoint_callback is not None and self.checkpoint_callback is not False: self.checkpoint_callback.best = checkpoint['checkpoint_callback_best'] if self.early_stop_callback is not None and self.early_stop_callback is not False: self.early_stop_callback.wait = checkpoint['early_stop_callback_wait'] self.early_stop_callback.patience = checkpoint['early_stop_callback_patience'] self.global_step = checkpoint['global_step'] self.current_epoch = checkpoint['epoch'] # Division deals with global step stepping once per accumulated batch # Inequality deals with different global step for odd vs even num_training_batches n_accum = 1 if self.accumulate_grad_batches is None else self.accumulate_grad_batches expected_steps = self.num_training_batches / n_accum if self.num_training_batches != 0 and self.global_step % expected_steps > 1: rank_zero_warn( "You're resuming from a checkpoint that ended mid-epoch. " "This can cause unreliable results if further training is done, " "consider using an end of epoch checkpoint. " ) # restore the optimizers optimizer_states = checkpoint['optimizer_states'] for optimizer, opt_state in zip(self.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) # move optimizer to GPU 1 weight at a time # avoids OOM if self.root_gpu is not None: for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(self.root_gpu) # restore the lr schedulers lr_schedulers = checkpoint['lr_schedulers'] for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): scheduler['scheduler'].load_state_dict(lrs_state) # ---------------------------------- # PRIVATE OPS # ---------------------------------- def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists os.makedirs(folderpath, exist_ok=True) # save logger to make sure we get all the metrics logger.save() ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 if not os.path.exists(folderpath): os.makedirs(folderpath, exist_ok=True) filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save model = self.get_model() checkpoint = self.dump_checkpoint() model.on_hpc_save(checkpoint) # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: self._atomic_save(checkpoint, filepath) except AttributeError: if 'hparams' in checkpoint: del checkpoint['hparams'] self._atomic_save(checkpoint, filepath) return filepath def hpc_load(self, folderpath, on_gpu): filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath)) # load on CPU first checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage) # load model state model = self.get_model() # load the state_dict on the model automatically model.load_state_dict(checkpoint['state_dict']) if self.root_gpu is not None: model.cuda(self.root_gpu) # load training state (affects trainer only) self.restore_training_state(checkpoint) # call model hook model.on_hpc_load(checkpoint) log.info(f'restored hpc model from: {filepath}') def max_ckpt_in_folder(self, path, name_key='ckpt_'): files = os.listdir(path) files = [x for x in files if name_key in x] if len(files) == 0: return 0 ckpt_vs = [] for name in files: name = name.split(name_key)[-1] name = re.sub('[^0-9]', '', name) ckpt_vs.append(int(name)) return max(ckpt_vs)