""" Callbacks ========= Callbacks supported by Lightning """ import logging import os import shutil import warnings import numpy as np from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel class Callback(object): r"""Abstract base class used to build new callbacks. """ def __init__(self): self.validation_data = None self.model = None def set_params(self, params): self.params = params def set_model(self, model): if type(model) is LightningDistributedDataParallel: model = model.module self.model = model def on_epoch_begin(self, epoch, logs=None): """ called when the epoch begins Args: epoch (int): current epoch logs (dict): key-value pairs of quantities to monitor Example: on_epoch_begin(epoch=2, logs={'val_loss': 0.2}) """ pass def on_epoch_end(self, epoch, logs=None): pass def on_batch_begin(self, batch, logs=None): """ called when the batch starts. Args: batch (Tensor): current batch tensor logs (dict): key-value pairs of quantities to monitor """ pass def on_batch_end(self, batch, logs=None): pass def on_train_begin(self, logs=None): pass def on_train_end(self, logs=None): pass class EarlyStopping(Callback): r""" Stop training when a monitored quantity has stopped improving. Args: monitor (str): quantity to be monitored. min_delta (float): minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. patience (int): number of epochs with no improvement after which training will be stopped. verbose (bool): verbosity mode. mode (str): one of {auto, min, max}. In `min` mode, training will stop when the quantity monitored has stopped decreasing; in `max` mode it will stop when the quantity monitored has stopped increasing; in `auto` mode, the direction is automatically inferred from the name of the monitored quantity. Example:: from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping early_stopping = EarlyStopping('val_loss') Trainer(early_stop_callback=early_stopping) """ def __init__(self, monitor='val_loss', min_delta=0.0, patience=0, verbose=0, mode='auto'): super(EarlyStopping, self).__init__() self.monitor = monitor self.patience = patience self.verbose = verbose self.min_delta = min_delta self.wait = 0 self.stopped_epoch = 0 if mode not in ['auto', 'min', 'max']: logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') mode = 'auto' if mode == 'min': self.monitor_op = np.less elif mode == 'max': self.monitor_op = np.greater else: if 'acc' in self.monitor: self.monitor_op = np.greater else: self.monitor_op = np.less if self.monitor_op == np.greater: self.min_delta *= 1 else: self.min_delta *= -1 self.on_train_begin() def on_train_begin(self, logs=None): # Allow instances to be re-used self.wait = 0 self.stopped_epoch = 0 self.best = np.Inf if self.monitor_op == np.less else -np.Inf def on_epoch_end(self, epoch, logs=None): current = logs.get(self.monitor) stop_training = False if current is None: warnings.warn( f'Early stopping conditioned on metric `{self.monitor}`' f' which is not available. Available metrics are: {",".join(list(logs.keys()))}', RuntimeWarning) stop_training = True return stop_training if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch stop_training = True self.on_train_end() return stop_training def on_train_end(self, logs=None): if self.stopped_epoch > 0 and self.verbose > 0: logging.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping') class ModelCheckpoint(Callback): r"""Save the model after every epoch. Args: filepath (str): path to save the model file. Can contain named formatting options to be auto-filled. Example:: # save epoch and val_loss in name ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5') # saves file like: /path/epoch_2-val_loss_0.2.hdf5 monitor (str): quantity to monitor. verbose (bool): verbosity mode, 0 or 1. save_top_k (int): if `save_top_k == k`, the best k models according to the quantity monitored will be saved. if `save_top_k == 0`, no models are saved. if `save_top_k == -1`, all models are saved. Please note that the monitors are checked every `period` epochs. if `save_top_k >= 2` and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with `v0`. mode (str): one of {auto, min, max}. If `save_top_k != 0`, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For `val_acc`, this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. save_weights_only (bool): if True, then only the model's weights will be saved (`model.save_weights(filepath)`), else the full model is saved (`model.save(filepath)`). period (int): Interval (number of epochs) between checkpoints. Example:: from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint checkpoint_callback = ModelCheckpoint(filepath='my_path') Trainer(checkpoint_callback=checkpoint_callback) # saves checkpoints to my_path whenever 'val_loss' has a new min """ def __init__(self, filepath, monitor='val_loss', verbose=0, save_top_k=1, save_weights_only=False, mode='auto', period=1, prefix=''): super(ModelCheckpoint, self).__init__() if ( save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0 ): warnings.warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" ) self.monitor = monitor self.verbose = verbose self.filepath = filepath os.makedirs(filepath, exist_ok=True) self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period self.epochs_since_last_check = 0 self.prefix = prefix self.best_k_models = {} # {filename: monitor} self.kth_best_model = '' self.best = 0 if mode not in ['auto', 'min', 'max']: warnings.warn( f'ModelCheckpoint mode {mode} is unknown, ' 'fallback to auto mode.', RuntimeWarning) mode = 'auto' if mode == 'min': self.monitor_op = np.less self.kth_value = np.Inf self.mode = 'min' elif mode == 'max': self.monitor_op = np.greater self.kth_value = -np.Inf self.mode = 'max' else: if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): self.monitor_op = np.greater self.kth_value = -np.Inf self.mode = 'max' else: self.monitor_op = np.less self.kth_value = np.Inf self.mode = 'min' def _del_model(self, filepath): dirpath = os.path.dirname(filepath) # make paths os.makedirs(dirpath, exist_ok=True) try: shutil.rmtree(filepath) except OSError: os.remove(filepath) def _save_model(self, filepath): dirpath = os.path.dirname(filepath) # make paths os.makedirs(dirpath, exist_ok=True) # delegate the saving to the model self.save_function(filepath) def check_monitor_top_k(self, current): less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k if less_than_k_models: return True return self.monitor_op(current, self.best_k_models[self.kth_best_model]) def on_epoch_end(self, epoch, logs=None): logs = logs or {} self.epochs_since_last_check += 1 if self.save_top_k == 0: # no models are saved return if self.epochs_since_last_check >= self.period: self.epochs_since_last_check = 0 filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt' version_cnt = 0 while os.path.isfile(filepath): # this epoch called before filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt' version_cnt += 1 if self.save_top_k != -1: current = logs.get(self.monitor) if current is None: warnings.warn( f'Can save best model only with {self.monitor} available,' ' skipping.', RuntimeWarning) else: if self.check_monitor_top_k(current): # remove kth if len(self.best_k_models.keys()) == self.save_top_k: delpath = self.kth_best_model self.best_k_models.pop(self.kth_best_model) self._del_model(delpath) self.best_k_models[filepath] = current if len(self.best_k_models.keys()) == self.save_top_k: # monitor dict has reached k elements if self.mode == 'min': self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get) else: self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get) self.kth_value = self.best_k_models[self.kth_best_model] if self.mode == 'min': self.best = min(self.best_k_models.values()) else: self.best = max(self.best_k_models.values()) if self.verbose > 0: logging.info( f'\nEpoch {epoch:05d}: {self.monitor} reached' f' {current:0.5f} (best {self.best:0.5f}), saving model to' f' {filepath} as top {self.save_top_k}') self._save_model(filepath) else: if self.verbose > 0: logging.info( f'\nEpoch {epoch:05d}: {self.monitor}' f' was not in top {self.save_top_k}') else: if self.verbose > 0: logging.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') self._save_model(filepath) class GradientAccumulationScheduler(Callback): r""" Change gradient accumulation factor according to scheduling. Args: scheduling (dict): scheduling in format {epoch: accumulation_factor} Example:: from pytorch_lightning import Trainer from pytorch_lightning.callbacks import GradientAccumulationScheduler # at epoch 5 start accumulating every 2 batches accumulator = GradientAccumulationScheduler(scheduling: {5: 2}) Trainer(accumulate_grad_batches=accumulator) """ def __init__(self, scheduling: dict): if scheduling == {}: # empty dict error raise TypeError("Empty dict cannot be interpreted correct") for key in scheduling.keys(): if not isinstance(key, int) or not isinstance(scheduling[key], int): raise TypeError("All epoches and accumulation factor must be integers") minimal_epoch = min(scheduling.keys()) if minimal_epoch < 1: msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" raise IndexError(msg) elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor scheduling.update({1: 1}) self.scheduling = scheduling self.epochs = sorted(scheduling.keys()) def on_epoch_begin(self, epoch, trainer): epoch += 1 # indexing epochs from 1 for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) break # if __name__ == '__main__': # c = EarlyStopping(min_delta=0.9, patience=2, verbose=True) # losses = [10, 9, 8, 8, 6, 4.3, 5, 4.4, 2.8, 2.5] # for i, loss in enumerate(losses): # should_stop = c.on_epoch_end(i, logs={'val_loss': loss}) # logging.info(loss) # if should_stop: # break