Split callbacks (#849)

* add .vscode in .gitignore

* Split callbacks in individual files + add a  property to Callback for easy trainer instance access

* formatting

* Add a conda env file for quick and easy env setup to develop on PL

* Adress comments

* add fix to kth_best_model

* add some typing to callbacks

* fix typo

* add autopep8 config to pyproject.toml

* format again

* format

* fix toml

* fix toml again

* consistent max line length in all config files

* remove conda env file

* Update pytorch_lightning/callbacks/early_stopping.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* docstring

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* fix logic error

* format

* simplify if/else

* format

* fix linting issue in changelog

* edit changelog about new callback mechanism

* fix remaining formating issue on CHANGELOG

* remove lambda function because it's compatible with pickle (used during ddp)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Hadrien Mary 2020-02-22 21:45:34 -05:00 committed by GitHub
parent da2f11a9c4
commit 89d5772f55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 616 additions and 436 deletions

2
.gitignore vendored
View File

@ -5,7 +5,7 @@ model_weights/
app/models/
pip-wheel-metadata/
lightning_logs/
.vscode/
# Test-tube
test_tube_logs/

2
.markdownlint.yml Normal file
View File

@ -0,0 +1,2 @@
MD013: false # headers with the same names
MD024: false # line length

View File

@ -5,7 +5,7 @@ scanner:
linter: pycodestyle # Other option is flake8
pycodestyle: # Same as scanner.linter value. Other option is flake8
max-line-length: 100 # Default is 79 in PEP 8
max-line-length: 120 # Default is 79 in PEP 8
ignore: # Errors and warnings to ignore
- W504 # line break after binary operator
- E402 # module level import not at top of file

View File

@ -1,10 +1,13 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased]
### Added
- Updated governance docs
- Added a check to ensure that the metric used for early stopping exists before training commences ([#542](https://github.com/PyTorchLightning/pytorch-lightning/pull/542))
- Added `optimizer_idx` argument to `backward` hook ([#733](https://github.com/PyTorchLightning/pytorch-lightning/pull/733))
@ -15,24 +18,35 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `train_dataloader`, `val_dataloader` and `test_dataloader` arguments to `Trainer.fit()`, for alternative data parsing ([#759](https://github.com/PyTorchLightning/pytorch-lightning/pull/759))
- Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868))
- Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876))
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
### Changed
- Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752))
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
- Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749))
### Deprecated
- None
### Removed
- Removed dependency on pandas ([#736](https://github.com/PyTorchLightning/pytorch-lightning/pull/736))
- Removed dependency on torchvision ([#797](https://github.com/PyTorchLightning/pytorch-lightning/pull/797))
- Removed dependency on scikit-learn ([#801](https://github.com/PyTorchLightning/pytorch-lightning/pull/801))
### Fixed
- Fixed a bug where early stopping `on_end_epoch` would be called inconsistently when `check_val_every_n_epoch == 0` ([#743](https://github.com/PyTorchLightning/pytorch-lightning/pull/743))
- Fixed a bug where the model checkpointer didn't write to the same directory as the logger ([#771](https://github.com/PyTorchLightning/pytorch-lightning/pull/771))
- Fixed a bug where the `TensorBoardLogger` class would create an additional empty log file during fitting ([#777](https://github.com/PyTorchLightning/pytorch-lightning/pull/777))
- Fixed a bug where `global_step` was advanced incorrectly when using `accumulate_grad_batches > 1` ([#832](https://github.com/PyTorchLightning/pytorch-lightning/pull/832))
## [0.6.0] - 2020-01-21
### Added
- Added support for resuming from a specific checkpoint via `resume_from_checkpoint` argument ([#516](https://github.com/PyTorchLightning/pytorch-lightning/pull/516))
- Added support for `ReduceLROnPlateau` scheduler ([#320](https://github.com/PyTorchLightning/pytorch-lightning/pull/320))
- Added support for Apex mode `O2` in conjunction with Data Parallel ([#493](https://github.com/PyTorchLightning/pytorch-lightning/pull/493))
@ -44,19 +58,27 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added option to disable validation by setting `val_percent_check=0` ([#649](https://github.com/PyTorchLightning/pytorch-lightning/pull/649))
- Added `NeptuneLogger` class ([#648](https://github.com/PyTorchLightning/pytorch-lightning/pull/648))
- Added `WandbLogger` class ([#627](https://github.com/PyTorchLightning/pytorch-lightning/pull/627))
### Changed
- Changed the default progress bar to print to stdout instead of stderr ([#531](https://github.com/PyTorchLightning/pytorch-lightning/pull/531))
- Renamed `step_idx` to `step`, `epoch_idx` to `epoch`, `max_num_epochs` to `max_epochs` and `min_num_epochs` to `min_epochs` ([#589](https://github.com/PyTorchLightning/pytorch-lightning/pull/589))
- Renamed `total_batch_nb` to `total_batches`, `nb_val_batches` to `num_val_batches`, `nb_training_batches` to `num_training_batches`, `max_nb_epochs` to `max_epochs`, `min_nb_epochs` to `min_epochs`, `nb_test_batches` to `num_test_batches`, and `nb_val_batches` to `num_val_batches` ([#567](https://github.com/PyTorchLightning/pytorch-lightning/pull/567))
- Changed gradient logging to use parameter names instead of indexes ([#660](https://github.com/PyTorchLightning/pytorch-lightning/pull/660))
- Changed the default logger to `TensorBoardLogger` ([#609](https://github.com/PyTorchLightning/pytorch-lightning/pull/609))
- Changed the directory for tensorboard logging to be the same as model checkpointing ([#706](https://github.com/PyTorchLightning/pytorch-lightning/pull/706))
### Deprecated
- Deprecated `max_nb_epochs` and `min_nb_epochs` ([#567](https://github.com/PyTorchLightning/pytorch-lightning/pull/567))
- Deprecated the `on_sanity_check_start` hook in `ModelHooks` ([#598](https://github.com/PyTorchLightning/pytorch-lightning/pull/598))
### Removed
- Removed the `save_best_only` argument from `ModelCheckpoint`, use `save_top_k=1` instead ([#128](https://github.com/PyTorchLightning/pytorch-lightning/pull/128))
### Fixed
- Fixed a bug which ocurred when using Adagrad with cuda ([#554](https://github.com/PyTorchLightning/pytorch-lightning/pull/554))
- Fixed a bug where training would be on the GPU despite setting `gpus=0` or `gpus=[]` ([#561](https://github.com/PyTorchLightning/pytorch-lightning/pull/561))
- Fixed an error with `print_nan_gradients` when some parameters do not require gradient ([#579](https://github.com/PyTorchLightning/pytorch-lightning/pull/579))
@ -79,18 +101,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where `on_train_end` was not called when ealy stopping ([#723](https://github.com/PyTorchLightning/pytorch-lightning/pull/723))
## [0.5.3] - 2019-11-06
### Added
- Added option to disable default logger, checkpointer, and early stopping by passing `logger=False`, `checkpoint_callback=False` and `early_stop_callback=False` respectively
- Added `CometLogger` for use with Comet.ml
- Added `val_check_interval` argument to `Trainer` allowing validition to be performed at every given number of batches
- Added functionality to save and load hyperparameters using the standard checkpoint mechanism
- Added call to `torch.cuda.empty_cache` before training starts
- Added option for user to override the call t `backward`
- Added support for truncated backprop through time via the `truncated_bptt_steps` argument in `Trainer`
- Added support for truncated backprop through time via the `truncated_bptt_steps` argument in `Trainer`
- Added option to operate on all outputs from `training_step` in DDP2
- Added a hook for modifying DDP init
- Added a hook for modifying Apex
### Changed
- Changed experiment version to be padded with zeros (e.g. `/dir/version_9` becomes `/dir/version_0009`)
- Changed callback metrics to include any metrics given in logs or progress bar
- Changed the default for `save_best_only` in `ModelCheckpoint` to `True`
@ -100,11 +126,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed weights restore to first attempt HPC weights before restoring normally, preventing both weights being restored and running out of memory
- Changed progress bar functionality to add multiple progress bars for train/val/test
- Changed calls to `print` to use `logging` instead
### Deprecated
- Deprecated `tng_dataloader`
### Removed
- None
### Fixed
- Fixed an issue where the number of batches was off by one during training
- Fixed a bug that occured when setting a ckeckpoint callback and `early_stop_callback=False`
- Fixed an error when importing CometLogger
@ -115,23 +147,35 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where checkpointing would sometimes erase the current directory
## [0.5.2] - 2019-10-10
### Added
- Added `weights_summary` argument to `Trainer` to be set to `full` (full summary), `top` (just top level modules) or other
- Added `tags` argument to `MLFlowLogger`
### Changed
- Changed default for `amp_level` to `O1`
### Deprecated
- None
### Removed
- Removed the `print_weights_summary` argument from `Trainer`
### Fixed
- Fixed a bug where logs were not written properly
- Fixed a bug where `logger.finalize` wasn't called after training is complete
- Fixed callback metric errors in DDP
- Fixed a bug where `TestTubeLogger` didn't log to the correct directory
## [0.5.1] - 2019-10-05
### Added
- Added the `LightningLoggerBase` class for experiment loggers
- Added `MLFlowLogger` for logging with `mlflow`
- Added `TestTubeLogger` for logging with `test_tube`
@ -139,84 +183,132 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for optimisers which require a closure (e.g. LBFGS)
- Added automatic `MASTER_PORT` defualt for DDP when not set manually
- Added new GPU memory logging options `'min_max'` (log only the min/max utilization) and `'all'` (log all the GPU memory)
### Changed
- Changed schedulers to always be called with the current epoch
- Changed `test_tube` to an optional dependency
- Changed data loaders to internally use a getter instead of a python property
- Disabled auto GPU loading when restoring weights to prevent out of memory errors
- Changed logging, early stopping and checkpointing to occur by default
### Deprecated
- None
### Removed
- None
### Fixed
- Fixed a bug with samplers that do not specify `set_epoch`
- Fixed a bug when using the `MLFlowLogger` with unsupported data types, this will now raise a warning
- Fixed a bug where gradient norms were alwasy zero using `track_grad_norm`
- Fixed a bug which causes a crash when logging memory
## [0.5.0] - 2019-09-26
### Added
- None
### Changed
- Changed `data_batch` argument to `batch` throughout
- Changed `batch_i` argument to `batch_idx` throughout
- Changed `tng_dataloader` method to `train_dataloader`
- Changed `on_tng_metrics` method to `on_training_metrics`
- Changed `gradient_clip` argument to `gradient_clip_val`
- Changed `add_log_row_interval` to `row_log_interval`
### Deprecated
- None
### Removed
- None
### Fixed
- Fixed a bug with tensorboard logging in multi-gpu setup
## [0.4.9] - 2019-09-16
### Added
- Added the flag `log_gpu_memory` to `Trainer` to deactivate logging of GPU
memory utilization
- Added SLURM resubmit functionality (port from test-tube)
- Added optional weight_save_path to trainer to remove the need for a checkpoint_callback when using cluster training
- Added option to use single gpu per node with `DistributedDataParallel`
### Changed
- Changed functionality of `validation_end` and `test_end` with multiple dataloaders to be given all of the dataloaders at once rather than in seperate calls
- Changed print_nan_grads to only print the parameter value and gradients when they contain NaN
- Changed gpu API to take integers as well (e.g. `gpus=2` instead of `gpus=[0, 1]`)
- All models now loaded on to CPU to avoid device and out of memory issues in PyTorch
### Deprecated
- None
### Removed
- None
### Fixed
- Fixed a bug where data types that implement `.to` but not `.cuda` would not be properly moved onto the GPU
- Fixed a bug where data would not be re-shuffled every epoch when using a `DistributedSampler`
## [0.4.8] - 2019-08-31
### Added
- Added `test_step` and `test_end` methods, used when `Trainer.test` is called
- Added `GradientAccumulationScheduler` callback which can be used to schedule changes to the number of accumulation batches
- Added option to skip the validation sanity check by setting `nb_sanity_val_steps = 0`
### Changed
- None
### Deprecated
- None
### Removed
- None
### Fixed
- Fixed a bug when setting `nb_sanity_val_steps = 0`
## [0.4.7] - 2019-08-24
### Added
- None
### Changed
- Changed the default `val_check_interval` to `1.0`
- Changed defaults for `nb_val_batches`, `nb_tng_batches` and `nb_test_batches` to 0
### Deprecated
- None
### Removed
- None
### Fixed
- Fixed a bug where the full validation set as used despite setting `val_percent_check`
- Fixed a bug where an `Exception` was thrown when using a data set containing a single batch
- Fixed a bug where an `Exception` was thrown if no `val_dataloader` was given
@ -226,114 +318,204 @@ memory utilization
- Fixed a bug where `AttributeError` could be suppressed by the `Trainer`
## [0.4.6] - 2019-08-15
### Added
- Added support for data to be given as a `dict` or `list` with a single gpu
- Added support for `configure_optimizers` to return a single optimizer, two list (optimizers and schedulers), or a single list
### Changed
- None
### Deprecated
- None
### Removed
- None
### Fixed
- Fixed a bug where returning just an optimizer list (i.e. without schedulers) from `configure_optimizers` would throw an `Exception`
## [0.4.5] - 2019-08-13
### Added
- Added `optimizer_step` method that can be overridden to change the standard optimizer behaviour
### Changed
- None
### Deprecated
- None
### Removed
- None
### Fixed
- None
## [0.4.4] - 2019-08-12
### Added
- Added supoort for multiple validation dataloaders
- Added support for latest test-tube logger (optimised for `torch==1.2.0`)
### Changed
- `validation_step` and `val_dataloader` are now optional
- `lr_scheduler` is now activated after epoch
### Deprecated
- None
### Removed
- None
### Fixed
- Fixed a bug where a warning would show when using `lr_scheduler` in `torch>1.1.0`
- Fixed a bug where an `Exception` would be thrown if using `torch.DistributedDataParallel` without using a `DistributedSampler`, this now throws a `Warning` instead
- Fixed a bug where an `Exception` would be thrown if using `torch.DistributedDataParallel` without using a `DistributedSampler`, this now throws a `Warning` instead
## [0.4.3] - 2019-08-10
### Added
- None
### Changed
- None
### Deprecated
- None
### Removed
- None
### Fixed
- Fixed a bug where accumulate gradients would scale the loss incorrectly
## [0.4.2] - 2019-08-08
### Added
- None
### Changed
- Changed install requirement to `torch==1.2.0`
### Deprecated
- None
### Removed
- None
### Fixed
- None
## [0.4.1] - 2019-08-08
### Added
- None
### Changed
- Changed install requirement to `torch==1.1.0`
### Deprecated
- None
### Removed
- None
### Fixed
- None
## [0.4.0] - 2019-08-08
### Added
- Added 16-bit support for a single GPU
- Added support for training continuation (preserves epoch, global step etc.)
### Changed
- Changed `training_step` and `validation_step`, outputs will no longer be automatically reduced
### Deprecated
- None
### Removed
- Removed need for `Experiment` object in `Trainer`
### Fixed
- Fixed issues with reducing outputs from generative models (such as images and text)
## [0.3.6.1] - 2019-07-27
### Added
- None
### Changed
- None
### Deprecated
- None
### Removed
- None
### Fixed
- Fixed a bug where `Experiment` object was not process safe, potentially causing logs to be overwritten
## [0.3.6] - 2019-07-25
### Added
- Added a decorator to do lazy data loading internally
### Changed
- None
### Deprecated
- None
### Removed
- None
### Fixed
- None

View File

@ -3,3 +3,7 @@ requires = [
"setuptools",
"wheel",
]
[tool.autopep8]
max_line_length = 120
ignore = ["W504", "W504", "E402", "E731", "C40", "E741", "F40", "F841"]

View File

@ -1,6 +1,11 @@
from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler
from .base import Callback
from .early_stopping import EarlyStopping
from .model_checkpoint import ModelCheckpoint
from .gradient_accumulation_scheduler import GradientAccumulationScheduler
__all__ = [
'Callback',
'EarlyStopping',
'ModelCheckpoint',
'GradientAccumulationScheduler',

View File

@ -0,0 +1,68 @@
"""
Callbacks
=========
Callbacks supported by Lightning
"""
import abc
_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"
class Callback(abc.ABC):
"""Abstract base class used to build new callbacks."""
def __init__(self):
self._trainer = None
@property
def trainer(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
return self._trainer
def set_trainer(self, trainer):
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
`trainer.batch_idx`, `trainer.global_step` can be used."""
self._trainer = trainer
def on_epoch_begin(self):
"""Called when the epoch begins."""
pass
def on_epoch_end(self):
"""Called when the epoch ends."""
pass
def on_batch_begin(self):
"""Called when the training batch begins."""
pass
def on_batch_end(self):
"""Called when the training batch ends."""
pass
def on_train_begin(self):
"""Called when the train begins."""
pass
def on_train_end(self):
"""Called when the train ends."""
pass
def on_validation_begin(self):
"""Called when the validation loop begins."""
pass
def on_validation_end(self):
"""Called when the validation loop ends."""
pass
def on_test_begin(self):
"""Called when the test begins."""
pass
def on_test_end(self):
"""Called when the test ends."""
pass

View File

@ -0,0 +1,114 @@
import logging as log
import warnings
import numpy as np
from .base import Callback
class EarlyStopping(Callback):
r"""
Stop training when a monitored quantity has stopped improving.
Args:
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
min_delta (float): minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience (int): number of epochs with no improvement
after which training will be stopped. Default: ``0``.
verbose (bool): verbosity mode. Default: ``False``.
mode (str): one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity. Default: ``'auto'``.
strict (bool): whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
early_stopping = EarlyStopping('val_loss')
Trainer(early_stop_callback=early_stopping)
"""
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0,
verbose: bool = False, mode: str = 'auto', strict: bool = True):
super().__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0
mode_dict = {
'min': np.less,
'max': np.greater,
'auto': np.greater if 'acc' in self.monitor else np.less
}
if mode not in mode_dict:
if self.verbose > 0:
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
mode = 'auto'
self.monitor_op = mode_dict[mode]
self.min_delta *= 1 if self.monitor_op == np.greater else -1
self.on_train_begin()
def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are:'
f' `{"`, `".join(list(logs.keys()))}`')
if monitor_val is None:
if self.strict:
raise RuntimeError(error_msg)
if self.verbose > 0:
warnings.warn(error_msg, RuntimeWarning)
return False
return True
def on_train_begin(self):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self):
logs = self.trainer.callback_metrics
stop_training = False
if not self.check_metrics(logs):
return stop_training
current = logs.get(self.monitor)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = self.trainer.current_epoch
stop_training = True
self.on_train_end()
return stop_training
def on_train_end(self):
if self.stopped_epoch > 0 and self.verbose > 0:
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')

View File

@ -0,0 +1,55 @@
import warnings
from .base import Callback
class GradientAccumulationScheduler(Callback):
r"""
Change gradient accumulation factor according to scheduling.
Args:
scheduling (dict): scheduling in format {epoch: accumulation_factor}
.. warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in
v0.8.0.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import GradientAccumulationScheduler
# at epoch 5 start accumulating every 2 batches
accumulator = GradientAccumulationScheduler(scheduling: {5: 2})
Trainer(accumulate_grad_batches=accumulator)
"""
def __init__(self, scheduling: dict):
super().__init__()
if not scheduling: # empty dict error
raise TypeError("Empty dict cannot be interpreted correct")
for key in scheduling:
if not isinstance(key, int) or not isinstance(scheduling[key], int):
raise TypeError("All epoches and accumulation factor must be integers")
minimal_epoch = min(scheduling.keys())
warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
if minimal_epoch < 1:
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
raise IndexError(msg)
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor
scheduling.update({1: 1})
self.scheduling = scheduling
self.epochs = sorted(scheduling.keys())
def on_epoch_begin(self):
trainer = self.trainer
# indexing epochs from 1 (until v0.6.x)
# In v0.8.0, ` + 1` should be removed.
epoch = trainer.current_epoch + 1
for i in reversed(range(len(self.epochs))):
if epoch >= self.epochs[i]:
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
break

View File

@ -0,0 +1,181 @@
import os
import shutil
import logging as log
import warnings
import numpy as np
from .base import Callback
class ModelCheckpoint(Callback):
r"""
Save the model after every epoch.
Args:
filepath: path to save the model file.
Can contain named formatting options to be auto-filled.
Example::
# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
monitor (str): quantity to monitor.
verbose (bool): verbosity mode, False or True.
save_top_k (int): if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
if ``save_top_k == -1``, all models are saved.
Please note that the monitors are checked every `period` epochs.
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode (str): one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only (bool): if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period (int): Interval (number of epochs) between checkpoints.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to my_path whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath='my_path')
Trainer(checkpoint_callback=checkpoint_callback)
"""
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
warnings.warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
os.makedirs(filepath, exist_ok=True)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_check = 0
self.prefix = prefix
self.best_k_models = {}
# {filename: monitor}
self.kth_best_model = ''
self.best = 0
self.save_function = None
mode_dict = {
'min': (np.less, np.Inf, 'min'),
'max': (np.greater, -np.Inf, 'max'),
'auto': (np.greater, -np.Inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
else (np.less, np.Inf, 'min'),
}
if mode not in mode_dict:
warnings.warn(
f'ModelCheckpoint mode {mode} is unknown, '
'fallback to auto mode.', RuntimeWarning)
mode = 'auto'
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
def _del_model(self, filepath):
try:
shutil.rmtree(filepath)
except OSError:
os.remove(filepath)
def _save_model(self, filepath):
# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# delegate the saving to the model
if self.save_function is not None:
self.save_function(filepath)
else:
raise ValueError(".save_function() not set")
def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
def on_validation_end(self):
logs = self.trainer.callback_metrics
epoch = self.trainer.current_epoch
self.epochs_since_last_check += 1
if self.save_top_k == 0:
# no models are saved
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
version_cnt = 0
while os.path.isfile(filepath):
# this epoch called before
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
version_cnt += 1
if self.save_top_k != -1:
current = logs.get(self.monitor)
if current is None:
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
else:
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor}'
f' was not in top {self.save_top_k}')
else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)
def _do_check_save(self, filepath, current, epoch):
# remove kth
if len(self.best_k_models) == self.save_top_k:
delpath = self.kth_best_model
self.best_k_models.pop(self.kth_best_model)
self._del_model(delpath)
self.best_k_models[filepath] = current
if len(self.best_k_models) == self.save_top_k:
# monitor dict has reached k elements
_op = max if self.mode == 'min' else min
self.kth_best_model = _op(self.best_k_models,
key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model]
_op = min if self.mode == 'min' else max
self.best = _op(self.best_k_models.values())
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)

View File

@ -1,431 +0,0 @@
"""
Callbacks
=========
Callbacks supported by Lightning
"""
import os
import shutil
import logging as log
import warnings
import numpy as np
class Callback(object):
"""Abstract base class used to build new callbacks."""
def __init__(self):
self._trainer = None
def set_trainer(self, trainer):
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
`trainer.batch_idx`, `trainer.global_step` can be used."""
self._trainer = trainer
def on_epoch_begin(self):
"""Called when the epoch begins."""
pass
def on_epoch_end(self):
"""Called when the epoch ends."""
pass
def on_batch_begin(self):
"""Called when the training batch begins."""
pass
def on_batch_end(self):
"""Called when the training batch ends."""
pass
def on_train_begin(self):
"""Called when the train begins."""
pass
def on_train_end(self):
"""Called when the train ends."""
pass
def on_validation_begin(self):
"""Called when the validation loop begins."""
pass
def on_validation_end(self):
"""Called when the validation loop ends."""
pass
def on_test_begin(self):
"""Called when the test begins."""
pass
def on_test_end(self):
"""Called when the test ends."""
pass
_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"
class EarlyStopping(Callback):
r"""
Stop training when a monitored quantity has stopped improving.
Args:
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
min_delta (float): minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience (int): number of epochs with no improvement
after which training will be stopped. Default: ``0``.
verbose (bool): verbosity mode. Default: ``0``.
mode (str): one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity. Default: ``'auto'``.
strict (bool): whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
early_stopping = EarlyStopping('val_loss')
Trainer(early_stop_callback=early_stopping)
"""
def __init__(self, monitor='val_loss',
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0
if mode not in ['auto', 'min', 'max']:
if self.verbose > 0:
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
self.on_train_begin()
def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are:'
f' `{"`, `".join(list(logs.keys()))}`')
if monitor_val is None:
if self.strict:
raise RuntimeError(error_msg)
if self.verbose > 0:
warnings.warn(error_msg, RuntimeWarning)
return False
return True
def on_train_begin(self):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
logs = self._trainer.callback_metrics
stop_training = False
if not self.check_metrics(logs):
return stop_training
current = logs.get(self.monitor)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = self._trainer.current_epoch
stop_training = True
self.on_train_end()
return stop_training
def on_train_end(self):
if self.stopped_epoch > 0 and self.verbose > 0:
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')
class ModelCheckpoint(Callback):
r"""
Save the model after every epoch.
Args:
filepath (str): path to save the model file.
Can contain named formatting options to be auto-filled.
Example::
# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
monitor (str): quantity to monitor.
verbose (bool): verbosity mode, 0 or 1.
save_top_k (int): if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if `save_top_k == 0`, no models are saved.
if `save_top_k == -1`, all models are saved.
Please note that the monitors are checked every `period` epochs.
if `save_top_k >= 2` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode (str): one of {auto, min, max}.
If `save_top_k != 0`, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only (bool): if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period (int): Interval (number of epochs) between checkpoints.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(filepath='my_path')
Trainer(checkpoint_callback=checkpoint_callback)
# saves checkpoints to my_path whenever 'val_loss' has a new min
"""
def __init__(self, filepath, monitor='val_loss', verbose=0,
save_top_k=1, save_weights_only=False,
mode='auto', period=1, prefix=''):
super(ModelCheckpoint, self).__init__()
if (
save_top_k and
os.path.isdir(filepath) and
len(os.listdir(filepath)) > 0
):
warnings.warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
os.makedirs(filepath, exist_ok=True)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_check = 0
self.prefix = prefix
self.best_k_models = {}
# {filename: monitor}
self.kth_best_model = ''
self.best = 0
if mode not in ['auto', 'min', 'max']:
warnings.warn(
f'ModelCheckpoint mode {mode} is unknown, '
'fallback to auto mode.', RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
self.kth_value = np.Inf
self.mode = 'min'
elif mode == 'max':
self.monitor_op = np.greater
self.kth_value = -np.Inf
self.mode = 'max'
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.kth_value = -np.Inf
self.mode = 'max'
else:
self.monitor_op = np.less
self.kth_value = np.Inf
self.mode = 'min'
def _del_model(self, filepath):
dirpath = os.path.dirname(filepath)
# make paths
os.makedirs(dirpath, exist_ok=True)
try:
shutil.rmtree(filepath)
except OSError:
os.remove(filepath)
def _save_model(self, filepath):
dirpath = os.path.dirname(filepath)
# make paths
os.makedirs(dirpath, exist_ok=True)
# delegate the saving to the model
self.save_function(filepath)
def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
def on_validation_end(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
logs = self._trainer.callback_metrics
epoch = self._trainer.current_epoch
self.epochs_since_last_check += 1
if self.save_top_k == 0:
# no models are saved
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
version_cnt = 0
while os.path.isfile(filepath):
# this epoch called before
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
version_cnt += 1
if self.save_top_k != -1:
current = logs.get(self.monitor)
if current is None:
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.check_monitor_top_k(current):
# remove kth
if len(self.best_k_models.keys()) == self.save_top_k:
delpath = self.kth_best_model
self.best_k_models.pop(self.kth_best_model)
self._del_model(delpath)
self.best_k_models[filepath] = current
if len(self.best_k_models.keys()) == self.save_top_k:
# monitor dict has reached k elements
if self.mode == 'min':
self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get)
else:
self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model]
if self.mode == 'min':
self.best = min(self.best_k_models.values())
else:
self.best = max(self.best_k_models.values())
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)
else:
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor}'
f' was not in top {self.save_top_k}')
else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)
class GradientAccumulationScheduler(Callback):
r"""
Change gradient accumulation factor according to scheduling.
Args:
scheduling (dict): scheduling in format {epoch: accumulation_factor}
.. warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in v0.8.0.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import GradientAccumulationScheduler
# at epoch 5 start accumulating every 2 batches
accumulator = GradientAccumulationScheduler(scheduling: {5: 2})
Trainer(accumulate_grad_batches=accumulator)
"""
def __init__(self, scheduling: dict):
super().__init__()
if scheduling == {}: # empty dict error
raise TypeError("Empty dict cannot be interpreted correct")
for key in scheduling.keys():
if not isinstance(key, int) or not isinstance(scheduling[key], int):
raise TypeError("All epoches and accumulation factor must be integers")
minimal_epoch = min(scheduling.keys())
warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
if minimal_epoch < 1:
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
raise IndexError(msg)
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor
scheduling.update({1: 1})
self.scheduling = scheduling
self.epochs = sorted(scheduling.keys())
def on_epoch_begin(self):
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
trainer = self._trainer
# indexing epochs from 1 (until v0.6.x)
# In v0.8.0, ` + 1` should be removed.
epoch = trainer.current_epoch + 1
for i in reversed(range(len(self.epochs))):
if epoch >= self.epochs[i]:
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
break