182 lines
7.0 KiB
Python
182 lines
7.0 KiB
Python
|
import os
|
||
|
import shutil
|
||
|
import logging as log
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from .base import Callback
|
||
|
|
||
|
|
||
|
class ModelCheckpoint(Callback):
|
||
|
r"""
|
||
|
Save the model after every epoch.
|
||
|
|
||
|
Args:
|
||
|
filepath: path to save the model file.
|
||
|
Can contain named formatting options to be auto-filled.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
# save epoch and val_loss in name
|
||
|
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
|
||
|
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
|
||
|
monitor (str): quantity to monitor.
|
||
|
verbose (bool): verbosity mode, False or True.
|
||
|
save_top_k (int): if `save_top_k == k`,
|
||
|
the best k models according to
|
||
|
the quantity monitored will be saved.
|
||
|
if ``save_top_k == 0``, no models are saved.
|
||
|
if ``save_top_k == -1``, all models are saved.
|
||
|
Please note that the monitors are checked every `period` epochs.
|
||
|
if ``save_top_k >= 2`` and the callback is called multiple
|
||
|
times inside an epoch, the name of the saved file will be
|
||
|
appended with a version count starting with `v0`.
|
||
|
mode (str): one of {auto, min, max}.
|
||
|
If ``save_top_k != 0``, the decision
|
||
|
to overwrite the current save file is made
|
||
|
based on either the maximization or the
|
||
|
minimization of the monitored quantity. For `val_acc`,
|
||
|
this should be `max`, for `val_loss` this should
|
||
|
be `min`, etc. In `auto` mode, the direction is
|
||
|
automatically inferred from the name of the monitored quantity.
|
||
|
save_weights_only (bool): if True, then only the model's weights will be
|
||
|
saved (`model.save_weights(filepath)`), else the full model
|
||
|
is saved (`model.save(filepath)`).
|
||
|
period (int): Interval (number of epochs) between checkpoints.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
from pytorch_lightning import Trainer
|
||
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||
|
|
||
|
# saves checkpoints to my_path whenever 'val_loss' has a new min
|
||
|
checkpoint_callback = ModelCheckpoint(filepath='my_path')
|
||
|
Trainer(checkpoint_callback=checkpoint_callback)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
|
||
|
save_top_k: int = 1, save_weights_only: bool = False,
|
||
|
mode: str = 'auto', period: int = 1, prefix: str = ''):
|
||
|
super().__init__()
|
||
|
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
|
||
|
warnings.warn(
|
||
|
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
|
||
|
"All files in this directory will be deleted when a checkpoint is saved!"
|
||
|
)
|
||
|
|
||
|
self.monitor = monitor
|
||
|
self.verbose = verbose
|
||
|
self.filepath = filepath
|
||
|
os.makedirs(filepath, exist_ok=True)
|
||
|
self.save_top_k = save_top_k
|
||
|
self.save_weights_only = save_weights_only
|
||
|
self.period = period
|
||
|
self.epochs_since_last_check = 0
|
||
|
self.prefix = prefix
|
||
|
self.best_k_models = {}
|
||
|
# {filename: monitor}
|
||
|
self.kth_best_model = ''
|
||
|
self.best = 0
|
||
|
self.save_function = None
|
||
|
|
||
|
mode_dict = {
|
||
|
'min': (np.less, np.Inf, 'min'),
|
||
|
'max': (np.greater, -np.Inf, 'max'),
|
||
|
'auto': (np.greater, -np.Inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
|
||
|
else (np.less, np.Inf, 'min'),
|
||
|
}
|
||
|
|
||
|
if mode not in mode_dict:
|
||
|
warnings.warn(
|
||
|
f'ModelCheckpoint mode {mode} is unknown, '
|
||
|
'fallback to auto mode.', RuntimeWarning)
|
||
|
mode = 'auto'
|
||
|
|
||
|
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
|
||
|
|
||
|
def _del_model(self, filepath):
|
||
|
try:
|
||
|
shutil.rmtree(filepath)
|
||
|
except OSError:
|
||
|
os.remove(filepath)
|
||
|
|
||
|
def _save_model(self, filepath):
|
||
|
# make paths
|
||
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||
|
|
||
|
# delegate the saving to the model
|
||
|
if self.save_function is not None:
|
||
|
self.save_function(filepath)
|
||
|
else:
|
||
|
raise ValueError(".save_function() not set")
|
||
|
|
||
|
def check_monitor_top_k(self, current):
|
||
|
less_than_k_models = len(self.best_k_models) < self.save_top_k
|
||
|
if less_than_k_models:
|
||
|
return True
|
||
|
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
|
||
|
|
||
|
def on_validation_end(self):
|
||
|
logs = self.trainer.callback_metrics
|
||
|
epoch = self.trainer.current_epoch
|
||
|
self.epochs_since_last_check += 1
|
||
|
|
||
|
if self.save_top_k == 0:
|
||
|
# no models are saved
|
||
|
return
|
||
|
if self.epochs_since_last_check >= self.period:
|
||
|
self.epochs_since_last_check = 0
|
||
|
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
|
||
|
version_cnt = 0
|
||
|
while os.path.isfile(filepath):
|
||
|
# this epoch called before
|
||
|
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
|
||
|
version_cnt += 1
|
||
|
|
||
|
if self.save_top_k != -1:
|
||
|
current = logs.get(self.monitor)
|
||
|
|
||
|
if current is None:
|
||
|
warnings.warn(
|
||
|
f'Can save best model only with {self.monitor} available,'
|
||
|
' skipping.', RuntimeWarning)
|
||
|
else:
|
||
|
if self.check_monitor_top_k(current):
|
||
|
self._do_check_save(filepath, current, epoch)
|
||
|
else:
|
||
|
if self.verbose > 0:
|
||
|
log.info(
|
||
|
f'\nEpoch {epoch:05d}: {self.monitor}'
|
||
|
f' was not in top {self.save_top_k}')
|
||
|
|
||
|
else:
|
||
|
if self.verbose > 0:
|
||
|
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
|
||
|
self._save_model(filepath)
|
||
|
|
||
|
def _do_check_save(self, filepath, current, epoch):
|
||
|
# remove kth
|
||
|
if len(self.best_k_models) == self.save_top_k:
|
||
|
delpath = self.kth_best_model
|
||
|
self.best_k_models.pop(self.kth_best_model)
|
||
|
self._del_model(delpath)
|
||
|
|
||
|
self.best_k_models[filepath] = current
|
||
|
if len(self.best_k_models) == self.save_top_k:
|
||
|
# monitor dict has reached k elements
|
||
|
_op = max if self.mode == 'min' else min
|
||
|
self.kth_best_model = _op(self.best_k_models,
|
||
|
key=self.best_k_models.get)
|
||
|
self.kth_value = self.best_k_models[self.kth_best_model]
|
||
|
|
||
|
_op = min if self.mode == 'min' else max
|
||
|
self.best = _op(self.best_k_models.values())
|
||
|
|
||
|
if self.verbose > 0:
|
||
|
log.info(
|
||
|
f'\nEpoch {epoch:05d}: {self.monitor} reached'
|
||
|
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
|
||
|
f' {filepath} as top {self.save_top_k}')
|
||
|
self._save_model(filepath)
|