Follow up of #2892 (#3202)

* Follow up of #2892

* typo

* iterabledataset
This commit is contained in:
Rohit Gupta 2020-08-28 00:58:29 +05:30 committed by GitHub
parent 40eaa2143e
commit 85cd558a3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 44 deletions

View File

@ -15,24 +15,18 @@
import multiprocessing
import platform
from abc import ABC, abstractmethod
from distutils.version import LooseVersion
from typing import Union, List, Tuple, Callable, Optional
import torch
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.debugging import InternalDebugger
try:
from torch.utils.data import IterableDataset
ITERABLE_DATASET_EXISTS = True
except ImportError:
ITERABLE_DATASET_EXISTS = False
try:
from apex import amp
@ -56,35 +50,6 @@ else:
HOROVOD_AVAILABLE = True
def _has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
and isinstance(dataloader.dataset, IterableDataset)
def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader. """
try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
has_len = False
if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len
class TrainerDataLoadingMixin(ABC):
# this is just a summary on variables used in this abstract class,
@ -147,7 +112,7 @@ class TrainerDataLoadingMixin(ABC):
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
is_iterable_ds = _has_iterable_dataset(dataloader)
is_iterable_ds = has_iterable_dataset(dataloader)
if not is_dataloader or is_iterable_ds:
return dataloader
@ -214,7 +179,7 @@ class TrainerDataLoadingMixin(ABC):
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
@ -238,7 +203,7 @@ class TrainerDataLoadingMixin(ABC):
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
else:
if not _has_len(self.train_dataloader):
if not has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
else:
@ -305,7 +270,7 @@ class TrainerDataLoadingMixin(ABC):
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
num_batches = len(dataloader) if has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')
# percent or num_steps

View File

@ -0,0 +1,47 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
import torch
from torch.utils.data import DataLoader, IterableDataset
from pytorch_lightning.utilities import rank_zero_warn
def has_iterable_dataset(dataloader: DataLoader):
return hasattr(dataloader, 'dataset') and isinstance(dataloader.dataset, IterableDataset)
def has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader. """
try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
has_len = False
if has_len and has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len

View File

@ -11,7 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
@ -624,7 +624,7 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path):
reason="IterableDataset with __len__ before 1.4 raises",
)
def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
original_dataset = model.train_dataloader().dataset
@ -637,8 +637,8 @@ def test_warning_with_iterable_dataset_and_len(tmpdir):
return len(original_dataset)
dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert _has_len(dataloader)
assert _has_iterable_dataset(dataloader)
assert has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,