lightning/pytorch_lightning/trainer/training_io.py

468 lines
15 KiB
Python

"""
Lightning can automate saving and loading checkpoints
=====================================================
Checkpointing is enabled by default to the current working directory.
To change the checkpoint path pass in::
Trainer(default_root_dir='/your/path/to/save/checkpoints')
To modify the behavior of checkpointing pass in your own callback.
.. code-block:: python
from pytorch_lightning.callbacks import ModelCheckpoint
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min',
prefix=''
)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
Restoring training session
--------------------------
You might want to not only load a model but also continue training it. Use this method to
restore the trainer state as well. This will continue from the epoch and global step you last left off.
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).
Lightning will restore the session if you pass a logger with the same version and there's a saved checkpoint.
.. code-block:: python
from pytorch_lightning import Trainer
trainer = Trainer(
resume_from_checkpoint=PATH
)
# this fit call loads model weights and trainer state
# the trainer continues seamlessly from where you left off
# without having to do anything else.
trainer.fit(model)
The trainer restores:
- global_step
- current_epoch
- All optimizers
- All lr_schedulers
- Model weights
You can even change the logic of your model as long as the weights and "architecture" of
the system isn't different. If you add a layer, for instance, it might not work.
At a rough level, here's what happens inside Trainer :py:mod:`pytorch_lightning.base_module.model_saving.py`:
.. code-block:: python
self.global_step = checkpoint['global_step']
self.current_epoch = checkpoint['epoch']
# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)
# uses the model you passed into trainer
model.load_state_dict(checkpoint['state_dict'])
"""
import os
import re
import signal
from abc import ABC
from argparse import Namespace
from subprocess import call
from typing import Union
import torch
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities import rank_zero_warn
try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
class TrainerIOMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
model: LightningModule
on_gpu: bool
root_gpu: ...
resume_from_checkpoint: ...
use_ddp: bool
use_ddp2: bool
checkpoint_callback: ...
proc_rank: int
weights_save_path: str
logger: Union[LightningLoggerBase, bool]
early_stop_callback: ...
lr_schedulers: ...
optimizers: ...
on_tpu: bool
num_training_batches: int
accumulate_grad_batches: int
def get_model(self):
is_dp_module = isinstance(self.model, (LightningDistributedDataParallel,
LightningDataParallel))
model = self.model.module if is_dp_module else self.model
return model
# --------------------
# CHECK-POINTING
# --------------------
def restore_weights(self, model: LightningModule):
"""
We attempt to restore weights in this order:
1. HPC weights.
2. if no HPC weights restore checkpoint_path weights
3. otherwise don't restore weights
"""
# clear cache before restore
if self.on_gpu:
torch.cuda.empty_cache()
# if script called from hpc resubmit, load weights
did_restore_hpc_weights = self.restore_hpc_weights_if_needed(model)
# clear cache after restore
if self.on_gpu:
torch.cuda.empty_cache()
if not did_restore_hpc_weights:
if self.resume_from_checkpoint is not None:
self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu)
# wait for all models to restore weights
if self.use_ddp or self.use_ddp2:
# wait for all processes to catch up
torch_distrib.barrier()
# wait for all models to restore weights
if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous("pl.TrainerIOMixin.restore_weights")
# clear cache after restore
if self.on_gpu:
torch.cuda.empty_cache()
# --------------------
# HPC SIGNAL HANDLING
# --------------------
def register_slurm_signal_handlers(self):
# see if we're using slurm (not interactive)
on_slurm = False
try:
job_name = os.environ['SLURM_JOB_NAME']
if job_name != 'bash':
on_slurm = True
except Exception as e:
pass
if on_slurm:
log.info('Set SLURM handle signals.')
signal.signal(signal.SIGUSR1, self.sig_handler)
signal.signal(signal.SIGTERM, self.term_handler)
def sig_handler(self, signum, frame): # pragma: no-cover
if self.proc_rank == 0:
# save weights
log.info('handling SIGUSR1')
self.hpc_save(self.weights_save_path, self.logger)
# find job id
job_id = os.environ['SLURM_JOB_ID']
cmd = 'scontrol requeue {}'.format(job_id)
# requeue job
log.info(f'requeing job {job_id}...')
result = call(cmd, shell=True)
# print result text
if result == 0:
log.info(f'requeued exp {job_id}')
else:
log.info('requeue failed...')
# close experiment to avoid issues
self.logger.close()
def term_handler(self, signum, frame):
# save
log.info("bypassing sigterm")
# --------------------
# MODEL SAVE CHECKPOINT
# --------------------
def _atomic_save(self, checkpoint, filepath: str):
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once
saving is finished.
Args:
checkpoint: The object to save.
Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
accepts.
filepath: The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in.
"""
tmp_path = str(filepath) + ".part"
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)
def save_checkpoint(self, filepath):
checkpoint = self.dump_checkpoint()
if self.proc_rank == 0:
# do the actual save
try:
self._atomic_save(checkpoint, filepath)
except AttributeError:
if 'hparams' in checkpoint:
del checkpoint['hparams']
self._atomic_save(checkpoint, filepath)
def restore(self, checkpoint_path: str, on_gpu: bool):
"""
Restore training state from checkpoint.
Also restores all training state like:
- epoch
- callbacks
- schedulers
- optimizer
"""
# if on_gpu:
# checkpoint = torch.load(checkpoint_path)
# else:
# load on CPU first
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# load model state
model = self.get_model()
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
if on_gpu:
model.cuda(self.root_gpu)
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
def dump_checkpoint(self):
checkpoint = {
'epoch': self.current_epoch + 1,
'global_step': self.global_step + 1,
}
if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
if self.early_stop_callback is not None and self.checkpoint_callback is not False:
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience
# save optimizers
optimizer_states = []
for i, optimizer in enumerate(self.optimizers):
optimizer_states.append(optimizer.state_dict())
checkpoint['optimizer_states'] = optimizer_states
# save lr schedulers
lr_schedulers = []
for scheduler in self.lr_schedulers:
lr_schedulers.append(scheduler['scheduler'].state_dict())
checkpoint['lr_schedulers'] = lr_schedulers
# add the hparams and state_dict from the model
model = self.get_model()
checkpoint['state_dict'] = model.state_dict()
if hasattr(model, "hparams"):
is_namespace = isinstance(model.hparams, Namespace)
checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams
checkpoint['hparams_type'] = 'namespace' if is_namespace else 'dict'
else:
rank_zero_warn(
"Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters."
)
# give the model a chance to add a few things
model.on_save_checkpoint(checkpoint)
return checkpoint
# --------------------
# HPC IO
# --------------------
def restore_hpc_weights_if_needed(self, model: LightningModule):
"""If there is a set of hpc weights, use as signal to restore model."""
did_restore = False
# look for hpc weights
folderpath = self.weights_save_path
if os.path.exists(folderpath):
files = os.listdir(folderpath)
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]
# if hpc weights exist restore model
if len(hpc_weight_paths) > 0:
self.hpc_load(folderpath, self.on_gpu)
did_restore = True
return did_restore
def restore_training_state(self, checkpoint):
"""
Restore trainer state.
Model will get its change to update
:param checkpoint:
:return:
"""
if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
if self.early_stop_callback is not None and self.early_stop_callback is not False:
self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']
self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']
self.global_step = checkpoint['global_step']
self.current_epoch = checkpoint['epoch']
# Division deals with global step stepping once per accumulated batch
# Inequality deals with different global step for odd vs even num_training_batches
n_accum = 1 if self.accumulate_grad_batches is None else self.accumulate_grad_batches
expected_steps = self.num_training_batches / n_accum
if self.num_training_batches != 0 and self.global_step % expected_steps > 1:
rank_zero_warn(
"You're resuming from a checkpoint that ended mid-epoch. "
"This can cause unreliable results if further training is done, "
"consider using an end of epoch checkpoint. "
)
# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.root_gpu is not None:
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.root_gpu)
# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)
# ----------------------------------
# PRIVATE OPS
# ----------------------------------
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
os.makedirs(folderpath, exist_ok=True)
# save logger to make sure we get all the metrics
logger.save()
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
if not os.path.exists(folderpath):
os.makedirs(folderpath, exist_ok=True)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
# give model a chance to do something on hpc_save
model = self.get_model()
checkpoint = self.dump_checkpoint()
model.on_hpc_save(checkpoint)
# do the actual save
# TODO: fix for anything with multiprocess DP, DDP, DDP2
try:
self._atomic_save(checkpoint, filepath)
except AttributeError:
if 'hparams' in checkpoint:
del checkpoint['hparams']
self._atomic_save(checkpoint, filepath)
return filepath
def hpc_load(self, folderpath, on_gpu):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
# load on CPU first
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
# load model state
model = self.get_model()
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
if self.root_gpu is not None:
model.cuda(self.root_gpu)
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
# call model hook
model.on_hpc_load(checkpoint)
log.info(f'restored hpc model from: {filepath}')
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = os.listdir(path)
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
ckpt_vs = []
for name in files:
name = name.split(name_key)[-1]
name = re.sub('[^0-9]', '', name)
ckpt_vs.append(int(name))
return max(ckpt_vs)