Continue Jeremy's early stopping PR #1504 (#2391)

* add state_dict for early stopping

* move best attr after monitor_op defined

* improve early stopping and model checkpoint callbacks

* fix formatting

* fix attr init order

* clean up setting of default_root_dir attr

* logger needs default root dir set first

* reorg trainer init

* remove direct references to checkpoint callback

* more fixes

* more bugfixes

* run callbacks at epoch end

* update tests to use on epoch end

* PR cleanup

* address failing tests

* refactor for homogeneity

* fix merge conflict

* separate tests

* tests for early stopping bug regressions

* small fixes

* revert model checkpoint change

* typo fix

* fix tests

* update train loop

* cannot pass an int as default_save_path

* refactor log message

* fix test case

* appease the linter

* fix some doctests

* move config to callback

* fixes from rebase

* fixes from rebase

* chlog

* docs

* reformat

* formatting

* fix

* fix

* fixes from rebase

* add new test for patience

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/callbacks/test_early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fix formatting

* remove enable_early_stop attribute

* add state_dict for early stopping

* move best attr after monitor_op defined

* improve early stopping and model checkpoint callbacks

* fix formatting

* fix attr init order

* clean up setting of default_root_dir attr

* logger needs default root dir set first

* reorg trainer init

* remove direct references to checkpoint callback

* more fixes

* more bugfixes

* run callbacks at epoch end

* update tests to use on epoch end

* PR cleanup

* address failing tests

* refactor for homogeneity

* fix merge conflict

* separate tests

* tests for early stopping bug regressions

* small fixes

* revert model checkpoint change

* typo fix

* fix tests

* update train loop

* fix test case

* appease the linter

* fix some doctests

* move config to callback

* fixes from rebase

* fixes from rebase

* chlog

* docs

* reformat

* formatting

* fix

* fix

* fixes from rebase

* add new test for patience

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/callbacks/test_early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fix formatting

* remove enable_early_stop attribute

* fix test with new epoch indexing

* fix progress bar totals

