* add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * cannot pass an int as default_save_path * refactor log message * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * fix test with new epoch indexing * fix progress bar totals * fix off by one error (see #2289) epoch starts at 0 now * added missing imports * fix hpc_save folderpath * fix formatting * fix tests * small fixes from a rebase * fix * tmpdir * tmpdir * tmpdir * wandb * fix merge conflict * add back evaluation after training * test_resume_early_stopping_from_checkpoint TODO * undo the horovod check * update changelog * remove a duplicate test from merge error * try fix dp_resume test * add the logger fix from master * try remove default_root_dir * try mocking numpy * try import numpy in docs test * fix wandb test * pep 8 fix * skip if no amp * dont mock when doctesting * install extra * fix the resume ES test * undo conf.py changes * revert remove comet pickle from test * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update weights_loading.rst * Update weights_loading.rst * Update weights_loading.rst * renamed flag * renamed flag * revert the None check in logger experiment name/version * add the old comments * _experiment * test chckpointing on DDP * skip the ddp test on windows * cloudpickle * renamed flag * renamed flag * parentheses for clarity * apply suggestion max epochs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
1e16681693
commit
25ee51bc57
|
@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387))
|
||||
|
||||
- Fixed several issues with early stopping and checkpoint callbacks ([#1504](https://github.com/PyTorchLightning/pytorch-lightning/pull/1504), [#2391](https://github.com/PyTorchLightning/pytorch-lightning/pull/2391))
|
||||
|
||||
- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405))
|
||||
|
||||
- Fixed loading model without arguments ([#2403](https://github.com/PyTorchLightning/pytorch-lightning/pull/2403))
|
||||
|
|
|
@ -48,6 +48,19 @@ We successfully extended functionality without polluting our super clean
|
|||
|
||||
----------------
|
||||
|
||||
Best Practices
|
||||
==============
|
||||
|
||||
1. Callbacks should be isolated in their functionality. Your callback should not rely on the
|
||||
behavior of other callbacks in order to work properly.
|
||||
2. Do not manually call methods from the callback. The callbacks are designed to be
|
||||
invoked at specific times during training. Directly calling methods (eg. `on_validation_end`)
|
||||
is strongly discouraged.
|
||||
3. Whenever possible, your callbacks should not depend on the order in which they are executed.
|
||||
|
||||
|
||||
---------
|
||||
|
||||
.. automodule:: pytorch_lightning.callbacks.base
|
||||
:noindex:
|
||||
:exclude-members:
|
||||
|
|
|
@ -92,6 +92,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
|
|||
.. testcode::
|
||||
|
||||
from pytorch_lightning.loggers import NeptuneLogger
|
||||
|
||||
neptune_logger = NeptuneLogger(
|
||||
api_key='ANONYMOUS', # replace with your own
|
||||
project_name='shared/pytorch-lightning-integration',
|
||||
|
@ -193,7 +194,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
|
|||
.. testcode::
|
||||
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
wandb_logger = WandbLogger()
|
||||
wandb_logger = WandbLogger(offline=True)
|
||||
trainer = Trainer(logger=wandb_logger)
|
||||
|
||||
The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your
|
||||
|
|
|
@ -29,7 +29,7 @@ Automatic saving
|
|||
Checkpointing is enabled by default to the current working directory.
|
||||
To change the checkpoint path pass in:
|
||||
|
||||
.. testcode::
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints')
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ Early Stopping
|
|||
Monitor a validation metric and stop training when it stops improving.
|
||||
|
||||
"""
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -58,7 +59,7 @@ class EarlyStopping(Callback):
|
|||
self.verbose = verbose
|
||||
self.strict = strict
|
||||
self.min_delta = min_delta
|
||||
self.wait = 0
|
||||
self.wait_count = 0
|
||||
self.stopped_epoch = 0
|
||||
self.mode = mode
|
||||
|
||||
|
@ -76,12 +77,17 @@ class EarlyStopping(Callback):
|
|||
log.info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.')
|
||||
|
||||
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
|
||||
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
|
||||
|
||||
def _validate_condition_metric(self, logs):
|
||||
"""
|
||||
Checks that the condition metric for early stopping is good
|
||||
:param logs:
|
||||
:return:
|
||||
|
||||
Args:
|
||||
logs: callback metrics from validation output
|
||||
|
||||
Return:
|
||||
True if specified metric is available
|
||||
"""
|
||||
monitor_val = logs.get(self.monitor)
|
||||
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
|
||||
|
@ -103,39 +109,48 @@ class EarlyStopping(Callback):
|
|||
def monitor_op(self):
|
||||
return self.mode_dict[self.mode]
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
# Allow instances to be re-used
|
||||
self.wait = 0
|
||||
self.stopped_epoch = 0
|
||||
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf
|
||||
def state_dict(self):
|
||||
return {
|
||||
'wait_count': self.wait_count,
|
||||
'stopped_epoch': self.stopped_epoch,
|
||||
'best_score': self.best_score,
|
||||
'patience': self.patience
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
state_dict = deepcopy(state_dict)
|
||||
self.wait_count = state_dict['wait_count']
|
||||
self.stopped_epoch = state_dict['stopped_epoch']
|
||||
self.best_score = state_dict['best_score']
|
||||
self.patience = state_dict['patience']
|
||||
|
||||
def on_sanity_check_end(self, trainer, pl_module):
|
||||
logs = trainer.callback_metrics
|
||||
self._validate_condition_metric(logs)
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
return self._run_early_stopping_check(trainer, pl_module)
|
||||
self._run_early_stopping_check(trainer, pl_module)
|
||||
|
||||
def _run_early_stopping_check(self, trainer, pl_module):
|
||||
logs = trainer.callback_metrics
|
||||
stop_training = False
|
||||
if not self._validate_condition_metric(logs):
|
||||
return stop_training
|
||||
return # short circuit if metric not present
|
||||
|
||||
current = logs.get(self.monitor)
|
||||
if not isinstance(current, torch.Tensor):
|
||||
current = torch.tensor(current)
|
||||
|
||||
if self.monitor_op(current - self.min_delta, self.best):
|
||||
self.best = current
|
||||
self.wait = 0
|
||||
if self.monitor_op(current - self.min_delta, self.best_score):
|
||||
self.best_score = current
|
||||
self.wait_count = 0
|
||||
else:
|
||||
self.wait += 1
|
||||
if self.wait >= self.patience:
|
||||
self.wait_count += 1
|
||||
if self.wait_count >= self.patience:
|
||||
self.stopped_epoch = trainer.current_epoch
|
||||
stop_training = True
|
||||
self.on_train_end(trainer, pl_module)
|
||||
|
||||
return stop_training
|
||||
trainer.should_stop = True
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
if self.stopped_epoch > 0 and self.verbose > 0:
|
||||
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
|
||||
' but will start from "0" in v0.8.0.', DeprecationWarning)
|
||||
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')
|
||||
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping triggered.')
|
||||
|
|
|
@ -226,6 +226,41 @@ class ModelCheckpoint(Callback):
|
|||
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
|
||||
return filepath
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
"""
|
||||
Determine model checkpoint save directory at runtime. References attributes from the
|
||||
Trainer's logger to determine where to save checkpoints.
|
||||
"""
|
||||
if self.dirpath is not None:
|
||||
return # short circuit
|
||||
|
||||
self.filename = '{epoch}'
|
||||
|
||||
if trainer.logger is not None and trainer.logger.experiment is not None:
|
||||
# weights_save_path overrides anything
|
||||
if getattr(trainer, 'weights_save_path', None) is not None:
|
||||
save_dir = trainer.weights_save_path
|
||||
else:
|
||||
save_dir = (getattr(trainer.logger, 'save_dir', None)
|
||||
or getattr(trainer.logger, '_save_dir', None)
|
||||
or trainer.default_root_dir)
|
||||
|
||||
version = trainer.logger.version if isinstance(
|
||||
trainer.logger.version, str) else f'version_{trainer.logger.version}'
|
||||
ckpt_path = os.path.join(
|
||||
save_dir,
|
||||
trainer.logger.name,
|
||||
version,
|
||||
"checkpoints"
|
||||
)
|
||||
else:
|
||||
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")
|
||||
|
||||
self.dirpath = ckpt_path
|
||||
os.makedirs(self.dirpath, exist_ok=True)
|
||||
trainer.ckpt_path = ckpt_path
|
||||
trainer.weights_save_path = ckpt_path
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
# only run on main process
|
||||
|
|
|
@ -131,12 +131,12 @@ class WandbLogger(LightningLoggerBase):
|
|||
self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def name(self) -> Optional[str]:
|
||||
# don't create an experiment if we don't have one
|
||||
name = self._experiment.project_name() if self._experiment else None
|
||||
return name
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
def version(self) -> Optional[str]:
|
||||
# don't create an experiment if we don't have one
|
||||
return self._experiment.id if self._experiment else None
|
||||
|
|
|
@ -32,79 +32,47 @@ class TrainerCallbackConfigMixin(ABC):
|
|||
def is_overridden(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
def configure_checkpoint_callback(self):
|
||||
def configure_checkpoint_callback(self, checkpoint_callback):
|
||||
"""
|
||||
Weight path set in this priority:
|
||||
Checkpoint_callback's path (if passed in).
|
||||
User provided weights_saved_path
|
||||
Otherwise use os.getcwd()
|
||||
"""
|
||||
ckpt_path = self.default_root_dir
|
||||
if self.checkpoint_callback:
|
||||
# init a default one
|
||||
if self.logger is not None and self.logger.experiment is not None:
|
||||
save_dir = (getattr(self.logger, 'save_dir', None) or
|
||||
getattr(self.logger, '_save_dir', None) or
|
||||
self.default_root_dir)
|
||||
|
||||
# weights_save_path overrides anything
|
||||
if self.weights_save_path is not None:
|
||||
save_dir = self.weights_save_path
|
||||
|
||||
version = self.logger.version if isinstance(
|
||||
self.logger.version, str) else f'version_{self.logger.version}'
|
||||
ckpt_path = os.path.join(save_dir, self.logger.name, version, "checkpoints")
|
||||
else:
|
||||
ckpt_path = os.path.join(self.default_root_dir, "checkpoints")
|
||||
|
||||
if checkpoint_callback is True:
|
||||
# when no val step is defined, use 'loss' otherwise 'val_loss'
|
||||
train_step_only = not self.is_overridden('validation_step')
|
||||
monitor_key = 'loss' if train_step_only else 'val_loss'
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=None,
|
||||
monitor=monitor_key
|
||||
)
|
||||
elif checkpoint_callback is False:
|
||||
checkpoint_callback = None
|
||||
|
||||
if self.checkpoint_callback is True:
|
||||
os.makedirs(ckpt_path, exist_ok=True)
|
||||
self.checkpoint_callback = ModelCheckpoint(
|
||||
filepath=ckpt_path,
|
||||
monitor=monitor_key
|
||||
)
|
||||
# If user specified None in filepath, override with runtime default
|
||||
elif isinstance(self.checkpoint_callback, ModelCheckpoint) \
|
||||
and self.checkpoint_callback.dirpath is None:
|
||||
self.checkpoint_callback.dirpath = ckpt_path
|
||||
self.checkpoint_callback.filename = '{epoch}'
|
||||
os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True)
|
||||
elif self.checkpoint_callback is False:
|
||||
self.checkpoint_callback = None
|
||||
|
||||
self.ckpt_path = ckpt_path
|
||||
|
||||
if self.checkpoint_callback:
|
||||
# set the path for the callbacks
|
||||
self.checkpoint_callback.save_function = self.save_checkpoint
|
||||
|
||||
# if checkpoint callback used, then override the weights path
|
||||
self.weights_save_path = self.checkpoint_callback.dirpath
|
||||
if checkpoint_callback:
|
||||
checkpoint_callback.save_function = self.save_checkpoint
|
||||
|
||||
# if weights_save_path is still none here, set to current working dir
|
||||
if self.weights_save_path is None:
|
||||
self.weights_save_path = self.default_root_dir
|
||||
|
||||
return checkpoint_callback
|
||||
|
||||
def configure_early_stopping(self, early_stop_callback):
|
||||
if early_stop_callback is True or None:
|
||||
self.early_stop_callback = EarlyStopping(
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=3,
|
||||
strict=True,
|
||||
verbose=True,
|
||||
mode='min'
|
||||
)
|
||||
self.enable_early_stop = True
|
||||
elif not early_stop_callback:
|
||||
self.early_stop_callback = None
|
||||
self.enable_early_stop = False
|
||||
early_stop_callback = None
|
||||
else:
|
||||
self.early_stop_callback = early_stop_callback
|
||||
self.enable_early_stop = True
|
||||
early_stop_callback = early_stop_callback
|
||||
return early_stop_callback
|
||||
|
||||
def configure_progress_bar(self, refresh_rate=1, process_position=0):
|
||||
progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)]
|
||||
|
|
|
@ -172,7 +172,6 @@ class TrainerDDPMixin(ABC):
|
|||
num_gpu_nodes: int
|
||||
gpus: List[int]
|
||||
logger: Union[LightningLoggerBase, bool]
|
||||
checkpoint_callback: Union[ModelCheckpoint, bool]
|
||||
data_parallel_device_ids: ...
|
||||
distributed_backend: Optional[str]
|
||||
amp_level: str
|
||||
|
|
|
@ -163,7 +163,6 @@ class TrainerLRFinderMixin(ABC):
|
|||
# Disable standard checkpoint & early stopping
|
||||
self.checkpoint_callback = False
|
||||
self.early_stop_callback = None
|
||||
self.enable_early_stop = False
|
||||
|
||||
# Required for saving the model
|
||||
self.optimizers, self.schedulers = [], [],
|
||||
|
@ -215,7 +214,6 @@ class TrainerLRFinderMixin(ABC):
|
|||
'max_steps': self.max_steps,
|
||||
'checkpoint_callback': self.checkpoint_callback,
|
||||
'early_stop_callback': self.early_stop_callback,
|
||||
'enable_early_stop': self.enable_early_stop,
|
||||
'configure_optimizers': model.configure_optimizers,
|
||||
}
|
||||
|
||||
|
@ -226,7 +224,6 @@ class TrainerLRFinderMixin(ABC):
|
|||
self.max_steps = self.__dumped_params['max_steps']
|
||||
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
|
||||
self.early_stop_callback = self.__dumped_params['early_stop_callback']
|
||||
self.enable_early_stop = self.__dumped_params['enable_early_stop']
|
||||
model.configure_optimizers = self.__dumped_params['configure_optimizers']
|
||||
del self.__dumped_params
|
||||
|
||||
|
|
|
@ -328,9 +328,60 @@ class Trainer(
|
|||
if 'LOCAL_RANK' in os.environ:
|
||||
rank_zero_only.rank = os.environ['LOCAL_RANK']
|
||||
|
||||
# Init callbacks
|
||||
# training bookeeping
|
||||
self.total_batch_idx = 0
|
||||
self.running_loss = TensorRunningAccum(window_length=20)
|
||||
self.batch_idx = 0
|
||||
self.progress_bar_metrics = {}
|
||||
self.callback_metrics = {}
|
||||
self.num_training_batches = 0
|
||||
self.num_val_batches = []
|
||||
self.num_test_batches = []
|
||||
self.train_dataloader = None
|
||||
self.test_dataloaders = None
|
||||
self.val_dataloaders = None
|
||||
|
||||
# training state
|
||||
self.model = None
|
||||
self.testing = False
|
||||
self.disable_validation = False
|
||||
self.prepare_data_per_node = prepare_data_per_node
|
||||
self.lr_schedulers = []
|
||||
self.optimizers = None
|
||||
self.optimizer_frequencies = []
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.interrupted = False
|
||||
self.should_stop = False
|
||||
|
||||
# set default save path if user didn't provide one
|
||||
if default_root_dir is None:
|
||||
default_root_dir = os.getcwd()
|
||||
self.default_root_dir = default_root_dir
|
||||
|
||||
self.configure_logger(logger)
|
||||
|
||||
# init callbacks
|
||||
self.callbacks = callbacks or []
|
||||
|
||||
# configure early stop callback
|
||||
# creates a default one if none passed in
|
||||
early_stop_callback = self.configure_early_stopping(early_stop_callback)
|
||||
if early_stop_callback:
|
||||
self.callbacks.append(early_stop_callback)
|
||||
|
||||
# configure checkpoint callback
|
||||
# it is important that this is the last callback to run
|
||||
# pass through the required args to figure out defaults
|
||||
self.weights_save_path = weights_save_path
|
||||
checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback)
|
||||
if checkpoint_callback:
|
||||
self.callbacks.append(checkpoint_callback)
|
||||
|
||||
# TODO refactor codebase (tests) to not directly reach into these callbacks
|
||||
self.checkpoint_callback = checkpoint_callback
|
||||
self.early_stop_callback = early_stop_callback
|
||||
|
||||
self.on_init_start()
|
||||
|
||||
# benchmarking
|
||||
|
@ -399,52 +450,11 @@ class Trainer(
|
|||
rank_zero_info('Running in fast_dev_run mode: will run a full train,'
|
||||
' val and test loop using a single batch')
|
||||
|
||||
# set default save path if user didn't provide one
|
||||
self.default_root_dir = default_root_dir
|
||||
|
||||
if self.default_root_dir is None:
|
||||
self.default_root_dir = os.getcwd()
|
||||
|
||||
# training bookeeping
|
||||
self.total_batch_idx = 0
|
||||
self.running_loss = TensorRunningAccum(window_length=20)
|
||||
self.batch_idx = 0
|
||||
self.progress_bar_metrics = {}
|
||||
self.callback_metrics = {}
|
||||
self.num_val_batches = [0]
|
||||
self.num_training_batches = 0
|
||||
self.num_test_batches = [0]
|
||||
self.train_dataloader = None
|
||||
self.test_dataloaders = None
|
||||
self.val_dataloaders = None
|
||||
|
||||
# training state
|
||||
self.model = None
|
||||
self.testing = False
|
||||
self.disable_validation = False
|
||||
self.lr_schedulers = []
|
||||
self.optimizers = None
|
||||
self.optimizer_frequencies = []
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.interrupted = False
|
||||
|
||||
# configure logger
|
||||
self.configure_logger(logger)
|
||||
|
||||
# configure profiler
|
||||
if profiler is True:
|
||||
profiler = SimpleProfiler()
|
||||
self.profiler = profiler or PassThroughProfiler()
|
||||
|
||||
# configure early stop callback
|
||||
# creates a default one if none passed in
|
||||
self.configure_early_stopping(early_stop_callback)
|
||||
|
||||
# configure checkpoint callback
|
||||
self.checkpoint_callback = checkpoint_callback
|
||||
self.weights_save_path = weights_save_path
|
||||
|
||||
# accumulated grads
|
||||
self.accumulate_grad_batches = accumulate_grad_batches
|
||||
self.configure_accumulated_gradients(accumulate_grad_batches)
|
||||
|
@ -1045,9 +1055,6 @@ class Trainer(
|
|||
# if cluster resets state, the model will update with the saved weights
|
||||
self.model = model
|
||||
|
||||
# set up checkpoint callback
|
||||
self.configure_checkpoint_callback()
|
||||
|
||||
# restore training and model before hpc call
|
||||
self.restore_weights(model)
|
||||
|
||||
|
@ -1078,13 +1085,10 @@ class Trainer(
|
|||
max_batches,
|
||||
False)
|
||||
_, _, _, callback_metrics, _ = self.process_output(eval_results)
|
||||
self.callback_metrics = callback_metrics
|
||||
|
||||
self.on_sanity_check_end()
|
||||
|
||||
# verify that early stop has conditioned on a metric that exists
|
||||
if self.enable_early_stop:
|
||||
self.early_stop_callback._validate_condition_metric(callback_metrics)
|
||||
|
||||
# clear cache before training
|
||||
if self.on_gpu and self.root_gpu is not None:
|
||||
# use context because of:
|
||||
|
|
|
@ -95,6 +95,7 @@ import torch.distributed as torch_distrib
|
|||
import pytorch_lightning
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.overrides.data_parallel import (
|
||||
LightningDistributedDataParallel,
|
||||
|
@ -328,26 +329,32 @@ class TrainerIOMixin(ABC):
|
|||
}
|
||||
|
||||
if not weights_only:
|
||||
if self.checkpoint_callback:
|
||||
|
||||
# TODO support more generic way for callbacks to persist a state_dict in a checkpoint
|
||||
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
|
||||
early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)]
|
||||
|
||||
if checkpoint_callbacks:
|
||||
# we add the official checkpoint callback to the end of the list
|
||||
# extra user provided callbacks will not be persisted yet
|
||||
checkpoint['checkpoint_callback_best_model_score'] = self.checkpoint_callback.best_model_score
|
||||
checkpoint['checkpoint_callback_best_model_path'] = self.checkpoint_callback.best_model_path
|
||||
|
||||
if self.early_stop_callback:
|
||||
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
|
||||
checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience
|
||||
if early_stopping_callbacks and checkpoint_callbacks:
|
||||
# we add the official early stopping callback to the end of the list
|
||||
# extra user provided callbacks will not be persisted yet
|
||||
checkpoint['early_stop_callback_state_dict'] = early_stopping_callbacks[-1].state_dict()
|
||||
|
||||
# save optimizers
|
||||
optimizer_states = []
|
||||
for i, optimizer in enumerate(self.optimizers):
|
||||
optimizer_states.append(optimizer.state_dict())
|
||||
|
||||
checkpoint['optimizer_states'] = optimizer_states
|
||||
|
||||
# save lr schedulers
|
||||
lr_schedulers = []
|
||||
for scheduler in self.lr_schedulers:
|
||||
lr_schedulers.append(scheduler['scheduler'].state_dict())
|
||||
|
||||
checkpoint['lr_schedulers'] = lr_schedulers
|
||||
|
||||
# save native amp scaling
|
||||
|
@ -405,21 +412,25 @@ class TrainerIOMixin(ABC):
|
|||
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
|
||||
)
|
||||
|
||||
if self.checkpoint_callback:
|
||||
# TODO support more generic way for callbacks to load callback state_dicts
|
||||
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
|
||||
early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)]
|
||||
|
||||
if checkpoint_callbacks:
|
||||
if 'checkpoint_callback_best_model_score' in checkpoint:
|
||||
self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best_model_score']
|
||||
checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best_model_score']
|
||||
else:
|
||||
# Old naming until version 0.7.6
|
||||
rank_zero_warn(
|
||||
'Loading a checkpoint created with an old version of Lightning; '
|
||||
'this will not be supported in the future.'
|
||||
)
|
||||
self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best']
|
||||
self.checkpoint_callback.best_model_path = checkpoint['checkpoint_callback_best_model_path']
|
||||
checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best']
|
||||
checkpoint_callbacks[-1].best_model_path = checkpoint['checkpoint_callback_best_model_path']
|
||||
|
||||
if self.early_stop_callback:
|
||||
self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']
|
||||
self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']
|
||||
if early_stopping_callbacks:
|
||||
state = checkpoint['early_stop_callback_state_dict']
|
||||
early_stopping_callbacks[-1].load_state_dict(state)
|
||||
|
||||
self.global_step = checkpoint['global_step']
|
||||
self.current_epoch = checkpoint['epoch']
|
||||
|
|
|
@ -144,8 +144,7 @@ in your model.
|
|||
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import signal
|
||||
import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
from typing import Union, List
|
||||
|
@ -157,6 +156,7 @@ import torch.distributed as torch_distrib
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
|
@ -164,7 +164,6 @@ from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -212,7 +211,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
fast_dev_run: ...
|
||||
accumulation_scheduler: ...
|
||||
lr_schedulers: ...
|
||||
enable_early_stop: ...
|
||||
early_stop_callback: ...
|
||||
callback_metrics: ...
|
||||
logger: Union[LightningLoggerBase, bool]
|
||||
|
@ -239,7 +237,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
max_steps: int
|
||||
min_steps: int
|
||||
total_batch_idx: int
|
||||
checkpoint_callback: ...
|
||||
terminate_on_nan: bool
|
||||
tpu_id: int
|
||||
interactive_ddp_procs: ...
|
||||
|
@ -264,7 +261,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def run_evaluation(self, *args):
|
||||
def run_evaluation(self, *args, **kwargs):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -340,9 +337,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
with self.profiler.profile('on_train_start'):
|
||||
# callbacks
|
||||
self.on_train_start()
|
||||
# initialize early stop callback
|
||||
if self.early_stop_callback is not None:
|
||||
self.early_stop_callback.on_train_start(self, self.get_model())
|
||||
# model hooks
|
||||
model.on_train_start()
|
||||
|
||||
|
@ -375,7 +369,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# -----------------
|
||||
self.run_training_epoch()
|
||||
|
||||
if self.max_steps and self.max_steps == self.global_step:
|
||||
if self.max_steps and self.max_steps <= self.global_step:
|
||||
self.run_training_teardown()
|
||||
return
|
||||
|
||||
|
@ -386,19 +380,14 @@ class TrainerTrainLoopMixin(ABC):
|
|||
met_min_epochs = epoch >= self.min_epochs - 1
|
||||
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
|
||||
|
||||
# TODO wrap this logic into the callback
|
||||
# DO NOT DELETE
|
||||
# early stopping as a (new Callback) class doesn't yet work because we have to know these
|
||||
# trainer flags including the current epoch stuff
|
||||
# all of this needs to go into the early stopping to clean up better
|
||||
if self.enable_early_stop:
|
||||
if self.should_stop:
|
||||
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
|
||||
should_stop = self.early_stop_callback.on_validation_end(self, self.get_model())
|
||||
# stop training
|
||||
stop = should_stop and met_min_epochs
|
||||
if stop:
|
||||
self.run_training_teardown()
|
||||
return
|
||||
self.run_training_teardown()
|
||||
return
|
||||
else:
|
||||
log.info('Trainer was signaled to stop but required minimum epochs'
|
||||
f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
|
||||
' not been met. Training will continue...')
|
||||
|
||||
self.run_training_teardown()
|
||||
|
||||
|
@ -444,6 +433,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# bookkeeping
|
||||
epoch_output = []
|
||||
should_check_val = False
|
||||
|
||||
# run epoch
|
||||
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
|
||||
|
@ -470,22 +460,24 @@ class TrainerTrainLoopMixin(ABC):
|
|||
self.update_train_loop_lr_schedulers()
|
||||
|
||||
# when returning -1 from train_step, we end epoch early
|
||||
early_stop_epoch = batch_output.signal == -1
|
||||
self.should_stop = batch_output.signal == -1
|
||||
|
||||
# -----------------------------------------
|
||||
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
|
||||
# -----------------------------------------
|
||||
should_check_val = self.check_validation_in_train_loop(batch_idx, early_stop_epoch, is_last_batch)
|
||||
should_check_val = self.should_check_val(batch_idx, is_last_batch)
|
||||
if self.fast_dev_run or should_check_val:
|
||||
self.run_evaluation(test_mode=False)
|
||||
|
||||
# -----------------------------------------
|
||||
# SAVE LOGGERS (ie: Tensorboard, etc...)
|
||||
# -----------------------------------------
|
||||
self.save_loggers_in_training_loop(batch_idx, early_stop_epoch)
|
||||
self.save_loggers_in_training_loop(batch_idx)
|
||||
|
||||
# -----------------------------------------
|
||||
# SAVE METRICS TO LOGGERS
|
||||
# -----------------------------------------
|
||||
self.save_train_loop_metrics_to_loggers(batch_idx, early_stop_epoch, batch_output)
|
||||
self.save_train_loop_metrics_to_loggers(batch_idx, batch_output)
|
||||
|
||||
# progress global step according to grads progress
|
||||
self.increment_accumulated_grad_global_step()
|
||||
|
@ -497,7 +489,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# end epoch early
|
||||
# stop when the flag is changed or we've gone past the amount
|
||||
# requested in the batches
|
||||
if early_stop_epoch or self.fast_dev_run:
|
||||
if self.fast_dev_run or self.should_stop:
|
||||
break
|
||||
|
||||
# let ddp devices catch up when using horovod
|
||||
|
@ -506,13 +498,19 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# process epoch outputs
|
||||
self.run_training_epoch_end(epoch_output)
|
||||
|
||||
# when no val loop is present or fast-dev-run still need to call checkpoints
|
||||
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
|
||||
self.call_checkpoint_callback()
|
||||
# checkpoint callback
|
||||
self.check_checkpoint_callback(should_check_val)
|
||||
|
||||
# epoch end hook
|
||||
self.run_on_epoch_end_hook(model)
|
||||
|
||||
def check_checkpoint_callback(self, should_check_val):
|
||||
# when no val loop is present or fast-dev-run still need to call checkpoints
|
||||
# TODO bake this logic into the checkpoint callback
|
||||
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
|
||||
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
|
||||
[c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks]
|
||||
|
||||
def update_train_loop_lr_schedulers(self):
|
||||
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
|
||||
# update lr
|
||||
|
@ -553,33 +551,28 @@ class TrainerTrainLoopMixin(ABC):
|
|||
self.global_step += 1
|
||||
self.total_batch_idx += 1
|
||||
|
||||
def save_train_loop_metrics_to_loggers(self, batch_idx, early_stop_epoch, batch_output):
|
||||
def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output):
|
||||
# when metrics should be logged
|
||||
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
|
||||
should_log_metrics = batch_idx % self.row_log_interval == 0 or self.should_stop
|
||||
if should_log_metrics or self.fast_dev_run:
|
||||
# logs user requested information to logger
|
||||
self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic)
|
||||
|
||||
def save_loggers_in_training_loop(self, batch_idx, early_stop_epoch):
|
||||
def save_loggers_in_training_loop(self, batch_idx):
|
||||
# when loggers should save to disk
|
||||
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
|
||||
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or self.should_stop
|
||||
if should_save_log or self.fast_dev_run:
|
||||
if self.is_global_zero and self.logger is not None:
|
||||
self.logger.save()
|
||||
|
||||
def check_validation_in_train_loop(self, batch_idx, early_stop_epoch, is_last_batch):
|
||||
def should_check_val(self, batch_idx, is_last_batch):
|
||||
# decide if we should run validation
|
||||
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
|
||||
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
|
||||
can_check_val = not self.disable_validation and can_check_epoch
|
||||
should_check_val = is_val_check_batch or early_stop_epoch
|
||||
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
|
||||
should_check_val = can_check_val and should_check_val
|
||||
|
||||
# if we need to run validation, then also call the checkpoint callback
|
||||
if self.fast_dev_run or should_check_val:
|
||||
self.run_evaluation(test_mode=self.testing)
|
||||
self.call_checkpoint_callback()
|
||||
should_check_val = is_val_check_batch or self.should_stop
|
||||
is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf'))
|
||||
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)
|
||||
|
||||
return should_check_val
|
||||
|
||||
|
@ -984,10 +977,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
else:
|
||||
lr_scheduler['scheduler'].step()
|
||||
|
||||
def call_checkpoint_callback(self):
|
||||
if self.checkpoint_callback is not None:
|
||||
self.checkpoint_callback.on_validation_end(self, self.get_model())
|
||||
|
||||
|
||||
def _with_is_last(iterable):
|
||||
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
|
||||
|
|
|
@ -188,7 +188,6 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
'callbacks': self.callbacks,
|
||||
'checkpoint_callback': self.checkpoint_callback,
|
||||
'early_stop_callback': self.early_stop_callback,
|
||||
'enable_early_stop': self.enable_early_stop,
|
||||
'auto_scale_batch_size': self.auto_scale_batch_size,
|
||||
'limit_train_batches': self.limit_train_batches,
|
||||
'model': self.model,
|
||||
|
@ -202,7 +201,6 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
self.callbacks = [] # not needed before full run
|
||||
self.checkpoint_callback = False # required for saving
|
||||
self.early_stop_callback = None
|
||||
self.enable_early_stop = False
|
||||
self.limit_train_batches = 1.0
|
||||
self.optimizers, self.schedulers = [], [] # required for saving
|
||||
self.model = model # required for saving
|
||||
|
@ -215,7 +213,6 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
|
||||
self.auto_scale_batch_size = self.__dumped_params['auto_scale_batch_size']
|
||||
self.early_stop_callback = self.__dumped_params['early_stop_callback']
|
||||
self.enable_early_stop = self.__dumped_params['enable_early_stop']
|
||||
self.limit_train_batches = self.__dumped_params['limit_train_batches']
|
||||
self.model = self.__dumped_params['model']
|
||||
del self.__dumped_params
|
||||
|
|
|
@ -160,6 +160,7 @@ def test_trainer_callback_system(tmpdir):
|
|||
test_callback = TestCallback()
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[test_callback],
|
||||
max_epochs=1,
|
||||
limit_val_batches=0.1,
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
import pickle
|
||||
|
||||
import cloudpickle
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
def test_resume_early_stopping_from_checkpoint(tmpdir):
|
||||
"""
|
||||
Prevent regressions to bugs:
|
||||
https://github.com/PyTorchLightning/pytorch-lightning/issues/1464
|
||||
https://github.com/PyTorchLightning/pytorch-lightning/issues/1463
|
||||
"""
|
||||
|
||||
class EarlyStoppingTestStore(EarlyStopping):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# cache the state for each epoch
|
||||
self.saved_states = []
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
super().on_validation_end(trainer, pl_module)
|
||||
self.saved_states.append(self.state_dict().copy())
|
||||
|
||||
class EarlyStoppingTestRestore(EarlyStopping):
|
||||
def __init__(self, expected_state):
|
||||
super().__init__()
|
||||
self.expected_state = expected_state
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
assert self.state_dict() == self.expected_state
|
||||
|
||||
model = EvalModelTemplate()
|
||||
checkpoint_callback = ModelCheckpoint(save_top_k=1)
|
||||
early_stop_callback = EarlyStoppingTestStore()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
early_stop_callback=early_stop_callback,
|
||||
max_epochs=4,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
checkpoint_filepath = checkpoint_callback.kth_best_model
|
||||
# ensure state is persisted properly
|
||||
checkpoint = torch.load(checkpoint_filepath)
|
||||
# the checkpoint saves "epoch + 1"
|
||||
early_stop_callback_state = early_stop_callback.saved_states[checkpoint['epoch'] - 1]
|
||||
assert 4 == len(early_stop_callback.saved_states)
|
||||
assert checkpoint['early_stop_callback_state_dict'] == early_stop_callback_state
|
||||
|
||||
# ensure state is reloaded properly (assertion in the callback)
|
||||
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state)
|
||||
new_trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
resume_from_checkpoint=checkpoint_filepath,
|
||||
early_stop_callback=early_stop_callback,
|
||||
)
|
||||
new_trainer.fit(model)
|
||||
|
||||
|
||||
def test_early_stopping_no_extraneous_invocations(tmpdir):
|
||||
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
|
||||
class EarlyStoppingTestInvocations(EarlyStopping):
|
||||
def __init__(self, expected_count):
|
||||
super().__init__()
|
||||
self.count = 0
|
||||
self.expected_count = expected_count
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
self.count += 1
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
assert self.count == self.expected_count
|
||||
|
||||
model = EvalModelTemplate()
|
||||
expected_count = 4
|
||||
early_stop_callback = EarlyStoppingTestInvocations(expected_count)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=early_stop_callback,
|
||||
val_check_interval=1.0,
|
||||
max_epochs=expected_count,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('loss_values, patience, expected_stop_epoch', [
|
||||
([6, 5, 5, 5, 5, 5], 3, 4),
|
||||
([6, 5, 4, 4, 3, 3], 1, 3),
|
||||
([6, 5, 6, 5, 5, 5], 3, 4),
|
||||
])
|
||||
def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch):
|
||||
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
|
||||
|
||||
class ModelOverrideValidationReturn(EvalModelTemplate):
|
||||
validation_return_values = torch.Tensor(loss_values)
|
||||
count = 0
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
loss = self.validation_return_values[self.count]
|
||||
self.count += 1
|
||||
return {"test_val_loss": loss}
|
||||
|
||||
model = ModelOverrideValidationReturn()
|
||||
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=early_stop_callback,
|
||||
val_check_interval=1.0,
|
||||
num_sanity_val_steps=0,
|
||||
max_epochs=10,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.current_epoch == expected_stop_epoch
|
||||
|
||||
|
||||
def test_pickling(tmpdir):
|
||||
early_stopping = EarlyStopping()
|
||||
|
||||
early_stopping_pickled = pickle.dumps(early_stopping)
|
||||
early_stopping_loaded = pickle.loads(early_stopping_pickled)
|
||||
assert vars(early_stopping) == vars(early_stopping_loaded)
|
||||
|
||||
early_stopping_pickled = cloudpickle.dumps(early_stopping)
|
||||
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
|
||||
assert vars(early_stopping) == vars(early_stopping_loaded)
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
import pickle
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
import cloudpickle
|
||||
import pytest
|
||||
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
|
||||
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
||||
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
|
||||
tutils.reset_seed()
|
||||
model = EvalModelTemplate()
|
||||
|
||||
checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
checkpoint_callback=checkpoint,
|
||||
overfit_pct=0.20,
|
||||
max_epochs=(save_top_k + 2),
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# These should be different if the dirpath has be overridden
|
||||
assert trainer.ckpt_path != trainer.default_root_dir
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'logger_version,expected',
|
||||
[(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
|
||||
)
|
||||
def test_model_checkpoint_path(tmpdir, logger_version, expected):
|
||||
"""Test that "version_" prefix is only added when logger's version is an integer"""
|
||||
tutils.reset_seed()
|
||||
model = EvalModelTemplate()
|
||||
logger = TensorBoardLogger(str(tmpdir), version=logger_version)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
overfit_pct=0.2,
|
||||
max_epochs=5,
|
||||
logger=logger,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
ckpt_version = Path(trainer.ckpt_path).parent.name
|
||||
assert ckpt_version == expected
|
||||
|
||||
|
||||
def test_pickling(tmpdir):
|
||||
ckpt = ModelCheckpoint(tmpdir)
|
||||
|
||||
ckpt_pickled = pickle.dumps(ckpt)
|
||||
ckpt_loaded = pickle.loads(ckpt_pickled)
|
||||
assert vars(ckpt) == vars(ckpt_loaded)
|
||||
|
||||
ckpt_pickled = cloudpickle.dumps(ckpt)
|
||||
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
|
||||
assert vars(ckpt) == vars(ckpt_loaded)
|
||||
|
||||
|
||||
class ModelCheckpointTestInvocations(ModelCheckpoint):
|
||||
# this class has to be defined outside the test function, otherwise we get pickle error
|
||||
# due to the way ddp process is launched
|
||||
|
||||
def __init__(self, expected_count, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.count = 0
|
||||
self.expected_count = expected_count
|
||||
|
||||
def _save_model(self, filepath):
|
||||
# make sure we don't save twice
|
||||
assert not os.path.isfile(filepath)
|
||||
self.count += 1
|
||||
super()._save_model(filepath)
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
super().on_train_end(trainer, pl_module)
|
||||
# on rank 0 we expect the saved files and on all others no saves
|
||||
assert (trainer.global_rank == 0 and self.count == self.expected_count) \
|
||||
or (trainer.global_rank > 0 and self.count == 0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
def test_model_checkpoint_no_extraneous_invocations(tmpdir):
|
||||
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
|
||||
model = EvalModelTemplate()
|
||||
num_epochs = 4
|
||||
model_checkpoint = ModelCheckpointTestInvocations(expected_count=num_epochs, save_top_k=-1)
|
||||
trainer = Trainer(
|
||||
distributed_backend='ddp_cpu',
|
||||
num_processes=2,
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=False,
|
||||
checkpoint_callback=model_checkpoint,
|
||||
max_epochs=num_epochs,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
assert 1 == result
|
|
@ -13,10 +13,11 @@ from tests.base import EvalModelTemplate
|
|||
([ProgressBar(refresh_rate=2)], 0),
|
||||
([ProgressBar(refresh_rate=2)], 1),
|
||||
])
|
||||
def test_progress_bar_on(callbacks, refresh_rate):
|
||||
def test_progress_bar_on(tmpdir, callbacks, refresh_rate):
|
||||
"""Test different ways the progress bar can be turned on."""
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=callbacks,
|
||||
progress_bar_refresh_rate=refresh_rate,
|
||||
max_epochs=1,
|
||||
|
@ -34,10 +35,11 @@ def test_progress_bar_on(callbacks, refresh_rate):
|
|||
([], False),
|
||||
([ModelCheckpoint('../trainer')], 0),
|
||||
])
|
||||
def test_progress_bar_off(callbacks, refresh_rate):
|
||||
def test_progress_bar_off(tmpdir, callbacks, refresh_rate):
|
||||
"""Test different ways the progress bar can be turned off."""
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=callbacks,
|
||||
progress_bar_refresh_rate=refresh_rate,
|
||||
)
|
||||
|
@ -54,12 +56,13 @@ def test_progress_bar_misconfiguration():
|
|||
Trainer(callbacks=callbacks)
|
||||
|
||||
|
||||
def test_progress_bar_totals():
|
||||
def test_progress_bar_totals(tmpdir):
|
||||
"""Test that the progress finishes with the correct total steps processed."""
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
progress_bar_refresh_rate=1,
|
||||
limit_val_batches=1.0,
|
||||
max_epochs=1,
|
||||
|
@ -105,10 +108,11 @@ def test_progress_bar_totals():
|
|||
assert bar.test_batch_idx == k
|
||||
|
||||
|
||||
def test_progress_bar_fast_dev_run():
|
||||
def test_progress_bar_fast_dev_run(tmpdir):
|
||||
model = EvalModelTemplate()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
)
|
||||
|
||||
|
@ -136,7 +140,7 @@ def test_progress_bar_fast_dev_run():
|
|||
|
||||
|
||||
@pytest.mark.parametrize('refresh_rate', [0, 1, 50])
|
||||
def test_progress_bar_progress_refresh(refresh_rate):
|
||||
def test_progress_bar_progress_refresh(tmpdir, refresh_rate):
|
||||
"""Test that the three progress bars get correctly updated when using different refresh rates."""
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
@ -172,6 +176,7 @@ def test_progress_bar_progress_refresh(refresh_rate):
|
|||
|
||||
progress_bar = CurrentProgressBar(refresh_rate=refresh_rate)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[progress_bar],
|
||||
progress_bar_refresh_rate=101, # should not matter if custom callback provided
|
||||
limit_train_batches=1.0,
|
||||
|
|
|
@ -108,7 +108,11 @@ def test_multiple_loggers_pickle(tmpdir):
|
|||
logger1 = CustomLogger()
|
||||
logger2 = CustomLogger()
|
||||
|
||||
trainer = Trainer(max_epochs=1, logger=[logger1, logger2])
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
logger=[logger1, logger2],
|
||||
)
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({"acc": 1.0}, 0)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
import pickle
|
||||
from unittest.mock import patch
|
||||
from unittest import mock
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
|
||||
@patch('pytorch_lightning.loggers.wandb.wandb')
|
||||
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
|
||||
def test_wandb_logger(wandb):
|
||||
"""Verify that basic functionality of wandb logger works.
|
||||
Wandb doesn't work well with pytest so we have to mock it out here."""
|
||||
|
@ -29,8 +29,8 @@ def test_wandb_logger(wandb):
|
|||
assert logger.version == wandb.init().id
|
||||
|
||||
|
||||
@patch('pytorch_lightning.loggers.wandb.wandb')
|
||||
def test_wandb_pickle(wandb):
|
||||
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
|
||||
def test_wandb_pickle(wandb, tmpdir):
|
||||
"""Verify that pickling trainer with wandb logger works.
|
||||
|
||||
Wandb doesn't work well with pytest so we have to mock it out here.
|
||||
|
@ -38,11 +38,18 @@ def test_wandb_pickle(wandb):
|
|||
class Experiment:
|
||||
id = 'the_id'
|
||||
|
||||
def project_name(self):
|
||||
return 'the_project_name'
|
||||
|
||||
wandb.init.return_value = Experiment()
|
||||
|
||||
logger = WandbLogger(id='the_id', offline=True)
|
||||
|
||||
trainer = Trainer(max_epochs=1, logger=logger)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
logger=logger,
|
||||
)
|
||||
# Access the experiment to ensure it's created
|
||||
assert trainer.logger.experiment, 'missing experiment'
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
|
|
|
@ -21,7 +21,7 @@ def test_amp_single_gpu(tmpdir, backend):
|
|||
max_epochs=1,
|
||||
gpus=1,
|
||||
distributed_backend=backend,
|
||||
precision=16
|
||||
precision=16,
|
||||
)
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
@ -100,6 +100,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
|
|||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
gpus=[0],
|
||||
distributed_backend='ddp',
|
||||
|
|
|
@ -24,6 +24,7 @@ def test_cpu_slurm_save_load(tmpdir):
|
|||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
logger=logger,
|
||||
limit_train_batches=0.2,
|
||||
|
@ -54,13 +55,14 @@ def test_cpu_slurm_save_load(tmpdir):
|
|||
|
||||
# test HPC saving
|
||||
# simulate snapshot on slurm
|
||||
saved_filepath = trainer.hpc_save(tmpdir, logger)
|
||||
saved_filepath = trainer.hpc_save(trainer.weights_save_path, logger)
|
||||
assert os.path.exists(saved_filepath)
|
||||
|
||||
# new logger file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir, version=version)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
logger=logger,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir),
|
||||
|
@ -212,6 +214,7 @@ def test_running_test_no_val(tmpdir):
|
|||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
limit_train_batches=0.4,
|
||||
|
|
|
@ -84,6 +84,7 @@ def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
|
|||
logger = OnlyMetricsListLogger()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=3,
|
||||
logger=logger,
|
||||
track_grad_norm=norm_type,
|
||||
|
|
|
@ -8,7 +8,7 @@ from tests.base import EvalModelTemplate
|
|||
|
||||
|
||||
@pytest.mark.parametrize('max_steps', [1, 2, 3])
|
||||
def test_on_before_zero_grad_called(max_steps):
|
||||
def test_on_before_zero_grad_called(tmpdir, max_steps):
|
||||
|
||||
class CurrentTestModel(EvalModelTemplate):
|
||||
on_before_zero_grad_called = 0
|
||||
|
@ -19,7 +19,9 @@ def test_on_before_zero_grad_called(max_steps):
|
|||
model = CurrentTestModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=max_steps,
|
||||
max_epochs=2,
|
||||
num_sanity_val_steps=5,
|
||||
)
|
||||
assert 0 == model.on_before_zero_grad_called
|
||||
|
|
|
@ -154,6 +154,7 @@ def test_dp_resume(tmpdir):
|
|||
max_epochs=1,
|
||||
gpus=2,
|
||||
distributed_backend='dp',
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
|
||||
# get logger
|
||||
|
|
|
@ -199,7 +199,7 @@ def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path):
|
|||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_val_batches=0.1,
|
||||
limit_train_batches=0.2
|
||||
limit_train_batches=0.2,
|
||||
)
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True),
|
||||
val_dataloaders=model.dataloader(train=False))
|
||||
|
@ -401,7 +401,7 @@ def test_inf_train_dataloader(tmpdir, check_interval):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_check_interval=check_interval
|
||||
val_check_interval=check_interval,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
# verify training completed
|
||||
|
@ -440,7 +440,7 @@ def test_error_on_zero_len_dataloader(tmpdir):
|
|||
max_epochs=1,
|
||||
limit_train_batches=0.1,
|
||||
limit_val_batches=0.1,
|
||||
limit_test_batches=0.1
|
||||
limit_test_batches=0.1,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
@ -534,7 +534,7 @@ def test_dataloader_reinit_for_subclass():
|
|||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
|
||||
def test_batch_size_smaller_than_num_gpus():
|
||||
def test_batch_size_smaller_than_num_gpus(tmpdir):
|
||||
# we need at least 3 gpus for this test
|
||||
num_gpus = 3
|
||||
batch_size = 3
|
||||
|
@ -572,6 +572,7 @@ def test_batch_size_smaller_than_num_gpus():
|
|||
model = CurrentTestModel(**hparams)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=0.1,
|
||||
limit_val_batches=0,
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_trainer_reset_correctly(tmpdir):
|
|||
|
||||
changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find',
|
||||
'early_stop_callback', 'accumulate_grad_batches',
|
||||
'enable_early_stop', 'checkpoint_callback']
|
||||
'checkpoint_callback']
|
||||
attributes_before = {}
|
||||
for ca in changed_attributes:
|
||||
attributes_before[ca] = getattr(trainer, ca)
|
||||
|
|
|
@ -36,9 +36,10 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
|
|||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
logger=logger,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir)
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir),
|
||||
)
|
||||
# fit model
|
||||
result = trainer.fit(model)
|
||||
|
@ -77,9 +78,10 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
|
|||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
logger=logger,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir)
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir),
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
||||
|
@ -297,8 +299,9 @@ def test_model_checkpoint_only_weights(tmpdir):
|
|||
model = EvalModelTemplate()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir, save_weights_only=True)
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir, save_weights_only=True),
|
||||
)
|
||||
# fit model
|
||||
result = trainer.fit(model)
|
||||
|
@ -469,7 +472,7 @@ def test_trainer_min_steps_and_epochs(tmpdir):
|
|||
early_stop_callback=EarlyStopping(monitor='val_loss', min_delta=1.0),
|
||||
val_check_interval=2,
|
||||
min_epochs=1,
|
||||
max_epochs=2
|
||||
max_epochs=7
|
||||
)
|
||||
|
||||
# define less min steps than 1 epoch
|
||||
|
@ -592,7 +595,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
|
|||
assert loaded_checkpoint_path == ckpt_path
|
||||
|
||||
|
||||
def test_disabled_validation():
|
||||
def test_disabled_validation(tmpdir):
|
||||
"""Verify that `limit_val_batches=0` disables the validation loop unless `fast_dev_run=True`."""
|
||||
|
||||
class CurrentModel(EvalModelTemplate):
|
||||
|
@ -612,6 +615,7 @@ def test_disabled_validation():
|
|||
model = CurrentModel(**hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
progress_bar_refresh_rate=0,
|
||||
max_epochs=2,
|
||||
limit_train_batches=0.4,
|
||||
|
@ -664,7 +668,7 @@ def test_nan_loss_detection(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=(model.test_batch_inf_loss + 1),
|
||||
terminate_on_nan=True
|
||||
terminate_on_nan=True,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
|
||||
|
@ -689,7 +693,7 @@ def test_nan_params_detection(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=(model.test_batch_nan + 1),
|
||||
terminate_on_nan=True
|
||||
terminate_on_nan=True,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
|
||||
|
@ -757,7 +761,7 @@ def test_gradient_clipping(tmpdir):
|
|||
max_steps=1,
|
||||
max_epochs=1,
|
||||
gradient_clip_val=1.0,
|
||||
default_root_dir=tmpdir
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
|
||||
# for the test
|
||||
|
@ -944,7 +948,7 @@ def test_trainer_omegaconf(trainer_params):
|
|||
def test_trainer_pickle(tmpdir):
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
default_root_dir=tmpdir
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
pickle.dumps(trainer)
|
||||
cloudpickle.dumps(trainer)
|
||||
|
|
|
@ -11,10 +11,10 @@ import tests.base.develop_utils as tutils
|
|||
from pytorch_lightning import Trainer
|
||||
|
||||
|
||||
@mock.patch('argparse.ArgumentParser.parse_args',
|
||||
return_value=Namespace(**Trainer.default_attributes()))
|
||||
def test_default_args(tmpdir):
|
||||
@mock.patch('argparse.ArgumentParser.parse_args')
|
||||
def test_default_args(mock_argparse, tmpdir):
|
||||
"""Tests default argument parser for Trainer"""
|
||||
mock_argparse.return_value = Namespace(**Trainer.default_attributes())
|
||||
|
||||
# logger file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
|
|
@ -2,7 +2,7 @@ from pytorch_lightning import Trainer
|
|||
from tests.base.deterministic_model import DeterministicModel
|
||||
|
||||
|
||||
def test_trainingstep_dict(tmpdir):
|
||||
def test_training_step_dict(tmpdir):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
"""
|
||||
|
@ -10,7 +10,11 @@ def test_trainingstep_dict(tmpdir):
|
|||
model.training_step = model.training_step_dict_return
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(fast_dev_run=True, weights_summary=None)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure correct steps were called
|
||||
|
@ -74,6 +78,7 @@ def test_full_training_loop_dict(tmpdir):
|
|||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
@ -112,10 +117,7 @@ def test_train_step_epoch_end(tmpdir):
|
|||
model.training_epoch_end = model.training_epoch_end_dict
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer = Trainer(max_epochs=1, weights_summary=None)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure correct steps were called
|
||||
|
|
|
@ -118,7 +118,7 @@ def test_model_reset_correctly(tmpdir):
|
|||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1
|
||||
max_epochs=1,
|
||||
)
|
||||
|
||||
before_state_dict = model.state_dict()
|
||||
|
@ -141,7 +141,7 @@ def test_trainer_reset_correctly(tmpdir):
|
|||
# logger file to get meta
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1
|
||||
max_epochs=1,
|
||||
)
|
||||
|
||||
changed_attributes = ['max_steps',
|
||||
|
@ -150,7 +150,6 @@ def test_trainer_reset_correctly(tmpdir):
|
|||
'callbacks',
|
||||
'checkpoint_callback',
|
||||
'early_stop_callback',
|
||||
'enable_early_stop',
|
||||
'limit_train_batches']
|
||||
|
||||
attributes_before = {}
|
||||
|
@ -224,7 +223,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
|
|||
max_epochs=1,
|
||||
limit_val_batches=0.1,
|
||||
limit_train_batches=0.2,
|
||||
auto_scale_batch_size='power'
|
||||
auto_scale_batch_size='power',
|
||||
)
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True))
|
||||
|
||||
|
|
Loading…
Reference in New Issue