diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c61cb82b7..5f5cebf89a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated semantic segmentation example with custom u-net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371)) +- Diabled val and test shuffling ([#1600](https://github.com/PyTorchLightning/pytorch-lightning/pull/1600)) + ### Deprecated diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d943a07461..a1f3eb4e92 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1349,7 +1349,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.hparams.batch_size, - shuffle=True + shuffle=False ) return loader @@ -1394,7 +1394,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks): loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.hparams.batch_size, - shuffle=True + shuffle=False ) return loader diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 71ef09349a..7d0b818052 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Union, List, Tuple, Callable import torch.distributed as torch_distrib -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, RandomSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule @@ -195,8 +195,7 @@ class TrainerDataLoadingMixin(ABC): self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - def _reset_eval_dataloader(self, model: LightningModule, - mode: str) -> Tuple[int, List[DataLoader]]: + def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[int, List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: @@ -211,6 +210,13 @@ class TrainerDataLoadingMixin(ABC): if not isinstance(dataloaders, list): dataloaders = [dataloaders] + # shuffling in val and test set is bad practice + for loader in dataloaders: + if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler): + raise MisconfigurationException( + f'Your {mode}_dataloader has shuffle=True, it is best practice to turn' + ' this off for validation and test dataloaders.') + # add samplers dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl] diff --git a/tests/base/eval_model_template.py b/tests/base/eval_model_template.py index cce59edd2a..77d83b483b 100644 --- a/tests/base/eval_model_template.py +++ b/tests/base/eval_model_template.py @@ -10,13 +10,14 @@ from tests.base.eval_model_test_epoch_ends import TestEpochEndVariations from tests.base.eval_model_test_steps import TestStepVariations from tests.base.eval_model_train_dataloaders import TrainDataloaderVariations from tests.base.eval_model_train_steps import TrainingStepVariations -from tests.base.eval_model_utils import ModelTemplateUtils +from tests.base.eval_model_utils import ModelTemplateUtils, ModelTemplateData from tests.base.eval_model_valid_dataloaders import ValDataloaderVariations from tests.base.eval_model_valid_epoch_ends import ValidationEpochEndVariations from tests.base.eval_model_valid_steps import ValidationStepVariations class EvalModelTemplate( + ModelTemplateData, ModelTemplateUtils, TrainingStepVariations, ValidationStepVariations, diff --git a/tests/base/eval_model_utils.py b/tests/base/eval_model_utils.py index 68618d17d1..e1a40f95b8 100644 --- a/tests/base/eval_model_utils.py +++ b/tests/base/eval_model_utils.py @@ -3,7 +3,8 @@ from torch.utils.data import DataLoader from tests.base.datasets import TrialMNIST -class ModelTemplateUtils: +class ModelTemplateData: + hparams: ... def dataloader(self, train): dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True) @@ -11,10 +12,14 @@ class ModelTemplateUtils: loader = DataLoader( dataset=dataset, batch_size=self.hparams.batch_size, - shuffle=True + # test and valid shall not be shuffled + shuffle=train, ) return loader + +class ModelTemplateUtils: + def get_output_metric(self, output, name): if isinstance(output, dict): val = output[name] diff --git a/tests/base/models.py b/tests/base/models.py index 66cf7ca2be..ebc6d75576 100644 --- a/tests/base/models.py +++ b/tests/base/models.py @@ -149,7 +149,7 @@ class TestModelBase(LightningModule): loader = DataLoader( dataset=dataset, batch_size=batch_size, - shuffle=True + shuffle=train ) return loader