* Follow up of #2892 * typo * iterabledataset
This commit is contained in:
parent
40eaa2143e
commit
85cd558a3f
|
@ -15,24 +15,18 @@
|
|||
import multiprocessing
|
||||
import platform
|
||||
from abc import ABC, abstractmethod
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Union, List, Tuple, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
|
||||
try:
|
||||
from torch.utils.data import IterableDataset
|
||||
ITERABLE_DATASET_EXISTS = True
|
||||
except ImportError:
|
||||
ITERABLE_DATASET_EXISTS = False
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -56,35 +50,6 @@ else:
|
|||
HOROVOD_AVAILABLE = True
|
||||
|
||||
|
||||
def _has_iterable_dataset(dataloader: DataLoader):
|
||||
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
|
||||
and isinstance(dataloader.dataset, IterableDataset)
|
||||
|
||||
|
||||
def _has_len(dataloader: DataLoader) -> bool:
|
||||
""" Checks if a given Dataloader has __len__ method implemented i.e. if
|
||||
it is a finite dataloader or infinite dataloader. """
|
||||
|
||||
try:
|
||||
# try getting the length
|
||||
if len(dataloader) == 0:
|
||||
raise ValueError('`Dataloader` returned 0 length.'
|
||||
' Please make sure that your Dataloader at least returns 1 batch')
|
||||
has_len = True
|
||||
except TypeError:
|
||||
has_len = False
|
||||
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
|
||||
has_len = False
|
||||
|
||||
if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
|
||||
rank_zero_warn(
|
||||
'Your `IterableDataset` has `__len__` defined.'
|
||||
' In combination with multi-processing data loading (e.g. batch size > 1),'
|
||||
' this can lead to unintended side effects since the samples will be duplicated.'
|
||||
)
|
||||
return has_len
|
||||
|
||||
|
||||
class TrainerDataLoadingMixin(ABC):
|
||||
|
||||
# this is just a summary on variables used in this abstract class,
|
||||
|
@ -147,7 +112,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# don't do anything if it's not a dataloader
|
||||
is_dataloader = isinstance(dataloader, DataLoader)
|
||||
# don't manipulate iterable datasets
|
||||
is_iterable_ds = _has_iterable_dataset(dataloader)
|
||||
is_iterable_ds = has_iterable_dataset(dataloader)
|
||||
|
||||
if not is_dataloader or is_iterable_ds:
|
||||
return dataloader
|
||||
|
@ -214,7 +179,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# automatically add samplers
|
||||
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
|
||||
|
||||
self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
|
||||
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
|
||||
self._worker_check(self.train_dataloader, 'train dataloader')
|
||||
|
||||
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
|
||||
|
@ -238,7 +203,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
f'to the number of the training batches ({self.num_training_batches}). '
|
||||
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
|
||||
else:
|
||||
if not _has_len(self.train_dataloader):
|
||||
if not has_len(self.train_dataloader):
|
||||
if self.val_check_interval == 1.0:
|
||||
self.val_check_batch = float('inf')
|
||||
else:
|
||||
|
@ -305,7 +270,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# datasets could be none, 1 or 2+
|
||||
if len(dataloaders) != 0:
|
||||
for i, dataloader in enumerate(dataloaders):
|
||||
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
|
||||
num_batches = len(dataloader) if has_len(dataloader) else float('inf')
|
||||
self._worker_check(dataloader, f'{mode} dataloader {i}')
|
||||
|
||||
# percent or num_steps
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
def has_iterable_dataset(dataloader: DataLoader):
|
||||
return hasattr(dataloader, 'dataset') and isinstance(dataloader.dataset, IterableDataset)
|
||||
|
||||
|
||||
def has_len(dataloader: DataLoader) -> bool:
|
||||
""" Checks if a given Dataloader has __len__ method implemented i.e. if
|
||||
it is a finite dataloader or infinite dataloader. """
|
||||
|
||||
try:
|
||||
# try getting the length
|
||||
if len(dataloader) == 0:
|
||||
raise ValueError('`Dataloader` returned 0 length.'
|
||||
' Please make sure that your Dataloader at least returns 1 batch')
|
||||
has_len = True
|
||||
except TypeError:
|
||||
has_len = False
|
||||
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
|
||||
has_len = False
|
||||
|
||||
if has_len and has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
|
||||
rank_zero_warn(
|
||||
'Your `IterableDataset` has `__len__` defined.'
|
||||
' In combination with multi-processing data loading (e.g. batch size > 1),'
|
||||
' this can lead to unintended side effects since the samples will be duplicated.'
|
||||
)
|
||||
return has_len
|
|
@ -11,7 +11,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
|
||||
import tests.base.develop_pipelines as tpipes
|
||||
from pytorch_lightning import Trainer, Callback
|
||||
from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len
|
||||
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
@ -624,7 +624,7 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path):
|
|||
reason="IterableDataset with __len__ before 1.4 raises",
|
||||
)
|
||||
def test_warning_with_iterable_dataset_and_len(tmpdir):
|
||||
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
|
||||
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
|
||||
model = EvalModelTemplate()
|
||||
original_dataset = model.train_dataloader().dataset
|
||||
|
||||
|
@ -637,8 +637,8 @@ def test_warning_with_iterable_dataset_and_len(tmpdir):
|
|||
return len(original_dataset)
|
||||
|
||||
dataloader = DataLoader(IterableWithLen(), batch_size=16)
|
||||
assert _has_len(dataloader)
|
||||
assert _has_iterable_dataset(dataloader)
|
||||
assert has_len(dataloader)
|
||||
assert has_iterable_dataset(dataloader)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=3,
|
||||
|
|
Loading…
Reference in New Issue