[WIP] Rename overfit_pct to overfit_batches (and fix) and val_percent_check and test_percent_check (and fix) (#2213)
* fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
97dfd3a80a
commit
04c794ca72
|
@ -21,6 +21,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added overfit_batches, limit_xxx_batches flags (overfit now uses training set for all three) ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213))
|
||||
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
|
||||
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
|
||||
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
|
||||
- Added metrics
|
||||
* Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
|
||||
* Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
|
||||
|
@ -54,6 +58,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `overfit_pct`, `val_percent_check`, `test_percent_check` ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213))
|
||||
- Deprecated `ModelCheckpoint`'s attributes `best` and `kth_best_model` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
|
||||
- Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917))
|
||||
|
||||
|
|
|
@ -48,12 +48,19 @@ Make model overfit on subset of data
|
|||
A good debugging technique is to take a tiny portion of your data (say 2 samples per class),
|
||||
and try to get your model to overfit. If it can't, it's a sign it won't work with large datasets.
|
||||
|
||||
(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.overfit_pct`
|
||||
(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.overfit_batches`
|
||||
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
|
||||
|
||||
.. testcode::
|
||||
|
||||
trainer = Trainer(overfit_pct=0.01)
|
||||
# use only 1% of training data (and use the same training Dataloader (with shuffle off) in val and test)
|
||||
trainer = Trainer(overfit_batches=0.01)
|
||||
|
||||
# or overfit a number of batches
|
||||
trainer = Trainer(overfit_batches=0.01)
|
||||
|
||||
With this flag, the train, val, and test sets will all be the same train set. We will also replace the sampler
|
||||
in the training set to turn off shuffle for you.
|
||||
|
||||
Print a summary of your LightningModule
|
||||
---------------------------------------
|
||||
|
|
|
@ -56,17 +56,17 @@ If you don't want to check 100% of the training/validation/test set (for debuggi
|
|||
# DEFAULT
|
||||
trainer = Trainer(
|
||||
train_percent_check=1.0,
|
||||
val_percent_check=1.0,
|
||||
test_percent_check=1.0
|
||||
limit_val_batches=1.0,
|
||||
limit_test_batches=1.0
|
||||
)
|
||||
|
||||
# check 10%, 20%, 30% only, respectively for training, validation and test set
|
||||
trainer = Trainer(
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.2,
|
||||
test_percent_check=0.3
|
||||
limit_val_batches=0.2,
|
||||
limit_test_batches=0.3
|
||||
)
|
||||
|
||||
.. note:: ``train_percent_check``, ``val_percent_check`` and ``test_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0. ``val_percent_check`` will be ignored if ``fast_dev_run=True``.
|
||||
.. note:: ``train_percent_check``, ``limit_val_batches`` and ``limit_test_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches`` > 0. ``limit_val_batches`` will be ignored if ``fast_dev_run=True``.
|
||||
|
||||
.. note:: If you set ``val_percent_check=0``, validation will be disabled.
|
||||
.. note:: If you set ``limit_val_batches=0``, validation will be disabled.
|
||||
|
|
|
@ -98,6 +98,7 @@ class ProgressBarBase(Callback):
|
|||
elif not self.trainer.disable_validation:
|
||||
is_val_epoch = trainer.current_epoch % trainer.check_val_every_n_epoch == 0
|
||||
total_val_batches = trainer.num_val_batches if is_val_epoch else 0
|
||||
total_val_batches = sum(total_val_batches)
|
||||
return total_val_batches
|
||||
|
||||
@property
|
||||
|
@ -111,6 +112,7 @@ class ProgressBarBase(Callback):
|
|||
total_test_batches = len(self.trainer.test_dataloaders)
|
||||
else:
|
||||
total_test_batches = self.trainer.num_test_batches
|
||||
total_test_batches = sum(total_test_batches)
|
||||
return total_test_batches
|
||||
|
||||
def disable(self):
|
||||
|
|
|
@ -433,6 +433,40 @@ Example::
|
|||
# default used by the Trainer
|
||||
trainer = Trainer(gradient_clip_val=0.0)
|
||||
|
||||
|
||||
limit_test_batches
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
How much of test dataset to check.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(limit_test_batches=1.0)
|
||||
|
||||
# run through only 25% of the test set each epoch
|
||||
trainer = Trainer(limit_test_batches=0.25)
|
||||
|
||||
# run for only 10 batches
|
||||
trainer = Trainer(limit_test_batches=10)
|
||||
|
||||
limit_val_batches
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
How much of validation dataset to check.
|
||||
Useful when debugging or testing something that happens at the end of an epoch.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(limit_val_batches=1.0)
|
||||
|
||||
# run through only 25% of the validation set each epoch
|
||||
trainer = Trainer(limit_val_batches=0.25)
|
||||
|
||||
# run for only 10 batches
|
||||
trainer = Trainer(limit_val_batches=10)
|
||||
|
||||
log_gpu_memory
|
||||
^^^^^^^^^^^^^^
|
||||
Options:
|
||||
|
@ -652,29 +686,28 @@ Example::
|
|||
|
||||
overfit_pct
|
||||
^^^^^^^^^^^
|
||||
Uses this much data of all datasets (training, validation, test).
|
||||
|
||||
.. warning:: .. deprecated:: 0.8.0.
|
||||
|
||||
Use `overfit_batches`. Will remove 1.0.0.
|
||||
|
||||
overfit_batches
|
||||
^^^^^^^^^^^^^^^
|
||||
Uses this much data of the training set. If will use the same training set for validation and testing.
|
||||
If the training Dataloaders(shuffle=True), Lightning will automatically disable it.
|
||||
|
||||
Useful for quickly debugging or trying to overfit on purpose.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(overfit_pct=0.0)
|
||||
trainer = Trainer(overfit_batches=0.0)
|
||||
|
||||
# use only 1% of the train, test, val datasets
|
||||
trainer = Trainer(overfit_pct=0.01)
|
||||
|
||||
# equivalent:
|
||||
trainer = Trainer(
|
||||
train_percent_check=0.01,
|
||||
val_percent_check=0.01,
|
||||
test_percent_check=0.01
|
||||
)
|
||||
|
||||
See Also:
|
||||
- `train_percent_check`_
|
||||
- `val_percent_check`_
|
||||
- `test_percent_check`_
|
||||
# use only 1% of the train set (and use the train set for val and test)
|
||||
trainer = Trainer(overfit_batches=0.01)
|
||||
|
||||
# overfit on 10 of the same batches
|
||||
trainer = Trainer(overfit_batches=10)
|
||||
|
||||
precision
|
||||
^^^^^^^^^
|
||||
|
@ -829,39 +862,7 @@ show_progress_bar
|
|||
test_percent_check
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
How much of test dataset to check.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(test_percent_check=1.0)
|
||||
|
||||
# run through only 25% of the test set each epoch
|
||||
trainer = Trainer(test_percent_check=0.25)
|
||||
|
||||
val_check_interval
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
How often within one training epoch to check the validation set.
|
||||
Can specify as float or int.
|
||||
|
||||
- use (float) to check within a training epoch
|
||||
- use (int) to check every n steps (batches)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(val_check_interval=1.0)
|
||||
|
||||
Example::
|
||||
|
||||
# check validation set 4 times during a training epoch
|
||||
trainer = Trainer(val_check_interval=0.25)
|
||||
|
||||
# check validation set every 1000 training batches
|
||||
# use this when using iterableDataset and your dataset has no length
|
||||
# (ie: production cases with streaming data)
|
||||
trainer = Trainer(val_check_interval=1000)
|
||||
.. warning:: deprecated in v0.8.0 please use `limit_test_batches`. Will remove in 1.0.0
|
||||
|
||||
track_grad_norm
|
||||
^^^^^^^^^^^^^^^
|
||||
|
@ -955,20 +956,36 @@ override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`:
|
|||
# do your own splitting on the batch
|
||||
return splits
|
||||
|
||||
val_check_interval
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
How often within one training epoch to check the validation set.
|
||||
Can specify as float or int.
|
||||
|
||||
- use (float) to check within a training epoch
|
||||
- use (int) to check every n steps (batches)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(val_check_interval=1.0)
|
||||
|
||||
Example::
|
||||
|
||||
# check validation set 4 times during a training epoch
|
||||
trainer = Trainer(val_check_interval=0.25)
|
||||
|
||||
# check validation set every 1000 training batches
|
||||
# use this when using iterableDataset and your dataset has no length
|
||||
# (ie: production cases with streaming data)
|
||||
trainer = Trainer(val_check_interval=1000)
|
||||
|
||||
|
||||
val_percent_check
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
How much of validation dataset to check.
|
||||
Useful when debugging or testing something that happens at the end of an epoch.
|
||||
.. warning:: deprecated in v0.8.0 please use `limit_val_batches`. Will remove in 1.0.0
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(val_percent_check=1.0)
|
||||
|
||||
# run through only 25% of the validation set each epoch
|
||||
trainer = Trainer(val_percent_check=0.25)
|
||||
|
||||
weights_save_path
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -70,12 +70,12 @@ class TrainerDataLoadingMixin(ABC):
|
|||
num_training_batches: Union[int, float]
|
||||
val_check_batch: ...
|
||||
val_dataloaders: List[DataLoader]
|
||||
num_val_batches: Union[int, float]
|
||||
num_val_batches: List[Union[int, float]]
|
||||
test_dataloaders: List[DataLoader]
|
||||
num_test_batches: Union[int, float]
|
||||
num_test_batches: List[Union[int, float]]
|
||||
train_percent_check: float
|
||||
val_percent_check: float
|
||||
test_percent_check: float
|
||||
limit_val_batches: float
|
||||
limit_test_batches: float
|
||||
replace_sampler_ddp: bool
|
||||
num_nodes: int
|
||||
num_processes: int
|
||||
|
@ -85,11 +85,16 @@ class TrainerDataLoadingMixin(ABC):
|
|||
def is_overridden(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
def _percent_range_check(self, name: str) -> None:
|
||||
def _limit_eval_batches_check(self, name: str) -> None:
|
||||
value = getattr(self, name)
|
||||
|
||||
# ints are fine
|
||||
if isinstance(value, int):
|
||||
return
|
||||
|
||||
msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}.'
|
||||
if name == 'val_check_interval':
|
||||
msg += ' If you want to disable validation set `val_percent_check` to 0.0 instead.'
|
||||
msg += ' If you want to disable validation set `limit_val_batches` to 0.0 instead.'
|
||||
|
||||
if not 0. <= value <= 1.:
|
||||
raise ValueError(msg)
|
||||
|
@ -139,15 +144,21 @@ class TrainerDataLoadingMixin(ABC):
|
|||
' distributed training. Either remove the sampler from your DataLoader or set'
|
||||
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
|
||||
|
||||
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
|
||||
# replace with distributed sampler
|
||||
sampler = self._get_distributed_sampler(dataloader)
|
||||
dataloader = self.replace_sampler(dataloader, sampler)
|
||||
|
||||
dl_args = {
|
||||
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
|
||||
}
|
||||
return dataloader
|
||||
|
||||
dl_args['sampler'] = self._get_distributed_sampler(dataloader)
|
||||
dataloader = type(dataloader)(**dl_args)
|
||||
def replace_sampler(self, dataloader, sampler):
|
||||
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
|
||||
|
||||
dl_args = {
|
||||
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
|
||||
}
|
||||
|
||||
dl_args['sampler'] = sampler
|
||||
dataloader = type(dataloader)(**dl_args)
|
||||
return dataloader
|
||||
|
||||
def _get_distributed_sampler(self, dataloader):
|
||||
|
@ -182,14 +193,17 @@ class TrainerDataLoadingMixin(ABC):
|
|||
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
|
||||
|
||||
self._worker_check(self.train_dataloader, 'train dataloader')
|
||||
self._percent_range_check('train_percent_check')
|
||||
self._limit_eval_batches_check('train_percent_check')
|
||||
|
||||
if not _has_len(self.train_dataloader):
|
||||
self.num_training_batches = float('inf')
|
||||
else:
|
||||
# try getting the length
|
||||
self.num_training_batches = len(self.train_dataloader)
|
||||
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
|
||||
if isinstance(self.train_percent_check, float):
|
||||
self.num_training_batches = len(self.train_dataloader)
|
||||
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
|
||||
else:
|
||||
self.num_training_batches = self.train_percent_check
|
||||
|
||||
# determine when to check validation
|
||||
# if int passed in, val checks that often
|
||||
|
@ -200,7 +214,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
raise ValueError(
|
||||
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
|
||||
f'to the number of the training batches ({self.num_training_batches}). '
|
||||
'If you want to disable validation set `val_percent_check` to 0.0 instead.')
|
||||
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
|
||||
else:
|
||||
if not _has_len(self.train_dataloader):
|
||||
if self.val_check_interval == 1.0:
|
||||
|
@ -212,12 +226,14 @@ class TrainerDataLoadingMixin(ABC):
|
|||
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
|
||||
' checking validation every k training batches.')
|
||||
else:
|
||||
self._percent_range_check('val_check_interval')
|
||||
self._limit_eval_batches_check('val_check_interval')
|
||||
|
||||
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
||||
self.val_check_batch = max(1, self.val_check_batch)
|
||||
|
||||
def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[Union[int, float], List[DataLoader]]:
|
||||
def _reset_eval_dataloader(self,
|
||||
model: LightningModule,
|
||||
mode: str) -> Tuple[List[Union[int, float]], List[DataLoader]]:
|
||||
"""Generic method to reset a dataloader for evaluation.
|
||||
|
||||
Args:
|
||||
|
@ -227,17 +243,32 @@ class TrainerDataLoadingMixin(ABC):
|
|||
Returns:
|
||||
Tuple (num_batches, dataloaders)
|
||||
"""
|
||||
dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader'))
|
||||
# use the training loader as val and test when overfitting
|
||||
if self.overfit_batches > 0:
|
||||
dataloaders = self.request_dataloader(getattr(model, 'train_dataloader'))
|
||||
else:
|
||||
dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader'))
|
||||
|
||||
if not isinstance(dataloaders, list):
|
||||
dataloaders = [dataloaders]
|
||||
|
||||
# shuffling in val and test set is bad practice
|
||||
for loader in dataloaders:
|
||||
for loader_i in range(len(dataloaders)):
|
||||
loader = dataloaders[loader_i]
|
||||
|
||||
# shuffling in val and test set is bad practice
|
||||
if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
|
||||
rank_zero_warn(
|
||||
f'Your {mode}_dataloader has shuffle=True, it is best practice to turn'
|
||||
' this off for validation and test dataloaders.')
|
||||
|
||||
# when overfitting, the dataloader should not have sampler
|
||||
if self.overfit_batches > 0:
|
||||
m = 'You requested to overfit but enabled training Dataloader shuffling. ' \
|
||||
'we are turning it off for you'
|
||||
rank_zero_warn(m)
|
||||
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
|
||||
|
||||
else:
|
||||
rank_zero_warn(
|
||||
f'Your {mode}_dataloader has shuffle=True, it is best practice to turn'
|
||||
' this off for validation and test dataloaders.')
|
||||
|
||||
if any([dl is None for dl in dataloaders]):
|
||||
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
|
||||
|
@ -245,29 +276,47 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# add samplers
|
||||
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None]
|
||||
|
||||
num_batches = 0
|
||||
loader_num_batches = []
|
||||
|
||||
# determine number of batches
|
||||
# datasets could be none, 1 or 2+
|
||||
if len(dataloaders) != 0:
|
||||
for i, dataloader in enumerate(dataloaders):
|
||||
num_batches = 0
|
||||
self._worker_check(dataloader, f'{mode} dataloader {i}')
|
||||
if not _has_len(dataloader):
|
||||
num_batches = float('inf')
|
||||
|
||||
percent_check = getattr(self, f'{mode}_percent_check')
|
||||
# percent or num_steps
|
||||
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
|
||||
|
||||
if num_batches != float('inf'):
|
||||
self._percent_range_check(f'{mode}_percent_check')
|
||||
if num_batches != float('inf'):
|
||||
self._limit_eval_batches_check(f'limit_{mode}_batches')
|
||||
|
||||
num_batches = sum(len(dataloader) for dataloader in dataloaders)
|
||||
num_batches = int(num_batches * percent_check)
|
||||
elif percent_check not in (0.0, 1.0):
|
||||
raise MisconfigurationException(
|
||||
'When using an infinite DataLoader (e.g. with an IterableDataset'
|
||||
f' or when DataLoader does not implement `__len__`) for `{mode}_dataloader`,'
|
||||
f' `Trainer({mode}_percent_check)` must be `0.0` or `1.0`.')
|
||||
return num_batches, dataloaders
|
||||
num_batches = len(dataloader)
|
||||
|
||||
# limit num batches either as a percent or num steps
|
||||
if isinstance(limit_eval_batches, float):
|
||||
num_batches = int(num_batches * limit_eval_batches)
|
||||
else:
|
||||
num_batches = limit_eval_batches
|
||||
|
||||
elif limit_eval_batches not in (0.0, 1.0):
|
||||
raise MisconfigurationException(
|
||||
'When using an infinite DataLoader (e.g. with an IterableDataset'
|
||||
f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
|
||||
f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.')
|
||||
|
||||
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
|
||||
min_pct = 1.0 / len(dataloader)
|
||||
m = f'you requested to check {limit_eval_batches} of the {mode} dataloader but ' \
|
||||
f'{limit_eval_batches}*{num_batches} = 0. Please increase the limit_{mode}_batches.' \
|
||||
f'Try at least limit_{mode}_batches={min_pct}'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
loader_num_batches.append(num_batches)
|
||||
|
||||
return loader_num_batches, dataloaders
|
||||
|
||||
def reset_val_dataloader(self, model: LightningModule) -> None:
|
||||
"""Resets the validation dataloader and determines the number of batches.
|
||||
|
@ -316,18 +365,19 @@ class TrainerDataLoadingMixin(ABC):
|
|||
|
||||
return dataloader
|
||||
|
||||
def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float,
|
||||
test_percent_check: float, overfit_pct: float) -> None:
|
||||
def determine_data_use_amount(self, train_percent_check: float, limit_val_batches: Union[int, float],
|
||||
limit_test_batches: Union[int, float], overfit_batches: float) -> None:
|
||||
"""Use less data for debugging purposes
|
||||
"""
|
||||
self.train_percent_check = train_percent_check
|
||||
self.val_percent_check = val_percent_check
|
||||
self.test_percent_check = test_percent_check
|
||||
if overfit_pct > 0:
|
||||
if overfit_pct > 1:
|
||||
self.limit_val_batches = limit_val_batches
|
||||
self.limit_test_batches = limit_test_batches
|
||||
if overfit_batches > 0:
|
||||
if isinstance(overfit_batches, float) and overfit_batches > 1:
|
||||
raise ValueError(
|
||||
f'`overfit_pct` must be not greater than 1.0, but got {overfit_pct:.3f}.')
|
||||
f'`overfit_batches` when used as a percentage must '
|
||||
f'be not 0.0 < x < 1.0 but got {overfit_batches:.3f}.')
|
||||
|
||||
self.train_percent_check = overfit_pct
|
||||
self.val_percent_check = overfit_pct
|
||||
self.test_percent_check = overfit_pct
|
||||
self.train_percent_check = overfit_batches
|
||||
self.limit_val_batches = overfit_batches
|
||||
self.limit_test_batches = overfit_batches
|
||||
|
|
|
@ -24,30 +24,30 @@ Set how much of the validation set to check
|
|||
|
||||
If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag
|
||||
|
||||
val_percent_check will be overwritten by overfit_pct if `overfit_pct > 0`
|
||||
limit_val_batches will be overwritten by overfit_batches if `overfit_batches > 0`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(val_percent_check=1.0)
|
||||
trainer = Trainer(limit_val_batches=1.0)
|
||||
|
||||
# check 10% only
|
||||
trainer = Trainer(val_percent_check=0.1)
|
||||
trainer = Trainer(limit_val_batches=0.1)
|
||||
|
||||
Set how much of the test set to check
|
||||
-------------------------------------
|
||||
|
||||
If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag
|
||||
|
||||
test_percent_check will be overwritten by overfit_pct if `overfit_pct > 0`
|
||||
limit_test_batches will be overwritten by overfit_batches if `overfit_batches > 0`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(test_percent_check=1.0)
|
||||
trainer = Trainer(limit_test_batches=1.0)
|
||||
|
||||
# check 10% only
|
||||
trainer = Trainer(test_percent_check=0.1)
|
||||
trainer = Trainer(limit_test_batches=0.1)
|
||||
|
||||
Set validation check frequency within 1 training epoch
|
||||
------------------------------------------------------
|
||||
|
@ -124,7 +124,7 @@ In this second case, the options you pass to trainer will be used when running
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from pprint import pprint
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -162,7 +162,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
single_gpu: bool
|
||||
data_parallel_device_ids: ...
|
||||
model: LightningModule
|
||||
num_test_batches: int
|
||||
num_test_batches: List[int]
|
||||
num_val_batches: int
|
||||
fast_dev_run: ...
|
||||
process_output: ...
|
||||
|
@ -222,13 +222,13 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
def reset_val_dataloader(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False):
|
||||
def _evaluate(self, model: LightningModule, dataloaders, max_batches: List[int], test_mode: bool = False):
|
||||
"""Run evaluation code.
|
||||
|
||||
Args:
|
||||
model: PT model
|
||||
dataloaders: list of PT dataloaders
|
||||
max_batches: Scalar
|
||||
max_batches: List of scalars
|
||||
test_mode:
|
||||
"""
|
||||
# enable eval mode
|
||||
|
@ -254,12 +254,15 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
dataloader = xla_pl.ParallelLoader(dataloader, [device])
|
||||
dataloader = dataloader.per_device_loader(device)
|
||||
|
||||
# each dataloader has a max num batches
|
||||
dl_max_batches = max_batches[dataloader_idx]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
# stop short when on fast_dev_run (sets max_batch=1)
|
||||
if batch_idx >= max_batches:
|
||||
if batch_idx >= dl_max_batches:
|
||||
break
|
||||
|
||||
# callbacks
|
||||
|
@ -359,7 +362,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
# cap max batches to 1 when using fast_dev_run
|
||||
if self.fast_dev_run:
|
||||
max_batches = 1
|
||||
max_batches = [1]
|
||||
|
||||
# Validation/Test begin callbacks
|
||||
if test_mode:
|
||||
|
@ -367,6 +370,11 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
else:
|
||||
self.on_validation_start()
|
||||
|
||||
# enable disabling validation step with limit_val_batches = 0
|
||||
should_skip = sum(max_batches) == 0
|
||||
if should_skip:
|
||||
return
|
||||
|
||||
# run evaluation
|
||||
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
|
||||
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)
|
||||
|
|
|
@ -93,7 +93,7 @@ class Trainer(
|
|||
tpu_cores: Optional[Union[List[int], int]] = None,
|
||||
log_gpu_memory: Optional[str] = None,
|
||||
progress_bar_refresh_rate: int = 1,
|
||||
overfit_pct: float = 0.0,
|
||||
overfit_batches: Union[int, float] = 0.0,
|
||||
track_grad_norm: Union[int, float, str] = -1,
|
||||
check_val_every_n_epoch: int = 1,
|
||||
fast_dev_run: bool = False,
|
||||
|
@ -103,8 +103,8 @@ class Trainer(
|
|||
max_steps: Optional[int] = None,
|
||||
min_steps: Optional[int] = None,
|
||||
train_percent_check: float = 1.0,
|
||||
val_percent_check: float = 1.0,
|
||||
test_percent_check: float = 1.0,
|
||||
limit_val_batches: Union[int, float] = 1.0,
|
||||
limit_test_batches: Union[int, float] = 1.0,
|
||||
val_check_interval: float = 1.0,
|
||||
log_save_interval: int = 100,
|
||||
row_log_interval: int = 50,
|
||||
|
@ -129,6 +129,9 @@ class Trainer(
|
|||
num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0
|
||||
use_amp=None, # backward compatible, todo: remove in v0.9.0
|
||||
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
|
||||
val_percent_check: float = 1.0, # backward compatible, todo: remove in v1.0.0
|
||||
test_percent_check: float = 1.0, # backward compatible, todo: remove in v1.0.0
|
||||
overfit_pct: float = 0.0 # backward compatible, todo: remove in v1.0.0
|
||||
):
|
||||
r"""
|
||||
|
||||
|
@ -185,7 +188,12 @@ class Trainer(
|
|||
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
|
||||
Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.
|
||||
|
||||
overfit_pct: How much of training-, validation-, and test dataset to check.
|
||||
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).
|
||||
|
||||
overfit_pct:
|
||||
.. warning:: .. deprecated:: 0.8.0
|
||||
|
||||
Use `overfit_batches` instead. Will remove 1.0.0.
|
||||
|
||||
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
|
||||
|
||||
|
@ -215,9 +223,19 @@ class Trainer(
|
|||
|
||||
train_percent_check: How much of training dataset to check.
|
||||
|
||||
val_percent_check: How much of validation dataset to check.
|
||||
limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
|
||||
|
||||
test_percent_check: How much of test dataset to check.
|
||||
limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)
|
||||
|
||||
val_percent_check:
|
||||
.. warning:: .. deprecated:: 0.8.0
|
||||
|
||||
Use `min_epochs` instead. Will remove 1.0.0.
|
||||
|
||||
test_percent_check:
|
||||
.. warning:: .. deprecated:: 0.8.0
|
||||
|
||||
Use `min_epochs` instead. Will remove 1.0.0.
|
||||
|
||||
val_check_interval: How often within one training epoch to check the validation set
|
||||
|
||||
|
@ -383,9 +401,9 @@ class Trainer(
|
|||
self.batch_idx = 0
|
||||
self.progress_bar_metrics = {}
|
||||
self.callback_metrics = {}
|
||||
self.num_val_batches = 0
|
||||
self.num_val_batches = [0]
|
||||
self.num_training_batches = 0
|
||||
self.num_test_batches = 0
|
||||
self.num_test_batches = [0]
|
||||
self.train_dataloader = None
|
||||
self.test_dataloaders = None
|
||||
self.val_dataloaders = None
|
||||
|
@ -468,9 +486,27 @@ class Trainer(
|
|||
self.row_log_interval = row_log_interval
|
||||
|
||||
# how much of the data to use
|
||||
self.overfit_pct = overfit_pct
|
||||
self.determine_data_use_amount(train_percent_check, val_percent_check,
|
||||
test_percent_check, overfit_pct)
|
||||
# TODO: remove in 1.0.0
|
||||
if overfit_pct > 0:
|
||||
overfit_batches = overfit_pct
|
||||
|
||||
# convert floats to ints
|
||||
overfit_batches = int(overfit_batches) if overfit_batches > 1.0 else overfit_batches
|
||||
self.overfit_batches = overfit_batches
|
||||
|
||||
# TODO: remove in 1.0.0
|
||||
if val_percent_check < 1.0:
|
||||
limit_val_batches = val_percent_check
|
||||
|
||||
if test_percent_check < 1.0:
|
||||
limit_test_batches = test_percent_check
|
||||
|
||||
limit_test_batches = int(limit_test_batches) if limit_test_batches > 1.0 else limit_test_batches
|
||||
limit_val_batches = int(limit_val_batches) if limit_val_batches > 1.0 else limit_val_batches
|
||||
|
||||
# TODO: convert train_percent_check to limit_train_batches
|
||||
self.determine_data_use_amount(train_percent_check, limit_val_batches,
|
||||
limit_test_batches, overfit_batches)
|
||||
|
||||
# AMP init
|
||||
# These are the only lines needed after v0.8.0
|
||||
|
@ -984,7 +1020,7 @@ class Trainer(
|
|||
return
|
||||
|
||||
# check if we should run validation during training
|
||||
self.disable_validation = not (self.is_overridden('validation_step') and self.val_percent_check > 0) \
|
||||
self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \
|
||||
and not self.fast_dev_run
|
||||
|
||||
# run tiny validation (if validation defined)
|
||||
|
@ -996,9 +1032,11 @@ class Trainer(
|
|||
ref_model.on_sanity_check_start()
|
||||
self.on_sanity_check_start()
|
||||
|
||||
num_loaders = len(self.val_dataloaders)
|
||||
max_batches = [self.num_sanity_val_steps] * num_loaders
|
||||
eval_results = self._evaluate(model,
|
||||
self.val_dataloaders,
|
||||
self.num_sanity_val_steps,
|
||||
max_batches,
|
||||
False)
|
||||
_, _, _, callback_metrics, _ = self.process_output(eval_results)
|
||||
|
||||
|
|
|
@ -68,7 +68,7 @@ Set how much of the training set to check
|
|||
|
||||
If you don't want to check 100% of the training set (for debugging or if it's huge), set this flag.
|
||||
|
||||
train_percent_check will be overwritten by overfit_pct if `overfit_pct > 0`
|
||||
train_percent_check will be overwritten by overfit_batches if `overfit_batches > 0`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -202,7 +202,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
check_val_every_n_epoch: ...
|
||||
num_training_batches: int
|
||||
val_check_batch: ...
|
||||
num_val_batches: int
|
||||
num_val_batches: List[int]
|
||||
disable_validation: bool
|
||||
fast_dev_run: ...
|
||||
accumulation_scheduler: ...
|
||||
|
|
|
@ -24,7 +24,7 @@ def test_early_stopping_functionality(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=True,
|
||||
overfit_pct=0.20,
|
||||
overfit_batches=0.20,
|
||||
max_epochs=20,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
@ -152,7 +152,7 @@ def test_trainer_callback_system(tmpdir):
|
|||
trainer_options = dict(
|
||||
callbacks=[test_callback],
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2,
|
||||
progress_bar_refresh_rate=0,
|
||||
)
|
||||
|
@ -258,7 +258,7 @@ def test_early_stopping_no_val_step(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=stopping,
|
||||
overfit_pct=0.20,
|
||||
overfit_batches=0.20,
|
||||
max_epochs=2,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
@ -292,7 +292,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmpdir,
|
||||
checkpoint_callback=checkpoint,
|
||||
overfit_pct=0.20,
|
||||
overfit_batches=0.20,
|
||||
max_epochs=2
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
@ -313,7 +313,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
|
|||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
overfit_pct=0.2,
|
||||
overfit_batches=0.2,
|
||||
max_epochs=2,
|
||||
logger=logger
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ def test_lr_logger_single_lr(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
|
@ -40,7 +40,7 @@ def test_lr_logger_no_lr(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
|
@ -61,7 +61,7 @@ def test_lr_logger_multi_lrs(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
|
@ -88,7 +88,7 @@ def test_lr_logger_param_groups(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.5,
|
||||
callbacks=[lr_logger]
|
||||
)
|
||||
|
|
|
@ -20,7 +20,7 @@ def test_progress_bar_on(callbacks, refresh_rate):
|
|||
callbacks=callbacks,
|
||||
progress_bar_refresh_rate=refresh_rate,
|
||||
max_epochs=1,
|
||||
overfit_pct=0.2,
|
||||
overfit_batches=0.2,
|
||||
)
|
||||
|
||||
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
|
||||
|
@ -61,7 +61,7 @@ def test_progress_bar_totals():
|
|||
|
||||
trainer = Trainer(
|
||||
progress_bar_refresh_rate=1,
|
||||
val_percent_check=1.0,
|
||||
limit_val_batches=1.0,
|
||||
max_epochs=1,
|
||||
)
|
||||
bar = trainer.progress_bar_callback
|
||||
|
|
|
@ -53,7 +53,7 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
|
|||
max_epochs=1,
|
||||
logger=logger,
|
||||
train_percent_check=0.2,
|
||||
val_percent_check=0.5,
|
||||
limit_val_batches=0.5,
|
||||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -145,7 +145,7 @@ def test_adding_step_key(tmpdir):
|
|||
max_epochs=3,
|
||||
default_root_dir=tmpdir,
|
||||
train_percent_check=0.001,
|
||||
val_percent_check=0.01,
|
||||
limit_val_batches=0.1,
|
||||
num_sanity_val_steps=0,
|
||||
)
|
||||
trainer.logger.log_metrics = _log_metrics_decorator(
|
||||
|
|
|
@ -99,7 +99,7 @@ def test_cpu_model_with_amp(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.4,
|
||||
limit_val_batches=0.4,
|
||||
precision=16
|
||||
)
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ def test_cpu_slurm_save_load(tmpdir):
|
|||
max_epochs=1,
|
||||
logger=logger,
|
||||
train_percent_check=0.2,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir)
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
@ -90,10 +90,10 @@ def test_early_stopping_cpu_model(tmpdir):
|
|||
early_stop_callback=stopping,
|
||||
max_epochs=2,
|
||||
gradient_clip_val=1.0,
|
||||
overfit_pct=0.20,
|
||||
overfit_batches=0.20,
|
||||
track_grad_norm=2,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
)
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
@ -119,7 +119,7 @@ def test_multi_cpu_model_ddp(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
gpus=None,
|
||||
num_processes=2,
|
||||
distributed_backend='ddp_cpu'
|
||||
|
@ -137,7 +137,7 @@ def test_lbfgs_cpu_model(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
weights_summary='top',
|
||||
train_percent_check=0.2,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
)
|
||||
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
|
@ -154,10 +154,10 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
|
|||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
gradient_clip_val=1.0,
|
||||
overfit_pct=0.20,
|
||||
overfit_batches=0.20,
|
||||
progress_bar_refresh_rate=0,
|
||||
train_percent_check=0.01,
|
||||
val_percent_check=0.01,
|
||||
limit_val_batches=0.01,
|
||||
)
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
@ -184,8 +184,8 @@ def test_running_test_after_fitting(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=2,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
test_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
limit_test_batches=0.2,
|
||||
checkpoint_callback=checkpoint,
|
||||
logger=logger
|
||||
)
|
||||
|
@ -214,8 +214,8 @@ def test_running_test_no_val(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
test_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
limit_test_batches=0.2,
|
||||
checkpoint_callback=checkpoint,
|
||||
logger=logger,
|
||||
early_stop_callback=False
|
||||
|
@ -238,7 +238,7 @@ def test_simple_cpu(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.1,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
@ -254,7 +254,7 @@ def test_cpu_model(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.4
|
||||
limit_val_batches=0.4
|
||||
)
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
@ -267,13 +267,13 @@ def test_all_features_cpu_model(tmpdir):
|
|||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
gradient_clip_val=1.0,
|
||||
overfit_pct=0.20,
|
||||
overfit_batches=0.20,
|
||||
track_grad_norm=2,
|
||||
progress_bar_refresh_rate=0,
|
||||
accumulate_grad_batches=2,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.4
|
||||
limit_val_batches=0.4
|
||||
)
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
@ -342,7 +342,7 @@ def test_tbptt_cpu_model(tmpdir):
|
|||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
truncated_bptt_steps=truncated_bptt_steps,
|
||||
val_percent_check=0,
|
||||
limit_val_batches=0,
|
||||
weights_summary=None,
|
||||
early_stop_callback=False
|
||||
)
|
||||
|
|
|
@ -22,7 +22,7 @@ def test_single_gpu_model(tmpdir, gpus):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
gpus=gpus
|
||||
)
|
||||
|
||||
|
@ -41,7 +41,7 @@ def test_multi_gpu_model(tmpdir, backend):
|
|||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
gpus=[0, 1],
|
||||
distributed_backend=backend,
|
||||
)
|
||||
|
@ -66,7 +66,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp')
|
||||
|
||||
|
@ -88,7 +88,7 @@ def test_multi_gpu_none_backend(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
gpus='-1'
|
||||
)
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_training_epoch_end_metrics_collection(tmpdir):
|
|||
trainer = Trainer(
|
||||
max_epochs=num_epochs,
|
||||
default_root_dir=tmpdir,
|
||||
overfit_pct=0.1,
|
||||
overfit_batches=0.1,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
assert result == 1
|
||||
|
|
|
@ -62,7 +62,7 @@ def test_horovod_cpu(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
distributed_backend='horovod',
|
||||
deterministic=True,
|
||||
)
|
||||
|
@ -79,7 +79,7 @@ def test_horovod_cpu_implicit(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
deterministic=True,
|
||||
)
|
||||
_run_horovod(trainer_options)
|
||||
|
@ -97,7 +97,7 @@ def test_horovod_multi_gpu(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
gpus=1,
|
||||
deterministic=True,
|
||||
distributed_backend='horovod'
|
||||
|
@ -132,7 +132,7 @@ def test_horovod_transfer_batch_to_gpu(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
gpus=1,
|
||||
deterministic=True,
|
||||
distributed_backend='horovod'
|
||||
|
@ -150,7 +150,7 @@ def test_horovod_multi_optimizer(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
deterministic=True,
|
||||
distributed_backend='horovod'
|
||||
)
|
||||
|
|
|
@ -39,7 +39,7 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False):
|
|||
assert model.hparams.test_arg == 14
|
||||
|
||||
# verify we can train
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_pct=0.5)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=0.5)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure the raw checkpoint saved the properties
|
||||
|
@ -156,7 +156,7 @@ def test_explicit_missing_args_hparams(tmpdir):
|
|||
assert model.hparams.test_arg == 14
|
||||
|
||||
# verify we can train
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure the raw checkpoint saved the properties
|
||||
|
@ -266,7 +266,7 @@ def test_collect_init_arguments(tmpdir, cls):
|
|||
assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss)
|
||||
|
||||
# verify that the checkpoint saved the correct values
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
|
||||
trainer.fit(model)
|
||||
raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
||||
|
||||
|
@ -349,7 +349,7 @@ def test_collect_init_arguments_with_local_vars(cls):
|
|||
# assert model.hparams.my_arg == 42
|
||||
#
|
||||
# # verify that the checkpoint saved the correct values
|
||||
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
|
||||
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
|
||||
# trainer.fit(model)
|
||||
#
|
||||
# # verify that model loads correctly
|
||||
|
|
|
@ -32,7 +32,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=2,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
checkpoint_callback=checkpoint,
|
||||
logger=logger,
|
||||
gpus=[0, 1],
|
||||
|
@ -80,7 +80,7 @@ def test_running_test_pretrained_model_cpu(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=3,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
checkpoint_callback=checkpoint,
|
||||
logger=logger
|
||||
)
|
||||
|
@ -111,7 +111,7 @@ def test_load_model_from_checkpoint(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=2,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
limit_val_batches=0.2,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
|
@ -188,7 +188,7 @@ def test_dp_resume(tmpdir):
|
|||
trainer_options['logger'] = new_logger
|
||||
trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir)
|
||||
trainer_options['train_percent_check'] = 0.5
|
||||
trainer_options['val_percent_check'] = 0.2
|
||||
trainer_options['limit_val_batches'] = 0.2
|
||||
trainer_options['max_epochs'] = 1
|
||||
new_trainer = Trainer(**trainer_options)
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
|
|||
with pytest.deprecated_call(match='v1.0'):
|
||||
trainer = Trainer(logger=False)
|
||||
# TODO: why `dataloder` is required if it is not used
|
||||
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
|
||||
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=[1])
|
||||
assert result == {'val_loss': torch.tensor(0.6)}
|
||||
|
||||
model = ModelVer0_7(hparams)
|
||||
|
@ -116,5 +116,5 @@ def test_tbd_remove_in_v1_0_0_model_hooks():
|
|||
with pytest.deprecated_call(match='v1.0'):
|
||||
trainer = Trainer(logger=False)
|
||||
# TODO: why `dataloder` is required if it is not used
|
||||
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
|
||||
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=[1])
|
||||
assert result == {'val_loss': torch.tensor(0.7)}
|
||||
|
|
|
@ -80,7 +80,7 @@ def test_multiple_val_dataloader(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=1.0,
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
@ -116,7 +116,7 @@ def test_multiple_test_dataloader(tmpdir, ckpt_path):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
@ -144,7 +144,7 @@ def test_train_dataloader_passed_to_fit(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True))
|
||||
|
@ -161,7 +161,7 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True),
|
||||
|
@ -183,7 +183,7 @@ def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True),
|
||||
|
@ -216,7 +216,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
fit_options = dict(train_dataloader=model.dataloader(train=True),
|
||||
|
@ -245,7 +245,7 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path):
|
|||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
|
||||
|
@ -286,7 +286,7 @@ def test_val_inf_dataloader_error(tmpdir):
|
|||
model = EvalModelTemplate()
|
||||
model.val_dataloader = model.val_dataloader__infinite
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.5)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.fit(model)
|
||||
|
@ -298,7 +298,7 @@ def test_test_inf_dataloader_error(tmpdir):
|
|||
model = EvalModelTemplate()
|
||||
model.test_dataloader = model.test_dataloader__infinite
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, test_percent_check=0.5)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
|
||||
trainer.test(model)
|
||||
|
@ -354,8 +354,8 @@ def test_error_on_zero_len_dataloader(tmpdir):
|
|||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
test_percent_check=0.1
|
||||
limit_val_batches=0.1,
|
||||
limit_test_batches=0.1
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
@ -371,7 +371,7 @@ def test_warning_with_few_workers(tmpdir, ckpt_path):
|
|||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
|
||||
|
@ -489,7 +489,7 @@ def test_batch_size_smaller_than_num_gpus():
|
|||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0,
|
||||
limit_val_batches=0,
|
||||
gpus=num_gpus,
|
||||
)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ def test_optimizer_with_scheduling(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
results = trainer.fit(model)
|
||||
|
@ -47,7 +47,7 @@ def test_multi_optimizer_with_scheduling(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
results = trainer.fit(model)
|
||||
|
@ -82,7 +82,7 @@ def test_multi_optimizer_with_scheduling_stepping(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
results = trainer.fit(model)
|
||||
|
@ -121,7 +121,7 @@ def test_reduce_lr_on_plateau_scheduling(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
results = trainer.fit(model)
|
||||
|
@ -211,7 +211,7 @@ def test_none_optimizer(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
result = trainer.fit(model)
|
||||
|
|
|
@ -163,7 +163,7 @@ def test_gradient_accumulation_scheduling(tmpdir):
|
|||
|
||||
trainer = Trainer(accumulate_grad_batches=schedule,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
max_epochs=2,
|
||||
default_root_dir=tmpdir)
|
||||
|
||||
|
@ -368,7 +368,7 @@ def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_serve
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=2,
|
||||
train_percent_check=0.65,
|
||||
val_percent_check=1,
|
||||
limit_val_batches=1,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=False,
|
||||
|
@ -588,7 +588,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
|
|||
|
||||
|
||||
def test_disabled_validation():
|
||||
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""
|
||||
"""Verify that `limit_val_batches=0` disables the validation loop unless `fast_dev_run=True`."""
|
||||
|
||||
class CurrentModel(EvalModelTemplate):
|
||||
|
||||
|
@ -610,22 +610,22 @@ def test_disabled_validation():
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=2,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.0,
|
||||
limit_val_batches=0.0,
|
||||
fast_dev_run=False,
|
||||
)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# check that val_percent_check=0 turns off validation
|
||||
# check that limit_val_batches=0 turns off validation
|
||||
assert result == 1, 'training failed to complete'
|
||||
assert trainer.current_epoch == 2
|
||||
assert not model.validation_step_invoked, \
|
||||
'`validation_step` should not run when `val_percent_check=0`'
|
||||
'`validation_step` should not run when `limit_val_batches=0`'
|
||||
assert not model.validation_epoch_end_invoked, \
|
||||
'`validation_epoch_end` should not run when `val_percent_check=0`'
|
||||
'`validation_epoch_end` should not run when `limit_val_batches=0`'
|
||||
|
||||
# check that val_percent_check has no influence when fast_dev_run is turned on
|
||||
# check that limit_val_batches has no influence when fast_dev_run is turned on
|
||||
model = CurrentModel(**hparams)
|
||||
trainer_options.update(fast_dev_run=True)
|
||||
trainer = Trainer(**trainer_options)
|
||||
|
@ -722,7 +722,7 @@ def test_trainer_interrupted_flag(tmpdir):
|
|||
trainer = Trainer(
|
||||
callbacks=[interrupt_callback, handle_interrupt_callback],
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2,
|
||||
progress_bar_refresh_rate=0,
|
||||
logger=False,
|
||||
|
|
|
@ -5,6 +5,87 @@ import tests.base.utils as tutils
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader
|
||||
|
||||
|
||||
def test_overfit(tmpdir):
|
||||
|
||||
# ------------------------------------------------------
|
||||
# Make sure shuffle is correct across loaders initially
|
||||
# ------------------------------------------------------
|
||||
model = EvalModelTemplate()
|
||||
model.train_dataloader()
|
||||
|
||||
# original train loader which should be replaced in all methods
|
||||
train_loader = model.train_dataloader()
|
||||
|
||||
# make sure the val and tests are not shuffled
|
||||
assert isinstance(train_loader.sampler, RandomSampler)
|
||||
assert isinstance(model.val_dataloader().sampler, SequentialSampler)
|
||||
assert isinstance(model.test_dataloader().sampler, SequentialSampler)
|
||||
|
||||
# ------------------------------------------------------
|
||||
# get the training loader and batch
|
||||
# ------------------------------------------------------
|
||||
train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False)
|
||||
full_train_samples = len(train_loader)
|
||||
num_train_samples = int(0.11 * full_train_samples)
|
||||
|
||||
(xa, ya) = next(iter(train_loader))
|
||||
|
||||
# ------------------------------------------------------
|
||||
# set VAL and Test loaders
|
||||
# ------------------------------------------------------
|
||||
val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False)
|
||||
test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False)
|
||||
|
||||
# set the model loaders
|
||||
model.train_dataloader = lambda: train_loader
|
||||
model.val_dataloader = lambda: val_loader
|
||||
model.test_dataloader = lambda: test_loader
|
||||
|
||||
# ------------------------------------------------------
|
||||
# run tests for both val and test
|
||||
# ------------------------------------------------------
|
||||
for split in ['val', 'test']:
|
||||
|
||||
# ------------------------------------------------------
|
||||
# test overfit_batches as percent
|
||||
# ------------------------------------------------------
|
||||
loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(model, split)
|
||||
assert loader_num_batches[0] == num_train_samples
|
||||
|
||||
# make sure we turned off shuffle for the user
|
||||
assert isinstance(dataloaders[0].sampler, SequentialSampler)
|
||||
|
||||
# make sure the loaders are the same
|
||||
(xb, yb) = next(iter(dataloaders[0]))
|
||||
assert torch.eq(xa, xb).all()
|
||||
assert torch.eq(ya, yb).all()
|
||||
|
||||
# ------------------------------------------------------
|
||||
# test overfit_batches as int
|
||||
# ------------------------------------------------------
|
||||
loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(model, split)
|
||||
assert loader_num_batches[0] == 1
|
||||
loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(model, split)
|
||||
assert loader_num_batches[0] == 5
|
||||
|
||||
# ------------------------------------------------------
|
||||
# test limit_xxx_batches as percent AND int
|
||||
# ------------------------------------------------------
|
||||
if split == 'val':
|
||||
loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(model, split)
|
||||
assert loader_num_batches[0] == int(0.1 * len(val_loader))
|
||||
|
||||
loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(model, split)
|
||||
assert loader_num_batches[0] == 10
|
||||
else:
|
||||
loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(model, split)
|
||||
assert loader_num_batches[0] == int(0.1 * len(test_loader))
|
||||
|
||||
loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(model, split)
|
||||
assert loader_num_batches[0] == 10
|
||||
|
||||
|
||||
def test_model_reset_correctly(tmpdir):
|
||||
|
@ -120,7 +201,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
limit_val_batches=0.1,
|
||||
train_percent_check=0.2,
|
||||
auto_scale_batch_size='power'
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue