377 lines
16 KiB
Python
377 lines
16 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import functools
|
|
import inspect
|
|
import os
|
|
from contextlib import contextmanager
|
|
from functools import partial
|
|
from itertools import chain
|
|
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union
|
|
|
|
import torch
|
|
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
|
|
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
|
|
from pytorch_lightning.trainer.states import RunningStage
|
|
from pytorch_lightning.utilities import rank_zero_warn
|
|
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
|
|
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from pytorch_lightning.utilities.seed import pl_worker_init_function
|
|
from pytorch_lightning.utilities.warnings import WarningCache
|
|
|
|
BType = Union[torch.Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]
|
|
|
|
warning_cache = WarningCache()
|
|
|
|
|
|
def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
|
|
if isinstance(batch, torch.Tensor):
|
|
if batch.ndim == 0:
|
|
yield 1
|
|
else:
|
|
yield batch.size(0)
|
|
elif isinstance(batch, (Iterable, Mapping)) and not isinstance(batch, str):
|
|
if isinstance(batch, Mapping):
|
|
batch = batch.values()
|
|
|
|
for sample in batch:
|
|
yield from _extract_batch_size(sample)
|
|
else:
|
|
yield None
|
|
|
|
|
|
def extract_batch_size(batch: BType) -> int:
|
|
"""Unpack a batch to find a ``torch.Tensor``.
|
|
|
|
Returns:
|
|
``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable.
|
|
"""
|
|
error_msg = (
|
|
"We could not infer the batch_size from the batch. Either simplify its structure"
|
|
" or provide the batch_size as `self.log(..., batch_size=batch_size)`."
|
|
)
|
|
batch_size = None
|
|
try:
|
|
for bs in _extract_batch_size(batch):
|
|
if batch_size is None:
|
|
batch_size = bs
|
|
elif batch_size != bs:
|
|
warning_cache.warn(
|
|
"Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
|
|
f" found is {batch_size}. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`."
|
|
)
|
|
break
|
|
except RecursionError:
|
|
raise RecursionError(error_msg)
|
|
|
|
if batch_size is None:
|
|
raise MisconfigurationException(error_msg)
|
|
|
|
return batch_size
|
|
|
|
|
|
def has_iterable_dataset(dataloader: DataLoader) -> bool:
|
|
return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset)
|
|
|
|
|
|
def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
|
|
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
|
|
infinite dataloader.
|
|
|
|
Raises:
|
|
ValueError:
|
|
If the length of Dataloader is 0, as it requires at least one batch
|
|
"""
|
|
|
|
try:
|
|
# try getting the length
|
|
if len(dataloader) == 0:
|
|
raise ValueError("`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch")
|
|
has_len = True
|
|
except TypeError:
|
|
has_len = False
|
|
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
|
|
has_len = False
|
|
|
|
if has_len and has_iterable_dataset(dataloader):
|
|
rank_zero_warn(
|
|
"Your `IterableDataset` has `__len__` defined."
|
|
" In combination with multi-process data loading (when num_workers > 1),"
|
|
" `__len__` could be inaccurate if each worker is not configured independently"
|
|
" to avoid having duplicate data."
|
|
)
|
|
return has_len
|
|
|
|
|
|
def has_len_all_ranks(
|
|
dataloader: DataLoader,
|
|
training_type: "pl.TrainingTypePlugin",
|
|
model: Union["pl.LightningModule", "pl.LightningDataModule"],
|
|
) -> bool:
|
|
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
|
|
infinite dataloader.
|
|
|
|
Raises:
|
|
ValueError:
|
|
If the length of Dataloader is 0, as it requires at least one batch
|
|
"""
|
|
try:
|
|
total_length = training_type.reduce(torch.tensor(len(dataloader)).to(model.device), reduce_op="sum")
|
|
local_length = len(dataloader)
|
|
|
|
if total_length == 0:
|
|
raise MisconfigurationException(
|
|
"Total length of `Dataloader` across ranks is zero. Please make sure that it returns at least 1 batch."
|
|
)
|
|
if total_length > 0 and local_length == 0:
|
|
if model.allow_zero_length_dataloader_with_multiple_devices:
|
|
rank_zero_warn(
|
|
"Total length of `Dataloader` across ranks is zero, but local rank has zero length."
|
|
" Please be cautious of uneven batch length."
|
|
)
|
|
has_len = False
|
|
else:
|
|
raise MisconfigurationException(
|
|
"`Dataloader` within local rank has zero length. Please make sure that it returns at least 1 batch."
|
|
)
|
|
else:
|
|
has_len = True
|
|
|
|
except TypeError:
|
|
has_len = False
|
|
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
|
|
has_len = False
|
|
|
|
if has_len and has_iterable_dataset(dataloader):
|
|
rank_zero_warn(
|
|
"Your `IterableDataset` has `__len__` defined."
|
|
" In combination with multi-process data loading (when num_workers > 1),"
|
|
" `__len__` could be inaccurate if each worker is not configured independently"
|
|
" to avoid having duplicate data."
|
|
)
|
|
return has_len
|
|
|
|
|
|
def get_len(dataloader: DataLoader) -> Union[int, float]:
|
|
"""Return the length of the given DataLoader.
|
|
|
|
If ``__len__`` method is not implemented, return float('inf').
|
|
"""
|
|
|
|
if has_len(dataloader):
|
|
return len(dataloader)
|
|
|
|
return float("inf")
|
|
|
|
|
|
def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader:
|
|
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
|
|
dl_cls = type(dataloader)
|
|
try:
|
|
dataloader = dl_cls(**dl_kwargs)
|
|
except TypeError as e:
|
|
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
|
|
# `__init__` arguments map to one `DataLoader.__init__` argument
|
|
import re
|
|
|
|
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
|
|
if not match:
|
|
# an unexpected `TypeError`, continue failure
|
|
raise
|
|
argument = match.groups()[0]
|
|
message = (
|
|
f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument"
|
|
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
|
|
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
|
|
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
|
|
)
|
|
raise MisconfigurationException(message) from e
|
|
return dataloader
|
|
|
|
|
|
def _get_dataloader_init_kwargs(
|
|
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
|
|
) -> Dict[str, Any]:
|
|
if not isinstance(dataloader, DataLoader):
|
|
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
|
|
|
|
# get the dataloader instance attributes
|
|
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
|
|
# not part of `vars`
|
|
attrs["multiprocessing_context"] = dataloader.multiprocessing_context
|
|
|
|
# get the dataloader instance `__init__` parameters
|
|
params = dict(inspect.signature(dataloader.__init__).parameters)
|
|
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
|
|
if has_variadic_kwargs:
|
|
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
|
|
params.update(inspect.signature(DataLoader.__init__).parameters)
|
|
del params["self"]
|
|
|
|
# keep only the params whose default is different to the current attr value
|
|
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
|
|
# add `dataset` as it might have been replaced with `*args`
|
|
non_defaults.add("dataset")
|
|
|
|
# kwargs to re-construct the dataloader
|
|
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
|
|
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode))
|
|
|
|
required_args = {
|
|
p.name
|
|
for p in params.values()
|
|
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs
|
|
}
|
|
# the dataloader has required args which we could not extract from the existing attributes
|
|
if required_args:
|
|
required_args = sorted(required_args)
|
|
dataloader_cls_name = dataloader.__class__.__name__
|
|
raise MisconfigurationException(
|
|
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
|
|
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
|
|
f"The missing attributes are {required_args}. "
|
|
f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
|
|
"manually add the `DistributedSampler` as: "
|
|
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
|
|
)
|
|
|
|
if not has_variadic_kwargs:
|
|
# the dataloader signature does not allow keyword arguments that need to be passed
|
|
missing_kwargs = dl_kwargs.keys() - params.keys()
|
|
if missing_kwargs:
|
|
missing_kwargs = sorted(missing_kwargs)
|
|
dataloader_cls_name = dataloader.__class__.__name__
|
|
raise MisconfigurationException(
|
|
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
|
|
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
|
|
f"The missing arguments are {missing_kwargs}. "
|
|
f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
|
|
"manually add the `DistributedSampler` as: "
|
|
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
|
|
)
|
|
|
|
if isinstance(dl_kwargs["dataset"], IterableDataset):
|
|
dl_kwargs["batch_sampler"] = None
|
|
dl_kwargs["sampler"] = None
|
|
|
|
if _FaultTolerantMode.detect_current_mode().is_automatic:
|
|
dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs)
|
|
|
|
return dl_kwargs
|
|
|
|
|
|
def _dataloader_init_kwargs_resolve_sampler(
|
|
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
|
|
) -> Dict[str, Any]:
|
|
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
|
|
re-instantiation.
|
|
|
|
If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
|
|
Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a
|
|
`FastForwardSampler`.
|
|
"""
|
|
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
|
|
batch_sampler = getattr(dataloader, "batch_sampler")
|
|
is_predicting = mode == RunningStage.PREDICTING
|
|
# checking the batch sampler type is different than PyTorch default.
|
|
if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting):
|
|
batch_sampler = type(batch_sampler)(
|
|
sampler,
|
|
batch_size=batch_sampler.batch_size,
|
|
drop_last=(False if is_predicting else batch_sampler.drop_last),
|
|
)
|
|
if is_predicting:
|
|
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
|
|
|
|
if fault_tolerant_mode.is_automatic:
|
|
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
|
|
fast_forward_sampler.setup(dataloader_batch_size=1)
|
|
|
|
return {
|
|
"sampler": None,
|
|
"shuffle": False,
|
|
"batch_sampler": batch_sampler,
|
|
"batch_size": 1,
|
|
"drop_last": False,
|
|
}
|
|
|
|
if fault_tolerant_mode.is_automatic:
|
|
fast_forward_sampler = sampler = FastForwardSampler(sampler)
|
|
fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)
|
|
|
|
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
|
|
|
|
|
|
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
|
|
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
|
|
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)
|
|
|
|
|
|
def _wrap_init(init: Callable) -> Callable:
|
|
"""Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
|
|
:class:`~torch.utils.data.DataLoader`."""
|
|
|
|
@functools.wraps(init)
|
|
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
|
|
params = dict(inspect.signature(obj.__init__).parameters)
|
|
params.pop("args", None)
|
|
params.pop("kwargs", None)
|
|
for arg_name, arg_value in chain(zip(params, args), kwargs.items()):
|
|
setattr(obj, arg_name, arg_value)
|
|
init(obj, *args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
# https://stackoverflow.com/a/63851681/9201239
|
|
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
|
|
"""Returns a list of all classes that inherit directly or indirectly from the given class."""
|
|
subclasses = set()
|
|
|
|
def recurse(cl: Type[Any]) -> None:
|
|
for subclass in cl.__subclasses__():
|
|
subclasses.add(subclass)
|
|
recurse(subclass)
|
|
|
|
recurse(cls)
|
|
return subclasses
|
|
|
|
|
|
@contextmanager
|
|
def _replace_dataloader_init_method() -> Generator[None, None, None]:
|
|
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
|
|
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
|
|
subclasses = _get_all_subclasses(DataLoader)
|
|
for subclass in subclasses:
|
|
subclass._old_init = subclass.__init__
|
|
subclass.__init__ = _wrap_init(subclass.__init__)
|
|
yield
|
|
for subclass in subclasses:
|
|
subclass.__init__ = subclass._old_init
|
|
del subclass._old_init
|
|
|
|
|
|
def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict:
|
|
dataset = dl_kwargs["dataset"]
|
|
if isinstance(dataset, IterableDataset):
|
|
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
|
|
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dataset)
|
|
elif get_len(dataset) != float("inf"):
|
|
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dataset)
|
|
else:
|
|
raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.")
|
|
return dl_kwargs
|