update checkpoint docs (#1016)

* update checkpoint docs

* fix tests

* fix tests

* formatting

* typing

* filename

* fix tests

* fixing tests

* fixing tests

* fixing tests

* unique name

* fixing

* fixing

* Update model_checkpoint.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Jirka Borovec 2020-03-03 21:16:57 +01:00 committed by GitHub
parent d1c0f1270d
commit 64de57b09e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 94 additions and 77 deletions

View File

@ -6,7 +6,7 @@ Save the model as often as requested.
"""
import os
import shutil
import glob
import logging as log
import warnings
@ -20,17 +20,19 @@ class ModelCheckpoint(Callback):
Save the model after every epoch.
Args:
filepath: path to save the model file.
dirpath: path to save the model file.
Can contain named formatting options to be auto-filled.
Example::
# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
monitor (str): quantity to monitor.
verbose (bool): verbosity mode, False or True.
save_top_k (int): if `save_top_k == k`,
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
# if such model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt
monitor: quantity to monitor.
verbose: verbosity mode, False or True.
save_top_k: if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
@ -39,7 +41,7 @@ class ModelCheckpoint(Callback):
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode (str): one of {auto, min, max}.
mode: one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
@ -47,35 +49,46 @@ class ModelCheckpoint(Callback):
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only (bool): if True, then only the model's weights will be
save_weights_only: if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period (int): Interval (number of epochs) between checkpoints.
period: Interval (number of epochs) between checkpoints.
prefix: String name for particular model
Example::
Example:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to my_path whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath='my_path')
checkpoint_callback = ModelCheckpoint('my_path')
Trainer(checkpoint_callback=checkpoint_callback)
"""
#: checkpoint extension
EXTENSION = '.ckpt'
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
def __init__(
self,
dirpath: str,
monitor: str = 'val_loss',
verbose: bool = False,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = 'auto',
period: int = 1,
prefix: str = ''
):
super().__init__()
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
if save_top_k and os.path.isdir(dirpath) and len(os.listdir(dirpath)) > 0:
warnings.warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
f"Checkpoint directory {dirpath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
os.makedirs(filepath, exist_ok=True)
self.dirpath = dirpath
os.makedirs(dirpath, exist_ok=True)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
@ -87,6 +100,14 @@ class ModelCheckpoint(Callback):
self.best = 0
self.save_function = None
# this create unique prefix if the give already exists
existing_checkpoints = sorted(glob.glob(os.path.join(self.dirpath, '*' + self.EXTENSION)))
existing_names = set(os.path.basename(ckpt).split('_epoch=')[0] for ckpt in existing_checkpoints)
version_cnt = 0
while self.prefix in existing_names:
self.prefix = f'{prefix}-v{version_cnt}'
version_cnt += 1
mode_dict = {
'min': (np.less, np.Inf, 'min'),
'max': (np.greater, -np.Inf, 'max'),
@ -102,15 +123,13 @@ class ModelCheckpoint(Callback):
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
def _del_model(self, filepath):
try:
shutil.rmtree(filepath)
except OSError:
os.remove(filepath)
def _del_model(self, filepath: str) -> None:
# shutil.rmtree(filepath)
os.remove(filepath)
def _save_model(self, filepath):
def _save_model(self, filepath: str) -> None:
# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
os.makedirs(self.dirpath, exist_ok=True)
# delegate the saving to the model
if self.save_function is not None:
@ -118,13 +137,20 @@ class ModelCheckpoint(Callback):
else:
raise ValueError(".save_function() not set")
def check_monitor_top_k(self, current):
def check_monitor_top_k(self, current: float) -> bool:
less_than_k_models = len(self.best_k_models) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
def on_validation_end(self, trainer, pl_module):
def _get_available_filepath(self, current: float, epoch: int) -> str:
current_str = f'{current:.2f}' if current else 'NaN'
fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}'
filepath = os.path.join(self.dirpath, fname + self.EXTENSION)
assert not os.path.isfile(filepath)
return filepath
def on_validation_end(self, trainer, pl_module) -> None:
# only run on main process
if trainer.proc_rank != 0:
return
@ -138,35 +164,27 @@ class ModelCheckpoint(Callback):
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
version_cnt = 0
while os.path.isfile(filepath):
# this epoch called before
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
version_cnt += 1
current = logs.get(self.monitor)
filepath = self._get_available_filepath(current, epoch)
if self.save_top_k != -1:
current = logs.get(self.monitor)
if current is None:
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
warnings.warn(f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
else:
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor}'
f' was not in top {self.save_top_k}')
log.info('Epoch %05d: %s was not in top %i', epoch, self.monitor, self.save_top_k)
else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
log.info('Epoch %05d: saving model to %s', epoch, filepath)
self._save_model(filepath)
def _do_check_save(self, filepath, current, epoch):
def _do_check_save(self, filepath: str, current: float, epoch: int) -> None:
# remove kth
if len(self.best_k_models) == self.save_top_k:
delpath = self.kth_best_model
@ -185,8 +203,6 @@ class ModelCheckpoint(Callback):
self.best = _op(self.best_k_models.values())
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
log.info('Epoch {epoch:05d}: %s reached %0.5f (best %0.5f), saving model to %s as top %i',
epoch, self.monitor, current, self.best, filepath, self.save_top_k)
self._save_model(filepath)

