update checkpoint docs (#1016)
* update checkpoint docs * fix tests * fix tests * formatting * typing * filename * fix tests * fixing tests * fixing tests * fixing tests * unique name * fixing * fixing * Update model_checkpoint.py Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
d1c0f1270d
commit
64de57b09e
|
@ -6,7 +6,7 @@ Save the model as often as requested.
|
|||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import glob
|
||||
import logging as log
|
||||
import warnings
|
||||
|
||||
|
@ -20,17 +20,19 @@ class ModelCheckpoint(Callback):
|
|||
Save the model after every epoch.
|
||||
|
||||
Args:
|
||||
filepath: path to save the model file.
|
||||
dirpath: path to save the model file.
|
||||
Can contain named formatting options to be auto-filled.
|
||||
|
||||
Example::
|
||||
|
||||
# save epoch and val_loss in name
|
||||
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
|
||||
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
|
||||
monitor (str): quantity to monitor.
|
||||
verbose (bool): verbosity mode, False or True.
|
||||
save_top_k (int): if `save_top_k == k`,
|
||||
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
|
||||
# if such model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt
|
||||
|
||||
monitor: quantity to monitor.
|
||||
verbose: verbosity mode, False or True.
|
||||
save_top_k: if `save_top_k == k`,
|
||||
the best k models according to
|
||||
the quantity monitored will be saved.
|
||||
if ``save_top_k == 0``, no models are saved.
|
||||
|
@ -39,7 +41,7 @@ class ModelCheckpoint(Callback):
|
|||
if ``save_top_k >= 2`` and the callback is called multiple
|
||||
times inside an epoch, the name of the saved file will be
|
||||
appended with a version count starting with `v0`.
|
||||
mode (str): one of {auto, min, max}.
|
||||
mode: one of {auto, min, max}.
|
||||
If ``save_top_k != 0``, the decision
|
||||
to overwrite the current save file is made
|
||||
based on either the maximization or the
|
||||
|
@ -47,35 +49,46 @@ class ModelCheckpoint(Callback):
|
|||
this should be `max`, for `val_loss` this should
|
||||
be `min`, etc. In `auto` mode, the direction is
|
||||
automatically inferred from the name of the monitored quantity.
|
||||
save_weights_only (bool): if True, then only the model's weights will be
|
||||
save_weights_only: if True, then only the model's weights will be
|
||||
saved (`model.save_weights(filepath)`), else the full model
|
||||
is saved (`model.save(filepath)`).
|
||||
period (int): Interval (number of epochs) between checkpoints.
|
||||
period: Interval (number of epochs) between checkpoints.
|
||||
prefix: String name for particular model
|
||||
|
||||
Example::
|
||||
Example:
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
# saves checkpoints to my_path whenever 'val_loss' has a new min
|
||||
checkpoint_callback = ModelCheckpoint(filepath='my_path')
|
||||
checkpoint_callback = ModelCheckpoint('my_path')
|
||||
Trainer(checkpoint_callback=checkpoint_callback)
|
||||
"""
|
||||
#: checkpoint extension
|
||||
EXTENSION = '.ckpt'
|
||||
|
||||
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
|
||||
save_top_k: int = 1, save_weights_only: bool = False,
|
||||
mode: str = 'auto', period: int = 1, prefix: str = ''):
|
||||
def __init__(
|
||||
self,
|
||||
dirpath: str,
|
||||
monitor: str = 'val_loss',
|
||||
verbose: bool = False,
|
||||
save_top_k: int = 1,
|
||||
save_weights_only: bool = False,
|
||||
mode: str = 'auto',
|
||||
period: int = 1,
|
||||
prefix: str = ''
|
||||
):
|
||||
super().__init__()
|
||||
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
|
||||
if save_top_k and os.path.isdir(dirpath) and len(os.listdir(dirpath)) > 0:
|
||||
warnings.warn(
|
||||
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
|
||||
f"Checkpoint directory {dirpath} exists and is not empty with save_top_k != 0."
|
||||
"All files in this directory will be deleted when a checkpoint is saved!"
|
||||
)
|
||||
|
||||
self.monitor = monitor
|
||||
self.verbose = verbose
|
||||
self.filepath = filepath
|
||||
os.makedirs(filepath, exist_ok=True)
|
||||
self.dirpath = dirpath
|
||||
os.makedirs(dirpath, exist_ok=True)
|
||||
self.save_top_k = save_top_k
|
||||
self.save_weights_only = save_weights_only
|
||||
self.period = period
|
||||
|
@ -87,6 +100,14 @@ class ModelCheckpoint(Callback):
|
|||
self.best = 0
|
||||
self.save_function = None
|
||||
|
||||
# this create unique prefix if the give already exists
|
||||
existing_checkpoints = sorted(glob.glob(os.path.join(self.dirpath, '*' + self.EXTENSION)))
|
||||
existing_names = set(os.path.basename(ckpt).split('_epoch=')[0] for ckpt in existing_checkpoints)
|
||||
version_cnt = 0
|
||||
while self.prefix in existing_names:
|
||||
self.prefix = f'{prefix}-v{version_cnt}'
|
||||
version_cnt += 1
|
||||
|
||||
mode_dict = {
|
||||
'min': (np.less, np.Inf, 'min'),
|
||||
'max': (np.greater, -np.Inf, 'max'),
|
||||
|
@ -102,15 +123,13 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
|
||||
|
||||
def _del_model(self, filepath):
|
||||
try:
|
||||
shutil.rmtree(filepath)
|
||||
except OSError:
|
||||
os.remove(filepath)
|
||||
def _del_model(self, filepath: str) -> None:
|
||||
# shutil.rmtree(filepath)
|
||||
os.remove(filepath)
|
||||
|
||||
def _save_model(self, filepath):
|
||||
def _save_model(self, filepath: str) -> None:
|
||||
# make paths
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
os.makedirs(self.dirpath, exist_ok=True)
|
||||
|
||||
# delegate the saving to the model
|
||||
if self.save_function is not None:
|
||||
|
@ -118,13 +137,20 @@ class ModelCheckpoint(Callback):
|
|||
else:
|
||||
raise ValueError(".save_function() not set")
|
||||
|
||||
def check_monitor_top_k(self, current):
|
||||
def check_monitor_top_k(self, current: float) -> bool:
|
||||
less_than_k_models = len(self.best_k_models) < self.save_top_k
|
||||
if less_than_k_models:
|
||||
return True
|
||||
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
def _get_available_filepath(self, current: float, epoch: int) -> str:
|
||||
current_str = f'{current:.2f}' if current else 'NaN'
|
||||
fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}'
|
||||
filepath = os.path.join(self.dirpath, fname + self.EXTENSION)
|
||||
assert not os.path.isfile(filepath)
|
||||
return filepath
|
||||
|
||||
def on_validation_end(self, trainer, pl_module) -> None:
|
||||
# only run on main process
|
||||
if trainer.proc_rank != 0:
|
||||
return
|
||||
|
@ -138,35 +164,27 @@ class ModelCheckpoint(Callback):
|
|||
return
|
||||
if self.epochs_since_last_check >= self.period:
|
||||
self.epochs_since_last_check = 0
|
||||
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
|
||||
version_cnt = 0
|
||||
while os.path.isfile(filepath):
|
||||
# this epoch called before
|
||||
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
|
||||
version_cnt += 1
|
||||
current = logs.get(self.monitor)
|
||||
filepath = self._get_available_filepath(current, epoch)
|
||||
|
||||
if self.save_top_k != -1:
|
||||
current = logs.get(self.monitor)
|
||||
|
||||
if current is None:
|
||||
warnings.warn(
|
||||
f'Can save best model only with {self.monitor} available,'
|
||||
' skipping.', RuntimeWarning)
|
||||
warnings.warn(f'Can save best model only with {self.monitor} available,'
|
||||
' skipping.', RuntimeWarning)
|
||||
else:
|
||||
if self.check_monitor_top_k(current):
|
||||
self._do_check_save(filepath, current, epoch)
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
log.info(
|
||||
f'\nEpoch {epoch:05d}: {self.monitor}'
|
||||
f' was not in top {self.save_top_k}')
|
||||
log.info('Epoch %05d: %s was not in top %i', epoch, self.monitor, self.save_top_k)
|
||||
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
|
||||
log.info('Epoch %05d: saving model to %s', epoch, filepath)
|
||||
self._save_model(filepath)
|
||||
|
||||
def _do_check_save(self, filepath, current, epoch):
|
||||
def _do_check_save(self, filepath: str, current: float, epoch: int) -> None:
|
||||
# remove kth
|
||||
if len(self.best_k_models) == self.save_top_k:
|
||||
delpath = self.kth_best_model
|
||||
|
@ -185,8 +203,6 @@ class ModelCheckpoint(Callback):
|
|||
self.best = _op(self.best_k_models.values())
|
||||
|
||||
if self.verbose > 0:
|
||||
log.info(
|
||||
f'\nEpoch {epoch:05d}: {self.monitor} reached'
|
||||
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
|
||||
f' {filepath} as top {self.save_top_k}')
|
||||
log.info('Epoch {epoch:05d}: %s reached %0.5f (best %0.5f), saving model to %s as top %i',
|
||||
epoch, self.monitor, current, self.best, filepath, self.save_top_k)
|
||||
self._save_model(filepath)
|
||||
|
|
|
@ -50,7 +50,7 @@ class TrainerCallbackConfigMixin(ABC):
|
|||
|
||||
self.ckpt_path = ckpt_path
|
||||
self.checkpoint_callback = ModelCheckpoint(
|
||||
filepath=ckpt_path
|
||||
dirpath=ckpt_path
|
||||
)
|
||||
elif self.checkpoint_callback is False:
|
||||
self.checkpoint_callback = None
|
||||
|
@ -62,7 +62,7 @@ class TrainerCallbackConfigMixin(ABC):
|
|||
self.checkpoint_callback.save_function = self.save_checkpoint
|
||||
|
||||
# if checkpoint callback used, then override the weights path
|
||||
self.weights_save_path = self.checkpoint_callback.filepath
|
||||
self.weights_save_path = self.checkpoint_callback.dirpath
|
||||
|
||||
# if weights_save_path is still none here, set to current working dir
|
||||
if self.weights_save_path is None:
|
||||
|
|
|
@ -810,7 +810,7 @@ class Trainer(TrainerIOMixin,
|
|||
self.amp_level = amp_level
|
||||
self.precision = precision
|
||||
|
||||
assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported'
|
||||
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
|
||||
|
||||
if self.precision == 16 and num_tpu_cores is None:
|
||||
use_amp = True
|
||||
|
|
|
@ -32,7 +32,7 @@ def run_model_test_no_loggers(trainer_options, model, min_acc=0.50):
|
|||
|
||||
# test model loading
|
||||
pretrained_model = load_model(trainer.logger,
|
||||
trainer.checkpoint_callback.filepath,
|
||||
trainer.checkpoint_callback.dirpath,
|
||||
path_expt=trainer_options.get('default_save_path'))
|
||||
|
||||
# test new model accuracy
|
||||
|
@ -70,7 +70,7 @@ def run_model_test(trainer_options, model, on_gpu=True):
|
|||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
||||
# test model loading
|
||||
pretrained_model = load_model(logger, trainer.checkpoint_callback.filepath)
|
||||
pretrained_model = load_model(logger, trainer.checkpoint_callback.dirpath)
|
||||
|
||||
# test new model accuracy
|
||||
test_loaders = model.test_dataloader()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import glob
|
||||
import logging as log
|
||||
import os
|
||||
|
||||
|
@ -52,7 +53,7 @@ def test_running_test_pretrained_model_ddp(tmpdir):
|
|||
# correct result and ok accuracy
|
||||
assert result == 1, 'training failed to complete'
|
||||
pretrained_model = tutils.load_model(logger,
|
||||
trainer.checkpoint_callback.filepath,
|
||||
trainer.checkpoint_callback.dirpath,
|
||||
module_class=LightningTestModel)
|
||||
|
||||
# run test set
|
||||
|
@ -96,7 +97,7 @@ def test_running_test_pretrained_model(tmpdir):
|
|||
# correct result and ok accuracy
|
||||
assert result == 1, 'training failed to complete'
|
||||
pretrained_model = tutils.load_model(
|
||||
logger, trainer.checkpoint_callback.filepath, module_class=LightningTestModel
|
||||
logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel
|
||||
)
|
||||
|
||||
new_trainer = Trainer(**trainer_options)
|
||||
|
@ -132,9 +133,7 @@ def test_load_model_from_checkpoint(tmpdir):
|
|||
assert result == 1, 'training failed to complete'
|
||||
|
||||
# load last checkpoint
|
||||
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt")
|
||||
if not os.path.isfile(last_checkpoint):
|
||||
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
|
||||
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
|
||||
pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint)
|
||||
|
||||
# test that hparams loaded correctly
|
||||
|
@ -186,7 +185,7 @@ def test_running_test_pretrained_model_dp(tmpdir):
|
|||
# correct result and ok accuracy
|
||||
assert result == 1, 'training failed to complete'
|
||||
pretrained_model = tutils.load_model(logger,
|
||||
trainer.checkpoint_callback.filepath,
|
||||
trainer.checkpoint_callback.dirpath,
|
||||
module_class=LightningTestModel)
|
||||
|
||||
new_trainer = Trainer(**trainer_options)
|
||||
|
@ -346,7 +345,7 @@ def test_load_model_with_missing_hparams(tmpdir):
|
|||
|
||||
model = LightningTestModelWithoutHyperparametersArg()
|
||||
trainer.fit(model)
|
||||
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
|
||||
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
|
||||
|
||||
# try to load a checkpoint that has hparams but model is missing hparams arg
|
||||
with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import glob
|
||||
import math
|
||||
import os
|
||||
import pytest
|
||||
|
@ -257,8 +258,12 @@ def test_model_checkpoint_options(tmp_path):
|
|||
assert len(file_lists) == len(losses), "Should save all models when save_top_k=-1"
|
||||
|
||||
# verify correct naming
|
||||
for i in range(0, len(losses)):
|
||||
assert f"_ckpt_epoch_{i}.ckpt" in file_lists
|
||||
for fname in {'_epoch=4_val_loss=2.50.ckpt',
|
||||
'_epoch=3_val_loss=5.00.ckpt',
|
||||
'_epoch=2_val_loss=2.80.ckpt',
|
||||
'_epoch=1_val_loss=9.00.ckpt',
|
||||
'_epoch=0_val_loss=10.00.ckpt'}:
|
||||
assert fname in file_lists
|
||||
|
||||
save_dir = tmp_path / "2"
|
||||
save_dir.mkdir()
|
||||
|
@ -297,7 +302,7 @@ def test_model_checkpoint_options(tmp_path):
|
|||
file_lists = set(os.listdir(save_dir))
|
||||
|
||||
assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
|
||||
assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists
|
||||
assert 'test_prefix_epoch=4_val_loss=2.50.ckpt' in file_lists
|
||||
|
||||
save_dir = tmp_path / "4"
|
||||
save_dir.mkdir()
|
||||
|
@ -320,9 +325,10 @@ def test_model_checkpoint_options(tmp_path):
|
|||
file_lists = set(os.listdir(save_dir))
|
||||
|
||||
assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
|
||||
assert '_ckpt_epoch_4.ckpt' in file_lists
|
||||
assert '_ckpt_epoch_2.ckpt' in file_lists
|
||||
assert 'other_file.ckpt' in file_lists
|
||||
for fname in {'_epoch=4_val_loss=2.50.ckpt',
|
||||
'_epoch=2_val_loss=2.80.ckpt',
|
||||
'other_file.ckpt'}:
|
||||
assert fname in file_lists
|
||||
|
||||
save_dir = tmp_path / "5"
|
||||
save_dir.mkdir()
|
||||
|
@ -365,9 +371,10 @@ def test_model_checkpoint_options(tmp_path):
|
|||
file_lists = set(os.listdir(save_dir))
|
||||
|
||||
assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
|
||||
assert '_ckpt_epoch_0_v2.ckpt' in file_lists
|
||||
assert '_ckpt_epoch_0_v1.ckpt' in file_lists
|
||||
assert '_ckpt_epoch_0.ckpt' in file_lists
|
||||
for fname in {'_epoch=0_val_loss=2.80.ckpt',
|
||||
'_epoch=0_val_loss=2.50.ckpt',
|
||||
'_epoch=0_val_loss=5.00.ckpt'}:
|
||||
assert fname in file_lists
|
||||
|
||||
|
||||
def test_model_freeze_unfreeze():
|
||||
|
@ -388,7 +395,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir):
|
|||
|
||||
hparams = tutils.get_hparams()
|
||||
|
||||
def new_model():
|
||||
def _new_model():
|
||||
# Create a model that tracks epochs and batches seen
|
||||
model = LightningTestModel(hparams)
|
||||
model.num_epochs_seen = 0
|
||||
|
@ -406,7 +413,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir):
|
|||
model.on_batch_start = types.MethodType(increment_batch, model)
|
||||
return model
|
||||
|
||||
model = new_model()
|
||||
model = _new_model()
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
|
@ -417,7 +424,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir):
|
|||
logger=False,
|
||||
default_save_path=tmpdir,
|
||||
early_stop_callback=False,
|
||||
val_check_interval=0.5,
|
||||
val_check_interval=1.,
|
||||
)
|
||||
|
||||
# fit model
|
||||
|
@ -430,15 +437,10 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir):
|
|||
assert model.num_batches_seen == training_batches * 2
|
||||
|
||||
# Other checkpoints can be uncommented if/when resuming mid-epoch is supported
|
||||
checkpoints = [
|
||||
# os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt"),
|
||||
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0_v0.ckpt"),
|
||||
# os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt"),
|
||||
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1_v0.ckpt"),
|
||||
]
|
||||
checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt')))
|
||||
|
||||
for check in checkpoints:
|
||||
next_model = new_model()
|
||||
next_model = _new_model()
|
||||
state = torch.load(check)
|
||||
|
||||
# Resume training
|
||||
|
|
Loading…
Reference in New Issue