From 2e9cd72add95afe3e856d535543cefd9bfbd7d0b Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Wed, 22 Jun 2022 01:53:24 +0200 Subject: [PATCH] Improve support for custom `DataLoader`s when instantiated in `*_dataloader` hook (#12981) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: Akihiro Nitta --- CHANGELOG.md | 8 +- src/pytorch_lightning/strategies/ipu.py | 6 +- src/pytorch_lightning/utilities/data.py | 197 +++++++++++++------ tests/tests_pytorch/lite/test_lite.py | 2 +- tests/tests_pytorch/utilities/test_data.py | 217 +++++++++++++++------ 5 files changed, 303 insertions(+), 127 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 463a5f41c1..5709ad7b55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,9 +106,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise an error if there are insufficient training batches when using a float value of `limit_train_batches` ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885)) -- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604)) +- `DataLoader` instantiated inside a `*_dataloader` hook will not set the passed arguments as attributes anymore ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981)) +- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604)) + ### Deprecated @@ -229,6 +231,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed + +- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981)) + + - Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014)) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index ece4e92b94..5413756c15 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -29,7 +29,7 @@ from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs +from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -228,9 +228,9 @@ class IPUStrategy(ParallelStrategy): # the user is returning the `poptorch.DataLoader` directly, don't change anything. return dataloader - dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler) + dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts - dataloader = poptorch.DataLoader(opts, **dl_kwargs) + dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs) return dataloader def _handle_gradient_accumulation_steps(self) -> None: diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index b788e6147d..2de82ceff0 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -17,12 +17,19 @@ import os from contextlib import contextmanager from dataclasses import fields from functools import partial -from itertools import chain -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union +from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union import torch from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler +from torch.utils.data import ( + BatchSampler, + DataLoader, + Dataset, + IterableDataset, + RandomSampler, + Sampler, + SequentialSampler, +) import pytorch_lightning as pl from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper @@ -179,10 +186,10 @@ def get_len(dataloader: DataLoader) -> Union[int, float]: def _update_dataloader( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None ) -> DataLoader: - dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode) + dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode=mode) dl_cls = type(dataloader) try: - dataloader = dl_cls(**dl_kwargs) + dataloader = dl_cls(*dl_args, **dl_kwargs) except TypeError as e: # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass # `__init__` arguments map to one `DataLoader.__init__` argument @@ -198,38 +205,62 @@ def _update_dataloader( f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." + " This argument was automatically passed to your DataLoader by PyTorch Lightning." ) raise MisconfigurationException(message) from e return dataloader -def _get_dataloader_init_kwargs( +def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None -) -> Dict[str, Any]: +) -> Tuple[Tuple[Any], Dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") - # get the dataloader instance attributes - attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} - # not part of `vars` - attrs["multiprocessing_context"] = dataloader.multiprocessing_context + was_wrapped = hasattr(dataloader, "__pl_dl_args") + if was_wrapped: + dl_args = dataloader.__pl_dl_args + dl_kwargs = dataloader.__pl_dl_kwargs + arg_names = dataloader.__pl_dl_arg_names + original_dataset = dataloader.__dataset # we have this saved from _wrap_init + else: + # get the dataloader instance attributes + attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} + # We cannot be 100% sure the class sets dataset argument. Let's set it to None to be safe + # and hope we can get it from the instance attributes + original_dataset = None + # not part of `vars` + attrs["multiprocessing_context"] = dataloader.multiprocessing_context + arg_names = () # get the dataloader instance `__init__` parameters params = dict(inspect.signature(dataloader.__init__).parameters) has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if has_variadic_kwargs: # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` - params.update(inspect.signature(DataLoader.__init__).parameters) - del params["self"] - # keep only the params whose default is different to the current attr value - non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]} - # add `dataset` as it might have been replaced with `*args` - non_defaults.add("dataset") + if was_wrapped: + # if the dataloader was wrapped in a hook, only take arguments with default values + # and assume user passes their kwargs correctly + params.update( + {k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty} + ) + else: + params.update(inspect.signature(DataLoader.__init__).parameters) + params.pop("self", None) - # kwargs to re-construct the dataloader - dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} - if isinstance(dl_kwargs["dataset"], IterableDataset): + if not was_wrapped: + # keep only the params whose default is different to the current attr value + non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]} + + # add `dataset` as it might have been replaced with `*args` + non_defaults.add("dataset") + # kwargs to re-construct the dataloader + dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} + dl_args = () + + dataset = dl_kwargs.get("dataset", original_dataset) + if isinstance(dataset, IterableDataset): dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None else: @@ -238,40 +269,43 @@ def _get_dataloader_init_kwargs( required_args = { p.name for p in params.values() - if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs + if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + and p.default is p.empty + and p.name not in dl_kwargs + and p.name not in arg_names } # the dataloader has required args which we could not extract from the existing attributes if required_args: required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ + missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args) raise MisconfigurationException( - f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " + f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. " - f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or " - "manually add the `DistributedSampler` as: " - f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." + f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a " + "`*_dataloader` hook of your module, we will do this for you." + f" Otherwise, define {missing_args_message} inside your `__init__`." ) if not has_variadic_kwargs: # the dataloader signature does not allow keyword arguments that need to be passed - missing_kwargs = dl_kwargs.keys() - params.keys() + missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys() if missing_kwargs: missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( - f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " + f"Trying to inject parameters into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. " - f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or " - "manually add the `DistributedSampler` as: " - f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." + f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, " + "add the `__init__` arguments or allow passing `**kwargs`" ) if _FaultTolerantMode.detect_current_mode().is_automatic: - dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs) + dl_args, dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper( + was_wrapped, arg_names, dl_args, dl_kwargs + ) - return dl_kwargs + return dl_args, dl_kwargs def _dataloader_init_kwargs_resolve_sampler( @@ -321,30 +355,35 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) -def _wrap_init(init: Callable) -> Callable: - """Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of - :class:`~torch.utils.data.DataLoader`.""" +def _wrap_dataloader_init(init: Callable) -> Callable: + """Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation + of custom subclasses.""" @functools.wraps(init) def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None: # We need to inspect `init`, as inspecting `obj.__init__` # can lead to inspecting the wrong function with multiple inheritance params = inspect.signature(init).parameters - - param_names = [ + param_names = tuple( param.name for param in params.values() if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) - ] + ) + param_names = param_names[: len(args)] + + if not hasattr(obj, "__pl_dl_args"): + obj.__pl_dl_args = args + obj.__pl_dl_kwargs = kwargs + obj.__pl_dl_arg_names = param_names + + # We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader) + # so that we can be sure, that it will not get changed anymore. + # That is why we are setting this in every `__init__` + if "dataset" in param_names: + setattr(obj, "__dataset", args[param_names.index("dataset")]) + elif "dataset" in kwargs: + setattr(obj, "__dataset", kwargs["dataset"]) - cls = type(obj) - for arg_name, arg_value in chain(zip(param_names, args), kwargs.items()): - if hasattr(cls, arg_name) and getattr(cls, arg_name).fset is None: - # the class defines a read-only (no setter) property of this name. it's likely that the implementation - # will set `self._arg_name = arg_value` in `__init__` which is the attribute returned by the `arg_name` - # property so we are fine skipping in that case - continue - setattr(obj, arg_name, arg_value) init(obj, *args, **kwargs) return wrapper @@ -368,33 +407,63 @@ def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]: def _replace_dataloader_init_method() -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of :class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method.""" - subclasses = _get_all_subclasses(DataLoader) - for subclass in subclasses: - subclass._old_init = subclass.__init__ - subclass.__init__ = _wrap_init(subclass.__init__) + classes = _get_all_subclasses(DataLoader) | {DataLoader} + wrapped = set() + for cls in classes: + if cls.__init__ not in wrapped: + cls._old_init = cls.__init__ + cls.__init__ = _wrap_dataloader_init(cls.__init__) + wrapped.add(cls.__init__) yield - for subclass in subclasses: - subclass.__init__ = subclass._old_init - del subclass._old_init + for cls in classes: + if hasattr(cls, "_old_init"): + cls.__init__ = cls._old_init + del cls._old_init -def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict: - dataset = dl_kwargs["dataset"] +def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset: if isinstance(dataset, IterableDataset): # wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states. - dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dataset) - elif get_len(dataset) != float("inf"): - dl_kwargs["dataset"] = CaptureMapDataset(dataset=dataset) + return CaptureIterableDataset(dataset=dataset) + if get_len(dataset) != float("inf"): + return CaptureMapDataset(dataset=dataset) + raise RuntimeError("This shouldn't happen, please open an issue on Lightning Github repository.") + + +def _apply_fault_tolerant_automatic_capture_dataset_wrapper( + was_wrapped: bool, arg_names: Tuple[str, ...], dl_args: Tuple[Any, ...], dl_kwargs: Dict[str, Any] +) -> Tuple[Tuple[str, ...], Dict[str, Any]]: + if "dataset" in dl_kwargs: + dl_kwargs["dataset"] = _wrap_with_capture_dataset(dl_kwargs["dataset"]) + elif "dataset" in arg_names: + dataset_idx = arg_names.index("dataset") + dataset = _wrap_with_capture_dataset(dl_args[dataset_idx]) + dl_args = dl_args[:dataset_idx] + (dataset,) + dl_args[dataset_idx + 1 :] else: - raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.") - return dl_kwargs + if was_wrapped: + avoid_message = ( + " To avoid this, either pass `DataLoader(dataset=your_dataset)` or the positional dataset argument" + " `DataLoader(your_dataset, ...)`." + ) + else: + avoid_message = " To avoid this, define `self.dataset = dataset` inside your DataLoader's `__init__`." + + raise MisconfigurationException( + "You enabled automatic Fault Tolerant mode, but we were not able to replace your dataset" + " with Fault Tolerant wrapper, because you have a custom DataLoader." + avoid_message + ) + + return dl_args, dl_kwargs def _is_dataloader_shuffled(dataloader: object) -> bool: - if hasattr(dataloader, "shuffle"): + if hasattr(dataloader, "__pl_dl_kwargs"): # this attribute is not part of PyTorch's DataLoader, but could have been set by # our `_replace_dataloader_init_method` context manager - return dataloader.shuffle + if "shuffle" in dataloader.__pl_dl_kwargs: + return dataloader.__pl_dl_kwargs["shuffle"] + if "shuffle" in dataloader.__pl_dl_arg_names: + return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")] if isinstance(dataloader.dataset, IterableDataset): # shuffling is useless with iterable datasets return False diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index 40d8f79ae7..f38ec9c294 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -201,7 +201,7 @@ def test_setup_dataloaders_raises_for_unknown_custom_args(): with pytest.raises( MisconfigurationException, match=( - r"Trying to inject `DistributedSampler` into the `CustomDataLoader` instance.*" + r"Trying to inject custom `Sampler` into the `CustomDataLoader` instance.*" r"The missing attributes are \['new_arg'\]" ), ): diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 3bf709c8e2..7b1e596d50 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -9,7 +9,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import ( - _get_dataloader_init_kwargs, + _get_dataloader_init_args_and_kwargs, _replace_dataloader_init_method, _update_dataloader, extract_batch_size, @@ -134,23 +134,28 @@ def test_has_len_all_rank(): def test_update_dataloader_typerror_custom_exception(): - class BadImpl(DataLoader): + class BadStandaloneGoodHookImpl(DataLoader): def __init__(self, foo, *args, **kwargs): self.foo = foo # positional conflict with `dataset` super().__init__(foo, *args, **kwargs) - dataloader = BadImpl([1, 2, 3]) + dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): _update_dataloader(dataloader, dataloader.sampler) - class BadImpl2(DataLoader): + with _replace_dataloader_init_method(): + dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) + new_dataloader = _update_dataloader(dataloader, dataloader.sampler) + assert isinstance(new_dataloader, BadStandaloneGoodHookImpl) + + class BadImpl(DataLoader): def __init__(self, randomize, *args, **kwargs): self.randomize = randomize # keyword conflict with `shuffle` super().__init__(*args, shuffle=randomize, **kwargs) - dataloader = BadImpl2(False, []) + dataloader = BadImpl(False, []) with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"): _update_dataloader(dataloader, dataloader.sampler) @@ -165,69 +170,165 @@ def test_update_dataloader_typerror_custom_exception(): assert isinstance(new_dataloader, GoodImpl) -def test_replace_dataloader_init_method(): - """Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and - sets them as attributes.""" +class DataLoaderSubclass1(DataLoader): + def __init__(self, attribute1, *args, **kwargs): + self.at1 = attribute1 + super().__init__(*args, **kwargs) - class DataLoaderSubclass1(DataLoader): - def __init__(self, attribute1, *args, **kwargs): - # intentionally not setting this attribute, calling super with different args - # self.attribute1 = attribute1 - super().__init__(*args, **kwargs) - class DataLoaderSubclass2(DataLoaderSubclass1): - def __init__(self, attribute2, *args, **kwargs): - # intentionally not setting this attribute, calling super with different args - # self.attribute2 = attribute2 - super().__init__(attribute2 + "-2", *args, **kwargs) +class DataLoaderSubclass2(DataLoaderSubclass1): + def __init__(self, attribute2, *args, **kwargs): + self.at2 = attribute2 + super().__init__(attribute2 + "-2", *args, **kwargs) + +class MyBaseDataLoader(DataLoader): + pass + + +class MyDataLoader(MyBaseDataLoader): + def __init__(self, data: torch.Tensor, *args, **kwargs): + self.data = data + super().__init__(range(data.size(0)), *args, **kwargs) + + +test3_data = torch.randn((10, 20)) + + +class PoptorchDataLoader(DataLoader): + def __init__(self, options, *args, **kwargs): + super().__init__(*args, **kwargs) + self._options = options + + @property + def options(self): + return self._options + + +class IncompleteDataLoader(DataLoader): + def __init__(self, dataset, batch_size, **kwargs): + batch_size = max(batch_size - 5, 0) + super().__init__(dataset, batch_size=batch_size, **kwargs) + + +class WeirdDataLoader1(DataLoader): + def __init__(self, arg1, arg2, **kwargs): + self.arg1 = arg1 + super().__init__(arg2, **kwargs) + + +class WeirdDataLoader2(DataLoader): + def __init__(self, data_part1, data_part2, **kwargs): + data = list(data_part1) + list(data_part2) + super().__init__(data, **kwargs) + + +class NoneDataLoader(DataLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class ChangingDataLoader(DataLoader): + def __init__(self, dataset, **kwargs): + super().__init__(list(dataset) + list(range(5, 10)), **kwargs) + + +@pytest.mark.parametrize( + ["cls", "args", "kwargs", "arg_names", "dataset", "checked_values"], + [ + pytest.param( + DataLoaderSubclass1, + ("attribute1",), + dict(dataset=range(4), batch_size=2), + ("attribute1",), + range(4), + dict(batch_size=2, at1="attribute1"), + id="test1", + ), + pytest.param( + DataLoaderSubclass2, + ("attribute2",), + dict(dataset=range(4), batch_size=2), + ("attribute2",), + range(4), + dict(batch_size=2, at1="attribute2-2", at2="attribute2"), + id="test2", + ), + pytest.param( + MyDataLoader, + (test3_data,), + dict(batch_size=2), + ("data",), + range(10), + dict(batch_size=2, data=test3_data), + id="test3", + ), + pytest.param(PoptorchDataLoader, (123, [1]), dict(), ("options",), [1], dict(options=123), id="test4"), + pytest.param( + IncompleteDataLoader, + (range(10),), + dict(batch_size=10), + ("dataset",), + range(10), + dict(batch_size=5), + id="test5", + ), + pytest.param( + WeirdDataLoader1, + (10, range(10)), + dict(batch_size=10), + ("arg1", "arg2"), + range(10), + dict(arg1=10, batch_size=10), + id="test6", + ), + pytest.param( + WeirdDataLoader2, + (range(10), range(10, 20)), + dict(batch_size=10), + ("data_part1", "data_part2"), + list(range(20)), + dict(batch_size=10), + id="test7", + ), + pytest.param(NoneDataLoader, (None,), dict(), (), None, dict(), id="test8"), + pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"), + ], +) +def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, checked_values): with _replace_dataloader_init_method(): - dataloader = DataLoaderSubclass1("attribute1", dataset=range(4), batch_size=2) + dataloader = cls(*args, **kwargs) - assert dataloader.attribute1 == "attribute1" + assert dataloader.__pl_dl_args == args + assert dataloader.__pl_dl_kwargs == kwargs + assert dataloader.__pl_dl_arg_names == arg_names + assert dataloader.__dataset == dataset - with _replace_dataloader_init_method(): - dataloader = DataLoaderSubclass2("attribute2", dataset=range(4), batch_size=2) + assert dataloader.dataset == dataset - assert dataloader.attribute1 == "attribute2-2" - assert dataloader.attribute2 == "attribute2" + for key, value in checked_values.items(): + dataloader_value = getattr(dataloader, key) + if isinstance(dataloader_value, torch.Tensor): + assert dataloader_value is value + else: + assert getattr(dataloader, key) == value - # Failing test case from issue 12564 - class MyBaseDataLoader(DataLoader): - pass + dataloader = _update_dataloader(dataloader, dataloader.sampler) - class MyDataLoader(MyBaseDataLoader): - def __init__(self, data: torch.Tensor, *args, **kwargs): - self.data = data - super().__init__(range(data.size(0)), *args, **kwargs) + assert isinstance(dataloader, cls) + assert not hasattr(dataloader, "__pl_dl_kwargs") + assert not hasattr(dataloader, "__pl_dl_arg_names") + assert not hasattr(dataloader, "__pl_dl_args") + assert not hasattr(dataloader, "__dataset") - data = torch.randn((10, 20)) + assert dataloader.dataset == dataset - with _replace_dataloader_init_method(): - dataloader = MyDataLoader(data, batch_size=2) - - assert dataloader.data is data - assert dataloader.dataset == range(10) - - # `poptorch.DataLoader` uses this pattern, simulate it - class PoptorchDataLoader(DataLoader): - def __init__(self, options, *args, **kwargs): - super().__init__(*args, **kwargs) - self._options = options - - @property - def options(self): - return self._options - - # †his read-only property pattern is fine - dataloader = PoptorchDataLoader(123, [1]) - assert dataloader.options == 123 - - # still works with the init replacement - with _replace_dataloader_init_method(): - dataloader = PoptorchDataLoader(123, [1]) - - assert dataloader.options == 123 + for key, value in checked_values.items(): + dataloader_value = getattr(dataloader, key) + if isinstance(dataloader_value, torch.Tensor): + assert dataloader_value is value + else: + assert getattr(dataloader, key) == value @pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING]) @@ -235,7 +336,7 @@ def test_dataloader_kwargs_replacement_with_iterable_dataset(mode): """Test that DataLoader kwargs are not replaced when using Iterable Dataset.""" dataset = RandomIterableDataset(7, 100) dataloader = DataLoader(dataset, batch_size=32) - dl_kwargs = _get_dataloader_init_kwargs(dataloader, dataloader.sampler, mode=mode) + _, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler, mode=mode) assert dl_kwargs["sampler"] is None assert dl_kwargs["batch_sampler"] is None assert dl_kwargs["batch_size"] is dataloader.batch_size