* fix off by one error (see #2289) epoch starts at 0 now

* added missing imports

* fix hpc_save folderpath

* fix formatting

* fix tests

* small fixes from a rebase

* fix

* tmpdir

* tmpdir

* tmpdir

* wandb

* fix merge conflict

* add back evaluation after training

* test_resume_early_stopping_from_checkpoint TODO

* undo the horovod check

* update changelog

* remove a duplicate test from merge error

* try fix dp_resume test

* add the logger fix from master

* try remove default_root_dir

* try mocking numpy

* try import numpy in docs test

* fix wandb test

* pep 8 fix

* skip if no amp

* dont mock when doctesting

* install extra

* fix the resume ES test

* undo conf.py changes

* revert remove comet pickle from test

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update weights_loading.rst

* Update weights_loading.rst

* Update weights_loading.rst

* renamed flag

* renamed flag

* revert the None check in logger experiment name/version

* add the old comments

* _experiment

* test chckpointing on DDP

* skip the ddp test on windows

* cloudpickle

* renamed flag

* renamed flag

* parentheses for clarity

* apply suggestion max epochs

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Adrian Wälchli 2020-06-29 03:36:46 +02:00 committed by GitHub
parent 1e16681693
commit 25ee51bc57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 532 additions and 230 deletions

View File

@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387))
- Fixed several issues with early stopping and checkpoint callbacks ([#1504](https://github.com/PyTorchLightning/pytorch-lightning/pull/1504), [#2391](https://github.com/PyTorchLightning/pytorch-lightning/pull/2391))
- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405))
- Fixed loading model without arguments ([#2403](https://github.com/PyTorchLightning/pytorch-lightning/pull/2403))

View File

@ -48,6 +48,19 @@ We successfully extended functionality without polluting our super clean
----------------
Best Practices
==============
1. Callbacks should be isolated in their functionality. Your callback should not rely on the
behavior of other callbacks in order to work properly.
2. Do not manually call methods from the callback. The callbacks are designed to be
invoked at specific times during training. Directly calling methods (eg. `on_validation_end`)
is strongly discouraged.
3. Whenever possible, your callbacks should not depend on the order in which they are executed.
---------
.. automodule:: pytorch_lightning.callbacks.base
:noindex:
:exclude-members:

View File

@ -92,6 +92,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
.. testcode::
from pytorch_lightning.loggers import NeptuneLogger
neptune_logger = NeptuneLogger(
api_key='ANONYMOUS', # replace with your own
project_name='shared/pytorch-lightning-integration',
@ -193,7 +194,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
.. testcode::
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger()
wandb_logger = WandbLogger(offline=True)
trainer = Trainer(logger=wandb_logger)
The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your

View File

@ -29,7 +29,7 @@ Automatic saving
Checkpointing is enabled by default to the current working directory.
To change the checkpoint path pass in:
.. testcode::
.. code-block:: python
trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints')

View File

@ -5,6 +5,7 @@ Early Stopping
Monitor a validation metric and stop training when it stops improving.
"""
from copy import deepcopy
import numpy as np
import torch
@ -58,7 +59,7 @@ class EarlyStopping(Callback):
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait = 0
self.wait_count = 0
self.stopped_epoch = 0
self.mode = mode
@ -76,12 +77,17 @@ class EarlyStopping(Callback):
log.info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.')
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
def _validate_condition_metric(self, logs):
"""
Checks that the condition metric for early stopping is good
:param logs:
:return:
Args:
logs: callback metrics from validation output
Return:
True if specified metric is available
"""
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
@ -103,39 +109,48 @@ class EarlyStopping(Callback):
def monitor_op(self):
return self.mode_dict[self.mode]
def on_train_start(self, trainer, pl_module):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf
def state_dict(self):
return {
'wait_count': self.wait_count,
'stopped_epoch': self.stopped_epoch,
'best_score': self.best_score,
'patience': self.patience
}
def load_state_dict(self, state_dict):
state_dict = deepcopy(state_dict)
self.wait_count = state_dict['wait_count']
self.stopped_epoch = state_dict['stopped_epoch']
self.best_score = state_dict['best_score']
self.patience = state_dict['patience']
def on_sanity_check_end(self, trainer, pl_module):
logs = trainer.callback_metrics
self._validate_condition_metric(logs)
def on_validation_end(self, trainer, pl_module):
return self._run_early_stopping_check(trainer, pl_module)
self._run_early_stopping_check(trainer, pl_module)
def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics
stop_training = False
if not self._validate_condition_metric(logs):
return stop_training
return # short circuit if metric not present
current = logs.get(self.monitor)
if not isinstance(current, torch.Tensor):
current = torch.tensor(current)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.wait_count += 1
if self.wait_count >= self.patience:
self.stopped_epoch = trainer.current_epoch
stop_training = True
self.on_train_end(trainer, pl_module)
return stop_training
trainer.should_stop = True
def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping triggered.')

View File

@ -226,6 +226,41 @@ class ModelCheckpoint(Callback):
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
return filepath
def on_train_start(self, trainer, pl_module):
"""
Determine model checkpoint save directory at runtime. References attributes from the
Trainer's logger to determine where to save checkpoints.
"""
if self.dirpath is not None:
return # short circuit
self.filename = '{epoch}'
if trainer.logger is not None and trainer.logger.experiment is not None:
# weights_save_path overrides anything
if getattr(trainer, 'weights_save_path', None) is not None:
save_dir = trainer.weights_save_path
else:
save_dir = (getattr(trainer.logger, 'save_dir', None)
or getattr(trainer.logger, '_save_dir', None)
or trainer.default_root_dir)
version = trainer.logger.version if isinstance(
trainer.logger.version, str) else f'version_{trainer.logger.version}'
ckpt_path = os.path.join(
save_dir,
trainer.logger.name,
version,
"checkpoints"
)
else:
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")
self.dirpath = ckpt_path
os.makedirs(self.dirpath, exist_ok=True)
trainer.ckpt_path = ckpt_path
trainer.weights_save_path = ckpt_path
@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process

View File

@ -131,12 +131,12 @@ class WandbLogger(LightningLoggerBase):
self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)
@property
def name(self) -> str:
def name(self) -> Optional[str]:
# don't create an experiment if we don't have one
name = self._experiment.project_name() if self._experiment else None
return name
@property
def version(self) -> str:
def version(self) -> Optional[str]:
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else None

View File

@ -32,79 +32,47 @@ class TrainerCallbackConfigMixin(ABC):
def is_overridden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def configure_checkpoint_callback(self):
def configure_checkpoint_callback(self, checkpoint_callback):
"""
Weight path set in this priority:
Checkpoint_callback's path (if passed in).
User provided weights_saved_path
Otherwise use os.getcwd()
"""
ckpt_path = self.default_root_dir
if self.checkpoint_callback:
# init a default one
if self.logger is not None and self.logger.experiment is not None:
save_dir = (getattr(self.logger, 'save_dir', None) or
getattr(self.logger, '_save_dir', None) or
self.default_root_dir)
# weights_save_path overrides anything
if self.weights_save_path is not None:
save_dir = self.weights_save_path
version = self.logger.version if isinstance(
self.logger.version, str) else f'version_{self.logger.version}'
ckpt_path = os.path.join(save_dir, self.logger.name, version, "checkpoints")
else:
ckpt_path = os.path.join(self.default_root_dir, "checkpoints")
if checkpoint_callback is True:
# when no val step is defined, use 'loss' otherwise 'val_loss'
train_step_only = not self.is_overridden('validation_step')
monitor_key = 'loss' if train_step_only else 'val_loss'
checkpoint_callback = ModelCheckpoint(
filepath=None,
monitor=monitor_key
)
elif checkpoint_callback is False:
checkpoint_callback = None
if self.checkpoint_callback is True:
os.makedirs(ckpt_path, exist_ok=True)
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path,
monitor=monitor_key
)
# If user specified None in filepath, override with runtime default
elif isinstance(self.checkpoint_callback, ModelCheckpoint) \
and self.checkpoint_callback.dirpath is None:
self.checkpoint_callback.dirpath = ckpt_path
self.checkpoint_callback.filename = '{epoch}'
os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None
self.ckpt_path = ckpt_path
if self.checkpoint_callback:
# set the path for the callbacks
self.checkpoint_callback.save_function = self.save_checkpoint
# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.dirpath
if checkpoint_callback:
checkpoint_callback.save_function = self.save_checkpoint
# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
self.weights_save_path = self.default_root_dir
return checkpoint_callback
def configure_early_stopping(self, early_stop_callback):
if early_stop_callback is True or None:
self.early_stop_callback = EarlyStopping(
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=True,
verbose=True,
mode='min'
)
self.enable_early_stop = True
elif not early_stop_callback:
self.early_stop_callback = None
self.enable_early_stop = False
early_stop_callback = None
else:
self.early_stop_callback = early_stop_callback
self.enable_early_stop = True
early_stop_callback = early_stop_callback
return early_stop_callback
def configure_progress_bar(self, refresh_rate=1, process_position=0):
progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)]

View File

@ -172,7 +172,6 @@ class TrainerDDPMixin(ABC):
num_gpu_nodes: int
gpus: List[int]
logger: Union[LightningLoggerBase, bool]
checkpoint_callback: Union[ModelCheckpoint, bool]
data_parallel_device_ids: ...
distributed_backend: Optional[str]
amp_level: str

View File

@ -163,7 +163,6 @@ class TrainerLRFinderMixin(ABC):
# Disable standard checkpoint & early stopping
self.checkpoint_callback = False
self.early_stop_callback = None
self.enable_early_stop = False
# Required for saving the model
self.optimizers, self.schedulers = [], [],
@ -215,7 +214,6 @@ class TrainerLRFinderMixin(ABC):
'max_steps': self.max_steps,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'enable_early_stop': self.enable_early_stop,
'configure_optimizers': model.configure_optimizers,
}
@ -226,7 +224,6 @@ class TrainerLRFinderMixin(ABC):
self.max_steps = self.__dumped_params['max_steps']
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
self.enable_early_stop = self.__dumped_params['enable_early_stop']
model.configure_optimizers = self.__dumped_params['configure_optimizers']
del self.__dumped_params

View File

@ -328,9 +328,60 @@ class Trainer(
if 'LOCAL_RANK' in os.environ:
rank_zero_only.rank = os.environ['LOCAL_RANK']
# Init callbacks
# training bookeeping
self.total_batch_idx = 0
self.running_loss = TensorRunningAccum(window_length=20)
self.batch_idx = 0
self.progress_bar_metrics = {}
self.callback_metrics = {}
self.num_training_batches = 0
self.num_val_batches = []
self.num_test_batches = []
self.train_dataloader = None
self.test_dataloaders = None
self.val_dataloaders = None
# training state
self.model = None
self.testing = False
self.disable_validation = False
self.prepare_data_per_node = prepare_data_per_node
self.lr_schedulers = []
self.optimizers = None
self.optimizer_frequencies = []
self.global_step = 0
self.current_epoch = 0
self.interrupted = False
self.should_stop = False
# set default save path if user didn't provide one
if default_root_dir is None:
default_root_dir = os.getcwd()
self.default_root_dir = default_root_dir
self.configure_logger(logger)
# init callbacks
self.callbacks = callbacks or []
# configure early stop callback
# creates a default one if none passed in
early_stop_callback = self.configure_early_stopping(early_stop_callback)
if early_stop_callback:
self.callbacks.append(early_stop_callback)
# configure checkpoint callback
# it is important that this is the last callback to run
# pass through the required args to figure out defaults
self.weights_save_path = weights_save_path
checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback)
if checkpoint_callback:
self.callbacks.append(checkpoint_callback)
# TODO refactor codebase (tests) to not directly reach into these callbacks
self.checkpoint_callback = checkpoint_callback
self.early_stop_callback = early_stop_callback
self.on_init_start()
# benchmarking
@ -399,52 +450,11 @@ class Trainer(
rank_zero_info('Running in fast_dev_run mode: will run a full train,'
' val and test loop using a single batch')
# set default save path if user didn't provide one
self.default_root_dir = default_root_dir
if self.default_root_dir is None:
self.default_root_dir = os.getcwd()
# training bookeeping
self.total_batch_idx = 0
self.running_loss = TensorRunningAccum(window_length=20)
self.batch_idx = 0
self.progress_bar_metrics = {}
self.callback_metrics = {}
self.num_val_batches = [0]
self.num_training_batches = 0
self.num_test_batches = [0]
self.train_dataloader = None
self.test_dataloaders = None
self.val_dataloaders = None
# training state
self.model = None
self.testing = False
self.disable_validation = False
self.lr_schedulers = []
self.optimizers = None
self.optimizer_frequencies = []
self.global_step = 0
self.current_epoch = 0
self.interrupted = False
# configure logger
self.configure_logger(logger)
# configure profiler
if profiler is True:
profiler = SimpleProfiler()
self.profiler = profiler or PassThroughProfiler()
# configure early stop callback
# creates a default one if none passed in
self.configure_early_stopping(early_stop_callback)
# configure checkpoint callback
self.checkpoint_callback = checkpoint_callback
self.weights_save_path = weights_save_path
# accumulated grads
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)
@ -1045,9 +1055,6 @@ class Trainer(
# if cluster resets state, the model will update with the saved weights
self.model = model
# set up checkpoint callback
self.configure_checkpoint_callback()
# restore training and model before hpc call
self.restore_weights(model)
@ -1078,13 +1085,10 @@ class Trainer(
max_batches,
False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)
self.callback_metrics = callback_metrics
self.on_sanity_check_end()
# verify that early stop has conditioned on a metric that exists
if self.enable_early_stop:
self.early_stop_callback._validate_condition_metric(callback_metrics)
# clear cache before training
if self.on_gpu and self.root_gpu is not None:
# use context because of:

View File

@ -95,6 +95,7 @@ import torch.distributed as torch_distrib
import pytorch_lightning
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
@ -328,26 +329,32 @@ class TrainerIOMixin(ABC):
}
if not weights_only:
if self.checkpoint_callback:
# TODO support more generic way for callbacks to persist a state_dict in a checkpoint
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)]
if checkpoint_callbacks:
# we add the official checkpoint callback to the end of the list
# extra user provided callbacks will not be persisted yet
checkpoint['checkpoint_callback_best_model_score'] = self.checkpoint_callback.best_model_score
checkpoint['checkpoint_callback_best_model_path'] = self.checkpoint_callback.best_model_path
if self.early_stop_callback:
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience
if early_stopping_callbacks and checkpoint_callbacks:
# we add the official early stopping callback to the end of the list
# extra user provided callbacks will not be persisted yet
checkpoint['early_stop_callback_state_dict'] = early_stopping_callbacks[-1].state_dict()
# save optimizers
optimizer_states = []
for i, optimizer in enumerate(self.optimizers):
optimizer_states.append(optimizer.state_dict())
checkpoint['optimizer_states'] = optimizer_states
# save lr schedulers
lr_schedulers = []
for scheduler in self.lr_schedulers:
lr_schedulers.append(scheduler['scheduler'].state_dict())
checkpoint['lr_schedulers'] = lr_schedulers
# save native amp scaling
@ -405,21 +412,25 @@ class TrainerIOMixin(ABC):
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
)
if self.checkpoint_callback:
# TODO support more generic way for callbacks to load callback state_dicts
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)]
if checkpoint_callbacks:
if 'checkpoint_callback_best_model_score' in checkpoint:
self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best_model_score']
checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best_model_score']
else:
# Old naming until version 0.7.6
rank_zero_warn(
'Loading a checkpoint created with an old version of Lightning; '
'this will not be supported in the future.'
)
self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best']
self.checkpoint_callback.best_model_path = checkpoint['checkpoint_callback_best_model_path']
checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best']
checkpoint_callbacks[-1].best_model_path = checkpoint['checkpoint_callback_best_model_path']
if self.early_stop_callback:
self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']
self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']
if early_stopping_callbacks:
state = checkpoint['early_stop_callback_state_dict']
early_stopping_callbacks[-1].load_state_dict(state)
self.global_step = checkpoint['global_step']
self.current_epoch = checkpoint['epoch']

View File

@ -144,8 +144,7 @@ in your model.
"""
import atexit
import signal
import subprocess
from abc import ABC, abstractmethod
from typing import Callable
from typing import Union, List
@ -157,6 +156,7 @@ import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.supporters import TensorRunningAccum
@ -164,7 +164,6 @@ from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.memory import recursive_detach
import subprocess
try:
from apex import amp
@ -212,7 +211,6 @@ class TrainerTrainLoopMixin(ABC):
fast_dev_run: ...
accumulation_scheduler: ...
lr_schedulers: ...
enable_early_stop: ...
early_stop_callback: ...
callback_metrics: ...
logger: Union[LightningLoggerBase, bool]
@ -239,7 +237,6 @@ class TrainerTrainLoopMixin(ABC):
max_steps: int
min_steps: int
total_batch_idx: int
checkpoint_callback: ...
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...
@ -264,7 +261,7 @@ class TrainerTrainLoopMixin(ABC):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def run_evaluation(self, *args):
def run_evaluation(self, *args, **kwargs):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
@ -340,9 +337,6 @@ class TrainerTrainLoopMixin(ABC):
with self.profiler.profile('on_train_start'):
# callbacks
self.on_train_start()
# initialize early stop callback
if self.early_stop_callback is not None:
self.early_stop_callback.on_train_start(self, self.get_model())
# model hooks
model.on_train_start()
@ -375,7 +369,7 @@ class TrainerTrainLoopMixin(ABC):
# -----------------
self.run_training_epoch()
if self.max_steps and self.max_steps == self.global_step:
if self.max_steps and self.max_steps <= self.global_step:
self.run_training_teardown()
return
@ -386,19 +380,14 @@ class TrainerTrainLoopMixin(ABC):
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
# TODO wrap this logic into the callback
# DO NOT DELETE
# early stopping as a (new Callback) class doesn't yet work because we have to know these
# trainer flags including the current epoch stuff
# all of this needs to go into the early stopping to clean up better
if self.enable_early_stop:
if self.should_stop:
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
should_stop = self.early_stop_callback.on_validation_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs
if stop:
self.run_training_teardown()
return
self.run_training_teardown()
return
else:
log.info('Trainer was signaled to stop but required minimum epochs'
f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
' not been met. Training will continue...')
self.run_training_teardown()
@ -444,6 +433,7 @@ class TrainerTrainLoopMixin(ABC):
# bookkeeping
epoch_output = []
should_check_val = False
# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
@ -470,22 +460,24 @@ class TrainerTrainLoopMixin(ABC):
self.update_train_loop_lr_schedulers()
# when returning -1 from train_step, we end epoch early
early_stop_epoch = batch_output.signal == -1
self.should_stop = batch_output.signal == -1
# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
should_check_val = self.check_validation_in_train_loop(batch_idx, early_stop_epoch, is_last_batch)
should_check_val = self.should_check_val(batch_idx, is_last_batch)
if self.fast_dev_run or should_check_val:
self.run_evaluation(test_mode=False)
# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
# -----------------------------------------
self.save_loggers_in_training_loop(batch_idx, early_stop_epoch)
self.save_loggers_in_training_loop(batch_idx)
# -----------------------------------------
# SAVE METRICS TO LOGGERS
# -----------------------------------------
self.save_train_loop_metrics_to_loggers(batch_idx, early_stop_epoch, batch_output)
self.save_train_loop_metrics_to_loggers(batch_idx, batch_output)
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
@ -497,7 +489,7 @@ class TrainerTrainLoopMixin(ABC):
# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
if early_stop_epoch or self.fast_dev_run:
if self.fast_dev_run or self.should_stop:
break
# let ddp devices catch up when using horovod
@ -506,13 +498,19 @@ class TrainerTrainLoopMixin(ABC):
# process epoch outputs
self.run_training_epoch_end(epoch_output)
# when no val loop is present or fast-dev-run still need to call checkpoints
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
self.call_checkpoint_callback()
# checkpoint callback
self.check_checkpoint_callback(should_check_val)
# epoch end hook
self.run_on_epoch_end_hook(model)
def check_checkpoint_callback(self, should_check_val):
# when no val loop is present or fast-dev-run still need to call checkpoints
# TODO bake this logic into the checkpoint callback
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
[c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks]
def update_train_loop_lr_schedulers(self):
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
# update lr
@ -553,33 +551,28 @@ class TrainerTrainLoopMixin(ABC):
self.global_step += 1
self.total_batch_idx += 1
def save_train_loop_metrics_to_loggers(self, batch_idx, early_stop_epoch, batch_output):
def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output):
# when metrics should be logged
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
should_log_metrics = batch_idx % self.row_log_interval == 0 or self.should_stop
if should_log_metrics or self.fast_dev_run:
# logs user requested information to logger
self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic)
def save_loggers_in_training_loop(self, batch_idx, early_stop_epoch):
def save_loggers_in_training_loop(self, batch_idx):
# when loggers should save to disk
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or self.should_stop
if should_save_log or self.fast_dev_run:
if self.is_global_zero and self.logger is not None:
self.logger.save()
def check_validation_in_train_loop(self, batch_idx, early_stop_epoch, is_last_batch):
def should_check_val(self, batch_idx, is_last_batch):
# decide if we should run validation
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
can_check_val = not self.disable_validation and can_check_epoch
should_check_val = is_val_check_batch or early_stop_epoch
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and should_check_val
# if we need to run validation, then also call the checkpoint callback
if self.fast_dev_run or should_check_val:
self.run_evaluation(test_mode=self.testing)
self.call_checkpoint_callback()
should_check_val = is_val_check_batch or self.should_stop
is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)
return should_check_val
@ -984,10 +977,6 @@ class TrainerTrainLoopMixin(ABC):
else:
lr_scheduler['scheduler'].step()
def call_checkpoint_callback(self):
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
def _with_is_last(iterable):
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.

