Improve support for custom `DataLoader`s when instantiated in `*_dataloader` hook (#12981)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
otaj 2022-06-22 01:53:24 +02:00 committed by GitHub
parent bbc51d16a3
commit 2e9cd72add
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 303 additions and 127 deletions

View File

@ -106,9 +106,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise an error if there are insufficient training batches when using a float value of `limit_train_batches` ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885))
- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
- `DataLoader` instantiated inside a `*_dataloader` hook will not set the passed arguments as attributes anymore ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))
- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
### Deprecated
@ -229,6 +231,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))

View File

@ -29,7 +29,7 @@ from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
@ -228,9 +228,9 @@ class IPUStrategy(ParallelStrategy):
# the user is returning the `poptorch.DataLoader` directly, don't change anything.
return dataloader
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = poptorch.DataLoader(opts, **dl_kwargs)
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
return dataloader
def _handle_gradient_accumulation_steps(self) -> None:

View File

@ -17,12 +17,19 @@ import os
from contextlib import contextmanager
from dataclasses import fields
from functools import partial
from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union
import torch
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler
from torch.utils.data import (
BatchSampler,
DataLoader,
Dataset,
IterableDataset,
RandomSampler,
Sampler,
SequentialSampler,
)
import pytorch_lightning as pl
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
@ -179,10 +186,10 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
def _update_dataloader(
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
) -> DataLoader:
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode=mode)
dl_cls = type(dataloader)
try:
dataloader = dl_cls(**dl_kwargs)
dataloader = dl_cls(*dl_args, **dl_kwargs)
except TypeError as e:
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
# `__init__` arguments map to one `DataLoader.__init__` argument
@ -198,38 +205,62 @@ def _update_dataloader(
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
" This argument was automatically passed to your DataLoader by PyTorch Lightning."
)
raise MisconfigurationException(message) from e
return dataloader
def _get_dataloader_init_kwargs(
def _get_dataloader_init_args_and_kwargs(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
) -> Dict[str, Any]:
) -> Tuple[Tuple[Any], Dict[str, Any]]:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
# not part of `vars`
attrs["multiprocessing_context"] = dataloader.multiprocessing_context
was_wrapped = hasattr(dataloader, "__pl_dl_args")
if was_wrapped:
dl_args = dataloader.__pl_dl_args
dl_kwargs = dataloader.__pl_dl_kwargs
arg_names = dataloader.__pl_dl_arg_names
original_dataset = dataloader.__dataset # we have this saved from _wrap_init
else:
# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
# We cannot be 100% sure the class sets dataset argument. Let's set it to None to be safe
# and hope we can get it from the instance attributes
original_dataset = None
# not part of `vars`
attrs["multiprocessing_context"] = dataloader.multiprocessing_context
arg_names = ()
# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
params.update(inspect.signature(DataLoader.__init__).parameters)
del params["self"]
# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_defaults.add("dataset")
if was_wrapped:
# if the dataloader was wrapped in a hook, only take arguments with default values
# and assume user passes their kwargs correctly
params.update(
{k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty}
)
else:
params.update(inspect.signature(DataLoader.__init__).parameters)
params.pop("self", None)
# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
if isinstance(dl_kwargs["dataset"], IterableDataset):
if not was_wrapped:
# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_defaults.add("dataset")
# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
dl_args = ()
dataset = dl_kwargs.get("dataset", original_dataset)
if isinstance(dataset, IterableDataset):
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None
else:
@ -238,40 +269,43 @@ def _get_dataloader_init_kwargs(
required_args = {
p.name
for p in params.values()
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
and p.default is p.empty
and p.name not in dl_kwargs
and p.name not in arg_names
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
"`*_dataloader` hook of your module, we will do this for you."
f" Otherwise, define {missing_args_message} inside your `__init__`."
)
if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = dl_kwargs.keys() - params.keys()
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
"add the `__init__` arguments or allow passing `**kwargs`"
)
if _FaultTolerantMode.detect_current_mode().is_automatic:
dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs)
dl_args, dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(
was_wrapped, arg_names, dl_args, dl_kwargs
)
return dl_kwargs
return dl_args, dl_kwargs
def _dataloader_init_kwargs_resolve_sampler(
@ -321,30 +355,35 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)
def _wrap_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
:class:`~torch.utils.data.DataLoader`."""
def _wrap_dataloader_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation
of custom subclasses."""
@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
# We need to inspect `init`, as inspecting `obj.__init__`
# can lead to inspecting the wrong function with multiple inheritance
params = inspect.signature(init).parameters
param_names = [
param_names = tuple(
param.name
for param in params.values()
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
]
)
param_names = param_names[: len(args)]
if not hasattr(obj, "__pl_dl_args"):
obj.__pl_dl_args = args
obj.__pl_dl_kwargs = kwargs
obj.__pl_dl_arg_names = param_names
# We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader)
# so that we can be sure, that it will not get changed anymore.
# That is why we are setting this in every `__init__`
if "dataset" in param_names:
setattr(obj, "__dataset", args[param_names.index("dataset")])
elif "dataset" in kwargs:
setattr(obj, "__dataset", kwargs["dataset"])
cls = type(obj)
for arg_name, arg_value in chain(zip(param_names, args), kwargs.items()):
if hasattr(cls, arg_name) and getattr(cls, arg_name).fset is None:
# the class defines a read-only (no setter) property of this name. it's likely that the implementation
# will set `self._arg_name = arg_value` in `__init__` which is the attribute returned by the `arg_name`
# property so we are fine skipping in that case
continue
setattr(obj, arg_name, arg_value)
init(obj, *args, **kwargs)
return wrapper
@ -368,33 +407,63 @@ def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
subclasses = _get_all_subclasses(DataLoader)
for subclass in subclasses:
subclass._old_init = subclass.__init__
subclass.__init__ = _wrap_init(subclass.__init__)
classes = _get_all_subclasses(DataLoader) | {DataLoader}
wrapped = set()
for cls in classes:
if cls.__init__ not in wrapped:
cls._old_init = cls.__init__
cls.__init__ = _wrap_dataloader_init(cls.__init__)
wrapped.add(cls.__init__)
yield
for subclass in subclasses:
subclass.__init__ = subclass._old_init
del subclass._old_init
for cls in classes:
if hasattr(cls, "_old_init"):
cls.__init__ = cls._old_init
del cls._old_init
def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict:
dataset = dl_kwargs["dataset"]
def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset:
if isinstance(dataset, IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dataset)
elif get_len(dataset) != float("inf"):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dataset)
return CaptureIterableDataset(dataset=dataset)
if get_len(dataset) != float("inf"):
return CaptureMapDataset(dataset=dataset)
raise RuntimeError("This shouldn't happen, please open an issue on Lightning Github repository.")
def _apply_fault_tolerant_automatic_capture_dataset_wrapper(
was_wrapped: bool, arg_names: Tuple[str, ...], dl_args: Tuple[Any, ...], dl_kwargs: Dict[str, Any]
) -> Tuple[Tuple[str, ...], Dict[str, Any]]:
if "dataset" in dl_kwargs:
dl_kwargs["dataset"] = _wrap_with_capture_dataset(dl_kwargs["dataset"])
elif "dataset" in arg_names:
dataset_idx = arg_names.index("dataset")
dataset = _wrap_with_capture_dataset(dl_args[dataset_idx])
dl_args = dl_args[:dataset_idx] + (dataset,) + dl_args[dataset_idx + 1 :]
else:
raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.")
return dl_kwargs
if was_wrapped:
avoid_message = (
" To avoid this, either pass `DataLoader(dataset=your_dataset)` or the positional dataset argument"
" `DataLoader(your_dataset, ...)`."
)
else:
avoid_message = " To avoid this, define `self.dataset = dataset` inside your DataLoader's `__init__`."
raise MisconfigurationException(
"You enabled automatic Fault Tolerant mode, but we were not able to replace your dataset"
" with Fault Tolerant wrapper, because you have a custom DataLoader." + avoid_message
)
return dl_args, dl_kwargs
def _is_dataloader_shuffled(dataloader: object) -> bool:
if hasattr(dataloader, "shuffle"):
if hasattr(dataloader, "__pl_dl_kwargs"):
# this attribute is not part of PyTorch's DataLoader, but could have been set by
# our `_replace_dataloader_init_method` context manager
return dataloader.shuffle
if "shuffle" in dataloader.__pl_dl_kwargs:
return dataloader.__pl_dl_kwargs["shuffle"]
if "shuffle" in dataloader.__pl_dl_arg_names:
return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")]
if isinstance(dataloader.dataset, IterableDataset):
# shuffling is useless with iterable datasets
return False

