Warn user when IterableDataset has __len__ defined (#2437)

* add warning when getting checking len

* added test

* changelog

* pep

* do not show warning below 1.4

* try version parse

* comments

* xfail

* Update requirements/base.txt

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/data_loading.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* version

Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
This commit is contained in:
Adrian Wälchli 2020-07-01 13:53:19 +02:00 committed by GitHub
parent 325852c6df
commit 927f305f7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 17 deletions

View File

@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added reduce ddp results on eval ([#2434](https://github.com/PyTorchLightning/pytorch-lightning/pull/2434))
- Added a warning when an `IterableDataset` has `__len__` defined ([#2437](https://github.com/PyTorchLightning/pytorch-lightning/pull/2437))
### Changed

View File

@ -1,8 +1,10 @@
import multiprocessing
import platform
from abc import ABC, abstractmethod
from distutils.version import LooseVersion
from typing import Union, List, Tuple, Callable, Optional
import multiprocessing
import torch
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
@ -41,19 +43,33 @@ else:
HOROVOD_AVAILABLE = True
def _has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
and isinstance(dataloader.dataset, IterableDataset)
def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader """
it is a finite dataloader or infinite dataloader. """
try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
return True
has_len = True
except TypeError:
return False
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
return False
has_len = False
if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len
class TrainerDataLoadingMixin(ABC):
@ -128,12 +144,9 @@ class TrainerDataLoadingMixin(ABC):
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
# don't manipulate iterable datasets
is_dataloader = isinstance(dataloader, DataLoader)
is_iterable_ds = False
if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'):
is_iterable_ds = isinstance(dataloader.dataset, IterableDataset)
# don't manipulate iterable datasets
is_iterable_ds = _has_iterable_dataset(dataloader)
if not is_dataloader or is_iterable_ds:
return dataloader
@ -285,11 +298,7 @@ class TrainerDataLoadingMixin(ABC):
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
try:
num_batches = len(dataloader)
except (TypeError, NotImplementedError):
num_batches = float('inf')
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')
# percent or num_steps

View File

@ -6,4 +6,4 @@ tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement
tqdm>=4.41.0
tqdm>=4.41.0

View File

@ -2,11 +2,13 @@ import platform
import pytest
import torch
from packaging.version import parse
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset
from torch.utils.data.dataset import Subset, IterableDataset
import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.data_loading import _has_len, _has_iterable_dataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
@ -487,6 +489,36 @@ def test_warning_with_few_workers(tmpdir, ckpt_path):
trainer.test(**test_options)
@pytest.mark.xfail(
parse(torch.__version__) < parse("1.4.0"),
reason="IterableDataset with __len__ before 1.4 raises",
)
def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
original_dataset = model.train_dataloader().dataset
class IterableWithLen(IterableDataset):
def __iter__(self):
return iter(original_dataset)
def __len__(self):
return len(original_dataset)
dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert _has_len(dataloader)
assert _has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,
)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_reinit_for_subclass():