View File

@ -188,7 +188,6 @@ class TrainerTrainingTricksMixin(ABC):
'callbacks': self.callbacks,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'enable_early_stop': self.enable_early_stop,
'auto_scale_batch_size': self.auto_scale_batch_size,
'limit_train_batches': self.limit_train_batches,
'model': self.model,
@ -202,7 +201,6 @@ class TrainerTrainingTricksMixin(ABC):
self.callbacks = [] # not needed before full run
self.checkpoint_callback = False # required for saving
self.early_stop_callback = None
self.enable_early_stop = False
self.limit_train_batches = 1.0
self.optimizers, self.schedulers = [], [] # required for saving
self.model = model # required for saving
@ -215,7 +213,6 @@ class TrainerTrainingTricksMixin(ABC):
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.auto_scale_batch_size = self.__dumped_params['auto_scale_batch_size']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
self.enable_early_stop = self.__dumped_params['enable_early_stop']
self.limit_train_batches = self.__dumped_params['limit_train_batches']
self.model = self.__dumped_params['model']
del self.__dumped_params

View File

@ -160,6 +160,7 @@ def test_trainer_callback_system(tmpdir):
test_callback = TestCallback()
trainer_options = dict(
default_root_dir=tmpdir,
callbacks=[test_callback],
max_epochs=1,
limit_val_batches=0.1,

View File

@ -0,0 +1,133 @@
import pickle
import cloudpickle
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from tests.base import EvalModelTemplate
def test_resume_early_stopping_from_checkpoint(tmpdir):
"""
Prevent regressions to bugs:
https://github.com/PyTorchLightning/pytorch-lightning/issues/1464
https://github.com/PyTorchLightning/pytorch-lightning/issues/1463
"""
class EarlyStoppingTestStore(EarlyStopping):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# cache the state for each epoch
self.saved_states = []
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.saved_states.append(self.state_dict().copy())
class EarlyStoppingTestRestore(EarlyStopping):
def __init__(self, expected_state):
super().__init__()
self.expected_state = expected_state
def on_train_start(self, trainer, pl_module):
assert self.state_dict() == self.expected_state
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(save_top_k=1)
early_stop_callback = EarlyStoppingTestStore()
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
early_stop_callback=early_stop_callback,
max_epochs=4,
)
trainer.fit(model)
checkpoint_filepath = checkpoint_callback.kth_best_model
# ensure state is persisted properly
checkpoint = torch.load(checkpoint_filepath)
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint['epoch'] - 1]
assert 4 == len(early_stop_callback.saved_states)
assert checkpoint['early_stop_callback_state_dict'] == early_stop_callback_state
# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state)
new_trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
resume_from_checkpoint=checkpoint_filepath,
early_stop_callback=early_stop_callback,
)
new_trainer.fit(model)
def test_early_stopping_no_extraneous_invocations(tmpdir):
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
class EarlyStoppingTestInvocations(EarlyStopping):
def __init__(self, expected_count):
super().__init__()
self.count = 0
self.expected_count = expected_count
def on_validation_end(self, trainer, pl_module):
self.count += 1
def on_train_end(self, trainer, pl_module):
assert self.count == self.expected_count
model = EvalModelTemplate()
expected_count = 4
early_stop_callback = EarlyStoppingTestInvocations(expected_count)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=early_stop_callback,
val_check_interval=1.0,
max_epochs=expected_count,
)
trainer.fit(model)
@pytest.mark.parametrize('loss_values, patience, expected_stop_epoch', [
([6, 5, 5, 5, 5, 5], 3, 4),
([6, 5, 4, 4, 3, 3], 1, 3),
([6, 5, 6, 5, 5, 5], 3, 4),
])
def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch):
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
class ModelOverrideValidationReturn(EvalModelTemplate):
validation_return_values = torch.Tensor(loss_values)
count = 0
def validation_epoch_end(self, outputs):
loss = self.validation_return_values[self.count]
self.count += 1
return {"test_val_loss": loss}
model = ModelOverrideValidationReturn()
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=early_stop_callback,
val_check_interval=1.0,
num_sanity_val_steps=0,
max_epochs=10,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch
def test_pickling(tmpdir):
early_stopping = EarlyStopping()
early_stopping_pickled = pickle.dumps(early_stopping)
early_stopping_loaded = pickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)
early_stopping_pickled = cloudpickle.dumps(early_stopping)
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)