View File

@ -201,7 +201,7 @@ def test_setup_dataloaders_raises_for_unknown_custom_args():
with pytest.raises(
MisconfigurationException,
match=(
r"Trying to inject `DistributedSampler` into the `CustomDataLoader` instance.*"
r"Trying to inject custom `Sampler` into the `CustomDataLoader` instance.*"
r"The missing attributes are \['new_arg'\]"
),
):

View File

@ -9,7 +9,7 @@ from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.data import (
_get_dataloader_init_kwargs,
_get_dataloader_init_args_and_kwargs,
_replace_dataloader_init_method,
_update_dataloader,
extract_batch_size,
@ -134,23 +134,28 @@ def test_has_len_all_rank():
def test_update_dataloader_typerror_custom_exception():
class BadImpl(DataLoader):
class BadStandaloneGoodHookImpl(DataLoader):
def __init__(self, foo, *args, **kwargs):
self.foo = foo
# positional conflict with `dataset`
super().__init__(foo, *args, **kwargs)
dataloader = BadImpl([1, 2, 3])
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"):
_update_dataloader(dataloader, dataloader.sampler)
class BadImpl2(DataLoader):
with _replace_dataloader_init_method():
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(new_dataloader, BadStandaloneGoodHookImpl)
class BadImpl(DataLoader):
def __init__(self, randomize, *args, **kwargs):
self.randomize = randomize
# keyword conflict with `shuffle`
super().__init__(*args, shuffle=randomize, **kwargs)
dataloader = BadImpl2(False, [])
dataloader = BadImpl(False, [])
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"):
_update_dataloader(dataloader, dataloader.sampler)
@ -165,69 +170,165 @@ def test_update_dataloader_typerror_custom_exception():
assert isinstance(new_dataloader, GoodImpl)
def test_replace_dataloader_init_method():
"""Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and
sets them as attributes."""
class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
self.at1 = attribute1
super().__init__(*args, **kwargs)
class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute1 = attribute1
super().__init__(*args, **kwargs)
class DataLoaderSubclass2(DataLoaderSubclass1):
def __init__(self, attribute2, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute2 = attribute2
super().__init__(attribute2 + "-2", *args, **kwargs)
class DataLoaderSubclass2(DataLoaderSubclass1):
def __init__(self, attribute2, *args, **kwargs):
self.at2 = attribute2
super().__init__(attribute2 + "-2", *args, **kwargs)
class MyBaseDataLoader(DataLoader):
pass
class MyDataLoader(MyBaseDataLoader):
def __init__(self, data: torch.Tensor, *args, **kwargs):
self.data = data
super().__init__(range(data.size(0)), *args, **kwargs)
test3_data = torch.randn((10, 20))
class PoptorchDataLoader(DataLoader):
def __init__(self, options, *args, **kwargs):
super().__init__(*args, **kwargs)
self._options = options
@property
def options(self):
return self._options
class IncompleteDataLoader(DataLoader):
def __init__(self, dataset, batch_size, **kwargs):
batch_size = max(batch_size - 5, 0)
super().__init__(dataset, batch_size=batch_size, **kwargs)
class WeirdDataLoader1(DataLoader):
def __init__(self, arg1, arg2, **kwargs):
self.arg1 = arg1
super().__init__(arg2, **kwargs)
class WeirdDataLoader2(DataLoader):
def __init__(self, data_part1, data_part2, **kwargs):
data = list(data_part1) + list(data_part2)
super().__init__(data, **kwargs)
class NoneDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class ChangingDataLoader(DataLoader):
def __init__(self, dataset, **kwargs):
super().__init__(list(dataset) + list(range(5, 10)), **kwargs)
@pytest.mark.parametrize(
["cls", "args", "kwargs", "arg_names", "dataset", "checked_values"],
[
pytest.param(
DataLoaderSubclass1,
("attribute1",),
dict(dataset=range(4), batch_size=2),
("attribute1",),
range(4),
dict(batch_size=2, at1="attribute1"),
id="test1",
),
pytest.param(
DataLoaderSubclass2,
("attribute2",),
dict(dataset=range(4), batch_size=2),
("attribute2",),
range(4),
dict(batch_size=2, at1="attribute2-2", at2="attribute2"),
id="test2",
),
pytest.param(
MyDataLoader,
(test3_data,),
dict(batch_size=2),
("data",),
range(10),
dict(batch_size=2, data=test3_data),
id="test3",
),
pytest.param(PoptorchDataLoader, (123, [1]), dict(), ("options",), [1], dict(options=123), id="test4"),
pytest.param(
IncompleteDataLoader,
(range(10),),
dict(batch_size=10),
("dataset",),
range(10),
dict(batch_size=5),
id="test5",
),
pytest.param(
WeirdDataLoader1,
(10, range(10)),
dict(batch_size=10),
("arg1", "arg2"),
range(10),
dict(arg1=10, batch_size=10),
id="test6",
),
pytest.param(
WeirdDataLoader2,
(range(10), range(10, 20)),
dict(batch_size=10),
("data_part1", "data_part2"),
list(range(20)),
dict(batch_size=10),
id="test7",
),
pytest.param(NoneDataLoader, (None,), dict(), (), None, dict(), id="test8"),
pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"),
],
)
def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, checked_values):
with _replace_dataloader_init_method():
dataloader = DataLoaderSubclass1("attribute1", dataset=range(4), batch_size=2)
dataloader = cls(*args, **kwargs)
assert dataloader.attribute1 == "attribute1"
assert dataloader.__pl_dl_args == args
assert dataloader.__pl_dl_kwargs == kwargs
assert dataloader.__pl_dl_arg_names == arg_names
assert dataloader.__dataset == dataset
with _replace_dataloader_init_method():
dataloader = DataLoaderSubclass2("attribute2", dataset=range(4), batch_size=2)
assert dataloader.dataset == dataset
assert dataloader.attribute1 == "attribute2-2"
assert dataloader.attribute2 == "attribute2"
for key, value in checked_values.items():
dataloader_value = getattr(dataloader, key)
if isinstance(dataloader_value, torch.Tensor):
assert dataloader_value is value
else:
assert getattr(dataloader, key) == value
# Failing test case from issue 12564
class MyBaseDataLoader(DataLoader):
pass
dataloader = _update_dataloader(dataloader, dataloader.sampler)
class MyDataLoader(MyBaseDataLoader):
def __init__(self, data: torch.Tensor, *args, **kwargs):
self.data = data
super().__init__(range(data.size(0)), *args, **kwargs)
assert isinstance(dataloader, cls)
assert not hasattr(dataloader, "__pl_dl_kwargs")
assert not hasattr(dataloader, "__pl_dl_arg_names")
assert not hasattr(dataloader, "__pl_dl_args")
assert not hasattr(dataloader, "__dataset")
data = torch.randn((10, 20))
assert dataloader.dataset == dataset
with _replace_dataloader_init_method():
dataloader = MyDataLoader(data, batch_size=2)
assert dataloader.data is data
assert dataloader.dataset == range(10)
# `poptorch.DataLoader` uses this pattern, simulate it
class PoptorchDataLoader(DataLoader):
def __init__(self, options, *args, **kwargs):
super().__init__(*args, **kwargs)
self._options = options
@property
def options(self):
return self._options
# †his read-only property pattern is fine
dataloader = PoptorchDataLoader(123, [1])
assert dataloader.options == 123
# still works with the init replacement
with _replace_dataloader_init_method():
dataloader = PoptorchDataLoader(123, [1])
assert dataloader.options == 123
for key, value in checked_values.items():
dataloader_value = getattr(dataloader, key)
if isinstance(dataloader_value, torch.Tensor):
assert dataloader_value is value
else:
assert getattr(dataloader, key) == value
@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
@ -235,7 +336,7 @@ def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
dataset = RandomIterableDataset(7, 100)
dataloader = DataLoader(dataset, batch_size=32)
dl_kwargs = _get_dataloader_init_kwargs(dataloader, dataloader.sampler, mode=mode)
_, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler, mode=mode)
assert dl_kwargs["sampler"] is None
assert dl_kwargs["batch_sampler"] is None
assert dl_kwargs["batch_size"] is dataloader.batch_size