Split callbacks (#849)
* add .vscode in .gitignore * Split callbacks in individual files + add a property to Callback for easy trainer instance access * formatting * Add a conda env file for quick and easy env setup to develop on PL * Adress comments * add fix to kth_best_model * add some typing to callbacks * fix typo * add autopep8 config to pyproject.toml * format again * format * fix toml * fix toml again * consistent max line length in all config files * remove conda env file * Update pytorch_lightning/callbacks/early_stopping.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * docstring * Update pytorch_lightning/callbacks/model_checkpoint.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * fix logic error * format * simplify if/else * format * fix linting issue in changelog * edit changelog about new callback mechanism * fix remaining formating issue on CHANGELOG * remove lambda function because it's compatible with pickle (used during ddp) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
da2f11a9c4
commit
89d5772f55
|
@ -5,7 +5,7 @@ model_weights/
|
|||
app/models/
|
||||
pip-wheel-metadata/
|
||||
lightning_logs/
|
||||
|
||||
.vscode/
|
||||
|
||||
# Test-tube
|
||||
test_tube_logs/
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
MD013: false # headers with the same names
|
||||
MD024: false # line length
|
|
@ -5,7 +5,7 @@ scanner:
|
|||
linter: pycodestyle # Other option is flake8
|
||||
|
||||
pycodestyle: # Same as scanner.linter value. Other option is flake8
|
||||
max-line-length: 100 # Default is 79 in PEP 8
|
||||
max-line-length: 120 # Default is 79 in PEP 8
|
||||
ignore: # Errors and warnings to ignore
|
||||
- W504 # line break after binary operator
|
||||
- E402 # module level import not at top of file
|
||||
|
|
186
CHANGELOG.md
186
CHANGELOG.md
|
@ -1,10 +1,13 @@
|
|||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Updated governance docs
|
||||
- Added a check to ensure that the metric used for early stopping exists before training commences ([#542](https://github.com/PyTorchLightning/pytorch-lightning/pull/542))
|
||||
- Added `optimizer_idx` argument to `backward` hook ([#733](https://github.com/PyTorchLightning/pytorch-lightning/pull/733))
|
||||
|
@ -15,24 +18,35 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `train_dataloader`, `val_dataloader` and `test_dataloader` arguments to `Trainer.fit()`, for alternative data parsing ([#759](https://github.com/PyTorchLightning/pytorch-lightning/pull/759))
|
||||
- Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868))
|
||||
- Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876))
|
||||
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752))
|
||||
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
|
||||
- Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749))
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed dependency on pandas ([#736](https://github.com/PyTorchLightning/pytorch-lightning/pull/736))
|
||||
- Removed dependency on torchvision ([#797](https://github.com/PyTorchLightning/pytorch-lightning/pull/797))
|
||||
- Removed dependency on scikit-learn ([#801](https://github.com/PyTorchLightning/pytorch-lightning/pull/801))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug where early stopping `on_end_epoch` would be called inconsistently when `check_val_every_n_epoch == 0` ([#743](https://github.com/PyTorchLightning/pytorch-lightning/pull/743))
|
||||
- Fixed a bug where the model checkpointer didn't write to the same directory as the logger ([#771](https://github.com/PyTorchLightning/pytorch-lightning/pull/771))
|
||||
- Fixed a bug where the `TensorBoardLogger` class would create an additional empty log file during fitting ([#777](https://github.com/PyTorchLightning/pytorch-lightning/pull/777))
|
||||
- Fixed a bug where `global_step` was advanced incorrectly when using `accumulate_grad_batches > 1` ([#832](https://github.com/PyTorchLightning/pytorch-lightning/pull/832))
|
||||
|
||||
## [0.6.0] - 2020-01-21
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for resuming from a specific checkpoint via `resume_from_checkpoint` argument ([#516](https://github.com/PyTorchLightning/pytorch-lightning/pull/516))
|
||||
- Added support for `ReduceLROnPlateau` scheduler ([#320](https://github.com/PyTorchLightning/pytorch-lightning/pull/320))
|
||||
- Added support for Apex mode `O2` in conjunction with Data Parallel ([#493](https://github.com/PyTorchLightning/pytorch-lightning/pull/493))
|
||||
|
@ -44,19 +58,27 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added option to disable validation by setting `val_percent_check=0` ([#649](https://github.com/PyTorchLightning/pytorch-lightning/pull/649))
|
||||
- Added `NeptuneLogger` class ([#648](https://github.com/PyTorchLightning/pytorch-lightning/pull/648))
|
||||
- Added `WandbLogger` class ([#627](https://github.com/PyTorchLightning/pytorch-lightning/pull/627))
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed the default progress bar to print to stdout instead of stderr ([#531](https://github.com/PyTorchLightning/pytorch-lightning/pull/531))
|
||||
- Renamed `step_idx` to `step`, `epoch_idx` to `epoch`, `max_num_epochs` to `max_epochs` and `min_num_epochs` to `min_epochs` ([#589](https://github.com/PyTorchLightning/pytorch-lightning/pull/589))
|
||||
- Renamed `total_batch_nb` to `total_batches`, `nb_val_batches` to `num_val_batches`, `nb_training_batches` to `num_training_batches`, `max_nb_epochs` to `max_epochs`, `min_nb_epochs` to `min_epochs`, `nb_test_batches` to `num_test_batches`, and `nb_val_batches` to `num_val_batches` ([#567](https://github.com/PyTorchLightning/pytorch-lightning/pull/567))
|
||||
- Changed gradient logging to use parameter names instead of indexes ([#660](https://github.com/PyTorchLightning/pytorch-lightning/pull/660))
|
||||
- Changed the default logger to `TensorBoardLogger` ([#609](https://github.com/PyTorchLightning/pytorch-lightning/pull/609))
|
||||
- Changed the directory for tensorboard logging to be the same as model checkpointing ([#706](https://github.com/PyTorchLightning/pytorch-lightning/pull/706))
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `max_nb_epochs` and `min_nb_epochs` ([#567](https://github.com/PyTorchLightning/pytorch-lightning/pull/567))
|
||||
- Deprecated the `on_sanity_check_start` hook in `ModelHooks` ([#598](https://github.com/PyTorchLightning/pytorch-lightning/pull/598))
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed the `save_best_only` argument from `ModelCheckpoint`, use `save_top_k=1` instead ([#128](https://github.com/PyTorchLightning/pytorch-lightning/pull/128))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug which ocurred when using Adagrad with cuda ([#554](https://github.com/PyTorchLightning/pytorch-lightning/pull/554))
|
||||
- Fixed a bug where training would be on the GPU despite setting `gpus=0` or `gpus=[]` ([#561](https://github.com/PyTorchLightning/pytorch-lightning/pull/561))
|
||||
- Fixed an error with `print_nan_gradients` when some parameters do not require gradient ([#579](https://github.com/PyTorchLightning/pytorch-lightning/pull/579))
|
||||
|
@ -79,18 +101,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed a bug where `on_train_end` was not called when ealy stopping ([#723](https://github.com/PyTorchLightning/pytorch-lightning/pull/723))
|
||||
|
||||
## [0.5.3] - 2019-11-06
|
||||
|
||||
### Added
|
||||
|
||||
- Added option to disable default logger, checkpointer, and early stopping by passing `logger=False`, `checkpoint_callback=False` and `early_stop_callback=False` respectively
|
||||
- Added `CometLogger` for use with Comet.ml
|
||||
- Added `val_check_interval` argument to `Trainer` allowing validition to be performed at every given number of batches
|
||||
- Added functionality to save and load hyperparameters using the standard checkpoint mechanism
|
||||
- Added call to `torch.cuda.empty_cache` before training starts
|
||||
- Added option for user to override the call t `backward`
|
||||
- Added support for truncated backprop through time via the `truncated_bptt_steps` argument in `Trainer`
|
||||
- Added support for truncated backprop through time via the `truncated_bptt_steps` argument in `Trainer`
|
||||
- Added option to operate on all outputs from `training_step` in DDP2
|
||||
- Added a hook for modifying DDP init
|
||||
- Added a hook for modifying Apex
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed experiment version to be padded with zeros (e.g. `/dir/version_9` becomes `/dir/version_0009`)
|
||||
- Changed callback metrics to include any metrics given in logs or progress bar
|
||||
- Changed the default for `save_best_only` in `ModelCheckpoint` to `True`
|
||||
|
@ -100,11 +126,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Changed weights restore to first attempt HPC weights before restoring normally, preventing both weights being restored and running out of memory
|
||||
- Changed progress bar functionality to add multiple progress bars for train/val/test
|
||||
- Changed calls to `print` to use `logging` instead
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `tng_dataloader`
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where the number of batches was off by one during training
|
||||
- Fixed a bug that occured when setting a ckeckpoint callback and `early_stop_callback=False`
|
||||
- Fixed an error when importing CometLogger
|
||||
|
@ -115,23 +147,35 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed a bug where checkpointing would sometimes erase the current directory
|
||||
|
||||
## [0.5.2] - 2019-10-10
|
||||
|
||||
### Added
|
||||
|
||||
- Added `weights_summary` argument to `Trainer` to be set to `full` (full summary), `top` (just top level modules) or other
|
||||
- Added `tags` argument to `MLFlowLogger`
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed default for `amp_level` to `O1`
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed the `print_weights_summary` argument from `Trainer`
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug where logs were not written properly
|
||||
- Fixed a bug where `logger.finalize` wasn't called after training is complete
|
||||
- Fixed callback metric errors in DDP
|
||||
- Fixed a bug where `TestTubeLogger` didn't log to the correct directory
|
||||
|
||||
## [0.5.1] - 2019-10-05
|
||||
|
||||
### Added
|
||||
|
||||
- Added the `LightningLoggerBase` class for experiment loggers
|
||||
- Added `MLFlowLogger` for logging with `mlflow`
|
||||
- Added `TestTubeLogger` for logging with `test_tube`
|
||||
|
@ -139,84 +183,132 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for optimisers which require a closure (e.g. LBFGS)
|
||||
- Added automatic `MASTER_PORT` defualt for DDP when not set manually
|
||||
- Added new GPU memory logging options `'min_max'` (log only the min/max utilization) and `'all'` (log all the GPU memory)
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed schedulers to always be called with the current epoch
|
||||
- Changed `test_tube` to an optional dependency
|
||||
- Changed data loaders to internally use a getter instead of a python property
|
||||
- Disabled auto GPU loading when restoring weights to prevent out of memory errors
|
||||
- Changed logging, early stopping and checkpointing to occur by default
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug with samplers that do not specify `set_epoch`
|
||||
- Fixed a bug when using the `MLFlowLogger` with unsupported data types, this will now raise a warning
|
||||
- Fixed a bug where gradient norms were alwasy zero using `track_grad_norm`
|
||||
- Fixed a bug which causes a crash when logging memory
|
||||
|
||||
## [0.5.0] - 2019-09-26
|
||||
|
||||
### Added
|
||||
|
||||
- None
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed `data_batch` argument to `batch` throughout
|
||||
- Changed `batch_i` argument to `batch_idx` throughout
|
||||
- Changed `tng_dataloader` method to `train_dataloader`
|
||||
- Changed `on_tng_metrics` method to `on_training_metrics`
|
||||
- Changed `gradient_clip` argument to `gradient_clip_val`
|
||||
- Changed `add_log_row_interval` to `row_log_interval`
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug with tensorboard logging in multi-gpu setup
|
||||
|
||||
## [0.4.9] - 2019-09-16
|
||||
|
||||
### Added
|
||||
|
||||
- Added the flag `log_gpu_memory` to `Trainer` to deactivate logging of GPU
|
||||
memory utilization
|
||||
- Added SLURM resubmit functionality (port from test-tube)
|
||||
- Added optional weight_save_path to trainer to remove the need for a checkpoint_callback when using cluster training
|
||||
- Added option to use single gpu per node with `DistributedDataParallel`
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed functionality of `validation_end` and `test_end` with multiple dataloaders to be given all of the dataloaders at once rather than in seperate calls
|
||||
- Changed print_nan_grads to only print the parameter value and gradients when they contain NaN
|
||||
- Changed gpu API to take integers as well (e.g. `gpus=2` instead of `gpus=[0, 1]`)
|
||||
- All models now loaded on to CPU to avoid device and out of memory issues in PyTorch
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug where data types that implement `.to` but not `.cuda` would not be properly moved onto the GPU
|
||||
- Fixed a bug where data would not be re-shuffled every epoch when using a `DistributedSampler`
|
||||
|
||||
## [0.4.8] - 2019-08-31
|
||||
|
||||
### Added
|
||||
|
||||
- Added `test_step` and `test_end` methods, used when `Trainer.test` is called
|
||||
- Added `GradientAccumulationScheduler` callback which can be used to schedule changes to the number of accumulation batches
|
||||
- Added option to skip the validation sanity check by setting `nb_sanity_val_steps = 0`
|
||||
|
||||
### Changed
|
||||
|
||||
- None
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug when setting `nb_sanity_val_steps = 0`
|
||||
|
||||
## [0.4.7] - 2019-08-24
|
||||
|
||||
### Added
|
||||
|
||||
- None
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed the default `val_check_interval` to `1.0`
|
||||
- Changed defaults for `nb_val_batches`, `nb_tng_batches` and `nb_test_batches` to 0
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug where the full validation set as used despite setting `val_percent_check`
|
||||
- Fixed a bug where an `Exception` was thrown when using a data set containing a single batch
|
||||
- Fixed a bug where an `Exception` was thrown if no `val_dataloader` was given
|
||||
|
@ -226,114 +318,204 @@ memory utilization
|
|||
- Fixed a bug where `AttributeError` could be suppressed by the `Trainer`
|
||||
|
||||
## [0.4.6] - 2019-08-15
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for data to be given as a `dict` or `list` with a single gpu
|
||||
- Added support for `configure_optimizers` to return a single optimizer, two list (optimizers and schedulers), or a single list
|
||||
|
||||
### Changed
|
||||
|
||||
- None
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug where returning just an optimizer list (i.e. without schedulers) from `configure_optimizers` would throw an `Exception`
|
||||
|
||||
## [0.4.5] - 2019-08-13
|
||||
|
||||
### Added
|
||||
|
||||
- Added `optimizer_step` method that can be overridden to change the standard optimizer behaviour
|
||||
|
||||
### Changed
|
||||
|
||||
- None
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- None
|
||||
|
||||
## [0.4.4] - 2019-08-12
|
||||
|
||||
### Added
|
||||
|
||||
- Added supoort for multiple validation dataloaders
|
||||
- Added support for latest test-tube logger (optimised for `torch==1.2.0`)
|
||||
|
||||
### Changed
|
||||
|
||||
- `validation_step` and `val_dataloader` are now optional
|
||||
- `lr_scheduler` is now activated after epoch
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug where a warning would show when using `lr_scheduler` in `torch>1.1.0`
|
||||
- Fixed a bug where an `Exception` would be thrown if using `torch.DistributedDataParallel` without using a `DistributedSampler`, this now throws a `Warning` instead
|
||||
- Fixed a bug where an `Exception` would be thrown if using `torch.DistributedDataParallel` without using a `DistributedSampler`, this now throws a `Warning` instead
|
||||
|
||||
## [0.4.3] - 2019-08-10
|
||||
|
||||
### Added
|
||||
|
||||
- None
|
||||
|
||||
### Changed
|
||||
|
||||
- None
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug where accumulate gradients would scale the loss incorrectly
|
||||
|
||||
## [0.4.2] - 2019-08-08
|
||||
|
||||
### Added
|
||||
|
||||
- None
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed install requirement to `torch==1.2.0`
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- None
|
||||
|
||||
## [0.4.1] - 2019-08-08
|
||||
|
||||
### Added
|
||||
|
||||
- None
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed install requirement to `torch==1.1.0`
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- None
|
||||
|
||||
## [0.4.0] - 2019-08-08
|
||||
|
||||
### Added
|
||||
|
||||
- Added 16-bit support for a single GPU
|
||||
- Added support for training continuation (preserves epoch, global step etc.)
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed `training_step` and `validation_step`, outputs will no longer be automatically reduced
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed need for `Experiment` object in `Trainer`
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed issues with reducing outputs from generative models (such as images and text)
|
||||
|
||||
## [0.3.6.1] - 2019-07-27
|
||||
|
||||
### Added
|
||||
|
||||
- None
|
||||
|
||||
### Changed
|
||||
|
||||
- None
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a bug where `Experiment` object was not process safe, potentially causing logs to be overwritten
|
||||
|
||||
## [0.3.6] - 2019-07-25
|
||||
|
||||
### Added
|
||||
|
||||
- Added a decorator to do lazy data loading internally
|
||||
|
||||
### Changed
|
||||
|
||||
- None
|
||||
|
||||
### Deprecated
|
||||
|
||||
- None
|
||||
|
||||
### Removed
|
||||
|
||||
- None
|
||||
|
||||
### Fixed
|
||||
|
||||
- None
|
||||
|
|
|
@ -3,3 +3,7 @@ requires = [
|
|||
"setuptools",
|
||||
"wheel",
|
||||
]
|
||||
|
||||
[tool.autopep8]
|
||||
max_line_length = 120
|
||||
ignore = ["W504", "W504", "E402", "E731", "C40", "E741", "F40", "F841"]
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler
|
||||
from .base import Callback
|
||||
from .early_stopping import EarlyStopping
|
||||
from .model_checkpoint import ModelCheckpoint
|
||||
from .gradient_accumulation_scheduler import GradientAccumulationScheduler
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Callback',
|
||||
'EarlyStopping',
|
||||
'ModelCheckpoint',
|
||||
'GradientAccumulationScheduler',
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
"""
|
||||
Callbacks
|
||||
=========
|
||||
|
||||
Callbacks supported by Lightning
|
||||
"""
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"
|
||||
|
||||
|
||||
class Callback(abc.ABC):
|
||||
"""Abstract base class used to build new callbacks."""
|
||||
|
||||
def __init__(self):
|
||||
self._trainer = None
|
||||
|
||||
@property
|
||||
def trainer(self):
|
||||
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
|
||||
return self._trainer
|
||||
|
||||
def set_trainer(self, trainer):
|
||||
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
|
||||
`trainer.batch_idx`, `trainer.global_step` can be used."""
|
||||
self._trainer = trainer
|
||||
|
||||
def on_epoch_begin(self):
|
||||
"""Called when the epoch begins."""
|
||||
pass
|
||||
|
||||
def on_epoch_end(self):
|
||||
"""Called when the epoch ends."""
|
||||
pass
|
||||
|
||||
def on_batch_begin(self):
|
||||
"""Called when the training batch begins."""
|
||||
pass
|
||||
|
||||
def on_batch_end(self):
|
||||
"""Called when the training batch ends."""
|
||||
pass
|
||||
|
||||
def on_train_begin(self):
|
||||
"""Called when the train begins."""
|
||||
pass
|
||||
|
||||
def on_train_end(self):
|
||||
"""Called when the train ends."""
|
||||
pass
|
||||
|
||||
def on_validation_begin(self):
|
||||
"""Called when the validation loop begins."""
|
||||
pass
|
||||
|
||||
def on_validation_end(self):
|
||||
"""Called when the validation loop ends."""
|
||||
pass
|
||||
|
||||
def on_test_begin(self):
|
||||
"""Called when the test begins."""
|
||||
pass
|
||||
|
||||
def on_test_end(self):
|
||||
"""Called when the test ends."""
|
||||
pass
|
|
@ -0,0 +1,114 @@
|
|||
import logging as log
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import Callback
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
r"""
|
||||
Stop training when a monitored quantity has stopped improving.
|
||||
|
||||
Args:
|
||||
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
|
||||
min_delta (float): minimum change in the monitored quantity
|
||||
to qualify as an improvement, i.e. an absolute
|
||||
change of less than `min_delta`, will count as no
|
||||
improvement. Default: ``0``.
|
||||
patience (int): number of epochs with no improvement
|
||||
after which training will be stopped. Default: ``0``.
|
||||
verbose (bool): verbosity mode. Default: ``False``.
|
||||
mode (str): one of {auto, min, max}. In `min` mode,
|
||||
training will stop when the quantity
|
||||
monitored has stopped decreasing; in `max`
|
||||
mode it will stop when the quantity
|
||||
monitored has stopped increasing; in `auto`
|
||||
mode, the direction is automatically inferred
|
||||
from the name of the monitored quantity. Default: ``'auto'``.
|
||||
strict (bool): whether to crash the training if `monitor` is
|
||||
not found in the metrics. Default: ``True``.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
|
||||
early_stopping = EarlyStopping('val_loss')
|
||||
Trainer(early_stop_callback=early_stopping)
|
||||
"""
|
||||
|
||||
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0,
|
||||
verbose: bool = False, mode: str = 'auto', strict: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.monitor = monitor
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.strict = strict
|
||||
self.min_delta = min_delta
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
|
||||
mode_dict = {
|
||||
'min': np.less,
|
||||
'max': np.greater,
|
||||
'auto': np.greater if 'acc' in self.monitor else np.less
|
||||
}
|
||||
|
||||
if mode not in mode_dict:
|
||||
if self.verbose > 0:
|
||||
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
|
||||
mode = 'auto'
|
||||
|
||||
self.monitor_op = mode_dict[mode]
|
||||
self.min_delta *= 1 if self.monitor_op == np.greater else -1
|
||||
|
||||
self.on_train_begin()
|
||||
|
||||
def check_metrics(self, logs):
|
||||
monitor_val = logs.get(self.monitor)
|
||||
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
|
||||
f' which is not available. Available metrics are:'
|
||||
f' `{"`, `".join(list(logs.keys()))}`')
|
||||
|
||||
if monitor_val is None:
|
||||
if self.strict:
|
||||
raise RuntimeError(error_msg)
|
||||
if self.verbose > 0:
|
||||
warnings.warn(error_msg, RuntimeWarning)
|
||||
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def on_train_begin(self):
|
||||
# Allow instances to be re-used
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
||||
|
||||
def on_epoch_end(self):
|
||||
logs = self.trainer.callback_metrics
|
||||
stop_training = False
|
||||
if not self.check_metrics(logs):
|
||||
return stop_training
|
||||
|
||||
current = logs.get(self.monitor)
|
||||
if self.monitor_op(current - self.min_delta, self.best):
|
||||
self.best = current
|
||||
self.wait = 0
|
||||
else:
|
||||
self.wait += 1
|
||||
if self.wait >= self.patience:
|
||||
self.stopped_epoch = self.trainer.current_epoch
|
||||
stop_training = True
|
||||
self.on_train_end()
|
||||
|
||||
return stop_training
|
||||
|
||||
def on_train_end(self):
|
||||
if self.stopped_epoch > 0 and self.verbose > 0:
|
||||
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
|
||||
' but will start from "0" in v0.8.0.', DeprecationWarning)
|
||||
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')
|
|
@ -0,0 +1,55 @@
|
|||
import warnings
|
||||
|
||||
from .base import Callback
|
||||
|
||||
|
||||
class GradientAccumulationScheduler(Callback):
|
||||
r"""
|
||||
Change gradient accumulation factor according to scheduling.
|
||||
|
||||
Args:
|
||||
scheduling (dict): scheduling in format {epoch: accumulation_factor}
|
||||
.. warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in
|
||||
v0.8.0.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||
|
||||
# at epoch 5 start accumulating every 2 batches
|
||||
accumulator = GradientAccumulationScheduler(scheduling: {5: 2})
|
||||
Trainer(accumulate_grad_batches=accumulator)
|
||||
"""
|
||||
|
||||
def __init__(self, scheduling: dict):
|
||||
super().__init__()
|
||||
|
||||
if not scheduling: # empty dict error
|
||||
raise TypeError("Empty dict cannot be interpreted correct")
|
||||
|
||||
for key in scheduling:
|
||||
if not isinstance(key, int) or not isinstance(scheduling[key], int):
|
||||
raise TypeError("All epoches and accumulation factor must be integers")
|
||||
|
||||
minimal_epoch = min(scheduling.keys())
|
||||
warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,'
|
||||
' but will start from "0" in v0.8.0.', DeprecationWarning)
|
||||
if minimal_epoch < 1:
|
||||
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
|
||||
raise IndexError(msg)
|
||||
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor
|
||||
scheduling.update({1: 1})
|
||||
|
||||
self.scheduling = scheduling
|
||||
self.epochs = sorted(scheduling.keys())
|
||||
|
||||
def on_epoch_begin(self):
|
||||
trainer = self.trainer
|
||||
# indexing epochs from 1 (until v0.6.x)
|
||||
# In v0.8.0, ` + 1` should be removed.
|
||||
epoch = trainer.current_epoch + 1
|
||||
for i in reversed(range(len(self.epochs))):
|
||||
if epoch >= self.epochs[i]:
|
||||
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
|
||||
break
|
|
@ -0,0 +1,181 @@
|
|||
import os
|
||||
import shutil
|
||||
import logging as log
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import Callback
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
r"""
|
||||
Save the model after every epoch.
|
||||
|
||||
Args:
|
||||
filepath: path to save the model file.
|
||||
Can contain named formatting options to be auto-filled.
|
||||
|
||||
Example::
|
||||
|
||||
# save epoch and val_loss in name
|
||||
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
|
||||
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
|
||||
monitor (str): quantity to monitor.
|
||||
verbose (bool): verbosity mode, False or True.
|
||||
save_top_k (int): if `save_top_k == k`,
|
||||
the best k models according to
|
||||
the quantity monitored will be saved.
|
||||
if ``save_top_k == 0``, no models are saved.
|
||||
if ``save_top_k == -1``, all models are saved.
|
||||
Please note that the monitors are checked every `period` epochs.
|
||||
if ``save_top_k >= 2`` and the callback is called multiple
|
||||
times inside an epoch, the name of the saved file will be
|
||||
appended with a version count starting with `v0`.
|
||||
mode (str): one of {auto, min, max}.
|
||||
If ``save_top_k != 0``, the decision
|
||||
to overwrite the current save file is made
|
||||
based on either the maximization or the
|
||||
minimization of the monitored quantity. For `val_acc`,
|
||||
this should be `max`, for `val_loss` this should
|
||||
be `min`, etc. In `auto` mode, the direction is
|
||||
automatically inferred from the name of the monitored quantity.
|
||||
save_weights_only (bool): if True, then only the model's weights will be
|
||||
saved (`model.save_weights(filepath)`), else the full model
|
||||
is saved (`model.save(filepath)`).
|
||||
period (int): Interval (number of epochs) between checkpoints.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
# saves checkpoints to my_path whenever 'val_loss' has a new min
|
||||
checkpoint_callback = ModelCheckpoint(filepath='my_path')
|
||||
Trainer(checkpoint_callback=checkpoint_callback)
|
||||
"""
|
||||
|
||||
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
|
||||
save_top_k: int = 1, save_weights_only: bool = False,
|
||||
mode: str = 'auto', period: int = 1, prefix: str = ''):
|
||||
super().__init__()
|
||||
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
|
||||
warnings.warn(
|
||||
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
|
||||
"All files in this directory will be deleted when a checkpoint is saved!"
|
||||
)
|
||||
|
||||
self.monitor = monitor
|
||||
self.verbose = verbose
|
||||
self.filepath = filepath
|
||||
os.makedirs(filepath, exist_ok=True)
|
||||
self.save_top_k = save_top_k
|
||||
self.save_weights_only = save_weights_only
|
||||
self.period = period
|
||||
self.epochs_since_last_check = 0
|
||||
self.prefix = prefix
|
||||
self.best_k_models = {}
|
||||
# {filename: monitor}
|
||||
self.kth_best_model = ''
|
||||
self.best = 0
|
||||
self.save_function = None
|
||||
|
||||
mode_dict = {
|
||||
'min': (np.less, np.Inf, 'min'),
|
||||
'max': (np.greater, -np.Inf, 'max'),
|
||||
'auto': (np.greater, -np.Inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
|
||||
else (np.less, np.Inf, 'min'),
|
||||
}
|
||||
|
||||
if mode not in mode_dict:
|
||||
warnings.warn(
|
||||
f'ModelCheckpoint mode {mode} is unknown, '
|
||||
'fallback to auto mode.', RuntimeWarning)
|
||||
mode = 'auto'
|
||||
|
||||
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
|
||||
|
||||
def _del_model(self, filepath):
|
||||
try:
|
||||
shutil.rmtree(filepath)
|
||||
except OSError:
|
||||
os.remove(filepath)
|
||||
|
||||
def _save_model(self, filepath):
|
||||
# make paths
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
# delegate the saving to the model
|
||||
if self.save_function is not None:
|
||||
self.save_function(filepath)
|
||||
else:
|
||||
raise ValueError(".save_function() not set")
|
||||
|
||||
def check_monitor_top_k(self, current):
|
||||
less_than_k_models = len(self.best_k_models) < self.save_top_k
|
||||
if less_than_k_models:
|
||||
return True
|
||||
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
|
||||
|
||||
def on_validation_end(self):
|
||||
logs = self.trainer.callback_metrics
|
||||
epoch = self.trainer.current_epoch
|
||||
self.epochs_since_last_check += 1
|
||||
|
||||
if self.save_top_k == 0:
|
||||
# no models are saved
|
||||
return
|
||||
if self.epochs_since_last_check >= self.period:
|
||||
self.epochs_since_last_check = 0
|
||||
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
|
||||
version_cnt = 0
|
||||
while os.path.isfile(filepath):
|
||||
# this epoch called before
|
||||
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
|
||||
version_cnt += 1
|
||||
|
||||
if self.save_top_k != -1:
|
||||
current = logs.get(self.monitor)
|
||||
|
||||
if current is None:
|
||||
warnings.warn(
|
||||
f'Can save best model only with {self.monitor} available,'
|
||||
' skipping.', RuntimeWarning)
|
||||
else:
|
||||
if self.check_monitor_top_k(current):
|
||||
self._do_check_save(filepath, current, epoch)
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
log.info(
|
||||
f'\nEpoch {epoch:05d}: {self.monitor}'
|
||||
f' was not in top {self.save_top_k}')
|
||||
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
|
||||
self._save_model(filepath)
|
||||
|
||||
def _do_check_save(self, filepath, current, epoch):
|
||||
# remove kth
|
||||
if len(self.best_k_models) == self.save_top_k:
|
||||
delpath = self.kth_best_model
|
||||
self.best_k_models.pop(self.kth_best_model)
|
||||
self._del_model(delpath)
|
||||
|
||||
self.best_k_models[filepath] = current
|
||||
if len(self.best_k_models) == self.save_top_k:
|
||||
# monitor dict has reached k elements
|
||||
_op = max if self.mode == 'min' else min
|
||||
self.kth_best_model = _op(self.best_k_models,
|
||||
key=self.best_k_models.get)
|
||||
self.kth_value = self.best_k_models[self.kth_best_model]
|
||||
|
||||
_op = min if self.mode == 'min' else max
|
||||
self.best = _op(self.best_k_models.values())
|
||||
|
||||
if self.verbose > 0:
|
||||
log.info(
|
||||
f'\nEpoch {epoch:05d}: {self.monitor} reached'
|
||||
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
|
||||
f' {filepath} as top {self.save_top_k}')
|
||||
self._save_model(filepath)
|
|
@ -1,431 +0,0 @@
|
|||
"""
|
||||
Callbacks
|
||||
=========
|
||||
|
||||
Callbacks supported by Lightning
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import logging as log
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Callback(object):
|
||||
"""Abstract base class used to build new callbacks."""
|
||||
|
||||
def __init__(self):
|
||||
self._trainer = None
|
||||
|
||||
def set_trainer(self, trainer):
|
||||
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
|
||||
`trainer.batch_idx`, `trainer.global_step` can be used."""
|
||||
self._trainer = trainer
|
||||
|
||||
def on_epoch_begin(self):
|
||||
"""Called when the epoch begins."""
|
||||
pass
|
||||
|
||||
def on_epoch_end(self):
|
||||
"""Called when the epoch ends."""
|
||||
pass
|
||||
|
||||
def on_batch_begin(self):
|
||||
"""Called when the training batch begins."""
|
||||
pass
|
||||
|
||||
def on_batch_end(self):
|
||||
"""Called when the training batch ends."""
|
||||
pass
|
||||
|
||||
def on_train_begin(self):
|
||||
"""Called when the train begins."""
|
||||
pass
|
||||
|
||||
def on_train_end(self):
|
||||
"""Called when the train ends."""
|
||||
pass
|
||||
|
||||
def on_validation_begin(self):
|
||||
"""Called when the validation loop begins."""
|
||||
pass
|
||||
|
||||
def on_validation_end(self):
|
||||
"""Called when the validation loop ends."""
|
||||
pass
|
||||
|
||||
def on_test_begin(self):
|
||||
"""Called when the test begins."""
|
||||
pass
|
||||
|
||||
def on_test_end(self):
|
||||
"""Called when the test ends."""
|
||||
pass
|
||||
|
||||
|
||||
_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
r"""
|
||||
Stop training when a monitored quantity has stopped improving.
|
||||
|
||||
Args:
|
||||
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
|
||||
min_delta (float): minimum change in the monitored quantity
|
||||
to qualify as an improvement, i.e. an absolute
|
||||
change of less than `min_delta`, will count as no
|
||||
improvement. Default: ``0``.
|
||||
patience (int): number of epochs with no improvement
|
||||
after which training will be stopped. Default: ``0``.
|
||||
verbose (bool): verbosity mode. Default: ``0``.
|
||||
mode (str): one of {auto, min, max}. In `min` mode,
|
||||
training will stop when the quantity
|
||||
monitored has stopped decreasing; in `max`
|
||||
mode it will stop when the quantity
|
||||
monitored has stopped increasing; in `auto`
|
||||
mode, the direction is automatically inferred
|
||||
from the name of the monitored quantity. Default: ``'auto'``.
|
||||
strict (bool): whether to crash the training if `monitor` is
|
||||
not found in the metrics. Default: ``True``.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
|
||||
early_stopping = EarlyStopping('val_loss')
|
||||
Trainer(early_stop_callback=early_stopping)
|
||||
"""
|
||||
|
||||
def __init__(self, monitor='val_loss',
|
||||
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
|
||||
super(EarlyStopping, self).__init__()
|
||||
|
||||
self.monitor = monitor
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.strict = strict
|
||||
self.min_delta = min_delta
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
|
||||
if mode not in ['auto', 'min', 'max']:
|
||||
if self.verbose > 0:
|
||||
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
|
||||
mode = 'auto'
|
||||
|
||||
if mode == 'min':
|
||||
self.monitor_op = np.less
|
||||
elif mode == 'max':
|
||||
self.monitor_op = np.greater
|
||||
else:
|
||||
if 'acc' in self.monitor:
|
||||
self.monitor_op = np.greater
|
||||
else:
|
||||
self.monitor_op = np.less
|
||||
|
||||
if self.monitor_op == np.greater:
|
||||
self.min_delta *= 1
|
||||
else:
|
||||
self.min_delta *= -1
|
||||
|
||||
self.on_train_begin()
|
||||
|
||||
def check_metrics(self, logs):
|
||||
monitor_val = logs.get(self.monitor)
|
||||
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
|
||||
f' which is not available. Available metrics are:'
|
||||
f' `{"`, `".join(list(logs.keys()))}`')
|
||||
|
||||
if monitor_val is None:
|
||||
if self.strict:
|
||||
raise RuntimeError(error_msg)
|
||||
if self.verbose > 0:
|
||||
warnings.warn(error_msg, RuntimeWarning)
|
||||
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def on_train_begin(self):
|
||||
# Allow instances to be re-used
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
||||
|
||||
def on_epoch_end(self):
|
||||
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
|
||||
|
||||
logs = self._trainer.callback_metrics
|
||||
stop_training = False
|
||||
if not self.check_metrics(logs):
|
||||
return stop_training
|
||||
|
||||
current = logs.get(self.monitor)
|
||||
if self.monitor_op(current - self.min_delta, self.best):
|
||||
self.best = current
|
||||
self.wait = 0
|
||||
else:
|
||||
self.wait += 1
|
||||
if self.wait >= self.patience:
|
||||
self.stopped_epoch = self._trainer.current_epoch
|
||||
stop_training = True
|
||||
self.on_train_end()
|
||||
|
||||
return stop_training
|
||||
|
||||
def on_train_end(self):
|
||||
if self.stopped_epoch > 0 and self.verbose > 0:
|
||||
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
|
||||
' but will start from "0" in v0.8.0.', DeprecationWarning)
|
||||
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
r"""
|
||||
|
||||
Save the model after every epoch.
|
||||
|
||||
Args:
|
||||
filepath (str): path to save the model file.
|
||||
Can contain named formatting options to be auto-filled.
|
||||
|
||||
Example::
|
||||
|
||||
# save epoch and val_loss in name
|
||||
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
|
||||
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
|
||||
monitor (str): quantity to monitor.
|
||||
verbose (bool): verbosity mode, 0 or 1.
|
||||
save_top_k (int): if `save_top_k == k`,
|
||||
the best k models according to
|
||||
the quantity monitored will be saved.
|
||||
if `save_top_k == 0`, no models are saved.
|
||||
if `save_top_k == -1`, all models are saved.
|
||||
Please note that the monitors are checked every `period` epochs.
|
||||
if `save_top_k >= 2` and the callback is called multiple
|
||||
times inside an epoch, the name of the saved file will be
|
||||
appended with a version count starting with `v0`.
|
||||
mode (str): one of {auto, min, max}.
|
||||
If `save_top_k != 0`, the decision
|
||||
to overwrite the current save file is made
|
||||
based on either the maximization or the
|
||||
minimization of the monitored quantity. For `val_acc`,
|
||||
this should be `max`, for `val_loss` this should
|
||||
be `min`, etc. In `auto` mode, the direction is
|
||||
automatically inferred from the name of the monitored quantity.
|
||||
save_weights_only (bool): if True, then only the model's weights will be
|
||||
saved (`model.save_weights(filepath)`), else the full model
|
||||
is saved (`model.save(filepath)`).
|
||||
period (int): Interval (number of epochs) between checkpoints.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(filepath='my_path')
|
||||
Trainer(checkpoint_callback=checkpoint_callback)
|
||||
|
||||
# saves checkpoints to my_path whenever 'val_loss' has a new min
|
||||
"""
|
||||
|
||||
def __init__(self, filepath, monitor='val_loss', verbose=0,
|
||||
save_top_k=1, save_weights_only=False,
|
||||
mode='auto', period=1, prefix=''):
|
||||
super(ModelCheckpoint, self).__init__()
|
||||
if (
|
||||
save_top_k and
|
||||
os.path.isdir(filepath) and
|
||||
len(os.listdir(filepath)) > 0
|
||||
):
|
||||
warnings.warn(
|
||||
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
|
||||
"All files in this directory will be deleted when a checkpoint is saved!"
|
||||
)
|
||||
|
||||
self.monitor = monitor
|
||||
self.verbose = verbose
|
||||
self.filepath = filepath
|
||||
os.makedirs(filepath, exist_ok=True)
|
||||
self.save_top_k = save_top_k
|
||||
self.save_weights_only = save_weights_only
|
||||
self.period = period
|
||||
self.epochs_since_last_check = 0
|
||||
self.prefix = prefix
|
||||
self.best_k_models = {}
|
||||
# {filename: monitor}
|
||||
self.kth_best_model = ''
|
||||
self.best = 0
|
||||
|
||||
if mode not in ['auto', 'min', 'max']:
|
||||
warnings.warn(
|
||||
f'ModelCheckpoint mode {mode} is unknown, '
|
||||
'fallback to auto mode.', RuntimeWarning)
|
||||
mode = 'auto'
|
||||
|
||||
if mode == 'min':
|
||||
self.monitor_op = np.less
|
||||
self.kth_value = np.Inf
|
||||
self.mode = 'min'
|
||||
elif mode == 'max':
|
||||
self.monitor_op = np.greater
|
||||
self.kth_value = -np.Inf
|
||||
self.mode = 'max'
|
||||
else:
|
||||
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
|
||||
self.monitor_op = np.greater
|
||||
self.kth_value = -np.Inf
|
||||
self.mode = 'max'
|
||||
else:
|
||||
self.monitor_op = np.less
|
||||
self.kth_value = np.Inf
|
||||
self.mode = 'min'
|
||||
|
||||
def _del_model(self, filepath):
|
||||
dirpath = os.path.dirname(filepath)
|
||||
|
||||
# make paths
|
||||
os.makedirs(dirpath, exist_ok=True)
|
||||
|
||||
try:
|
||||
shutil.rmtree(filepath)
|
||||
except OSError:
|
||||
os.remove(filepath)
|
||||
|
||||
def _save_model(self, filepath):
|
||||
dirpath = os.path.dirname(filepath)
|
||||
|
||||
# make paths
|
||||
os.makedirs(dirpath, exist_ok=True)
|
||||
|
||||
# delegate the saving to the model
|
||||
self.save_function(filepath)
|
||||
|
||||
def check_monitor_top_k(self, current):
|
||||
less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k
|
||||
if less_than_k_models:
|
||||
return True
|
||||
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
|
||||
|
||||
def on_validation_end(self):
|
||||
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
|
||||
|
||||
logs = self._trainer.callback_metrics
|
||||
epoch = self._trainer.current_epoch
|
||||
self.epochs_since_last_check += 1
|
||||
|
||||
if self.save_top_k == 0:
|
||||
# no models are saved
|
||||
return
|
||||
if self.epochs_since_last_check >= self.period:
|
||||
self.epochs_since_last_check = 0
|
||||
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
|
||||
version_cnt = 0
|
||||
while os.path.isfile(filepath):
|
||||
# this epoch called before
|
||||
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
|
||||
version_cnt += 1
|
||||
|
||||
if self.save_top_k != -1:
|
||||
current = logs.get(self.monitor)
|
||||
|
||||
if current is None:
|
||||
warnings.warn(
|
||||
f'Can save best model only with {self.monitor} available,'
|
||||
' skipping.', RuntimeWarning)
|
||||
else:
|
||||
if self.check_monitor_top_k(current):
|
||||
|
||||
# remove kth
|
||||
if len(self.best_k_models.keys()) == self.save_top_k:
|
||||
delpath = self.kth_best_model
|
||||
self.best_k_models.pop(self.kth_best_model)
|
||||
self._del_model(delpath)
|
||||
|
||||
self.best_k_models[filepath] = current
|
||||
if len(self.best_k_models.keys()) == self.save_top_k:
|
||||
# monitor dict has reached k elements
|
||||
if self.mode == 'min':
|
||||
self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get)
|
||||
else:
|
||||
self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get)
|
||||
self.kth_value = self.best_k_models[self.kth_best_model]
|
||||
|
||||
if self.mode == 'min':
|
||||
self.best = min(self.best_k_models.values())
|
||||
else:
|
||||
self.best = max(self.best_k_models.values())
|
||||
if self.verbose > 0:
|
||||
log.info(
|
||||
f'\nEpoch {epoch:05d}: {self.monitor} reached'
|
||||
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
|
||||
f' {filepath} as top {self.save_top_k}')
|
||||
self._save_model(filepath)
|
||||
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
log.info(
|
||||
f'\nEpoch {epoch:05d}: {self.monitor}'
|
||||
f' was not in top {self.save_top_k}')
|
||||
|
||||
else:
|
||||
if self.verbose > 0:
|
||||
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
|
||||
self._save_model(filepath)
|
||||
|
||||
|
||||
class GradientAccumulationScheduler(Callback):
|
||||
r"""
|
||||
Change gradient accumulation factor according to scheduling.
|
||||
|
||||
Args:
|
||||
scheduling (dict): scheduling in format {epoch: accumulation_factor}
|
||||
.. warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in v0.8.0.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||
|
||||
# at epoch 5 start accumulating every 2 batches
|
||||
accumulator = GradientAccumulationScheduler(scheduling: {5: 2})
|
||||
Trainer(accumulate_grad_batches=accumulator)
|
||||
"""
|
||||
|
||||
def __init__(self, scheduling: dict):
|
||||
super().__init__()
|
||||
|
||||
if scheduling == {}: # empty dict error
|
||||
raise TypeError("Empty dict cannot be interpreted correct")
|
||||
|
||||
for key in scheduling.keys():
|
||||
if not isinstance(key, int) or not isinstance(scheduling[key], int):
|
||||
raise TypeError("All epoches and accumulation factor must be integers")
|
||||
|
||||
minimal_epoch = min(scheduling.keys())
|
||||
warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,'
|
||||
' but will start from "0" in v0.8.0.', DeprecationWarning)
|
||||
if minimal_epoch < 1:
|
||||
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
|
||||
raise IndexError(msg)
|
||||
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor
|
||||
scheduling.update({1: 1})
|
||||
|
||||
self.scheduling = scheduling
|
||||
self.epochs = sorted(scheduling.keys())
|
||||
|
||||
def on_epoch_begin(self):
|
||||
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
|
||||
|
||||
trainer = self._trainer
|
||||
# indexing epochs from 1 (until v0.6.x)
|
||||
# In v0.8.0, ` + 1` should be removed.
|
||||
epoch = trainer.current_epoch + 1
|
||||
for i in reversed(range(len(self.epochs))):
|
||||
if epoch >= self.epochs[i]:
|
||||
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
|
||||
break
|
Loading…
Reference in New Issue