[WIP] Rename overfit_pct to overfit_batches (and fix) and val_percent_check and test_percent_check (and fix) (#2213)

* fixed percent check for val/test

* fixed percent check for val/test

* fixed percent check for val/test

* fixed percent check for val/test

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* overfit_pct now uses train loaders for val and test and does not shuffle

* add on fit_start on fit_end hooks

* add on fit_start on fit_end hooks

* add on fit_start on fit_end hooks

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
William Falcon 2020-06-17 08:03:28 -04:00 committed by GitHub
parent 97dfd3a80a
commit 04c794ca72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 425 additions and 217 deletions

View File

@ -21,6 +21,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added overfit_batches, limit_xxx_batches flags (overfit now uses training set for all three) ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213))
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
- Added metrics
* Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
* Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
@ -54,6 +58,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
- Deprecated `overfit_pct`, `val_percent_check`, `test_percent_check` ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213))
- Deprecated `ModelCheckpoint`'s attributes `best` and `kth_best_model` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
- Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917))

View File

@ -48,12 +48,19 @@ Make model overfit on subset of data
A good debugging technique is to take a tiny portion of your data (say 2 samples per class),
and try to get your model to overfit. If it can't, it's a sign it won't work with large datasets.
(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.overfit_pct`
(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.overfit_batches`
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
.. testcode::
trainer = Trainer(overfit_pct=0.01)
# use only 1% of training data (and use the same training Dataloader (with shuffle off) in val and test)
trainer = Trainer(overfit_batches=0.01)
# or overfit a number of batches
trainer = Trainer(overfit_batches=0.01)
With this flag, the train, val, and test sets will all be the same train set. We will also replace the sampler
in the training set to turn off shuffle for you.
Print a summary of your LightningModule
---------------------------------------

View File

@ -56,17 +56,17 @@ If you don't want to check 100% of the training/validation/test set (for debuggi
# DEFAULT
trainer = Trainer(
train_percent_check=1.0,
val_percent_check=1.0,
test_percent_check=1.0
limit_val_batches=1.0,
limit_test_batches=1.0
)
# check 10%, 20%, 30% only, respectively for training, validation and test set
trainer = Trainer(
train_percent_check=0.1,
val_percent_check=0.2,
test_percent_check=0.3
limit_val_batches=0.2,
limit_test_batches=0.3
)
.. note:: ``train_percent_check``, ``val_percent_check`` and ``test_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0. ``val_percent_check`` will be ignored if ``fast_dev_run=True``.
.. note:: ``train_percent_check``, ``limit_val_batches`` and ``limit_test_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches`` > 0. ``limit_val_batches`` will be ignored if ``fast_dev_run=True``.
.. note:: If you set ``val_percent_check=0``, validation will be disabled.
.. note:: If you set ``limit_val_batches=0``, validation will be disabled.

View File

@ -98,6 +98,7 @@ class ProgressBarBase(Callback):
elif not self.trainer.disable_validation:
is_val_epoch = trainer.current_epoch % trainer.check_val_every_n_epoch == 0
total_val_batches = trainer.num_val_batches if is_val_epoch else 0
total_val_batches = sum(total_val_batches)
return total_val_batches
@property
@ -111,6 +112,7 @@ class ProgressBarBase(Callback):
total_test_batches = len(self.trainer.test_dataloaders)
else:
total_test_batches = self.trainer.num_test_batches
total_test_batches = sum(total_test_batches)
return total_test_batches
def disable(self):

View File

@ -433,6 +433,40 @@ Example::
# default used by the Trainer
trainer = Trainer(gradient_clip_val=0.0)
limit_test_batches
^^^^^^^^^^^^^^^^^^
How much of test dataset to check.
Example::
# default used by the Trainer
trainer = Trainer(limit_test_batches=1.0)
# run through only 25% of the test set each epoch
trainer = Trainer(limit_test_batches=0.25)
# run for only 10 batches
trainer = Trainer(limit_test_batches=10)
limit_val_batches
^^^^^^^^^^^^^^^^^
How much of validation dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Example::
# default used by the Trainer
trainer = Trainer(limit_val_batches=1.0)
# run through only 25% of the validation set each epoch
trainer = Trainer(limit_val_batches=0.25)
# run for only 10 batches
trainer = Trainer(limit_val_batches=10)
log_gpu_memory
^^^^^^^^^^^^^^
Options:
@ -652,29 +686,28 @@ Example::
overfit_pct
^^^^^^^^^^^
Uses this much data of all datasets (training, validation, test).
.. warning:: .. deprecated:: 0.8.0.
Use `overfit_batches`. Will remove 1.0.0.
overfit_batches
^^^^^^^^^^^^^^^
Uses this much data of the training set. If will use the same training set for validation and testing.
If the training Dataloaders(shuffle=True), Lightning will automatically disable it.
Useful for quickly debugging or trying to overfit on purpose.
Example::
# default used by the Trainer
trainer = Trainer(overfit_pct=0.0)
trainer = Trainer(overfit_batches=0.0)
# use only 1% of the train, test, val datasets
trainer = Trainer(overfit_pct=0.01)
# equivalent:
trainer = Trainer(
train_percent_check=0.01,
val_percent_check=0.01,
test_percent_check=0.01
)
See Also:
- `train_percent_check`_
- `val_percent_check`_
- `test_percent_check`_
# use only 1% of the train set (and use the train set for val and test)
trainer = Trainer(overfit_batches=0.01)
# overfit on 10 of the same batches
trainer = Trainer(overfit_batches=10)
precision
^^^^^^^^^
@ -829,39 +862,7 @@ show_progress_bar
test_percent_check
^^^^^^^^^^^^^^^^^^
How much of test dataset to check.
Example::
# default used by the Trainer
trainer = Trainer(test_percent_check=1.0)
# run through only 25% of the test set each epoch
trainer = Trainer(test_percent_check=0.25)
val_check_interval
^^^^^^^^^^^^^^^^^^
How often within one training epoch to check the validation set.
Can specify as float or int.
- use (float) to check within a training epoch
- use (int) to check every n steps (batches)
.. code-block:: python
# default used by the Trainer
trainer = Trainer(val_check_interval=1.0)
Example::
# check validation set 4 times during a training epoch
trainer = Trainer(val_check_interval=0.25)
# check validation set every 1000 training batches
# use this when using iterableDataset and your dataset has no length
# (ie: production cases with streaming data)
trainer = Trainer(val_check_interval=1000)
.. warning:: deprecated in v0.8.0 please use `limit_test_batches`. Will remove in 1.0.0
track_grad_norm
^^^^^^^^^^^^^^^
@ -955,20 +956,36 @@ override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:
# do your own splitting on the batch
return splits
val_check_interval
^^^^^^^^^^^^^^^^^^
How often within one training epoch to check the validation set.
Can specify as float or int.
- use (float) to check within a training epoch
- use (int) to check every n steps (batches)
.. code-block:: python
# default used by the Trainer
trainer = Trainer(val_check_interval=1.0)
Example::
# check validation set 4 times during a training epoch
trainer = Trainer(val_check_interval=0.25)
# check validation set every 1000 training batches
# use this when using iterableDataset and your dataset has no length
# (ie: production cases with streaming data)
trainer = Trainer(val_check_interval=1000)
val_percent_check
^^^^^^^^^^^^^^^^^
How much of validation dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
.. warning:: deprecated in v0.8.0 please use `limit_val_batches`. Will remove in 1.0.0
Example::
# default used by the Trainer
trainer = Trainer(val_percent_check=1.0)
# run through only 25% of the validation set each epoch
trainer = Trainer(val_percent_check=0.25)
weights_save_path
^^^^^^^^^^^^^^^^^

View File

@ -70,12 +70,12 @@ class TrainerDataLoadingMixin(ABC):
num_training_batches: Union[int, float]
val_check_batch: ...
val_dataloaders: List[DataLoader]
num_val_batches: Union[int, float]
num_val_batches: List[Union[int, float]]
test_dataloaders: List[DataLoader]
num_test_batches: Union[int, float]
num_test_batches: List[Union[int, float]]
train_percent_check: float
val_percent_check: float
test_percent_check: float
limit_val_batches: float
limit_test_batches: float
replace_sampler_ddp: bool
num_nodes: int
num_processes: int
@ -85,11 +85,16 @@ class TrainerDataLoadingMixin(ABC):
def is_overridden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def _percent_range_check(self, name: str) -> None:
def _limit_eval_batches_check(self, name: str) -> None:
value = getattr(self, name)
# ints are fine
if isinstance(value, int):
return
msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}.'
if name == 'val_check_interval':
msg += ' If you want to disable validation set `val_percent_check` to 0.0 instead.'
msg += ' If you want to disable validation set `limit_val_batches` to 0.0 instead.'
if not 0. <= value <= 1.:
raise ValueError(msg)
@ -139,15 +144,21 @@ class TrainerDataLoadingMixin(ABC):
' distributed training. Either remove the sampler from your DataLoader or set'
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader)
dataloader = self.replace_sampler(dataloader, sampler)
dl_args = {
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
}
return dataloader
dl_args['sampler'] = self._get_distributed_sampler(dataloader)
dataloader = type(dataloader)(**dl_args)
def replace_sampler(self, dataloader, sampler):
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
dl_args = {
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
}
dl_args['sampler'] = sampler
dataloader = type(dataloader)(**dl_args)
return dataloader
def _get_distributed_sampler(self, dataloader):
@ -182,14 +193,17 @@ class TrainerDataLoadingMixin(ABC):
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self._worker_check(self.train_dataloader, 'train dataloader')
self._percent_range_check('train_percent_check')
self._limit_eval_batches_check('train_percent_check')
if not _has_len(self.train_dataloader):
self.num_training_batches = float('inf')
else:
# try getting the length
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
if isinstance(self.train_percent_check, float):
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
else:
self.num_training_batches = self.train_percent_check
# determine when to check validation
# if int passed in, val checks that often
@ -200,7 +214,7 @@ class TrainerDataLoadingMixin(ABC):
raise ValueError(
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `val_percent_check` to 0.0 instead.')
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
else:
if not _has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
@ -212,12 +226,14 @@ class TrainerDataLoadingMixin(ABC):
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
' checking validation every k training batches.')
else:
self._percent_range_check('val_check_interval')
self._limit_eval_batches_check('val_check_interval')
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[Union[int, float], List[DataLoader]]:
def _reset_eval_dataloader(self,
model: LightningModule,
mode: str) -> Tuple[List[Union[int, float]], List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
Args:
@ -227,17 +243,32 @@ class TrainerDataLoadingMixin(ABC):
Returns:
Tuple (num_batches, dataloaders)
"""
dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader'))
# use the training loader as val and test when overfitting
if self.overfit_batches > 0:
dataloaders = self.request_dataloader(getattr(model, 'train_dataloader'))
else:
dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader'))
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
# shuffling in val and test set is bad practice
for loader in dataloaders:
for loader_i in range(len(dataloaders)):
loader = dataloaders[loader_i]
# shuffling in val and test set is bad practice
if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
rank_zero_warn(
f'Your {mode}_dataloader has shuffle=True, it is best practice to turn'
' this off for validation and test dataloaders.')
# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0:
m = 'You requested to overfit but enabled training Dataloader shuffling. ' \
'we are turning it off for you'
rank_zero_warn(m)
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
else:
rank_zero_warn(
f'Your {mode}_dataloader has shuffle=True, it is best practice to turn'
' this off for validation and test dataloaders.')
if any([dl is None for dl in dataloaders]):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
@ -245,29 +276,47 @@ class TrainerDataLoadingMixin(ABC):
# add samplers
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None]
num_batches = 0
loader_num_batches = []
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = 0
self._worker_check(dataloader, f'{mode} dataloader {i}')
if not _has_len(dataloader):
num_batches = float('inf')
percent_check = getattr(self, f'{mode}_percent_check')
# percent or num_steps
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
if num_batches != float('inf'):
self._percent_range_check(f'{mode}_percent_check')
if num_batches != float('inf'):
self._limit_eval_batches_check(f'limit_{mode}_batches')
num_batches = sum(len(dataloader) for dataloader in dataloaders)
num_batches = int(num_batches * percent_check)
elif percent_check not in (0.0, 1.0):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
f' or when DataLoader does not implement `__len__`) for `{mode}_dataloader`,'
f' `Trainer({mode}_percent_check)` must be `0.0` or `1.0`.')
return num_batches, dataloaders
num_batches = len(dataloader)
# limit num batches either as a percent or num steps
if isinstance(limit_eval_batches, float):
num_batches = int(num_batches * limit_eval_batches)
else:
num_batches = limit_eval_batches
elif limit_eval_batches not in (0.0, 1.0):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset'
f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.')
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
min_pct = 1.0 / len(dataloader)
m = f'you requested to check {limit_eval_batches} of the {mode} dataloader but ' \
f'{limit_eval_batches}*{num_batches} = 0. Please increase the limit_{mode}_batches.' \
f'Try at least limit_{mode}_batches={min_pct}'
raise MisconfigurationException(m)
loader_num_batches.append(num_batches)
return loader_num_batches, dataloaders
def reset_val_dataloader(self, model: LightningModule) -> None:
"""Resets the validation dataloader and determines the number of batches.
@ -316,18 +365,19 @@ class TrainerDataLoadingMixin(ABC):
return dataloader
def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float,
test_percent_check: float, overfit_pct: float) -> None:
def determine_data_use_amount(self, train_percent_check: float, limit_val_batches: Union[int, float],
limit_test_batches: Union[int, float], overfit_batches: float) -> None:
"""Use less data for debugging purposes
"""
self.train_percent_check = train_percent_check
self.val_percent_check = val_percent_check
self.test_percent_check = test_percent_check
if overfit_pct > 0:
if overfit_pct > 1:
self.limit_val_batches = limit_val_batches
self.limit_test_batches = limit_test_batches
if overfit_batches > 0:
if isinstance(overfit_batches, float) and overfit_batches > 1:
raise ValueError(
f'`overfit_pct` must be not greater than 1.0, but got {overfit_pct:.3f}.')
f'`overfit_batches` when used as a percentage must '
f'be not 0.0 < x < 1.0 but got {overfit_batches:.3f}.')
self.train_percent_check = overfit_pct
self.val_percent_check = overfit_pct
self.test_percent_check = overfit_pct
self.train_percent_check = overfit_batches
self.limit_val_batches = overfit_batches
self.limit_test_batches = overfit_batches

