lightning/pytorch_lightning/trainer/callback_config.py

127 lines
4.9 KiB
Python
Raw Normal View History

import os
from abc import ABC, abstractmethod
Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_disabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_enabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-04-24 00:46:18 +00:00
from typing import Union, List
Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_disabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_enabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-04-24 00:46:18 +00:00
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.loggers import LightningLoggerBase
Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_disabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_enabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-04-24 00:46:18 +00:00
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class TrainerCallbackConfigMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_disabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_enabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-04-24 00:46:18 +00:00
callbacks: List[Callback]
default_root_dir: str
logger: Union[LightningLoggerBase, bool]
weights_save_path: str
ckpt_path: str
checkpoint_callback: ModelCheckpoint
Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_disabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_enabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-04-24 00:46:18 +00:00
progress_bar_refresh_rate: int
process_position: int
@property
@abstractmethod
def slurm_job_id(self) -> int:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def save_checkpoint(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def configure_checkpoint_callback(self):
"""
Weight path set in this priority:
Checkpoint_callback's path (if passed in).
User provided weights_saved_path
Otherwise use os.getcwd()
"""
ckpt_path = self.default_root_dir
if self.checkpoint_callback:
# init a default one
if self.logger is not None:
save_dir = (getattr(self.logger, 'save_dir', None) or
getattr(self.logger, '_save_dir', None) or
self.default_root_dir)
# weights_save_path overrides anything
if self.weights_save_path is not None:
save_dir = self.weights_save_path
ckpt_path = os.path.join(
save_dir,
self.logger.name,
f'version_{self.logger.version}',
"checkpoints"
)
else:
ckpt_path = os.path.join(self.default_root_dir, "checkpoints")
proper checkpoint implementation (#1043) * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * name formatting * version * testing * add test * fix test * Update model_checkpoint.py * doctests * pylint * tests * debug * debug * enabled early stopping/checkpooiunt even without val step * fix MNIST download (#1044) * fix MNIST download * simple * name formatting * version * testing * add test * fix test * doctests * tests * debug * debug * rebased 1041 * rebased 1041 * tests * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-03-05 04:02:19 +00:00
# when no val step is defined, use 'loss' otherwise 'val_loss'
train_step_only = not self.is_overriden('validation_step')
monitor_key = 'loss' if train_step_only else 'val_loss'
if self.checkpoint_callback is True:
os.makedirs(ckpt_path, exist_ok=True)
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path,
monitor=monitor_key
)
# If user specified None in filepath, override with runtime default
elif isinstance(self.checkpoint_callback, ModelCheckpoint) \
and self.checkpoint_callback.dirpath is None:
self.checkpoint_callback.dirpath = ckpt_path
self.checkpoint_callback.filename = '{epoch}'
os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None
self.ckpt_path = ckpt_path
if self.checkpoint_callback:
# set the path for the callbacks
self.checkpoint_callback.save_function = self.save_checkpoint
# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.dirpath
# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
self.weights_save_path = self.default_root_dir
2020-01-26 14:42:57 +00:00
def configure_early_stopping(self, early_stop_callback):
if early_stop_callback is True or None:
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=True,
verbose=True,
mode='min'
)
self.enable_early_stop = True
elif not early_stop_callback:
self.early_stop_callback = None
self.enable_early_stop = False
else:
self.early_stop_callback = early_stop_callback
self.enable_early_stop = True
Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_disabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_enabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-04-24 00:46:18 +00:00
def configure_progress_bar(self):
progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)]
if len(progress_bars) > 1:
raise MisconfigurationException(
'You added multiple progress bar callbacks to the Trainer, but currently only one'
' progress bar is supported.'
)
elif len(progress_bars) == 1:
self.progress_bar_callback = progress_bars[0]
elif self.progress_bar_refresh_rate > 0:
self.progress_bar_callback = ProgressBar(
refresh_rate=self.progress_bar_refresh_rate,
process_position=self.process_position,
)
self.callbacks.append(self.progress_bar_callback)
else:
self.progress_bar_callback = None