View File

@ -0,0 +1,107 @@
import os
import pickle
import platform
from pathlib import Path
import cloudpickle
import pytest
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()
model = EvalModelTemplate()
checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_pct=0.20,
max_epochs=(save_top_k + 2),
)
trainer.fit(model)
# These should be different if the dirpath has be overridden
assert trainer.ckpt_path != trainer.default_root_dir
@pytest.mark.parametrize(
'logger_version,expected',
[(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
)
def test_model_checkpoint_path(tmpdir, logger_version, expected):
"""Test that "version_" prefix is only added when logger's version is an integer"""
tutils.reset_seed()
model = EvalModelTemplate()
logger = TensorBoardLogger(str(tmpdir), version=logger_version)
trainer = Trainer(
default_root_dir=tmpdir,
overfit_pct=0.2,
max_epochs=5,
logger=logger,
)
trainer.fit(model)
ckpt_version = Path(trainer.ckpt_path).parent.name
assert ckpt_version == expected
def test_pickling(tmpdir):
ckpt = ModelCheckpoint(tmpdir)
ckpt_pickled = pickle.dumps(ckpt)
ckpt_loaded = pickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)
ckpt_pickled = cloudpickle.dumps(ckpt)
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)
class ModelCheckpointTestInvocations(ModelCheckpoint):
# this class has to be defined outside the test function, otherwise we get pickle error
# due to the way ddp process is launched
def __init__(self, expected_count, *args, **kwargs):
super().__init__(*args, **kwargs)
self.count = 0
self.expected_count = expected_count
def _save_model(self, filepath):
# make sure we don't save twice
assert not os.path.isfile(filepath)
self.count += 1
super()._save_model(filepath)
def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)
# on rank 0 we expect the saved files and on all others no saves
assert (trainer.global_rank == 0 and self.count == self.expected_count) \
or (trainer.global_rank > 0 and self.count == 0)
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
def test_model_checkpoint_no_extraneous_invocations(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
model = EvalModelTemplate()
num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(expected_count=num_epochs, save_top_k=-1)
trainer = Trainer(
distributed_backend='ddp_cpu',
num_processes=2,
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
)
result = trainer.fit(model)
assert 1 == result

View File

@ -13,10 +13,11 @@ from tests.base import EvalModelTemplate
([ProgressBar(refresh_rate=2)], 0),
([ProgressBar(refresh_rate=2)], 1),
])
def test_progress_bar_on(callbacks, refresh_rate):
def test_progress_bar_on(tmpdir, callbacks, refresh_rate):
"""Test different ways the progress bar can be turned on."""
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
progress_bar_refresh_rate=refresh_rate,
max_epochs=1,
@ -34,10 +35,11 @@ def test_progress_bar_on(callbacks, refresh_rate):
([], False),
([ModelCheckpoint('../trainer')], 0),
])
def test_progress_bar_off(callbacks, refresh_rate):
def test_progress_bar_off(tmpdir, callbacks, refresh_rate):
"""Test different ways the progress bar can be turned off."""
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
progress_bar_refresh_rate=refresh_rate,
)
@ -54,12 +56,13 @@ def test_progress_bar_misconfiguration():
Trainer(callbacks=callbacks)
def test_progress_bar_totals():
def test_progress_bar_totals(tmpdir):
"""Test that the progress finishes with the correct total steps processed."""
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=1,
limit_val_batches=1.0,
max_epochs=1,
@ -105,10 +108,11 @@ def test_progress_bar_totals():
assert bar.test_batch_idx == k
def test_progress_bar_fast_dev_run():
def test_progress_bar_fast_dev_run(tmpdir):
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)
@ -136,7 +140,7 @@ def test_progress_bar_fast_dev_run():
@pytest.mark.parametrize('refresh_rate', [0, 1, 50])
def test_progress_bar_progress_refresh(refresh_rate):
def test_progress_bar_progress_refresh(tmpdir, refresh_rate):
"""Test that the three progress bars get correctly updated when using different refresh rates."""
model = EvalModelTemplate()
@ -172,6 +176,7 @@ def test_progress_bar_progress_refresh(refresh_rate):
progress_bar = CurrentProgressBar(refresh_rate=refresh_rate)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[progress_bar],
progress_bar_refresh_rate=101, # should not matter if custom callback provided
limit_train_batches=1.0,

