diable val and test shuffling (#1600)

* diable val and test shuffling

* diable val and test shuffling

* diable val and test shuffling

* diable val and test shuffling

* log

* condition

* shuffle

* refactor

Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
William Falcon 2020-04-25 16:45:20 -04:00 committed by GitHub
parent 791ba91dec
commit b620d86c54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 23 additions and 9 deletions

View File

@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Updated semantic segmentation example with custom u-net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371))
- Diabled val and test shuffling ([#1600](https://github.com/PyTorchLightning/pytorch-lightning/pull/1600))
### Deprecated

View File

@ -1349,7 +1349,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
shuffle=False
)
return loader
@ -1394,7 +1394,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
shuffle=False
)
return loader

View File

@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.core import LightningModule
@ -195,8 +195,7 @@ class TrainerDataLoadingMixin(ABC):
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
def _reset_eval_dataloader(self, model: LightningModule,
mode: str) -> Tuple[int, List[DataLoader]]:
def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[int, List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
Args:
@ -211,6 +210,13 @@ class TrainerDataLoadingMixin(ABC):
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
# shuffling in val and test set is bad practice
for loader in dataloaders:
if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
raise MisconfigurationException(
f'Your {mode}_dataloader has shuffle=True, it is best practice to turn'
' this off for validation and test dataloaders.')
# add samplers
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl]

View File

@ -10,13 +10,14 @@ from tests.base.eval_model_test_epoch_ends import TestEpochEndVariations
from tests.base.eval_model_test_steps import TestStepVariations
from tests.base.eval_model_train_dataloaders import TrainDataloaderVariations
from tests.base.eval_model_train_steps import TrainingStepVariations
from tests.base.eval_model_utils import ModelTemplateUtils
from tests.base.eval_model_utils import ModelTemplateUtils, ModelTemplateData
from tests.base.eval_model_valid_dataloaders import ValDataloaderVariations
from tests.base.eval_model_valid_epoch_ends import ValidationEpochEndVariations
from tests.base.eval_model_valid_steps import ValidationStepVariations
class EvalModelTemplate(
ModelTemplateData,
ModelTemplateUtils,
TrainingStepVariations,
ValidationStepVariations,

View File

@ -3,7 +3,8 @@ from torch.utils.data import DataLoader
from tests.base.datasets import TrialMNIST
class ModelTemplateUtils:
class ModelTemplateData:
hparams: ...
def dataloader(self, train):
dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True)
@ -11,10 +12,14 @@ class ModelTemplateUtils:
loader = DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
# test and valid shall not be shuffled
shuffle=train,
)
return loader
class ModelTemplateUtils:
def get_output_metric(self, output, name):
if isinstance(output, dict):
val = output[name]

View File

@ -149,7 +149,7 @@ class TestModelBase(LightningModule):
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True
shuffle=train
)
return loader