From 04c794ca72f45dd492e49baff773c2b8e274b9ed Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 17 Jun 2020 08:03:28 -0400 Subject: [PATCH] [WIP] Rename overfit_pct to overfit_batches (and fix) and val_percent_check and test_percent_check (and fix) (#2213) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks Co-authored-by: Adrian Wälchli Co-authored-by: Rohit Gupta Co-authored-by: Jirka Borovec --- CHANGELOG.md | 5 + docs/source/debugging.rst | 11 +- docs/source/fast_training.rst | 12 +- pytorch_lightning/callbacks/progress.py | 2 + pytorch_lightning/trainer/__init__.py | 133 +++++++++-------- pytorch_lightning/trainer/data_loading.py | 142 +++++++++++++------ pytorch_lightning/trainer/evaluation_loop.py | 32 +++-- pytorch_lightning/trainer/trainer.py | 64 +++++++-- pytorch_lightning/trainer/training_loop.py | 4 +- tests/callbacks/test_callbacks.py | 10 +- tests/callbacks/test_lr.py | 8 +- tests/callbacks/test_progress_bar.py | 4 +- tests/loggers/test_all.py | 2 +- tests/loggers/test_base.py | 2 +- tests/models/test_amp.py | 2 +- tests/models/test_cpu.py | 32 ++--- tests/models/test_gpu.py | 8 +- tests/models/test_hooks.py | 2 +- tests/models/test_horovod.py | 10 +- tests/models/test_hparams.py | 8 +- tests/models/test_restore.py | 8 +- tests/test_deprecated.py | 4 +- tests/trainer/test_dataloaders.py | 26 ++-- tests/trainer/test_optimizers.py | 10 +- tests/trainer/test_trainer.py | 18 +-- tests/trainer/test_trainer_tricks.py | 83 ++++++++++- 26 files changed, 425 insertions(+), 217 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7294bef613..fd9e72d36d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added overfit_batches, limit_xxx_batches flags (overfit now uses training set for all three) ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213)) +- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877)) +- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327)) +- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488)) - Added metrics * Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877)) * Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327)) @@ -54,6 +58,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `overfit_pct`, `val_percent_check`, `test_percent_check` ([#2213](https://github.com/PyTorchLightning/pytorch-lightning/pull/2213)) - Deprecated `ModelCheckpoint`'s attributes `best` and `kth_best_model` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) - Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917)) diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index 741f94c524..06807eeb17 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -48,12 +48,19 @@ Make model overfit on subset of data A good debugging technique is to take a tiny portion of your data (say 2 samples per class), and try to get your model to overfit. If it can't, it's a sign it won't work with large datasets. -(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.overfit_pct` +(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.overfit_batches` argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) .. testcode:: - trainer = Trainer(overfit_pct=0.01) + # use only 1% of training data (and use the same training Dataloader (with shuffle off) in val and test) + trainer = Trainer(overfit_batches=0.01) + + # or overfit a number of batches + trainer = Trainer(overfit_batches=0.01) + +With this flag, the train, val, and test sets will all be the same train set. We will also replace the sampler +in the training set to turn off shuffle for you. Print a summary of your LightningModule --------------------------------------- diff --git a/docs/source/fast_training.rst b/docs/source/fast_training.rst index 208838f58b..9b44eb0955 100644 --- a/docs/source/fast_training.rst +++ b/docs/source/fast_training.rst @@ -56,17 +56,17 @@ If you don't want to check 100% of the training/validation/test set (for debuggi # DEFAULT trainer = Trainer( train_percent_check=1.0, - val_percent_check=1.0, - test_percent_check=1.0 + limit_val_batches=1.0, + limit_test_batches=1.0 ) # check 10%, 20%, 30% only, respectively for training, validation and test set trainer = Trainer( train_percent_check=0.1, - val_percent_check=0.2, - test_percent_check=0.3 + limit_val_batches=0.2, + limit_test_batches=0.3 ) -.. note:: ``train_percent_check``, ``val_percent_check`` and ``test_percent_check`` will be overwritten by ``overfit_pct`` if ``overfit_pct`` > 0. ``val_percent_check`` will be ignored if ``fast_dev_run=True``. +.. note:: ``train_percent_check``, ``limit_val_batches`` and ``limit_test_batches`` will be overwritten by ``overfit_batches`` if ``overfit_batches`` > 0. ``limit_val_batches`` will be ignored if ``fast_dev_run=True``. -.. note:: If you set ``val_percent_check=0``, validation will be disabled. +.. note:: If you set ``limit_val_batches=0``, validation will be disabled. diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index fb3dafb6e6..df4152dd36 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -98,6 +98,7 @@ class ProgressBarBase(Callback): elif not self.trainer.disable_validation: is_val_epoch = trainer.current_epoch % trainer.check_val_every_n_epoch == 0 total_val_batches = trainer.num_val_batches if is_val_epoch else 0 + total_val_batches = sum(total_val_batches) return total_val_batches @property @@ -111,6 +112,7 @@ class ProgressBarBase(Callback): total_test_batches = len(self.trainer.test_dataloaders) else: total_test_batches = self.trainer.num_test_batches + total_test_batches = sum(total_test_batches) return total_test_batches def disable(self): diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 357613b068..c011052224 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -433,6 +433,40 @@ Example:: # default used by the Trainer trainer = Trainer(gradient_clip_val=0.0) + +limit_test_batches +^^^^^^^^^^^^^^^^^^ + +How much of test dataset to check. + +Example:: + + # default used by the Trainer + trainer = Trainer(limit_test_batches=1.0) + + # run through only 25% of the test set each epoch + trainer = Trainer(limit_test_batches=0.25) + + # run for only 10 batches + trainer = Trainer(limit_test_batches=10) + +limit_val_batches +^^^^^^^^^^^^^^^^^ + +How much of validation dataset to check. +Useful when debugging or testing something that happens at the end of an epoch. + +Example:: + + # default used by the Trainer + trainer = Trainer(limit_val_batches=1.0) + + # run through only 25% of the validation set each epoch + trainer = Trainer(limit_val_batches=0.25) + + # run for only 10 batches + trainer = Trainer(limit_val_batches=10) + log_gpu_memory ^^^^^^^^^^^^^^ Options: @@ -652,29 +686,28 @@ Example:: overfit_pct ^^^^^^^^^^^ -Uses this much data of all datasets (training, validation, test). + +.. warning:: .. deprecated:: 0.8.0. + + Use `overfit_batches`. Will remove 1.0.0. + +overfit_batches +^^^^^^^^^^^^^^^ +Uses this much data of the training set. If will use the same training set for validation and testing. +If the training Dataloaders(shuffle=True), Lightning will automatically disable it. + Useful for quickly debugging or trying to overfit on purpose. Example:: # default used by the Trainer - trainer = Trainer(overfit_pct=0.0) + trainer = Trainer(overfit_batches=0.0) - # use only 1% of the train, test, val datasets - trainer = Trainer(overfit_pct=0.01) - - # equivalent: - trainer = Trainer( - train_percent_check=0.01, - val_percent_check=0.01, - test_percent_check=0.01 - ) - -See Also: - - `train_percent_check`_ - - `val_percent_check`_ - - `test_percent_check`_ + # use only 1% of the train set (and use the train set for val and test) + trainer = Trainer(overfit_batches=0.01) + # overfit on 10 of the same batches + trainer = Trainer(overfit_batches=10) precision ^^^^^^^^^ @@ -829,39 +862,7 @@ show_progress_bar test_percent_check ^^^^^^^^^^^^^^^^^^ -How much of test dataset to check. - -Example:: - - # default used by the Trainer - trainer = Trainer(test_percent_check=1.0) - - # run through only 25% of the test set each epoch - trainer = Trainer(test_percent_check=0.25) - -val_check_interval -^^^^^^^^^^^^^^^^^^ - -How often within one training epoch to check the validation set. -Can specify as float or int. - -- use (float) to check within a training epoch -- use (int) to check every n steps (batches) - -.. code-block:: python - - # default used by the Trainer - trainer = Trainer(val_check_interval=1.0) - -Example:: - - # check validation set 4 times during a training epoch - trainer = Trainer(val_check_interval=0.25) - - # check validation set every 1000 training batches - # use this when using iterableDataset and your dataset has no length - # (ie: production cases with streaming data) - trainer = Trainer(val_check_interval=1000) +.. warning:: deprecated in v0.8.0 please use `limit_test_batches`. Will remove in 1.0.0 track_grad_norm ^^^^^^^^^^^^^^^ @@ -955,20 +956,36 @@ override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`: # do your own splitting on the batch return splits +val_check_interval +^^^^^^^^^^^^^^^^^^ + +How often within one training epoch to check the validation set. +Can specify as float or int. + +- use (float) to check within a training epoch +- use (int) to check every n steps (batches) + +.. code-block:: python + + # default used by the Trainer + trainer = Trainer(val_check_interval=1.0) + +Example:: + + # check validation set 4 times during a training epoch + trainer = Trainer(val_check_interval=0.25) + + # check validation set every 1000 training batches + # use this when using iterableDataset and your dataset has no length + # (ie: production cases with streaming data) + trainer = Trainer(val_check_interval=1000) + val_percent_check ^^^^^^^^^^^^^^^^^ -How much of validation dataset to check. -Useful when debugging or testing something that happens at the end of an epoch. +.. warning:: deprecated in v0.8.0 please use `limit_val_batches`. Will remove in 1.0.0 -Example:: - - # default used by the Trainer - trainer = Trainer(val_percent_check=1.0) - - # run through only 25% of the validation set each epoch - trainer = Trainer(val_percent_check=0.25) weights_save_path ^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index d66ed38a56..28388cd2ae 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -70,12 +70,12 @@ class TrainerDataLoadingMixin(ABC): num_training_batches: Union[int, float] val_check_batch: ... val_dataloaders: List[DataLoader] - num_val_batches: Union[int, float] + num_val_batches: List[Union[int, float]] test_dataloaders: List[DataLoader] - num_test_batches: Union[int, float] + num_test_batches: List[Union[int, float]] train_percent_check: float - val_percent_check: float - test_percent_check: float + limit_val_batches: float + limit_test_batches: float replace_sampler_ddp: bool num_nodes: int num_processes: int @@ -85,11 +85,16 @@ class TrainerDataLoadingMixin(ABC): def is_overridden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - def _percent_range_check(self, name: str) -> None: + def _limit_eval_batches_check(self, name: str) -> None: value = getattr(self, name) + + # ints are fine + if isinstance(value, int): + return + msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}.' if name == 'val_check_interval': - msg += ' If you want to disable validation set `val_percent_check` to 0.0 instead.' + msg += ' If you want to disable validation set `limit_val_batches` to 0.0 instead.' if not 0. <= value <= 1.: raise ValueError(msg) @@ -139,15 +144,21 @@ class TrainerDataLoadingMixin(ABC): ' distributed training. Either remove the sampler from your DataLoader or set' ' `replace_sampler_ddp`=False if you want to use your custom sampler.') - skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] + # replace with distributed sampler + sampler = self._get_distributed_sampler(dataloader) + dataloader = self.replace_sampler(dataloader, sampler) - dl_args = { - k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys - } + return dataloader - dl_args['sampler'] = self._get_distributed_sampler(dataloader) - dataloader = type(dataloader)(**dl_args) + def replace_sampler(self, dataloader, sampler): + skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] + dl_args = { + k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys + } + + dl_args['sampler'] = sampler + dataloader = type(dataloader)(**dl_args) return dataloader def _get_distributed_sampler(self, dataloader): @@ -182,14 +193,17 @@ class TrainerDataLoadingMixin(ABC): self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) self._worker_check(self.train_dataloader, 'train dataloader') - self._percent_range_check('train_percent_check') + self._limit_eval_batches_check('train_percent_check') if not _has_len(self.train_dataloader): self.num_training_batches = float('inf') else: # try getting the length - self.num_training_batches = len(self.train_dataloader) - self.num_training_batches = int(self.num_training_batches * self.train_percent_check) + if isinstance(self.train_percent_check, float): + self.num_training_batches = len(self.train_dataloader) + self.num_training_batches = int(self.num_training_batches * self.train_percent_check) + else: + self.num_training_batches = self.train_percent_check # determine when to check validation # if int passed in, val checks that often @@ -200,7 +214,7 @@ class TrainerDataLoadingMixin(ABC): raise ValueError( f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' f'to the number of the training batches ({self.num_training_batches}). ' - 'If you want to disable validation set `val_percent_check` to 0.0 instead.') + 'If you want to disable validation set `limit_val_batches` to 0.0 instead.') else: if not _has_len(self.train_dataloader): if self.val_check_interval == 1.0: @@ -212,12 +226,14 @@ class TrainerDataLoadingMixin(ABC): ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' ' checking validation every k training batches.') else: - self._percent_range_check('val_check_interval') + self._limit_eval_batches_check('val_check_interval') self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[Union[int, float], List[DataLoader]]: + def _reset_eval_dataloader(self, + model: LightningModule, + mode: str) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: @@ -227,17 +243,32 @@ class TrainerDataLoadingMixin(ABC): Returns: Tuple (num_batches, dataloaders) """ - dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader')) + # use the training loader as val and test when overfitting + if self.overfit_batches > 0: + dataloaders = self.request_dataloader(getattr(model, 'train_dataloader')) + else: + dataloaders = self.request_dataloader(getattr(model, f'{mode}_dataloader')) if not isinstance(dataloaders, list): dataloaders = [dataloaders] - # shuffling in val and test set is bad practice - for loader in dataloaders: + for loader_i in range(len(dataloaders)): + loader = dataloaders[loader_i] + + # shuffling in val and test set is bad practice if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): - rank_zero_warn( - f'Your {mode}_dataloader has shuffle=True, it is best practice to turn' - ' this off for validation and test dataloaders.') + + # when overfitting, the dataloader should not have sampler + if self.overfit_batches > 0: + m = 'You requested to overfit but enabled training Dataloader shuffling. ' \ + 'we are turning it off for you' + rank_zero_warn(m) + dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset)) + + else: + rank_zero_warn( + f'Your {mode}_dataloader has shuffle=True, it is best practice to turn' + ' this off for validation and test dataloaders.') if any([dl is None for dl in dataloaders]): rank_zero_warn("One of given dataloaders is None and it will be skipped.") @@ -245,29 +276,47 @@ class TrainerDataLoadingMixin(ABC): # add samplers dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None] - num_batches = 0 + loader_num_batches = [] # determine number of batches # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): + num_batches = 0 self._worker_check(dataloader, f'{mode} dataloader {i}') if not _has_len(dataloader): num_batches = float('inf') - percent_check = getattr(self, f'{mode}_percent_check') + # percent or num_steps + limit_eval_batches = getattr(self, f'limit_{mode}_batches') - if num_batches != float('inf'): - self._percent_range_check(f'{mode}_percent_check') + if num_batches != float('inf'): + self._limit_eval_batches_check(f'limit_{mode}_batches') - num_batches = sum(len(dataloader) for dataloader in dataloaders) - num_batches = int(num_batches * percent_check) - elif percent_check not in (0.0, 1.0): - raise MisconfigurationException( - 'When using an infinite DataLoader (e.g. with an IterableDataset' - f' or when DataLoader does not implement `__len__`) for `{mode}_dataloader`,' - f' `Trainer({mode}_percent_check)` must be `0.0` or `1.0`.') - return num_batches, dataloaders + num_batches = len(dataloader) + + # limit num batches either as a percent or num steps + if isinstance(limit_eval_batches, float): + num_batches = int(num_batches * limit_eval_batches) + else: + num_batches = limit_eval_batches + + elif limit_eval_batches not in (0.0, 1.0): + raise MisconfigurationException( + 'When using an infinite DataLoader (e.g. with an IterableDataset' + f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,' + f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.') + + if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): + min_pct = 1.0 / len(dataloader) + m = f'you requested to check {limit_eval_batches} of the {mode} dataloader but ' \ + f'{limit_eval_batches}*{num_batches} = 0. Please increase the limit_{mode}_batches.' \ + f'Try at least limit_{mode}_batches={min_pct}' + raise MisconfigurationException(m) + + loader_num_batches.append(num_batches) + + return loader_num_batches, dataloaders def reset_val_dataloader(self, model: LightningModule) -> None: """Resets the validation dataloader and determines the number of batches. @@ -316,18 +365,19 @@ class TrainerDataLoadingMixin(ABC): return dataloader - def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float, - test_percent_check: float, overfit_pct: float) -> None: + def determine_data_use_amount(self, train_percent_check: float, limit_val_batches: Union[int, float], + limit_test_batches: Union[int, float], overfit_batches: float) -> None: """Use less data for debugging purposes """ self.train_percent_check = train_percent_check - self.val_percent_check = val_percent_check - self.test_percent_check = test_percent_check - if overfit_pct > 0: - if overfit_pct > 1: + self.limit_val_batches = limit_val_batches + self.limit_test_batches = limit_test_batches + if overfit_batches > 0: + if isinstance(overfit_batches, float) and overfit_batches > 1: raise ValueError( - f'`overfit_pct` must be not greater than 1.0, but got {overfit_pct:.3f}.') + f'`overfit_batches` when used as a percentage must ' + f'be not 0.0 < x < 1.0 but got {overfit_batches:.3f}.') - self.train_percent_check = overfit_pct - self.val_percent_check = overfit_pct - self.test_percent_check = overfit_pct + self.train_percent_check = overfit_batches + self.limit_val_batches = overfit_batches + self.limit_test_batches = overfit_batches diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index b6c490d6eb..230538ed89 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -24,30 +24,30 @@ Set how much of the validation set to check If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag -val_percent_check will be overwritten by overfit_pct if `overfit_pct > 0` +limit_val_batches will be overwritten by overfit_batches if `overfit_batches > 0` .. code-block:: python # DEFAULT - trainer = Trainer(val_percent_check=1.0) + trainer = Trainer(limit_val_batches=1.0) # check 10% only - trainer = Trainer(val_percent_check=0.1) + trainer = Trainer(limit_val_batches=0.1) Set how much of the test set to check ------------------------------------- If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag -test_percent_check will be overwritten by overfit_pct if `overfit_pct > 0` +limit_test_batches will be overwritten by overfit_batches if `overfit_batches > 0` .. code-block:: python # DEFAULT - trainer = Trainer(test_percent_check=1.0) + trainer = Trainer(limit_test_batches=1.0) # check 10% only - trainer = Trainer(test_percent_check=0.1) + trainer = Trainer(limit_test_batches=0.1) Set validation check frequency within 1 training epoch ------------------------------------------------------ @@ -124,7 +124,7 @@ In this second case, the options you pass to trainer will be used when running from abc import ABC, abstractmethod from pprint import pprint -from typing import Callable, Optional +from typing import Callable, Optional, List import torch from torch.utils.data import DataLoader @@ -162,7 +162,7 @@ class TrainerEvaluationLoopMixin(ABC): single_gpu: bool data_parallel_device_ids: ... model: LightningModule - num_test_batches: int + num_test_batches: List[int] num_val_batches: int fast_dev_run: ... process_output: ... @@ -222,13 +222,13 @@ class TrainerEvaluationLoopMixin(ABC): def reset_val_dataloader(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False): + def _evaluate(self, model: LightningModule, dataloaders, max_batches: List[int], test_mode: bool = False): """Run evaluation code. Args: model: PT model dataloaders: list of PT dataloaders - max_batches: Scalar + max_batches: List of scalars test_mode: """ # enable eval mode @@ -254,12 +254,15 @@ class TrainerEvaluationLoopMixin(ABC): dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) + # each dataloader has a max num batches + dl_max_batches = max_batches[dataloader_idx] + for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when on fast_dev_run (sets max_batch=1) - if batch_idx >= max_batches: + if batch_idx >= dl_max_batches: break # callbacks @@ -359,7 +362,7 @@ class TrainerEvaluationLoopMixin(ABC): # cap max batches to 1 when using fast_dev_run if self.fast_dev_run: - max_batches = 1 + max_batches = [1] # Validation/Test begin callbacks if test_mode: @@ -367,6 +370,11 @@ class TrainerEvaluationLoopMixin(ABC): else: self.on_validation_start() + # enable disabling validation step with limit_val_batches = 0 + should_skip = sum(max_batches) == 0 + if should_skip: + return + # run evaluation eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode) _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 07465a0994..2fef8c4440 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -93,7 +93,7 @@ class Trainer( tpu_cores: Optional[Union[List[int], int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, - overfit_pct: float = 0.0, + overfit_batches: Union[int, float] = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, @@ -103,8 +103,8 @@ class Trainer( max_steps: Optional[int] = None, min_steps: Optional[int] = None, train_percent_check: float = 1.0, - val_percent_check: float = 1.0, - test_percent_check: float = 1.0, + limit_val_batches: Union[int, float] = 1.0, + limit_test_batches: Union[int, float] = 1.0, val_check_interval: float = 1.0, log_save_interval: int = 100, row_log_interval: int = 50, @@ -129,6 +129,9 @@ class Trainer( num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0 use_amp=None, # backward compatible, todo: remove in v0.9.0 show_progress_bar=None, # backward compatible, todo: remove in v0.9.0 + val_percent_check: float = 1.0, # backward compatible, todo: remove in v1.0.0 + test_percent_check: float = 1.0, # backward compatible, todo: remove in v1.0.0 + overfit_pct: float = 0.0 # backward compatible, todo: remove in v1.0.0 ): r""" @@ -185,7 +188,12 @@ class Trainer( progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. - overfit_pct: How much of training-, validation-, and test dataset to check. + overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). + + overfit_pct: + .. warning:: .. deprecated:: 0.8.0 + + Use `overfit_batches` instead. Will remove 1.0.0. track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. @@ -215,9 +223,19 @@ class Trainer( train_percent_check: How much of training dataset to check. - val_percent_check: How much of validation dataset to check. + limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches) - test_percent_check: How much of test dataset to check. + limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches) + + val_percent_check: + .. warning:: .. deprecated:: 0.8.0 + + Use `min_epochs` instead. Will remove 1.0.0. + + test_percent_check: + .. warning:: .. deprecated:: 0.8.0 + + Use `min_epochs` instead. Will remove 1.0.0. val_check_interval: How often within one training epoch to check the validation set @@ -383,9 +401,9 @@ class Trainer( self.batch_idx = 0 self.progress_bar_metrics = {} self.callback_metrics = {} - self.num_val_batches = 0 + self.num_val_batches = [0] self.num_training_batches = 0 - self.num_test_batches = 0 + self.num_test_batches = [0] self.train_dataloader = None self.test_dataloaders = None self.val_dataloaders = None @@ -468,9 +486,27 @@ class Trainer( self.row_log_interval = row_log_interval # how much of the data to use - self.overfit_pct = overfit_pct - self.determine_data_use_amount(train_percent_check, val_percent_check, - test_percent_check, overfit_pct) + # TODO: remove in 1.0.0 + if overfit_pct > 0: + overfit_batches = overfit_pct + + # convert floats to ints + overfit_batches = int(overfit_batches) if overfit_batches > 1.0 else overfit_batches + self.overfit_batches = overfit_batches + + # TODO: remove in 1.0.0 + if val_percent_check < 1.0: + limit_val_batches = val_percent_check + + if test_percent_check < 1.0: + limit_test_batches = test_percent_check + + limit_test_batches = int(limit_test_batches) if limit_test_batches > 1.0 else limit_test_batches + limit_val_batches = int(limit_val_batches) if limit_val_batches > 1.0 else limit_val_batches + + # TODO: convert train_percent_check to limit_train_batches + self.determine_data_use_amount(train_percent_check, limit_val_batches, + limit_test_batches, overfit_batches) # AMP init # These are the only lines needed after v0.8.0 @@ -984,7 +1020,7 @@ class Trainer( return # check if we should run validation during training - self.disable_validation = not (self.is_overridden('validation_step') and self.val_percent_check > 0) \ + self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \ and not self.fast_dev_run # run tiny validation (if validation defined) @@ -996,9 +1032,11 @@ class Trainer( ref_model.on_sanity_check_start() self.on_sanity_check_start() + num_loaders = len(self.val_dataloaders) + max_batches = [self.num_sanity_val_steps] * num_loaders eval_results = self._evaluate(model, self.val_dataloaders, - self.num_sanity_val_steps, + max_batches, False) _, _, _, callback_metrics, _ = self.process_output(eval_results) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ed98192eaa..6cc2ac55c5 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -68,7 +68,7 @@ Set how much of the training set to check If you don't want to check 100% of the training set (for debugging or if it's huge), set this flag. -train_percent_check will be overwritten by overfit_pct if `overfit_pct > 0` +train_percent_check will be overwritten by overfit_batches if `overfit_batches > 0` .. code-block:: python @@ -202,7 +202,7 @@ class TrainerTrainLoopMixin(ABC): check_val_every_n_epoch: ... num_training_batches: int val_check_batch: ... - num_val_batches: int + num_val_batches: List[int] disable_validation: bool fast_dev_run: ... accumulation_scheduler: ... diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index ce32de54a2..1a8ba1fb0c 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -24,7 +24,7 @@ def test_early_stopping_functionality(tmpdir): trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=True, - overfit_pct=0.20, + overfit_batches=0.20, max_epochs=20, ) result = trainer.fit(model) @@ -152,7 +152,7 @@ def test_trainer_callback_system(tmpdir): trainer_options = dict( callbacks=[test_callback], max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2, progress_bar_refresh_rate=0, ) @@ -258,7 +258,7 @@ def test_early_stopping_no_val_step(tmpdir): trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=stopping, - overfit_pct=0.20, + overfit_batches=0.20, max_epochs=2, ) result = trainer.fit(model) @@ -292,7 +292,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, - overfit_pct=0.20, + overfit_batches=0.20, max_epochs=2 ) trainer.fit(model) @@ -313,7 +313,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): trainer = Trainer( default_root_dir=tmpdir, - overfit_pct=0.2, + overfit_batches=0.2, max_epochs=2, logger=logger ) diff --git a/tests/callbacks/test_lr.py b/tests/callbacks/test_lr.py index 80e7b3ca5c..bd5c77d117 100644 --- a/tests/callbacks/test_lr.py +++ b/tests/callbacks/test_lr.py @@ -17,7 +17,7 @@ def test_lr_logger_single_lr(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.5, callbacks=[lr_logger] ) @@ -40,7 +40,7 @@ def test_lr_logger_no_lr(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.5, callbacks=[lr_logger] ) @@ -61,7 +61,7 @@ def test_lr_logger_multi_lrs(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.5, callbacks=[lr_logger] ) @@ -88,7 +88,7 @@ def test_lr_logger_param_groups(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.5, callbacks=[lr_logger] ) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index e0dffce1f3..847027d5c4 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -20,7 +20,7 @@ def test_progress_bar_on(callbacks, refresh_rate): callbacks=callbacks, progress_bar_refresh_rate=refresh_rate, max_epochs=1, - overfit_pct=0.2, + overfit_batches=0.2, ) progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] @@ -61,7 +61,7 @@ def test_progress_bar_totals(): trainer = Trainer( progress_bar_refresh_rate=1, - val_percent_check=1.0, + limit_val_batches=1.0, max_epochs=1, ) bar = trainer.progress_bar_callback diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index f8a8fead41..ec2bf25e2b 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -53,7 +53,7 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class): max_epochs=1, logger=logger, train_percent_check=0.2, - val_percent_check=0.5, + limit_val_batches=0.5, fast_dev_run=True, ) trainer.fit(model) diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 5046991d7c..29d7ee3e45 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -145,7 +145,7 @@ def test_adding_step_key(tmpdir): max_epochs=3, default_root_dir=tmpdir, train_percent_check=0.001, - val_percent_check=0.01, + limit_val_batches=0.1, num_sanity_val_steps=0, ) trainer.logger.log_metrics = _log_metrics_decorator( diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index bf33df9190..0c4af3898f 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -99,7 +99,7 @@ def test_cpu_model_with_amp(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.4, + limit_val_batches=0.4, precision=16 ) diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index c31e157a71..1d19e609fc 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -26,7 +26,7 @@ def test_cpu_slurm_save_load(tmpdir): max_epochs=1, logger=logger, train_percent_check=0.2, - val_percent_check=0.2, + limit_val_batches=0.2, checkpoint_callback=ModelCheckpoint(tmpdir) ) result = trainer.fit(model) @@ -90,10 +90,10 @@ def test_early_stopping_cpu_model(tmpdir): early_stop_callback=stopping, max_epochs=2, gradient_clip_val=1.0, - overfit_pct=0.20, + overfit_batches=0.20, track_grad_norm=2, train_percent_check=0.1, - val_percent_check=0.1, + limit_val_batches=0.1, ) model = EvalModelTemplate() @@ -119,7 +119,7 @@ def test_multi_cpu_model_ddp(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.2, + limit_val_batches=0.2, gpus=None, num_processes=2, distributed_backend='ddp_cpu' @@ -137,7 +137,7 @@ def test_lbfgs_cpu_model(tmpdir): progress_bar_refresh_rate=0, weights_summary='top', train_percent_check=0.2, - val_percent_check=0.2, + limit_val_batches=0.2, ) hparams = EvalModelTemplate.get_default_hparams() @@ -154,10 +154,10 @@ def test_default_logger_callbacks_cpu_model(tmpdir): default_root_dir=tmpdir, max_epochs=1, gradient_clip_val=1.0, - overfit_pct=0.20, + overfit_batches=0.20, progress_bar_refresh_rate=0, train_percent_check=0.01, - val_percent_check=0.01, + limit_val_batches=0.01, ) model = EvalModelTemplate() @@ -184,8 +184,8 @@ def test_running_test_after_fitting(tmpdir): progress_bar_refresh_rate=0, max_epochs=2, train_percent_check=0.4, - val_percent_check=0.2, - test_percent_check=0.2, + limit_val_batches=0.2, + limit_test_batches=0.2, checkpoint_callback=checkpoint, logger=logger ) @@ -214,8 +214,8 @@ def test_running_test_no_val(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.2, - test_percent_check=0.2, + limit_val_batches=0.2, + limit_test_batches=0.2, checkpoint_callback=checkpoint, logger=logger, early_stop_callback=False @@ -238,7 +238,7 @@ def test_simple_cpu(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.1, ) result = trainer.fit(model) @@ -254,7 +254,7 @@ def test_cpu_model(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.4 + limit_val_batches=0.4 ) model = EvalModelTemplate() @@ -267,13 +267,13 @@ def test_all_features_cpu_model(tmpdir): trainer_options = dict( default_root_dir=tmpdir, gradient_clip_val=1.0, - overfit_pct=0.20, + overfit_batches=0.20, track_grad_norm=2, progress_bar_refresh_rate=0, accumulate_grad_batches=2, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.4 + limit_val_batches=0.4 ) model = EvalModelTemplate() @@ -342,7 +342,7 @@ def test_tbptt_cpu_model(tmpdir): default_root_dir=tmpdir, max_epochs=1, truncated_bptt_steps=truncated_bptt_steps, - val_percent_check=0, + limit_val_batches=0, weights_summary=None, early_stop_callback=False ) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 3fd44265ad..a752d29a78 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -22,7 +22,7 @@ def test_single_gpu_model(tmpdir, gpus): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.1, - val_percent_check=0.1, + limit_val_batches=0.1, gpus=gpus ) @@ -41,7 +41,7 @@ def test_multi_gpu_model(tmpdir, backend): default_root_dir=tmpdir, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.2, + limit_val_batches=0.2, gpus=[0, 1], distributed_backend=backend, ) @@ -66,7 +66,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.1, - val_percent_check=0.1, + limit_val_batches=0.1, gpus=[0, 1], distributed_backend='ddp') @@ -88,7 +88,7 @@ def test_multi_gpu_none_backend(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.1, - val_percent_check=0.1, + limit_val_batches=0.1, gpus='-1' ) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index d4cfefab29..307f9381ca 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -57,7 +57,7 @@ def test_training_epoch_end_metrics_collection(tmpdir): trainer = Trainer( max_epochs=num_epochs, default_root_dir=tmpdir, - overfit_pct=0.1, + overfit_batches=0.1, ) result = trainer.fit(model) assert result == 1 diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 4e5fe0ef81..b2e5062267 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -62,7 +62,7 @@ def test_horovod_cpu(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.2, + limit_val_batches=0.2, distributed_backend='horovod', deterministic=True, ) @@ -79,7 +79,7 @@ def test_horovod_cpu_implicit(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.2, + limit_val_batches=0.2, deterministic=True, ) _run_horovod(trainer_options) @@ -97,7 +97,7 @@ def test_horovod_multi_gpu(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.2, + limit_val_batches=0.2, gpus=1, deterministic=True, distributed_backend='horovod' @@ -132,7 +132,7 @@ def test_horovod_transfer_batch_to_gpu(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.2, + limit_val_batches=0.2, gpus=1, deterministic=True, distributed_backend='horovod' @@ -150,7 +150,7 @@ def test_horovod_multi_optimizer(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, train_percent_check=0.4, - val_percent_check=0.2, + limit_val_batches=0.2, deterministic=True, distributed_backend='horovod' ) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 38ad43ccf7..b802f894f4 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -39,7 +39,7 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False): assert model.hparams.test_arg == 14 # verify we can train - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_pct=0.5) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=0.5) trainer.fit(model) # make sure the raw checkpoint saved the properties @@ -156,7 +156,7 @@ def test_explicit_missing_args_hparams(tmpdir): assert model.hparams.test_arg == 14 # verify we can train - trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) trainer.fit(model) # make sure the raw checkpoint saved the properties @@ -266,7 +266,7 @@ def test_collect_init_arguments(tmpdir, cls): assert isinstance(model.hparams.my_loss, torch.nn.CosineEmbeddingLoss) # verify that the checkpoint saved the correct values - trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) trainer.fit(model) raw_checkpoint_path = _raw_checkpoint_path(trainer) @@ -349,7 +349,7 @@ def test_collect_init_arguments_with_local_vars(cls): # assert model.hparams.my_arg == 42 # # # verify that the checkpoint saved the correct values -# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5) +# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) # trainer.fit(model) # # # verify that model loads correctly diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index c2219548da..3878d8867c 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -32,7 +32,7 @@ def test_running_test_pretrained_model_distrib(tmpdir, backend): progress_bar_refresh_rate=0, max_epochs=2, train_percent_check=0.4, - val_percent_check=0.2, + limit_val_batches=0.2, checkpoint_callback=checkpoint, logger=logger, gpus=[0, 1], @@ -80,7 +80,7 @@ def test_running_test_pretrained_model_cpu(tmpdir): progress_bar_refresh_rate=0, max_epochs=3, train_percent_check=0.4, - val_percent_check=0.2, + limit_val_batches=0.2, checkpoint_callback=checkpoint, logger=logger ) @@ -111,7 +111,7 @@ def test_load_model_from_checkpoint(tmpdir): progress_bar_refresh_rate=0, max_epochs=2, train_percent_check=0.4, - val_percent_check=0.2, + limit_val_batches=0.2, checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), default_root_dir=tmpdir, ) @@ -188,7 +188,7 @@ def test_dp_resume(tmpdir): trainer_options['logger'] = new_logger trainer_options['checkpoint_callback'] = ModelCheckpoint(tmpdir) trainer_options['train_percent_check'] = 0.5 - trainer_options['val_percent_check'] = 0.2 + trainer_options['limit_val_batches'] = 0.2 trainer_options['max_epochs'] = 1 new_trainer = Trainer(**trainer_options) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 6e108cb47c..4fc3e4747a 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -103,7 +103,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks(): with pytest.deprecated_call(match='v1.0'): trainer = Trainer(logger=False) # TODO: why `dataloder` is required if it is not used - result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1) + result = trainer._evaluate(model, dataloaders=[[None]], max_batches=[1]) assert result == {'val_loss': torch.tensor(0.6)} model = ModelVer0_7(hparams) @@ -116,5 +116,5 @@ def test_tbd_remove_in_v1_0_0_model_hooks(): with pytest.deprecated_call(match='v1.0'): trainer = Trainer(logger=False) # TODO: why `dataloder` is required if it is not used - result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1) + result = trainer._evaluate(model, dataloaders=[[None]], max_batches=[1]) assert result == {'val_loss': torch.tensor(0.7)} diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index b700197b9c..e92bb54840 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -80,7 +80,7 @@ def test_multiple_val_dataloader(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=1.0, ) result = trainer.fit(model) @@ -116,7 +116,7 @@ def test_multiple_test_dataloader(tmpdir, ckpt_path): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) trainer.fit(model) @@ -144,7 +144,7 @@ def test_train_dataloader_passed_to_fit(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) fit_options = dict(train_dataloader=model.dataloader(train=True)) @@ -161,7 +161,7 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) fit_options = dict(train_dataloader=model.dataloader(train=True), @@ -183,7 +183,7 @@ def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) fit_options = dict(train_dataloader=model.dataloader(train=True), @@ -216,7 +216,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) fit_options = dict(train_dataloader=model.dataloader(train=True), @@ -245,7 +245,7 @@ def test_mixing_of_dataloader_options(tmpdir, ckpt_path): trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) @@ -286,7 +286,7 @@ def test_val_inf_dataloader_error(tmpdir): model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__infinite - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_percent_check=0.5) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5) with pytest.raises(MisconfigurationException, match='infinite DataLoader'): trainer.fit(model) @@ -298,7 +298,7 @@ def test_test_inf_dataloader_error(tmpdir): model = EvalModelTemplate() model.test_dataloader = model.test_dataloader__infinite - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, test_percent_check=0.5) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5) with pytest.raises(MisconfigurationException, match='infinite DataLoader'): trainer.test(model) @@ -354,8 +354,8 @@ def test_error_on_zero_len_dataloader(tmpdir): default_root_dir=tmpdir, max_epochs=1, train_percent_check=0.1, - val_percent_check=0.1, - test_percent_check=0.1 + limit_val_batches=0.1, + limit_test_batches=0.1 ) trainer.fit(model) @@ -371,7 +371,7 @@ def test_warning_with_few_workers(tmpdir, ckpt_path): trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) @@ -489,7 +489,7 @@ def test_batch_size_smaller_than_num_gpus(): trainer = Trainer( max_epochs=1, train_percent_check=0.1, - val_percent_check=0, + limit_val_batches=0, gpus=num_gpus, ) diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index be1f429c85..d81f38ba8b 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -16,7 +16,7 @@ def test_optimizer_with_scheduling(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) results = trainer.fit(model) @@ -47,7 +47,7 @@ def test_multi_optimizer_with_scheduling(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) results = trainer.fit(model) @@ -82,7 +82,7 @@ def test_multi_optimizer_with_scheduling_stepping(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) results = trainer.fit(model) @@ -121,7 +121,7 @@ def test_reduce_lr_on_plateau_scheduling(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) results = trainer.fit(model) @@ -211,7 +211,7 @@ def test_none_optimizer(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2 ) result = trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 424e4c7b94..04cb7164ee 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -163,7 +163,7 @@ def test_gradient_accumulation_scheduling(tmpdir): trainer = Trainer(accumulate_grad_batches=schedule, train_percent_check=0.1, - val_percent_check=0.1, + limit_val_batches=0.1, max_epochs=2, default_root_dir=tmpdir) @@ -368,7 +368,7 @@ def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_serve progress_bar_refresh_rate=0, max_epochs=2, train_percent_check=0.65, - val_percent_check=1, + limit_val_batches=1, checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), default_root_dir=tmpdir, early_stop_callback=False, @@ -588,7 +588,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): def test_disabled_validation(): - """Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`.""" + """Verify that `limit_val_batches=0` disables the validation loop unless `fast_dev_run=True`.""" class CurrentModel(EvalModelTemplate): @@ -610,22 +610,22 @@ def test_disabled_validation(): progress_bar_refresh_rate=0, max_epochs=2, train_percent_check=0.4, - val_percent_check=0.0, + limit_val_batches=0.0, fast_dev_run=False, ) trainer = Trainer(**trainer_options) result = trainer.fit(model) - # check that val_percent_check=0 turns off validation + # check that limit_val_batches=0 turns off validation assert result == 1, 'training failed to complete' assert trainer.current_epoch == 2 assert not model.validation_step_invoked, \ - '`validation_step` should not run when `val_percent_check=0`' + '`validation_step` should not run when `limit_val_batches=0`' assert not model.validation_epoch_end_invoked, \ - '`validation_epoch_end` should not run when `val_percent_check=0`' + '`validation_epoch_end` should not run when `limit_val_batches=0`' - # check that val_percent_check has no influence when fast_dev_run is turned on + # check that limit_val_batches has no influence when fast_dev_run is turned on model = CurrentModel(**hparams) trainer_options.update(fast_dev_run=True) trainer = Trainer(**trainer_options) @@ -722,7 +722,7 @@ def test_trainer_interrupted_flag(tmpdir): trainer = Trainer( callbacks=[interrupt_callback, handle_interrupt_callback], max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2, progress_bar_refresh_rate=0, logger=False, diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 973ed32e7c..9762d39f18 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -5,6 +5,87 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from torch.utils.data import RandomSampler, SequentialSampler, DataLoader + + +def test_overfit(tmpdir): + + # ------------------------------------------------------ + # Make sure shuffle is correct across loaders initially + # ------------------------------------------------------ + model = EvalModelTemplate() + model.train_dataloader() + + # original train loader which should be replaced in all methods + train_loader = model.train_dataloader() + + # make sure the val and tests are not shuffled + assert isinstance(train_loader.sampler, RandomSampler) + assert isinstance(model.val_dataloader().sampler, SequentialSampler) + assert isinstance(model.test_dataloader().sampler, SequentialSampler) + + # ------------------------------------------------------ + # get the training loader and batch + # ------------------------------------------------------ + train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False) + full_train_samples = len(train_loader) + num_train_samples = int(0.11 * full_train_samples) + + (xa, ya) = next(iter(train_loader)) + + # ------------------------------------------------------ + # set VAL and Test loaders + # ------------------------------------------------------ + val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False) + test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False) + + # set the model loaders + model.train_dataloader = lambda: train_loader + model.val_dataloader = lambda: val_loader + model.test_dataloader = lambda: test_loader + + # ------------------------------------------------------ + # run tests for both val and test + # ------------------------------------------------------ + for split in ['val', 'test']: + + # ------------------------------------------------------ + # test overfit_batches as percent + # ------------------------------------------------------ + loader_num_batches, dataloaders = Trainer(overfit_batches=0.11)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == num_train_samples + + # make sure we turned off shuffle for the user + assert isinstance(dataloaders[0].sampler, SequentialSampler) + + # make sure the loaders are the same + (xb, yb) = next(iter(dataloaders[0])) + assert torch.eq(xa, xb).all() + assert torch.eq(ya, yb).all() + + # ------------------------------------------------------ + # test overfit_batches as int + # ------------------------------------------------------ + loader_num_batches, dataloaders = Trainer(overfit_batches=1)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == 1 + loader_num_batches, dataloaders = Trainer(overfit_batches=5)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == 5 + + # ------------------------------------------------------ + # test limit_xxx_batches as percent AND int + # ------------------------------------------------------ + if split == 'val': + loader_num_batches, dataloaders = Trainer(limit_val_batches=0.1)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == int(0.1 * len(val_loader)) + + loader_num_batches, dataloaders = Trainer(limit_val_batches=10)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == 10 + else: + loader_num_batches, dataloaders = Trainer(limit_test_batches=0.1)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == int(0.1 * len(test_loader)) + + loader_num_batches, dataloaders = Trainer(limit_test_batches=10)._reset_eval_dataloader(model, split) + assert loader_num_batches[0] == 10 def test_model_reset_correctly(tmpdir): @@ -120,7 +201,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_percent_check=0.1, + limit_val_batches=0.1, train_percent_check=0.2, auto_scale_batch_size='power' )