Warn user when IterableDataset has __len__ defined (#2437)
* add warning when getting checking len * added test * changelog * pep * do not show warning below 1.4 * try version parse * comments * xfail * Update requirements/base.txt Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/data_loading.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * version Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
This commit is contained in:
parent
325852c6df
commit
927f305f7e
|
@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added reduce ddp results on eval ([#2434](https://github.com/PyTorchLightning/pytorch-lightning/pull/2434))
|
||||
|
||||
- Added a warning when an `IterableDataset` has `__len__` defined ([#2437](https://github.com/PyTorchLightning/pytorch-lightning/pull/2437))
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import multiprocessing
|
||||
import platform
|
||||
from abc import ABC, abstractmethod
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Union, List, Tuple, Callable, Optional
|
||||
import multiprocessing
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
@ -41,19 +43,33 @@ else:
|
|||
HOROVOD_AVAILABLE = True
|
||||
|
||||
|
||||
def _has_iterable_dataset(dataloader: DataLoader):
|
||||
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
|
||||
and isinstance(dataloader.dataset, IterableDataset)
|
||||
|
||||
|
||||
def _has_len(dataloader: DataLoader) -> bool:
|
||||
""" Checks if a given Dataloader has __len__ method implemented i.e. if
|
||||
it is a finite dataloader or infinite dataloader """
|
||||
it is a finite dataloader or infinite dataloader. """
|
||||
|
||||
try:
|
||||
# try getting the length
|
||||
if len(dataloader) == 0:
|
||||
raise ValueError('`Dataloader` returned 0 length.'
|
||||
' Please make sure that your Dataloader at least returns 1 batch')
|
||||
return True
|
||||
has_len = True
|
||||
except TypeError:
|
||||
return False
|
||||
has_len = False
|
||||
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
|
||||
return False
|
||||
has_len = False
|
||||
|
||||
if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
|
||||
rank_zero_warn(
|
||||
'Your `IterableDataset` has `__len__` defined.'
|
||||
' In combination with multi-processing data loading (e.g. batch size > 1),'
|
||||
' this can lead to unintended side effects since the samples will be duplicated.'
|
||||
)
|
||||
return has_len
|
||||
|
||||
|
||||
class TrainerDataLoadingMixin(ABC):
|
||||
|
@ -128,12 +144,9 @@ class TrainerDataLoadingMixin(ABC):
|
|||
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
|
||||
|
||||
# don't do anything if it's not a dataloader
|
||||
# don't manipulate iterable datasets
|
||||
is_dataloader = isinstance(dataloader, DataLoader)
|
||||
|
||||
is_iterable_ds = False
|
||||
if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'):
|
||||
is_iterable_ds = isinstance(dataloader.dataset, IterableDataset)
|
||||
# don't manipulate iterable datasets
|
||||
is_iterable_ds = _has_iterable_dataset(dataloader)
|
||||
|
||||
if not is_dataloader or is_iterable_ds:
|
||||
return dataloader
|
||||
|
@ -285,11 +298,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# datasets could be none, 1 or 2+
|
||||
if len(dataloaders) != 0:
|
||||
for i, dataloader in enumerate(dataloaders):
|
||||
try:
|
||||
num_batches = len(dataloader)
|
||||
except (TypeError, NotImplementedError):
|
||||
num_batches = float('inf')
|
||||
|
||||
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
|
||||
self._worker_check(dataloader, f'{mode} dataloader {i}')
|
||||
|
||||
# percent or num_steps
|
||||
|
|
|
@ -6,4 +6,4 @@ tensorboard>=1.14
|
|||
future>=0.17.1 # required for builtins in setup.py
|
||||
# pyyaml>=3.13
|
||||
PyYAML>=5.1 # OmegaConf requirement
|
||||
tqdm>=4.41.0
|
||||
tqdm>=4.41.0
|
||||
|
|
|
@ -2,11 +2,13 @@ import platform
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import parse
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import Subset
|
||||
from torch.utils.data.dataset import Subset, IterableDataset
|
||||
|
||||
import tests.base.develop_pipelines as tpipes
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.trainer.data_loading import _has_len, _has_iterable_dataset
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
@ -487,6 +489,36 @@ def test_warning_with_few_workers(tmpdir, ckpt_path):
|
|||
trainer.test(**test_options)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
parse(torch.__version__) < parse("1.4.0"),
|
||||
reason="IterableDataset with __len__ before 1.4 raises",
|
||||
)
|
||||
def test_warning_with_iterable_dataset_and_len(tmpdir):
|
||||
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
|
||||
model = EvalModelTemplate()
|
||||
original_dataset = model.train_dataloader().dataset
|
||||
|
||||
class IterableWithLen(IterableDataset):
|
||||
|
||||
def __iter__(self):
|
||||
return iter(original_dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(original_dataset)
|
||||
|
||||
dataloader = DataLoader(IterableWithLen(), batch_size=16)
|
||||
assert _has_len(dataloader)
|
||||
assert _has_iterable_dataset(dataloader)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=3,
|
||||
)
|
||||
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
|
||||
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
|
||||
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
|
||||
trainer.test(model, test_dataloaders=[dataloader])
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
|
||||
def test_dataloader_reinit_for_subclass():
|
||||
|
||||
|
|
Loading…
Reference in New Issue