View File

@ -50,7 +50,7 @@ class TrainerCallbackConfigMixin(ABC):
self.ckpt_path = ckpt_path
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path
dirpath=ckpt_path
)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None
@ -62,7 +62,7 @@ class TrainerCallbackConfigMixin(ABC):
self.checkpoint_callback.save_function = self.save_checkpoint
# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.filepath
self.weights_save_path = self.checkpoint_callback.dirpath
# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:

View File

@ -810,7 +810,7 @@ class Trainer(TrainerIOMixin,
self.amp_level = amp_level
self.precision = precision
assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported'
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
if self.precision == 16 and num_tpu_cores is None:
use_amp = True

View File

@ -32,7 +32,7 @@ def run_model_test_no_loggers(trainer_options, model, min_acc=0.50):
# test model loading
pretrained_model = load_model(trainer.logger,
trainer.checkpoint_callback.filepath,
trainer.checkpoint_callback.dirpath,
path_expt=trainer_options.get('default_save_path'))
# test new model accuracy
@ -70,7 +70,7 @@ def run_model_test(trainer_options, model, on_gpu=True):
assert result == 1, 'amp + ddp model failed to complete'
# test model loading
pretrained_model = load_model(logger, trainer.checkpoint_callback.filepath)
pretrained_model = load_model(logger, trainer.checkpoint_callback.dirpath)
# test new model accuracy
test_loaders = model.test_dataloader()

View File

@ -1,3 +1,4 @@
import glob
import logging as log
import os
@ -52,7 +53,7 @@ def test_running_test_pretrained_model_ddp(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(logger,
trainer.checkpoint_callback.filepath,
trainer.checkpoint_callback.dirpath,
module_class=LightningTestModel)
# run test set
@ -96,7 +97,7 @@ def test_running_test_pretrained_model(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(
logger, trainer.checkpoint_callback.filepath, module_class=LightningTestModel
logger, trainer.checkpoint_callback.dirpath, module_class=LightningTestModel
)
new_trainer = Trainer(**trainer_options)
@ -132,9 +133,7 @@ def test_load_model_from_checkpoint(tmpdir):
assert result == 1, 'training failed to complete'
# load last checkpoint
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt")
if not os.path.isfile(last_checkpoint):
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
pretrained_model = LightningTestModel.load_from_checkpoint(last_checkpoint)
# test that hparams loaded correctly
@ -186,7 +185,7 @@ def test_running_test_pretrained_model_dp(tmpdir):
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = tutils.load_model(logger,
trainer.checkpoint_callback.filepath,
trainer.checkpoint_callback.dirpath,
module_class=LightningTestModel)
new_trainer = Trainer(**trainer_options)
@ -346,7 +345,7 @@ def test_load_model_with_missing_hparams(tmpdir):
model = LightningTestModelWithoutHyperparametersArg()
trainer.fit(model)
last_checkpoint = os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
# try to load a checkpoint that has hparams but model is missing hparams arg
with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"):

View File

@ -1,3 +1,4 @@
import glob
import math
import os
import pytest
@ -257,8 +258,12 @@ def test_model_checkpoint_options(tmp_path):
assert len(file_lists) == len(losses), "Should save all models when save_top_k=-1"
# verify correct naming
for i in range(0, len(losses)):
assert f"_ckpt_epoch_{i}.ckpt" in file_lists
for fname in {'_epoch=4_val_loss=2.50.ckpt',
'_epoch=3_val_loss=5.00.ckpt',
'_epoch=2_val_loss=2.80.ckpt',
'_epoch=1_val_loss=9.00.ckpt',
'_epoch=0_val_loss=10.00.ckpt'}:
assert fname in file_lists
save_dir = tmp_path / "2"
save_dir.mkdir()
@ -297,7 +302,7 @@ def test_model_checkpoint_options(tmp_path):
file_lists = set(os.listdir(save_dir))
assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"
assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists
assert 'test_prefix_epoch=4_val_loss=2.50.ckpt' in file_lists
save_dir = tmp_path / "4"
save_dir.mkdir()
@ -320,9 +325,10 @@ def test_model_checkpoint_options(tmp_path):
file_lists = set(os.listdir(save_dir))
assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2'
assert '_ckpt_epoch_4.ckpt' in file_lists
assert '_ckpt_epoch_2.ckpt' in file_lists
assert 'other_file.ckpt' in file_lists
for fname in {'_epoch=4_val_loss=2.50.ckpt',
'_epoch=2_val_loss=2.80.ckpt',
'other_file.ckpt'}:
assert fname in file_lists
save_dir = tmp_path / "5"
save_dir.mkdir()
@ -365,9 +371,10 @@ def test_model_checkpoint_options(tmp_path):
file_lists = set(os.listdir(save_dir))
assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3'
assert '_ckpt_epoch_0_v2.ckpt' in file_lists
assert '_ckpt_epoch_0_v1.ckpt' in file_lists
assert '_ckpt_epoch_0.ckpt' in file_lists
for fname in {'_epoch=0_val_loss=2.80.ckpt',
'_epoch=0_val_loss=2.50.ckpt',
'_epoch=0_val_loss=5.00.ckpt'}:
assert fname in file_lists
def test_model_freeze_unfreeze():
@ -388,7 +395,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir):
hparams = tutils.get_hparams()
def new_model():
def _new_model():
# Create a model that tracks epochs and batches seen
model = LightningTestModel(hparams)
model.num_epochs_seen = 0
@ -406,7 +413,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir):
model.on_batch_start = types.MethodType(increment_batch, model)
return model
model = new_model()
model = _new_model()
trainer_options = dict(
show_progress_bar=False,
@ -417,7 +424,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir):
logger=False,
default_save_path=tmpdir,
early_stop_callback=False,
val_check_interval=0.5,
val_check_interval=1.,
)
# fit model
@ -430,15 +437,10 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir):
assert model.num_batches_seen == training_batches * 2
# Other checkpoints can be uncommented if/when resuming mid-epoch is supported
checkpoints = [
# os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt"),
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0_v0.ckpt"),
# os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt"),
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1_v0.ckpt"),
]
checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt')))
for check in checkpoints:
next_model = new_model()
next_model = _new_model()
state = torch.load(check)
# Resume training