View File

@ -24,30 +24,30 @@ Set how much of the validation set to check
If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag
val_percent_check will be overwritten by overfit_pct if `overfit_pct > 0`
limit_val_batches will be overwritten by overfit_batches if `overfit_batches > 0`
.. code-block:: python
# DEFAULT
trainer = Trainer(val_percent_check=1.0)
trainer = Trainer(limit_val_batches=1.0)
# check 10% only
trainer = Trainer(val_percent_check=0.1)
trainer = Trainer(limit_val_batches=0.1)
Set how much of the test set to check
-------------------------------------
If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag
test_percent_check will be overwritten by overfit_pct if `overfit_pct > 0`
limit_test_batches will be overwritten by overfit_batches if `overfit_batches > 0`
.. code-block:: python
# DEFAULT
trainer = Trainer(test_percent_check=1.0)
trainer = Trainer(limit_test_batches=1.0)
# check 10% only
trainer = Trainer(test_percent_check=0.1)
trainer = Trainer(limit_test_batches=0.1)
Set validation check frequency within 1 training epoch
------------------------------------------------------
@ -124,7 +124,7 @@ In this second case, the options you pass to trainer will be used when running
from abc import ABC, abstractmethod
from pprint import pprint
from typing import Callable, Optional
from typing import Callable, Optional, List
import torch
from torch.utils.data import DataLoader
@ -162,7 +162,7 @@ class TrainerEvaluationLoopMixin(ABC):
single_gpu: bool
data_parallel_device_ids: ...
model: LightningModule
num_test_batches: int
num_test_batches: List[int]
num_val_batches: int
fast_dev_run: ...
process_output: ...
@ -222,13 +222,13 @@ class TrainerEvaluationLoopMixin(ABC):
def reset_val_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False):
def _evaluate(self, model: LightningModule, dataloaders, max_batches: List[int], test_mode: bool = False):
"""Run evaluation code.
Args:
model: PT model
dataloaders: list of PT dataloaders
max_batches: Scalar
max_batches: List of scalars
test_mode:
"""
# enable eval mode
@ -254,12 +254,15 @@ class TrainerEvaluationLoopMixin(ABC):
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)
# each dataloader has a max num batches
dl_max_batches = max_batches[dataloader_idx]
for batch_idx, batch in enumerate(dataloader):
if batch is None:
continue
# stop short when on fast_dev_run (sets max_batch=1)
if batch_idx >= max_batches:
if batch_idx >= dl_max_batches:
break
# callbacks
@ -359,7 +362,7 @@ class TrainerEvaluationLoopMixin(ABC):
# cap max batches to 1 when using fast_dev_run
if self.fast_dev_run:
max_batches = 1
max_batches = [1]
# Validation/Test begin callbacks
if test_mode:
@ -367,6 +370,11 @@ class TrainerEvaluationLoopMixin(ABC):
else:
self.on_validation_start()
# enable disabling validation step with limit_val_batches = 0
should_skip = sum(max_batches) == 0
if should_skip:
return
# run evaluation
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)