View File

@ -108,7 +108,11 @@ def test_multiple_loggers_pickle(tmpdir):
logger1 = CustomLogger()
logger2 = CustomLogger()
trainer = Trainer(max_epochs=1, logger=[logger1, logger2])
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=[logger1, logger2],
)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0}, 0)

View File

@ -1,12 +1,12 @@
import os
import pickle
from unittest.mock import patch
from unittest import mock
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
@patch('pytorch_lightning.loggers.wandb.wandb')
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_logger(wandb):
"""Verify that basic functionality of wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here."""
@ -29,8 +29,8 @@ def test_wandb_logger(wandb):
assert logger.version == wandb.init().id
@patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_pickle(wandb):
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_pickle(wandb, tmpdir):
"""Verify that pickling trainer with wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here.
@ -38,11 +38,18 @@ def test_wandb_pickle(wandb):
class Experiment:
id = 'the_id'
def project_name(self):
return 'the_project_name'
wandb.init.return_value = Experiment()
logger = WandbLogger(id='the_id', offline=True)
trainer = Trainer(max_epochs=1, logger=logger)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
)
# Access the experiment to ensure it's created
assert trainer.logger.experiment, 'missing experiment'
pkl_bytes = pickle.dumps(trainer)

View File

@ -21,7 +21,7 @@ def test_amp_single_gpu(tmpdir, backend):
max_epochs=1,
gpus=1,
distributed_backend=backend,
precision=16
precision=16,
)
model = EvalModelTemplate()
@ -100,6 +100,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gpus=[0],
distributed_backend='ddp',

View File

@ -24,6 +24,7 @@ def test_cpu_slurm_save_load(tmpdir):
# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
limit_train_batches=0.2,
@ -54,13 +55,14 @@ def test_cpu_slurm_save_load(tmpdir):
# test HPC saving
# simulate snapshot on slurm
saved_filepath = trainer.hpc_save(tmpdir, logger)
saved_filepath = trainer.hpc_save(trainer.weights_save_path, logger)
assert os.path.exists(saved_filepath)
# new logger file to get meta
logger = tutils.get_default_logger(tmpdir, version=version)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir),
@ -212,6 +214,7 @@ def test_running_test_no_val(tmpdir):
# fit model
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
limit_train_batches=0.4,

View File

@ -84,6 +84,7 @@ def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
logger = OnlyMetricsListLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
logger=logger,
track_grad_norm=norm_type,

View File

@ -8,7 +8,7 @@ from tests.base import EvalModelTemplate
@pytest.mark.parametrize('max_steps', [1, 2, 3])
def test_on_before_zero_grad_called(max_steps):
def test_on_before_zero_grad_called(tmpdir, max_steps):
class CurrentTestModel(EvalModelTemplate):
on_before_zero_grad_called = 0
@ -19,7 +19,9 @@ def test_on_before_zero_grad_called(max_steps):
model = CurrentTestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=max_steps,
max_epochs=2,
num_sanity_val_steps=5,
)
assert 0 == model.on_before_zero_grad_called

View File

@ -154,6 +154,7 @@ def test_dp_resume(tmpdir):
max_epochs=1,
gpus=2,
distributed_backend='dp',
default_root_dir=tmpdir,
)
# get logger

View File

@ -199,7 +199,7 @@ def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path):
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2
limit_train_batches=0.2,
)
fit_options = dict(train_dataloader=model.dataloader(train=True),
val_dataloaders=model.dataloader(train=False))
@ -401,7 +401,7 @@ def test_inf_train_dataloader(tmpdir, check_interval):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=check_interval
val_check_interval=check_interval,
)
result = trainer.fit(model)
# verify training completed
@ -440,7 +440,7 @@ def test_error_on_zero_len_dataloader(tmpdir):
max_epochs=1,
limit_train_batches=0.1,
limit_val_batches=0.1,
limit_test_batches=0.1
limit_test_batches=0.1,
)
trainer.fit(model)
@ -534,7 +534,7 @@ def test_dataloader_reinit_for_subclass():
@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
def test_batch_size_smaller_than_num_gpus():
def test_batch_size_smaller_than_num_gpus(tmpdir):
# we need at least 3 gpus for this test
num_gpus = 3
batch_size = 3
@ -572,6 +572,7 @@ def test_batch_size_smaller_than_num_gpus():
model = CurrentTestModel(**hparams)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=0.1,
limit_val_batches=0,

View File

@ -57,7 +57,7 @@ def test_trainer_reset_correctly(tmpdir):
changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find',
'early_stop_callback', 'accumulate_grad_batches',
'enable_early_stop', 'checkpoint_callback']
'checkpoint_callback']
attributes_before = {}
for ca in changed_attributes:
attributes_before[ca] = getattr(trainer, ca)

View File

@ -36,9 +36,10 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
logger = tutils.get_default_logger(tmpdir)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir)
checkpoint_callback=ModelCheckpoint(tmpdir),
)
# fit model
result = trainer.fit(model)
@ -77,9 +78,10 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir)
checkpoint_callback=ModelCheckpoint(tmpdir),
)
result = trainer.fit(model)
@ -297,8 +299,9 @@ def test_model_checkpoint_only_weights(tmpdir):
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_weights_only=True)
checkpoint_callback=ModelCheckpoint(tmpdir, save_weights_only=True),
)
# fit model
result = trainer.fit(model)
@ -469,7 +472,7 @@ def test_trainer_min_steps_and_epochs(tmpdir):
early_stop_callback=EarlyStopping(monitor='val_loss', min_delta=1.0),
val_check_interval=2,
min_epochs=1,
max_epochs=2
max_epochs=7
)
# define less min steps than 1 epoch
@ -592,7 +595,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
assert loaded_checkpoint_path == ckpt_path
def test_disabled_validation():
def test_disabled_validation(tmpdir):
"""Verify that `limit_val_batches=0` disables the validation loop unless `fast_dev_run=True`."""
class CurrentModel(EvalModelTemplate):
@ -612,6 +615,7 @@ def test_disabled_validation():
model = CurrentModel(**hparams)
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=2,
limit_train_batches=0.4,
@ -664,7 +668,7 @@ def test_nan_loss_detection(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=(model.test_batch_inf_loss + 1),
terminate_on_nan=True
terminate_on_nan=True,
)
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
@ -689,7 +693,7 @@ def test_nan_params_detection(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=(model.test_batch_nan + 1),
terminate_on_nan=True
terminate_on_nan=True,
)
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
@ -757,7 +761,7 @@ def test_gradient_clipping(tmpdir):
max_steps=1,
max_epochs=1,
gradient_clip_val=1.0,
default_root_dir=tmpdir
default_root_dir=tmpdir,
)
# for the test
@ -944,7 +948,7 @@ def test_trainer_omegaconf(trainer_params):
def test_trainer_pickle(tmpdir):
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir
default_root_dir=tmpdir,
)
pickle.dumps(trainer)
cloudpickle.dumps(trainer)

View File

@ -11,10 +11,10 @@ import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
@mock.patch('argparse.ArgumentParser.parse_args',
return_value=Namespace(**Trainer.default_attributes()))
def test_default_args(tmpdir):
@mock.patch('argparse.ArgumentParser.parse_args')
def test_default_args(mock_argparse, tmpdir):
"""Tests default argument parser for Trainer"""
mock_argparse.return_value = Namespace(**Trainer.default_attributes())
# logger file to get meta
logger = tutils.get_default_logger(tmpdir)

View File

@ -2,7 +2,7 @@ from pytorch_lightning import Trainer
from tests.base.deterministic_model import DeterministicModel
def test_trainingstep_dict(tmpdir):
def test_training_step_dict(tmpdir):
"""
Tests that only training_step can be used
"""
@ -10,7 +10,11 @@ def test_trainingstep_dict(tmpdir):
model.training_step = model.training_step_dict_return
model.val_dataloader = None
trainer = Trainer(fast_dev_run=True, weights_summary=None)
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
@ -74,6 +78,7 @@ def test_full_training_loop_dict(tmpdir):
model.val_dataloader = None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
weights_summary=None,
)
@ -112,10 +117,7 @@ def test_train_step_epoch_end(tmpdir):
model.training_epoch_end = model.training_epoch_end_dict
model.val_dataloader = None
trainer = Trainer(
max_epochs=1,
weights_summary=None,
)
trainer = Trainer(max_epochs=1, weights_summary=None)
trainer.fit(model)
# make sure correct steps were called

View File

@ -118,7 +118,7 @@ def test_model_reset_correctly(tmpdir):
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1
max_epochs=1,
)
before_state_dict = model.state_dict()
@ -141,7 +141,7 @@ def test_trainer_reset_correctly(tmpdir):
# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1
max_epochs=1,
)
changed_attributes = ['max_steps',
@ -150,7 +150,6 @@ def test_trainer_reset_correctly(tmpdir):
'callbacks',
'checkpoint_callback',
'early_stop_callback',
'enable_early_stop',
'limit_train_batches']
attributes_before = {}
@ -224,7 +223,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
auto_scale_batch_size='power'
auto_scale_batch_size='power',
)
fit_options = dict(train_dataloader=model.dataloader(train=True))