diable val and test shuffling (#1600)
* diable val and test shuffling * diable val and test shuffling * diable val and test shuffling * diable val and test shuffling * log * condition * shuffle * refactor Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
parent
791ba91dec
commit
b620d86c54
|
@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Updated semantic segmentation example with custom u-net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371))
|
||||
|
||||
- Diabled val and test shuffling ([#1600](https://github.com/PyTorchLightning/pytorch-lightning/pull/1600))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
|
|
@ -1349,7 +1349,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=True
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
return loader
|
||||
|
@ -1394,7 +1394,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
loader = torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=True
|
||||
shuffle=False
|
||||
)
|
||||
|
||||
return loader
|
||||
|
|
|
@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import Union, List, Tuple, Callable
|
||||
|
||||
import torch.distributed as torch_distrib
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from pytorch_lightning.core import LightningModule
|
||||
|
@ -195,8 +195,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
||||
self.val_check_batch = max(1, self.val_check_batch)
|
||||
|
||||
def _reset_eval_dataloader(self, model: LightningModule,
|
||||
mode: str) -> Tuple[int, List[DataLoader]]:
|
||||
def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[int, List[DataLoader]]:
|
||||
"""Generic method to reset a dataloader for evaluation.
|
||||
|
||||
Args:
|
||||
|
@ -211,6 +210,13 @@ class TrainerDataLoadingMixin(ABC):
|
|||
if not isinstance(dataloaders, list):
|
||||
dataloaders = [dataloaders]
|
||||
|
||||
# shuffling in val and test set is bad practice
|
||||
for loader in dataloaders:
|
||||
if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
|
||||
raise MisconfigurationException(
|
||||
f'Your {mode}_dataloader has shuffle=True, it is best practice to turn'
|
||||
' this off for validation and test dataloaders.')
|
||||
|
||||
# add samplers
|
||||
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl]
|
||||
|
||||
|
|
|
@ -10,13 +10,14 @@ from tests.base.eval_model_test_epoch_ends import TestEpochEndVariations
|
|||
from tests.base.eval_model_test_steps import TestStepVariations
|
||||
from tests.base.eval_model_train_dataloaders import TrainDataloaderVariations
|
||||
from tests.base.eval_model_train_steps import TrainingStepVariations
|
||||
from tests.base.eval_model_utils import ModelTemplateUtils
|
||||
from tests.base.eval_model_utils import ModelTemplateUtils, ModelTemplateData
|
||||
from tests.base.eval_model_valid_dataloaders import ValDataloaderVariations
|
||||
from tests.base.eval_model_valid_epoch_ends import ValidationEpochEndVariations
|
||||
from tests.base.eval_model_valid_steps import ValidationStepVariations
|
||||
|
||||
|
||||
class EvalModelTemplate(
|
||||
ModelTemplateData,
|
||||
ModelTemplateUtils,
|
||||
TrainingStepVariations,
|
||||
ValidationStepVariations,
|
||||
|
|
|
@ -3,7 +3,8 @@ from torch.utils.data import DataLoader
|
|||
from tests.base.datasets import TrialMNIST
|
||||
|
||||
|
||||
class ModelTemplateUtils:
|
||||
class ModelTemplateData:
|
||||
hparams: ...
|
||||
|
||||
def dataloader(self, train):
|
||||
dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True)
|
||||
|
@ -11,10 +12,14 @@ class ModelTemplateUtils:
|
|||
loader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.hparams.batch_size,
|
||||
shuffle=True
|
||||
# test and valid shall not be shuffled
|
||||
shuffle=train,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
class ModelTemplateUtils:
|
||||
|
||||
def get_output_metric(self, output, name):
|
||||
if isinstance(output, dict):
|
||||
val = output[name]
|
||||
|
|
|
@ -149,7 +149,7 @@ class TestModelBase(LightningModule):
|
|||
loader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True
|
||||
shuffle=train
|
||||
)
|
||||
|
||||
return loader
|
||||
|
|
Loading…
Reference in New Issue