View File

@ -93,7 +93,7 @@ class Trainer(
tpu_cores: Optional[Union[List[int], int]] = None,
log_gpu_memory: Optional[str] = None,
progress_bar_refresh_rate: int = 1,
overfit_pct: float = 0.0,
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: bool = False,
@ -103,8 +103,8 @@ class Trainer(
max_steps: Optional[int] = None,
min_steps: Optional[int] = None,
train_percent_check: float = 1.0,
val_percent_check: float = 1.0,
test_percent_check: float = 1.0,
limit_val_batches: Union[int, float] = 1.0,
limit_test_batches: Union[int, float] = 1.0,
val_check_interval: float = 1.0,
log_save_interval: int = 100,
row_log_interval: int = 50,
@ -129,6 +129,9 @@ class Trainer(
num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0
use_amp=None, # backward compatible, todo: remove in v0.9.0
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
val_percent_check: float = 1.0, # backward compatible, todo: remove in v1.0.0
test_percent_check: float = 1.0, # backward compatible, todo: remove in v1.0.0
overfit_pct: float = 0.0 # backward compatible, todo: remove in v1.0.0
):
r"""
@ -185,7 +188,12 @@ class Trainer(
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.
overfit_pct: How much of training-, validation-, and test dataset to check.
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).
overfit_pct:
.. warning:: .. deprecated:: 0.8.0
Use `overfit_batches` instead. Will remove 1.0.0.
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
@ -215,9 +223,19 @@ class Trainer(
train_percent_check: How much of training dataset to check.
val_percent_check: How much of validation dataset to check.
limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
test_percent_check: How much of test dataset to check.
limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)
val_percent_check:
.. warning:: .. deprecated:: 0.8.0
Use `min_epochs` instead. Will remove 1.0.0.
test_percent_check:
.. warning:: .. deprecated:: 0.8.0
Use `min_epochs` instead. Will remove 1.0.0.
val_check_interval: How often within one training epoch to check the validation set
@ -383,9 +401,9 @@ class Trainer(
self.batch_idx = 0
self.progress_bar_metrics = {}
self.callback_metrics = {}
self.num_val_batches = 0
self.num_val_batches = [0]
self.num_training_batches = 0
self.num_test_batches = 0
self.num_test_batches = [0]
self.train_dataloader = None
self.test_dataloaders = None
self.val_dataloaders = None
@ -468,9 +486,27 @@ class Trainer(
self.row_log_interval = row_log_interval
# how much of the data to use
self.overfit_pct = overfit_pct
self.determine_data_use_amount(train_percent_check, val_percent_check,
test_percent_check, overfit_pct)
# TODO: remove in 1.0.0
if overfit_pct > 0:
overfit_batches = overfit_pct
# convert floats to ints
overfit_batches = int(overfit_batches) if overfit_batches > 1.0 else overfit_batches
self.overfit_batches = overfit_batches
# TODO: remove in 1.0.0
if val_percent_check < 1.0:
limit_val_batches = val_percent_check
if test_percent_check < 1.0:
limit_test_batches = test_percent_check
limit_test_batches = int(limit_test_batches) if limit_test_batches > 1.0 else limit_test_batches
limit_val_batches = int(limit_val_batches) if limit_val_batches > 1.0 else limit_val_batches
# TODO: convert train_percent_check to limit_train_batches
self.determine_data_use_amount(train_percent_check, limit_val_batches,
limit_test_batches, overfit_batches)
# AMP init
# These are the only lines needed after v0.8.0
@ -984,7 +1020,7 @@ class Trainer(
return
# check if we should run validation during training
self.disable_validation = not (self.is_overridden('validation_step') and self.val_percent_check > 0) \
self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \
and not self.fast_dev_run
# run tiny validation (if validation defined)
@ -996,9 +1032,11 @@ class Trainer(
ref_model.on_sanity_check_start()
self.on_sanity_check_start()
num_loaders = len(self.val_dataloaders)
max_batches = [self.num_sanity_val_steps] * num_loaders
eval_results = self._evaluate(model,
self.val_dataloaders,
self.num_sanity_val_steps,
max_batches,
False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)

View File

@ -68,7 +68,7 @@ Set how much of the training set to check
If you don't want to check 100% of the training set (for debugging or if it's huge), set this flag.
train_percent_check will be overwritten by overfit_pct if `overfit_pct > 0`
train_percent_check will be overwritten by overfit_batches if `overfit_batches > 0`
.. code-block:: python
@ -202,7 +202,7 @@ class TrainerTrainLoopMixin(ABC):
check_val_every_n_epoch: ...
num_training_batches: int
val_check_batch: ...
num_val_batches: int
num_val_batches: List[int]
disable_validation: bool
fast_dev_run: ...
accumulation_scheduler: ...

View File

@ -24,7 +24,7 @@ def test_early_stopping_functionality(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=True,
overfit_pct=0.20,
overfit_batches=0.20,
max_epochs=20,
)
result = trainer.fit(model)
@ -152,7 +152,7 @@ def test_trainer_callback_system(tmpdir):
trainer_options = dict(
callbacks=[test_callback],
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2,
progress_bar_refresh_rate=0,
)
@ -258,7 +258,7 @@ def test_early_stopping_no_val_step(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=stopping,
overfit_pct=0.20,
overfit_batches=0.20,
max_epochs=2,
)
result = trainer.fit(model)
@ -292,7 +292,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
trainer = Trainer(default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_pct=0.20,
overfit_batches=0.20,
max_epochs=2
)
trainer.fit(model)
@ -313,7 +313,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
trainer = Trainer(
default_root_dir=tmpdir,
overfit_pct=0.2,
overfit_batches=0.2,
max_epochs=2,
logger=logger
)

View File

@ -17,7 +17,7 @@ def test_lr_logger_single_lr(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
@ -40,7 +40,7 @@ def test_lr_logger_no_lr(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
@ -61,7 +61,7 @@ def test_lr_logger_multi_lrs(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
@ -88,7 +88,7 @@ def test_lr_logger_param_groups(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)

View File

@ -20,7 +20,7 @@ def test_progress_bar_on(callbacks, refresh_rate):
callbacks=callbacks,
progress_bar_refresh_rate=refresh_rate,
max_epochs=1,
overfit_pct=0.2,
overfit_batches=0.2,
)
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
@ -61,7 +61,7 @@ def test_progress_bar_totals():
trainer = Trainer(
progress_bar_refresh_rate=1,
val_percent_check=1.0,
limit_val_batches=1.0,
max_epochs=1,
)
bar = trainer.progress_bar_callback

View File

@ -53,7 +53,7 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
max_epochs=1,
logger=logger,
train_percent_check=0.2,
val_percent_check=0.5,
limit_val_batches=0.5,
fast_dev_run=True,
)
trainer.fit(model)

View File

@ -145,7 +145,7 @@ def test_adding_step_key(tmpdir):
max_epochs=3,
default_root_dir=tmpdir,
train_percent_check=0.001,
val_percent_check=0.01,
limit_val_batches=0.1,
num_sanity_val_steps=0,
)
trainer.logger.log_metrics = _log_metrics_decorator(

View File

@ -99,7 +99,7 @@ def test_cpu_model_with_amp(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4,
limit_val_batches=0.4,
precision=16
)

View File

@ -26,7 +26,7 @@ def test_cpu_slurm_save_load(tmpdir):
max_epochs=1,
logger=logger,
train_percent_check=0.2,
val_percent_check=0.2,
limit_val_batches=0.2,
checkpoint_callback=ModelCheckpoint(tmpdir)
)
result = trainer.fit(model)
@ -90,10 +90,10 @@ def test_early_stopping_cpu_model(tmpdir):
early_stop_callback=stopping,
max_epochs=2,
gradient_clip_val=1.0,
overfit_pct=0.20,
overfit_batches=0.20,
track_grad_norm=2,
train_percent_check=0.1,
val_percent_check=0.1,
limit_val_batches=0.1,
)
model = EvalModelTemplate()
@ -119,7 +119,7 @@ def test_multi_cpu_model_ddp(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
limit_val_batches=0.2,
gpus=None,
num_processes=2,
distributed_backend='ddp_cpu'
@ -137,7 +137,7 @@ def test_lbfgs_cpu_model(tmpdir):
progress_bar_refresh_rate=0,
weights_summary='top',
train_percent_check=0.2,
val_percent_check=0.2,
limit_val_batches=0.2,
)
hparams = EvalModelTemplate.get_default_hparams()
@ -154,10 +154,10 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
default_root_dir=tmpdir,
max_epochs=1,
gradient_clip_val=1.0,
overfit_pct=0.20,
overfit_batches=0.20,
progress_bar_refresh_rate=0,
train_percent_check=0.01,
val_percent_check=0.01,
limit_val_batches=0.01,
)
model = EvalModelTemplate()
@ -184,8 +184,8 @@ def test_running_test_after_fitting(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=2,
train_percent_check=0.4,
val_percent_check=0.2,
test_percent_check=0.2,
limit_val_batches=0.2,
limit_test_batches=0.2,
checkpoint_callback=checkpoint,
logger=logger
)
@ -214,8 +214,8 @@ def test_running_test_no_val(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
test_percent_check=0.2,
limit_val_batches=0.2,
limit_test_batches=0.2,
checkpoint_callback=checkpoint,
logger=logger,
early_stop_callback=False
@ -238,7 +238,7 @@ def test_simple_cpu(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.1,
)
result = trainer.fit(model)
@ -254,7 +254,7 @@ def test_cpu_model(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4
limit_val_batches=0.4
)
model = EvalModelTemplate()
@ -267,13 +267,13 @@ def test_all_features_cpu_model(tmpdir):
trainer_options = dict(
default_root_dir=tmpdir,
gradient_clip_val=1.0,
overfit_pct=0.20,
overfit_batches=0.20,
track_grad_norm=2,
progress_bar_refresh_rate=0,
accumulate_grad_batches=2,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4
limit_val_batches=0.4
)
model = EvalModelTemplate()
@ -342,7 +342,7 @@ def test_tbptt_cpu_model(tmpdir):
default_root_dir=tmpdir,
max_epochs=1,
truncated_bptt_steps=truncated_bptt_steps,
val_percent_check=0,
limit_val_batches=0,
weights_summary=None,
early_stop_callback=False
)

View File

@ -22,7 +22,7 @@ def test_single_gpu_model(tmpdir, gpus):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
limit_val_batches=0.1,
gpus=gpus
)
@ -41,7 +41,7 @@ def test_multi_gpu_model(tmpdir, backend):
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
limit_val_batches=0.2,
gpus=[0, 1],
distributed_backend=backend,
)
@ -66,7 +66,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
limit_val_batches=0.1,
gpus=[0, 1],
distributed_backend='ddp')
@ -88,7 +88,7 @@ def test_multi_gpu_none_backend(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
limit_val_batches=0.1,
gpus='-1'
)

View File

@ -57,7 +57,7 @@ def test_training_epoch_end_metrics_collection(tmpdir):
trainer = Trainer(
max_epochs=num_epochs,
default_root_dir=tmpdir,
overfit_pct=0.1,
overfit_batches=0.1,
)
result = trainer.fit(model)
assert result == 1

View File

@ -62,7 +62,7 @@ def test_horovod_cpu(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
limit_val_batches=0.2,
distributed_backend='horovod',
deterministic=True,
)
@ -79,7 +79,7 @@ def test_horovod_cpu_implicit(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
limit_val_batches=0.2,
deterministic=True,
)
_run_horovod(trainer_options)
@ -97,7 +97,7 @@ def test_horovod_multi_gpu(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
limit_val_batches=0.2,
gpus=1,
deterministic=True,
distributed_backend='horovod'
@ -132,7 +132,7 @@ def test_horovod_transfer_batch_to_gpu(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
limit_val_batches=0.2,
gpus=1,
deterministic=True,
distributed_backend='horovod'
@ -150,7 +150,7 @@ def test_horovod_multi_optimizer(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
limit_val_batches=0.2,
deterministic=True,
distributed_backend='horovod'
)

View File

@ -39,7 +39,7 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False):
assert model.hparams.test_arg == 14
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_pct=0.5)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=0.5)
trainer.fit(model)
# make sure the raw checkpoint saved the properties
@ -156,7 +156,7 @@ def test_explicit_missing_args_hparams(tmpdir):
assert model.hparams.test_arg == 14
# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
trainer.fit(model)
# make sure the raw checkpoint saved the properties
@ -266,7 +266,7 @@ def test_collect_init_arguments(tmpdir, cls):
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
# verify that the checkpoint saved the correct values
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
trainer.fit(model)
raw_checkpoint_path = _raw_checkpoint_path(trainer)
@ -349,7 +349,7 @@ def test_collect_init_arguments_with_local_vars(cls):
# assert model.hparams.my_arg == 42
#
# # verify that the checkpoint saved the correct values
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
# trainer.fit(model)
#
# # verify that model loads correctly

View File

@ -32,7 +32,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):
progress_bar_refresh_rate=0,
max_epochs=2,
train_percent_check=0.4,
val_percent_check=0.2,
limit_val_batches=0.2,
checkpoint_callback=checkpoint,
logger=logger,
gpus=[0, 1],
@ -80,7 +80,7 @@ def test_running_test_pretrained_model_cpu(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=3,
train_percent_check=0.4,
val_percent_check=0.2,
limit_val_batches=0.2,
checkpoint_callback=checkpoint,
logger=logger
)
@ -111,7 +111,7 @@ def test_load_model_from_checkpoint(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=2,
train_percent_check=0.4,
val_percent_check=0.2,
limit_val_batches=0.2,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
default_root_dir=tmpdir,
)
@ -188,7 +188,7 @@ def test_dp_resume(tmpdir):
trainer_options['logger'] = new_logger
trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir)
trainer_options['train_percent_check'] = 0.5
trainer_options['val_percent_check'] = 0.2
trainer_options['limit_val_batches'] = 0.2
trainer_options['max_epochs'] = 1
new_trainer = Trainer(**trainer_options)

View File

@ -103,7 +103,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=[1])
assert result == {'val_loss': torch.tensor(0.6)}
model = ModelVer0_7(hparams)
@ -116,5 +116,5 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
with pytest.deprecated_call(match='v1.0'):
trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=[1])
assert result == {'val_loss': torch.tensor(0.7)}

View File

@ -80,7 +80,7 @@ def test_multiple_val_dataloader(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=1.0,
)
result = trainer.fit(model)
@ -116,7 +116,7 @@ def test_multiple_test_dataloader(tmpdir, ckpt_path):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
trainer.fit(model)
@ -144,7 +144,7 @@ def test_train_dataloader_passed_to_fit(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
fit_options = dict(train_dataloader=model.dataloader(train=True))
@ -161,7 +161,7 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
fit_options = dict(train_dataloader=model.dataloader(train=True),
@ -183,7 +183,7 @@ def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
fit_options = dict(train_dataloader=model.dataloader(train=True),
@ -216,7 +216,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
fit_options = dict(train_dataloader=model.dataloader(train=True),
@ -245,7 +245,7 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
@ -286,7 +286,7 @@ def test_val_inf_dataloader_error(tmpdir):
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__infinite
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.5)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.fit(model)
@ -298,7 +298,7 @@ def test_test_inf_dataloader_error(tmpdir):
model = EvalModelTemplate()
model.test_dataloader = model.test_dataloader__infinite
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, test_percent_check=0.5)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
trainer.test(model)
@ -354,8 +354,8 @@ def test_error_on_zero_len_dataloader(tmpdir):
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
test_percent_check=0.1
limit_val_batches=0.1,
limit_test_batches=0.1
)
trainer.fit(model)
@ -371,7 +371,7 @@ def test_warning_with_few_workers(tmpdir, ckpt_path):
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
@ -489,7 +489,7 @@ def test_batch_size_smaller_than_num_gpus():
trainer = Trainer(
max_epochs=1,
train_percent_check=0.1,
val_percent_check=0,
limit_val_batches=0,
gpus=num_gpus,
)

View File

@ -16,7 +16,7 @@ def test_optimizer_with_scheduling(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
results = trainer.fit(model)
@ -47,7 +47,7 @@ def test_multi_optimizer_with_scheduling(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
results = trainer.fit(model)
@ -82,7 +82,7 @@ def test_multi_optimizer_with_scheduling_stepping(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
results = trainer.fit(model)
@ -121,7 +121,7 @@ def test_reduce_lr_on_plateau_scheduling(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
results = trainer.fit(model)
@ -211,7 +211,7 @@ def test_none_optimizer(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2
)
result = trainer.fit(model)

View File

@ -163,7 +163,7 @@ def test_gradient_accumulation_scheduling(tmpdir):
trainer = Trainer(accumulate_grad_batches=schedule,
train_percent_check=0.1,
val_percent_check=0.1,
limit_val_batches=0.1,
max_epochs=2,
default_root_dir=tmpdir)
@ -368,7 +368,7 @@ def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_serve
progress_bar_refresh_rate=0,
max_epochs=2,
train_percent_check=0.65,
val_percent_check=1,
limit_val_batches=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
default_root_dir=tmpdir,
early_stop_callback=False,
@ -588,7 +588,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
def test_disabled_validation():
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""
"""Verify that `limit_val_batches=0` disables the validation loop unless `fast_dev_run=True`."""
class CurrentModel(EvalModelTemplate):
@ -610,22 +610,22 @@ def test_disabled_validation():
progress_bar_refresh_rate=0,
max_epochs=2,
train_percent_check=0.4,
val_percent_check=0.0,
limit_val_batches=0.0,
fast_dev_run=False,
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# check that val_percent_check=0 turns off validation
# check that limit_val_batches=0 turns off validation
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 2
assert not model.validation_step_invoked, \
'`validation_step` should not run when `val_percent_check=0`'
'`validation_step` should not run when `limit_val_batches=0`'
assert not model.validation_epoch_end_invoked, \
'`validation_epoch_end` should not run when `val_percent_check=0`'
'`validation_epoch_end` should not run when `limit_val_batches=0`'
# check that val_percent_check has no influence when fast_dev_run is turned on
# check that limit_val_batches has no influence when fast_dev_run is turned on
model = CurrentModel(**hparams)
trainer_options.update(fast_dev_run=True)
trainer = Trainer(**trainer_options)
@ -722,7 +722,7 @@ def test_trainer_interrupted_flag(tmpdir):
trainer = Trainer(
callbacks=[interrupt_callback, handle_interrupt_callback],
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2,
progress_bar_refresh_rate=0,
logger=False,

View File

@ -5,6 +5,87 @@ import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader
def test_overfit(tmpdir):
# ------------------------------------------------------
# Make sure shuffle is correct across loaders initially
# ------------------------------------------------------
model = EvalModelTemplate()
model.train_dataloader()
# original train loader which should be replaced in all methods
train_loader = model.train_dataloader()
# make sure the val and tests are not shuffled
assert isinstance(train_loader.sampler, RandomSampler)
assert isinstance(model.val_dataloader().sampler, SequentialSampler)
assert isinstance(model.test_dataloader().sampler, SequentialSampler)
# ------------------------------------------------------
# get the training loader and batch
# ------------------------------------------------------
train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False)
full_train_samples = len(train_loader)
num_train_samples = int(0.11 * full_train_samples)
(xa, ya) = next(iter(train_loader))
# ------------------------------------------------------
# set VAL and Test loaders
# ------------------------------------------------------
val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False)
test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False)
# set the model loaders
model.train_dataloader = lambda: train_loader
model.val_dataloader = lambda: val_loader
model.test_dataloader = lambda: test_loader
# ------------------------------------------------------
# run tests for both val and test
# ------------------------------------------------------
for split in ['val', 'test']:
# ------------------------------------------------------
# test overfit_batches as percent
# ------------------------------------------------------
loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(model, split)
assert loader_num_batches[0] == num_train_samples
# make sure we turned off shuffle for the user
assert isinstance(dataloaders[0].sampler, SequentialSampler)
# make sure the loaders are the same
(xb, yb) = next(iter(dataloaders[0]))
assert torch.eq(xa, xb).all()
assert torch.eq(ya, yb).all()
# ------------------------------------------------------
# test overfit_batches as int
# ------------------------------------------------------
loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(model, split)
assert loader_num_batches[0] == 1
loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(model, split)
assert loader_num_batches[0] == 5
# ------------------------------------------------------
# test limit_xxx_batches as percent AND int
# ------------------------------------------------------
if split == 'val':
loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(model, split)
assert loader_num_batches[0] == int(0.1 * len(val_loader))
loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(model, split)
assert loader_num_batches[0] == 10
else:
loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(model, split)
assert loader_num_batches[0] == int(0.1 * len(test_loader))
loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(model, split)
assert loader_num_batches[0] == 10
def test_model_reset_correctly(tmpdir):
@ -120,7 +201,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
limit_val_batches=0.1,
train_percent_check=0.2,
auto_scale_batch_size='power'
)