From 89d5772f5549d383cbc4cf4ee602fd5e30db3def Mon Sep 17 00:00:00 2001 From: Hadrien Mary Date: Sat, 22 Feb 2020 21:45:34 -0500 Subject: [PATCH] Split callbacks (#849) * add .vscode in .gitignore * Split callbacks in individual files + add a property to Callback for easy trainer instance access * formatting * Add a conda env file for quick and easy env setup to develop on PL * Adress comments * add fix to kth_best_model * add some typing to callbacks * fix typo * add autopep8 config to pyproject.toml * format again * format * fix toml * fix toml again * consistent max line length in all config files * remove conda env file * Update pytorch_lightning/callbacks/early_stopping.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/callbacks/model_checkpoint.py Co-Authored-By: Jirka Borovec * docstring * Update pytorch_lightning/callbacks/model_checkpoint.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/callbacks/model_checkpoint.py Co-Authored-By: Jirka Borovec * fix logic error * format * simplify if/else * format * fix linting issue in changelog * edit changelog about new callback mechanism * fix remaining formating issue on CHANGELOG * remove lambda function because it's compatible with pickle (used during ddp) Co-authored-by: Jirka Borovec --- .gitignore | 2 +- .markdownlint.yml | 2 + .pep8speaks.yml | 2 +- CHANGELOG.md | 186 +++++++- pyproject.toml | 4 + pytorch_lightning/callbacks/__init__.py | 7 +- pytorch_lightning/callbacks/base.py | 68 +++ pytorch_lightning/callbacks/early_stopping.py | 114 +++++ .../gradient_accumulation_scheduler.py | 55 +++ .../callbacks/model_checkpoint.py | 181 ++++++++ pytorch_lightning/callbacks/pt_callbacks.py | 431 ------------------ 11 files changed, 616 insertions(+), 436 deletions(-) create mode 100644 .markdownlint.yml create mode 100644 pytorch_lightning/callbacks/base.py create mode 100644 pytorch_lightning/callbacks/early_stopping.py create mode 100644 pytorch_lightning/callbacks/gradient_accumulation_scheduler.py create mode 100644 pytorch_lightning/callbacks/model_checkpoint.py delete mode 100644 pytorch_lightning/callbacks/pt_callbacks.py diff --git a/.gitignore b/.gitignore index e0a972039f..8a81bf71ea 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ model_weights/ app/models/ pip-wheel-metadata/ lightning_logs/ - +.vscode/ # Test-tube test_tube_logs/ diff --git a/.markdownlint.yml b/.markdownlint.yml new file mode 100644 index 0000000000..bc310daa64 --- /dev/null +++ b/.markdownlint.yml @@ -0,0 +1,2 @@ +MD013: false # headers with the same names +MD024: false # line length diff --git a/.pep8speaks.yml b/.pep8speaks.yml index e33d46b159..154d5258ba 100644 --- a/.pep8speaks.yml +++ b/.pep8speaks.yml @@ -5,7 +5,7 @@ scanner: linter: pycodestyle # Other option is flake8 pycodestyle: # Same as scanner.linter value. Other option is flake8 - max-line-length: 100 # Default is 79 in PEP 8 + max-line-length: 120 # Default is 79 in PEP 8 ignore: # Errors and warnings to ignore - W504 # line break after binary operator - E402 # module level import not at top of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 4568c88ce7..b96da2d8d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,13 @@ # Changelog + All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] + ### Added + - Updated governance docs - Added a check to ensure that the metric used for early stopping exists before training commences ([#542](https://github.com/PyTorchLightning/pytorch-lightning/pull/542)) - Added `optimizer_idx` argument to `backward` hook ([#733](https://github.com/PyTorchLightning/pytorch-lightning/pull/733)) @@ -15,24 +18,35 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `train_dataloader`, `val_dataloader` and `test_dataloader` arguments to `Trainer.fit()`, for alternative data parsing ([#759](https://github.com/PyTorchLightning/pytorch-lightning/pull/759)) - Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868)) - Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876)) +- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849)) + ### Changed + - Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752)) - Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767)) - Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749)) + ### Deprecated + - None + ### Removed + - Removed dependency on pandas ([#736](https://github.com/PyTorchLightning/pytorch-lightning/pull/736)) - Removed dependency on torchvision ([#797](https://github.com/PyTorchLightning/pytorch-lightning/pull/797)) - Removed dependency on scikit-learn ([#801](https://github.com/PyTorchLightning/pytorch-lightning/pull/801)) + ### Fixed + - Fixed a bug where early stopping `on_end_epoch` would be called inconsistently when `check_val_every_n_epoch == 0` ([#743](https://github.com/PyTorchLightning/pytorch-lightning/pull/743)) - Fixed a bug where the model checkpointer didn't write to the same directory as the logger ([#771](https://github.com/PyTorchLightning/pytorch-lightning/pull/771)) - Fixed a bug where the `TensorBoardLogger` class would create an additional empty log file during fitting ([#777](https://github.com/PyTorchLightning/pytorch-lightning/pull/777)) - Fixed a bug where `global_step` was advanced incorrectly when using `accumulate_grad_batches > 1` ([#832](https://github.com/PyTorchLightning/pytorch-lightning/pull/832)) ## [0.6.0] - 2020-01-21 + ### Added + - Added support for resuming from a specific checkpoint via `resume_from_checkpoint` argument ([#516](https://github.com/PyTorchLightning/pytorch-lightning/pull/516)) - Added support for `ReduceLROnPlateau` scheduler ([#320](https://github.com/PyTorchLightning/pytorch-lightning/pull/320)) - Added support for Apex mode `O2` in conjunction with Data Parallel ([#493](https://github.com/PyTorchLightning/pytorch-lightning/pull/493)) @@ -44,19 +58,27 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added option to disable validation by setting `val_percent_check=0` ([#649](https://github.com/PyTorchLightning/pytorch-lightning/pull/649)) - Added `NeptuneLogger` class ([#648](https://github.com/PyTorchLightning/pytorch-lightning/pull/648)) - Added `WandbLogger` class ([#627](https://github.com/PyTorchLightning/pytorch-lightning/pull/627)) + ### Changed + - Changed the default progress bar to print to stdout instead of stderr ([#531](https://github.com/PyTorchLightning/pytorch-lightning/pull/531)) - Renamed `step_idx` to `step`, `epoch_idx` to `epoch`, `max_num_epochs` to `max_epochs` and `min_num_epochs` to `min_epochs` ([#589](https://github.com/PyTorchLightning/pytorch-lightning/pull/589)) - Renamed `total_batch_nb` to `total_batches`, `nb_val_batches` to `num_val_batches`, `nb_training_batches` to `num_training_batches`, `max_nb_epochs` to `max_epochs`, `min_nb_epochs` to `min_epochs`, `nb_test_batches` to `num_test_batches`, and `nb_val_batches` to `num_val_batches` ([#567](https://github.com/PyTorchLightning/pytorch-lightning/pull/567)) - Changed gradient logging to use parameter names instead of indexes ([#660](https://github.com/PyTorchLightning/pytorch-lightning/pull/660)) - Changed the default logger to `TensorBoardLogger` ([#609](https://github.com/PyTorchLightning/pytorch-lightning/pull/609)) - Changed the directory for tensorboard logging to be the same as model checkpointing ([#706](https://github.com/PyTorchLightning/pytorch-lightning/pull/706)) + ### Deprecated + - Deprecated `max_nb_epochs` and `min_nb_epochs` ([#567](https://github.com/PyTorchLightning/pytorch-lightning/pull/567)) - Deprecated the `on_sanity_check_start` hook in `ModelHooks` ([#598](https://github.com/PyTorchLightning/pytorch-lightning/pull/598)) + ### Removed + - Removed the `save_best_only` argument from `ModelCheckpoint`, use `save_top_k=1` instead ([#128](https://github.com/PyTorchLightning/pytorch-lightning/pull/128)) + ### Fixed + - Fixed a bug which ocurred when using Adagrad with cuda ([#554](https://github.com/PyTorchLightning/pytorch-lightning/pull/554)) - Fixed a bug where training would be on the GPU despite setting `gpus=0` or `gpus=[]` ([#561](https://github.com/PyTorchLightning/pytorch-lightning/pull/561)) - Fixed an error with `print_nan_gradients` when some parameters do not require gradient ([#579](https://github.com/PyTorchLightning/pytorch-lightning/pull/579)) @@ -79,18 +101,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `on_train_end` was not called when ealy stopping ([#723](https://github.com/PyTorchLightning/pytorch-lightning/pull/723)) ## [0.5.3] - 2019-11-06 + ### Added + - Added option to disable default logger, checkpointer, and early stopping by passing `logger=False`, `checkpoint_callback=False` and `early_stop_callback=False` respectively - Added `CometLogger` for use with Comet.ml - Added `val_check_interval` argument to `Trainer` allowing validition to be performed at every given number of batches - Added functionality to save and load hyperparameters using the standard checkpoint mechanism - Added call to `torch.cuda.empty_cache` before training starts - Added option for user to override the call t `backward` -- Added support for truncated backprop through time via the `truncated_bptt_steps` argument in `Trainer` +- Added support for truncated backprop through time via the `truncated_bptt_steps` argument in `Trainer` - Added option to operate on all outputs from `training_step` in DDP2 - Added a hook for modifying DDP init - Added a hook for modifying Apex + ### Changed + - Changed experiment version to be padded with zeros (e.g. `/dir/version_9` becomes `/dir/version_0009`) - Changed callback metrics to include any metrics given in logs or progress bar - Changed the default for `save_best_only` in `ModelCheckpoint` to `True` @@ -100,11 +126,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed weights restore to first attempt HPC weights before restoring normally, preventing both weights being restored and running out of memory - Changed progress bar functionality to add multiple progress bars for train/val/test - Changed calls to `print` to use `logging` instead + ### Deprecated + - Deprecated `tng_dataloader` + ### Removed + - None + ### Fixed + - Fixed an issue where the number of batches was off by one during training - Fixed a bug that occured when setting a ckeckpoint callback and `early_stop_callback=False` - Fixed an error when importing CometLogger @@ -115,23 +147,35 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where checkpointing would sometimes erase the current directory ## [0.5.2] - 2019-10-10 + ### Added + - Added `weights_summary` argument to `Trainer` to be set to `full` (full summary), `top` (just top level modules) or other - Added `tags` argument to `MLFlowLogger` + ### Changed + - Changed default for `amp_level` to `O1` + ### Deprecated + - None + ### Removed + - Removed the `print_weights_summary` argument from `Trainer` + ### Fixed + - Fixed a bug where logs were not written properly - Fixed a bug where `logger.finalize` wasn't called after training is complete - Fixed callback metric errors in DDP - Fixed a bug where `TestTubeLogger` didn't log to the correct directory ## [0.5.1] - 2019-10-05 + ### Added + - Added the `LightningLoggerBase` class for experiment loggers - Added `MLFlowLogger` for logging with `mlflow` - Added `TestTubeLogger` for logging with `test_tube` @@ -139,84 +183,132 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for optimisers which require a closure (e.g. LBFGS) - Added automatic `MASTER_PORT` defualt for DDP when not set manually - Added new GPU memory logging options `'min_max'` (log only the min/max utilization) and `'all'` (log all the GPU memory) + ### Changed + - Changed schedulers to always be called with the current epoch - Changed `test_tube` to an optional dependency - Changed data loaders to internally use a getter instead of a python property - Disabled auto GPU loading when restoring weights to prevent out of memory errors - Changed logging, early stopping and checkpointing to occur by default + ### Deprecated + - None + ### Removed + - None + ### Fixed + - Fixed a bug with samplers that do not specify `set_epoch` - Fixed a bug when using the `MLFlowLogger` with unsupported data types, this will now raise a warning - Fixed a bug where gradient norms were alwasy zero using `track_grad_norm` - Fixed a bug which causes a crash when logging memory ## [0.5.0] - 2019-09-26 + ### Added + - None + ### Changed + - Changed `data_batch` argument to `batch` throughout - Changed `batch_i` argument to `batch_idx` throughout - Changed `tng_dataloader` method to `train_dataloader` - Changed `on_tng_metrics` method to `on_training_metrics` - Changed `gradient_clip` argument to `gradient_clip_val` - Changed `add_log_row_interval` to `row_log_interval` + ### Deprecated + - None + ### Removed + - None + ### Fixed + - Fixed a bug with tensorboard logging in multi-gpu setup ## [0.4.9] - 2019-09-16 + ### Added + - Added the flag `log_gpu_memory` to `Trainer` to deactivate logging of GPU memory utilization - Added SLURM resubmit functionality (port from test-tube) - Added optional weight_save_path to trainer to remove the need for a checkpoint_callback when using cluster training - Added option to use single gpu per node with `DistributedDataParallel` + ### Changed + - Changed functionality of `validation_end` and `test_end` with multiple dataloaders to be given all of the dataloaders at once rather than in seperate calls - Changed print_nan_grads to only print the parameter value and gradients when they contain NaN - Changed gpu API to take integers as well (e.g. `gpus=2` instead of `gpus=[0, 1]`) - All models now loaded on to CPU to avoid device and out of memory issues in PyTorch + ### Deprecated + - None + ### Removed + - None + ### Fixed + - Fixed a bug where data types that implement `.to` but not `.cuda` would not be properly moved onto the GPU - Fixed a bug where data would not be re-shuffled every epoch when using a `DistributedSampler` ## [0.4.8] - 2019-08-31 + ### Added + - Added `test_step` and `test_end` methods, used when `Trainer.test` is called - Added `GradientAccumulationScheduler` callback which can be used to schedule changes to the number of accumulation batches - Added option to skip the validation sanity check by setting `nb_sanity_val_steps = 0` + ### Changed + - None + ### Deprecated + - None + ### Removed + - None + ### Fixed + - Fixed a bug when setting `nb_sanity_val_steps = 0` ## [0.4.7] - 2019-08-24 + ### Added + - None + ### Changed + - Changed the default `val_check_interval` to `1.0` - Changed defaults for `nb_val_batches`, `nb_tng_batches` and `nb_test_batches` to 0 + ### Deprecated + - None + ### Removed + - None + ### Fixed + - Fixed a bug where the full validation set as used despite setting `val_percent_check` - Fixed a bug where an `Exception` was thrown when using a data set containing a single batch - Fixed a bug where an `Exception` was thrown if no `val_dataloader` was given @@ -226,114 +318,204 @@ memory utilization - Fixed a bug where `AttributeError` could be suppressed by the `Trainer` ## [0.4.6] - 2019-08-15 + ### Added + - Added support for data to be given as a `dict` or `list` with a single gpu - Added support for `configure_optimizers` to return a single optimizer, two list (optimizers and schedulers), or a single list + ### Changed + - None + ### Deprecated + - None + ### Removed + - None + ### Fixed + - Fixed a bug where returning just an optimizer list (i.e. without schedulers) from `configure_optimizers` would throw an `Exception` ## [0.4.5] - 2019-08-13 + ### Added + - Added `optimizer_step` method that can be overridden to change the standard optimizer behaviour + ### Changed + - None + ### Deprecated + - None + ### Removed + - None + ### Fixed + - None ## [0.4.4] - 2019-08-12 + ### Added + - Added supoort for multiple validation dataloaders - Added support for latest test-tube logger (optimised for `torch==1.2.0`) + ### Changed + - `validation_step` and `val_dataloader` are now optional - `lr_scheduler` is now activated after epoch + ### Deprecated + - None + ### Removed + - None + ### Fixed + - Fixed a bug where a warning would show when using `lr_scheduler` in `torch>1.1.0` -- Fixed a bug where an `Exception` would be thrown if using `torch.DistributedDataParallel` without using a `DistributedSampler`, this now throws a `Warning` instead +- Fixed a bug where an `Exception` would be thrown if using `torch.DistributedDataParallel` without using a `DistributedSampler`, this now throws a `Warning` instead ## [0.4.3] - 2019-08-10 + ### Added + - None + ### Changed + - None + ### Deprecated + - None + ### Removed + - None + ### Fixed + - Fixed a bug where accumulate gradients would scale the loss incorrectly ## [0.4.2] - 2019-08-08 + ### Added + - None + ### Changed + - Changed install requirement to `torch==1.2.0` + ### Deprecated + - None + ### Removed + - None + ### Fixed + - None ## [0.4.1] - 2019-08-08 + ### Added + - None + ### Changed + - Changed install requirement to `torch==1.1.0` + ### Deprecated + - None + ### Removed + - None + ### Fixed + - None ## [0.4.0] - 2019-08-08 + ### Added + - Added 16-bit support for a single GPU - Added support for training continuation (preserves epoch, global step etc.) + ### Changed + - Changed `training_step` and `validation_step`, outputs will no longer be automatically reduced + ### Deprecated + - None + ### Removed + - Removed need for `Experiment` object in `Trainer` + ### Fixed + - Fixed issues with reducing outputs from generative models (such as images and text) ## [0.3.6.1] - 2019-07-27 + ### Added + - None + ### Changed + - None + ### Deprecated + - None + ### Removed + - None + ### Fixed + - Fixed a bug where `Experiment` object was not process safe, potentially causing logs to be overwritten ## [0.3.6] - 2019-07-25 + ### Added + - Added a decorator to do lazy data loading internally + ### Changed + - None + ### Deprecated + - None + ### Removed + - None + ### Fixed + - None diff --git a/pyproject.toml b/pyproject.toml index ee7778129d..4c3ee6e11f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,3 +3,7 @@ requires = [ "setuptools", "wheel", ] + +[tool.autopep8] +max_line_length = 120 +ignore = ["W504", "W504", "E402", "E731", "C40", "E741", "F40", "F841"] diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 9538036563..5618797275 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -1,6 +1,11 @@ -from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler +from .base import Callback +from .early_stopping import EarlyStopping +from .model_checkpoint import ModelCheckpoint +from .gradient_accumulation_scheduler import GradientAccumulationScheduler + __all__ = [ + 'Callback', 'EarlyStopping', 'ModelCheckpoint', 'GradientAccumulationScheduler', diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py new file mode 100644 index 0000000000..7150138d66 --- /dev/null +++ b/pytorch_lightning/callbacks/base.py @@ -0,0 +1,68 @@ +""" +Callbacks +========= + +Callbacks supported by Lightning +""" + +import abc + + +_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization" + + +class Callback(abc.ABC): + """Abstract base class used to build new callbacks.""" + + def __init__(self): + self._trainer = None + + @property + def trainer(self): + assert self._trainer is not None, _NO_TRAINER_ERROR_MSG + return self._trainer + + def set_trainer(self, trainer): + """Make a link to the trainer, so different things like `trainer.current_epoch`, + `trainer.batch_idx`, `trainer.global_step` can be used.""" + self._trainer = trainer + + def on_epoch_begin(self): + """Called when the epoch begins.""" + pass + + def on_epoch_end(self): + """Called when the epoch ends.""" + pass + + def on_batch_begin(self): + """Called when the training batch begins.""" + pass + + def on_batch_end(self): + """Called when the training batch ends.""" + pass + + def on_train_begin(self): + """Called when the train begins.""" + pass + + def on_train_end(self): + """Called when the train ends.""" + pass + + def on_validation_begin(self): + """Called when the validation loop begins.""" + pass + + def on_validation_end(self): + """Called when the validation loop ends.""" + pass + + def on_test_begin(self): + """Called when the test begins.""" + pass + + def on_test_end(self): + """Called when the test ends.""" + pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py new file mode 100644 index 0000000000..645eff2485 --- /dev/null +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -0,0 +1,114 @@ +import logging as log +import warnings + +import numpy as np + +from .base import Callback + + +class EarlyStopping(Callback): + r""" + Stop training when a monitored quantity has stopped improving. + + Args: + monitor (str): quantity to be monitored. Default: ``'val_loss'``. + min_delta (float): minimum change in the monitored quantity + to qualify as an improvement, i.e. an absolute + change of less than `min_delta`, will count as no + improvement. Default: ``0``. + patience (int): number of epochs with no improvement + after which training will be stopped. Default: ``0``. + verbose (bool): verbosity mode. Default: ``False``. + mode (str): one of {auto, min, max}. In `min` mode, + training will stop when the quantity + monitored has stopped decreasing; in `max` + mode it will stop when the quantity + monitored has stopped increasing; in `auto` + mode, the direction is automatically inferred + from the name of the monitored quantity. Default: ``'auto'``. + strict (bool): whether to crash the training if `monitor` is + not found in the metrics. Default: ``True``. + + Example:: + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import EarlyStopping + + early_stopping = EarlyStopping('val_loss') + Trainer(early_stop_callback=early_stopping) + """ + + def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0, + verbose: bool = False, mode: str = 'auto', strict: bool = True): + super().__init__() + + self.monitor = monitor + self.patience = patience + self.verbose = verbose + self.strict = strict + self.min_delta = min_delta + self.wait = 0 + self.stopped_epoch = 0 + + mode_dict = { + 'min': np.less, + 'max': np.greater, + 'auto': np.greater if 'acc' in self.monitor else np.less + } + + if mode not in mode_dict: + if self.verbose > 0: + log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') + mode = 'auto' + + self.monitor_op = mode_dict[mode] + self.min_delta *= 1 if self.monitor_op == np.greater else -1 + + self.on_train_begin() + + def check_metrics(self, logs): + monitor_val = logs.get(self.monitor) + error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' + f' which is not available. Available metrics are:' + f' `{"`, `".join(list(logs.keys()))}`') + + if monitor_val is None: + if self.strict: + raise RuntimeError(error_msg) + if self.verbose > 0: + warnings.warn(error_msg, RuntimeWarning) + + return False + + return True + + def on_train_begin(self): + # Allow instances to be re-used + self.wait = 0 + self.stopped_epoch = 0 + self.best = np.Inf if self.monitor_op == np.less else -np.Inf + + def on_epoch_end(self): + logs = self.trainer.callback_metrics + stop_training = False + if not self.check_metrics(logs): + return stop_training + + current = logs.get(self.monitor) + if self.monitor_op(current - self.min_delta, self.best): + self.best = current + self.wait = 0 + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = self.trainer.current_epoch + stop_training = True + self.on_train_end() + + return stop_training + + def on_train_end(self): + if self.stopped_epoch > 0 and self.verbose > 0: + warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' + ' but will start from "0" in v0.8.0.', DeprecationWarning) + log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping') diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py new file mode 100644 index 0000000000..f4e5ad3764 --- /dev/null +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -0,0 +1,55 @@ +import warnings + +from .base import Callback + + +class GradientAccumulationScheduler(Callback): + r""" + Change gradient accumulation factor according to scheduling. + + Args: + scheduling (dict): scheduling in format {epoch: accumulation_factor} + .. warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in + v0.8.0. + + Example:: + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import GradientAccumulationScheduler + + # at epoch 5 start accumulating every 2 batches + accumulator = GradientAccumulationScheduler(scheduling: {5: 2}) + Trainer(accumulate_grad_batches=accumulator) + """ + + def __init__(self, scheduling: dict): + super().__init__() + + if not scheduling: # empty dict error + raise TypeError("Empty dict cannot be interpreted correct") + + for key in scheduling: + if not isinstance(key, int) or not isinstance(scheduling[key], int): + raise TypeError("All epoches and accumulation factor must be integers") + + minimal_epoch = min(scheduling.keys()) + warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,' + ' but will start from "0" in v0.8.0.', DeprecationWarning) + if minimal_epoch < 1: + msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" + raise IndexError(msg) + if minimal_epoch != 1: # if user didnt define first epoch accumulation factor + scheduling.update({1: 1}) + + self.scheduling = scheduling + self.epochs = sorted(scheduling.keys()) + + def on_epoch_begin(self): + trainer = self.trainer + # indexing epochs from 1 (until v0.6.x) + # In v0.8.0, ` + 1` should be removed. + epoch = trainer.current_epoch + 1 + for i in reversed(range(len(self.epochs))): + if epoch >= self.epochs[i]: + trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) + break diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py new file mode 100644 index 0000000000..a9f7d65b3d --- /dev/null +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -0,0 +1,181 @@ +import os +import shutil +import logging as log +import warnings + +import numpy as np + +from .base import Callback + + +class ModelCheckpoint(Callback): + r""" + Save the model after every epoch. + + Args: + filepath: path to save the model file. + Can contain named formatting options to be auto-filled. + + Example:: + + # save epoch and val_loss in name + ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5') + # saves file like: /path/epoch_2-val_loss_0.2.hdf5 + monitor (str): quantity to monitor. + verbose (bool): verbosity mode, False or True. + save_top_k (int): if `save_top_k == k`, + the best k models according to + the quantity monitored will be saved. + if ``save_top_k == 0``, no models are saved. + if ``save_top_k == -1``, all models are saved. + Please note that the monitors are checked every `period` epochs. + if ``save_top_k >= 2`` and the callback is called multiple + times inside an epoch, the name of the saved file will be + appended with a version count starting with `v0`. + mode (str): one of {auto, min, max}. + If ``save_top_k != 0``, the decision + to overwrite the current save file is made + based on either the maximization or the + minimization of the monitored quantity. For `val_acc`, + this should be `max`, for `val_loss` this should + be `min`, etc. In `auto` mode, the direction is + automatically inferred from the name of the monitored quantity. + save_weights_only (bool): if True, then only the model's weights will be + saved (`model.save_weights(filepath)`), else the full model + is saved (`model.save(filepath)`). + period (int): Interval (number of epochs) between checkpoints. + + Example:: + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import ModelCheckpoint + + # saves checkpoints to my_path whenever 'val_loss' has a new min + checkpoint_callback = ModelCheckpoint(filepath='my_path') + Trainer(checkpoint_callback=checkpoint_callback) + """ + + def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False, + save_top_k: int = 1, save_weights_only: bool = False, + mode: str = 'auto', period: int = 1, prefix: str = ''): + super().__init__() + if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: + warnings.warn( + f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." + "All files in this directory will be deleted when a checkpoint is saved!" + ) + + self.monitor = monitor + self.verbose = verbose + self.filepath = filepath + os.makedirs(filepath, exist_ok=True) + self.save_top_k = save_top_k + self.save_weights_only = save_weights_only + self.period = period + self.epochs_since_last_check = 0 + self.prefix = prefix + self.best_k_models = {} + # {filename: monitor} + self.kth_best_model = '' + self.best = 0 + self.save_function = None + + mode_dict = { + 'min': (np.less, np.Inf, 'min'), + 'max': (np.greater, -np.Inf, 'max'), + 'auto': (np.greater, -np.Inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') + else (np.less, np.Inf, 'min'), + } + + if mode not in mode_dict: + warnings.warn( + f'ModelCheckpoint mode {mode} is unknown, ' + 'fallback to auto mode.', RuntimeWarning) + mode = 'auto' + + self.monitor_op, self.kth_value, self.mode = mode_dict[mode] + + def _del_model(self, filepath): + try: + shutil.rmtree(filepath) + except OSError: + os.remove(filepath) + + def _save_model(self, filepath): + # make paths + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + # delegate the saving to the model + if self.save_function is not None: + self.save_function(filepath) + else: + raise ValueError(".save_function() not set") + + def check_monitor_top_k(self, current): + less_than_k_models = len(self.best_k_models) < self.save_top_k + if less_than_k_models: + return True + return self.monitor_op(current, self.best_k_models[self.kth_best_model]) + + def on_validation_end(self): + logs = self.trainer.callback_metrics + epoch = self.trainer.current_epoch + self.epochs_since_last_check += 1 + + if self.save_top_k == 0: + # no models are saved + return + if self.epochs_since_last_check >= self.period: + self.epochs_since_last_check = 0 + filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt' + version_cnt = 0 + while os.path.isfile(filepath): + # this epoch called before + filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt' + version_cnt += 1 + + if self.save_top_k != -1: + current = logs.get(self.monitor) + + if current is None: + warnings.warn( + f'Can save best model only with {self.monitor} available,' + ' skipping.', RuntimeWarning) + else: + if self.check_monitor_top_k(current): + self._do_check_save(filepath, current, epoch) + else: + if self.verbose > 0: + log.info( + f'\nEpoch {epoch:05d}: {self.monitor}' + f' was not in top {self.save_top_k}') + + else: + if self.verbose > 0: + log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') + self._save_model(filepath) + + def _do_check_save(self, filepath, current, epoch): + # remove kth + if len(self.best_k_models) == self.save_top_k: + delpath = self.kth_best_model + self.best_k_models.pop(self.kth_best_model) + self._del_model(delpath) + + self.best_k_models[filepath] = current + if len(self.best_k_models) == self.save_top_k: + # monitor dict has reached k elements + _op = max if self.mode == 'min' else min + self.kth_best_model = _op(self.best_k_models, + key=self.best_k_models.get) + self.kth_value = self.best_k_models[self.kth_best_model] + + _op = min if self.mode == 'min' else max + self.best = _op(self.best_k_models.values()) + + if self.verbose > 0: + log.info( + f'\nEpoch {epoch:05d}: {self.monitor} reached' + f' {current:0.5f} (best {self.best:0.5f}), saving model to' + f' {filepath} as top {self.save_top_k}') + self._save_model(filepath) diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py deleted file mode 100644 index 8fe557ceb7..0000000000 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ /dev/null @@ -1,431 +0,0 @@ -""" -Callbacks -========= - -Callbacks supported by Lightning -""" - -import os -import shutil -import logging as log -import warnings - -import numpy as np - - -class Callback(object): - """Abstract base class used to build new callbacks.""" - - def __init__(self): - self._trainer = None - - def set_trainer(self, trainer): - """Make a link to the trainer, so different things like `trainer.current_epoch`, - `trainer.batch_idx`, `trainer.global_step` can be used.""" - self._trainer = trainer - - def on_epoch_begin(self): - """Called when the epoch begins.""" - pass - - def on_epoch_end(self): - """Called when the epoch ends.""" - pass - - def on_batch_begin(self): - """Called when the training batch begins.""" - pass - - def on_batch_end(self): - """Called when the training batch ends.""" - pass - - def on_train_begin(self): - """Called when the train begins.""" - pass - - def on_train_end(self): - """Called when the train ends.""" - pass - - def on_validation_begin(self): - """Called when the validation loop begins.""" - pass - - def on_validation_end(self): - """Called when the validation loop ends.""" - pass - - def on_test_begin(self): - """Called when the test begins.""" - pass - - def on_test_end(self): - """Called when the test ends.""" - pass - - -_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization" - - -class EarlyStopping(Callback): - r""" - Stop training when a monitored quantity has stopped improving. - - Args: - monitor (str): quantity to be monitored. Default: ``'val_loss'``. - min_delta (float): minimum change in the monitored quantity - to qualify as an improvement, i.e. an absolute - change of less than `min_delta`, will count as no - improvement. Default: ``0``. - patience (int): number of epochs with no improvement - after which training will be stopped. Default: ``0``. - verbose (bool): verbosity mode. Default: ``0``. - mode (str): one of {auto, min, max}. In `min` mode, - training will stop when the quantity - monitored has stopped decreasing; in `max` - mode it will stop when the quantity - monitored has stopped increasing; in `auto` - mode, the direction is automatically inferred - from the name of the monitored quantity. Default: ``'auto'``. - strict (bool): whether to crash the training if `monitor` is - not found in the metrics. Default: ``True``. - - Example:: - - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import EarlyStopping - - early_stopping = EarlyStopping('val_loss') - Trainer(early_stop_callback=early_stopping) - """ - - def __init__(self, monitor='val_loss', - min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True): - super(EarlyStopping, self).__init__() - - self.monitor = monitor - self.patience = patience - self.verbose = verbose - self.strict = strict - self.min_delta = min_delta - self.wait = 0 - self.stopped_epoch = 0 - - if mode not in ['auto', 'min', 'max']: - if self.verbose > 0: - log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') - mode = 'auto' - - if mode == 'min': - self.monitor_op = np.less - elif mode == 'max': - self.monitor_op = np.greater - else: - if 'acc' in self.monitor: - self.monitor_op = np.greater - else: - self.monitor_op = np.less - - if self.monitor_op == np.greater: - self.min_delta *= 1 - else: - self.min_delta *= -1 - - self.on_train_begin() - - def check_metrics(self, logs): - monitor_val = logs.get(self.monitor) - error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' - f' which is not available. Available metrics are:' - f' `{"`, `".join(list(logs.keys()))}`') - - if monitor_val is None: - if self.strict: - raise RuntimeError(error_msg) - if self.verbose > 0: - warnings.warn(error_msg, RuntimeWarning) - - return False - - return True - - def on_train_begin(self): - # Allow instances to be re-used - self.wait = 0 - self.stopped_epoch = 0 - self.best = np.Inf if self.monitor_op == np.less else -np.Inf - - def on_epoch_end(self): - assert self._trainer is not None, _NO_TRAINER_ERROR_MSG - - logs = self._trainer.callback_metrics - stop_training = False - if not self.check_metrics(logs): - return stop_training - - current = logs.get(self.monitor) - if self.monitor_op(current - self.min_delta, self.best): - self.best = current - self.wait = 0 - else: - self.wait += 1 - if self.wait >= self.patience: - self.stopped_epoch = self._trainer.current_epoch - stop_training = True - self.on_train_end() - - return stop_training - - def on_train_end(self): - if self.stopped_epoch > 0 and self.verbose > 0: - warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' - ' but will start from "0" in v0.8.0.', DeprecationWarning) - log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping') - - -class ModelCheckpoint(Callback): - r""" - - Save the model after every epoch. - - Args: - filepath (str): path to save the model file. - Can contain named formatting options to be auto-filled. - - Example:: - - # save epoch and val_loss in name - ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5') - # saves file like: /path/epoch_2-val_loss_0.2.hdf5 - monitor (str): quantity to monitor. - verbose (bool): verbosity mode, 0 or 1. - save_top_k (int): if `save_top_k == k`, - the best k models according to - the quantity monitored will be saved. - if `save_top_k == 0`, no models are saved. - if `save_top_k == -1`, all models are saved. - Please note that the monitors are checked every `period` epochs. - if `save_top_k >= 2` and the callback is called multiple - times inside an epoch, the name of the saved file will be - appended with a version count starting with `v0`. - mode (str): one of {auto, min, max}. - If `save_top_k != 0`, the decision - to overwrite the current save file is made - based on either the maximization or the - minimization of the monitored quantity. For `val_acc`, - this should be `max`, for `val_loss` this should - be `min`, etc. In `auto` mode, the direction is - automatically inferred from the name of the monitored quantity. - save_weights_only (bool): if True, then only the model's weights will be - saved (`model.save_weights(filepath)`), else the full model - is saved (`model.save(filepath)`). - period (int): Interval (number of epochs) between checkpoints. - - Example:: - - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import ModelCheckpoint - - checkpoint_callback = ModelCheckpoint(filepath='my_path') - Trainer(checkpoint_callback=checkpoint_callback) - - # saves checkpoints to my_path whenever 'val_loss' has a new min - """ - - def __init__(self, filepath, monitor='val_loss', verbose=0, - save_top_k=1, save_weights_only=False, - mode='auto', period=1, prefix=''): - super(ModelCheckpoint, self).__init__() - if ( - save_top_k and - os.path.isdir(filepath) and - len(os.listdir(filepath)) > 0 - ): - warnings.warn( - f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." - "All files in this directory will be deleted when a checkpoint is saved!" - ) - - self.monitor = monitor - self.verbose = verbose - self.filepath = filepath - os.makedirs(filepath, exist_ok=True) - self.save_top_k = save_top_k - self.save_weights_only = save_weights_only - self.period = period - self.epochs_since_last_check = 0 - self.prefix = prefix - self.best_k_models = {} - # {filename: monitor} - self.kth_best_model = '' - self.best = 0 - - if mode not in ['auto', 'min', 'max']: - warnings.warn( - f'ModelCheckpoint mode {mode} is unknown, ' - 'fallback to auto mode.', RuntimeWarning) - mode = 'auto' - - if mode == 'min': - self.monitor_op = np.less - self.kth_value = np.Inf - self.mode = 'min' - elif mode == 'max': - self.monitor_op = np.greater - self.kth_value = -np.Inf - self.mode = 'max' - else: - if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): - self.monitor_op = np.greater - self.kth_value = -np.Inf - self.mode = 'max' - else: - self.monitor_op = np.less - self.kth_value = np.Inf - self.mode = 'min' - - def _del_model(self, filepath): - dirpath = os.path.dirname(filepath) - - # make paths - os.makedirs(dirpath, exist_ok=True) - - try: - shutil.rmtree(filepath) - except OSError: - os.remove(filepath) - - def _save_model(self, filepath): - dirpath = os.path.dirname(filepath) - - # make paths - os.makedirs(dirpath, exist_ok=True) - - # delegate the saving to the model - self.save_function(filepath) - - def check_monitor_top_k(self, current): - less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k - if less_than_k_models: - return True - return self.monitor_op(current, self.best_k_models[self.kth_best_model]) - - def on_validation_end(self): - assert self._trainer is not None, _NO_TRAINER_ERROR_MSG - - logs = self._trainer.callback_metrics - epoch = self._trainer.current_epoch - self.epochs_since_last_check += 1 - - if self.save_top_k == 0: - # no models are saved - return - if self.epochs_since_last_check >= self.period: - self.epochs_since_last_check = 0 - filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt' - version_cnt = 0 - while os.path.isfile(filepath): - # this epoch called before - filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt' - version_cnt += 1 - - if self.save_top_k != -1: - current = logs.get(self.monitor) - - if current is None: - warnings.warn( - f'Can save best model only with {self.monitor} available,' - ' skipping.', RuntimeWarning) - else: - if self.check_monitor_top_k(current): - - # remove kth - if len(self.best_k_models.keys()) == self.save_top_k: - delpath = self.kth_best_model - self.best_k_models.pop(self.kth_best_model) - self._del_model(delpath) - - self.best_k_models[filepath] = current - if len(self.best_k_models.keys()) == self.save_top_k: - # monitor dict has reached k elements - if self.mode == 'min': - self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get) - else: - self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get) - self.kth_value = self.best_k_models[self.kth_best_model] - - if self.mode == 'min': - self.best = min(self.best_k_models.values()) - else: - self.best = max(self.best_k_models.values()) - if self.verbose > 0: - log.info( - f'\nEpoch {epoch:05d}: {self.monitor} reached' - f' {current:0.5f} (best {self.best:0.5f}), saving model to' - f' {filepath} as top {self.save_top_k}') - self._save_model(filepath) - - else: - if self.verbose > 0: - log.info( - f'\nEpoch {epoch:05d}: {self.monitor}' - f' was not in top {self.save_top_k}') - - else: - if self.verbose > 0: - log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') - self._save_model(filepath) - - -class GradientAccumulationScheduler(Callback): - r""" - Change gradient accumulation factor according to scheduling. - - Args: - scheduling (dict): scheduling in format {epoch: accumulation_factor} - .. warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in v0.8.0. - - Example:: - - from pytorch_lightning import Trainer - from pytorch_lightning.callbacks import GradientAccumulationScheduler - - # at epoch 5 start accumulating every 2 batches - accumulator = GradientAccumulationScheduler(scheduling: {5: 2}) - Trainer(accumulate_grad_batches=accumulator) - """ - - def __init__(self, scheduling: dict): - super().__init__() - - if scheduling == {}: # empty dict error - raise TypeError("Empty dict cannot be interpreted correct") - - for key in scheduling.keys(): - if not isinstance(key, int) or not isinstance(scheduling[key], int): - raise TypeError("All epoches and accumulation factor must be integers") - - minimal_epoch = min(scheduling.keys()) - warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,' - ' but will start from "0" in v0.8.0.', DeprecationWarning) - if minimal_epoch < 1: - msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" - raise IndexError(msg) - if minimal_epoch != 1: # if user didnt define first epoch accumulation factor - scheduling.update({1: 1}) - - self.scheduling = scheduling - self.epochs = sorted(scheduling.keys()) - - def on_epoch_begin(self): - assert self._trainer is not None, _NO_TRAINER_ERROR_MSG - - trainer = self._trainer - # indexing epochs from 1 (until v0.6.x) - # In v0.8.0, ` + 1` should be removed. - epoch = trainer.current_epoch + 1 - for i in reversed(range(len(self.epochs))): - if epoch >= self.epochs[i]: - trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) - break