Improve support for custom `DataLoader`s when instantiated in `*_dataloader` hook (#12981)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
parent
bbc51d16a3
commit
2e9cd72add
|
@ -106,9 +106,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Raise an error if there are insufficient training batches when using a float value of `limit_train_batches` ([#12885](https://github.com/PyTorchLightning/pytorch-lightning/pull/12885))
|
||||
|
||||
|
||||
- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
|
||||
- `DataLoader` instantiated inside a `*_dataloader` hook will not set the passed arguments as attributes anymore ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))
|
||||
|
||||
|
||||
- The `WandbLogger` will now use the run name in the logs folder if it is provided, and otherwise the project name ([#12604](https://github.com/PyTorchLightning/pytorch-lightning/pull/12604))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
@ -229,6 +231,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
|
||||
- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))
|
||||
|
||||
|
||||
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))
|
||||
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
|||
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
|
||||
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
@ -228,9 +228,9 @@ class IPUStrategy(ParallelStrategy):
|
|||
# the user is returning the `poptorch.DataLoader` directly, don't change anything.
|
||||
return dataloader
|
||||
|
||||
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler)
|
||||
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler)
|
||||
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
|
||||
dataloader = poptorch.DataLoader(opts, **dl_kwargs)
|
||||
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
|
||||
return dataloader
|
||||
|
||||
def _handle_gradient_accumulation_steps(self) -> None:
|
||||
|
|
|
@ -17,12 +17,19 @@ import os
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import fields
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler, Sampler, SequentialSampler
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
Dataset,
|
||||
IterableDataset,
|
||||
RandomSampler,
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
|
||||
|
@ -179,10 +186,10 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
|
|||
def _update_dataloader(
|
||||
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
|
||||
) -> DataLoader:
|
||||
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
|
||||
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode=mode)
|
||||
dl_cls = type(dataloader)
|
||||
try:
|
||||
dataloader = dl_cls(**dl_kwargs)
|
||||
dataloader = dl_cls(*dl_args, **dl_kwargs)
|
||||
except TypeError as e:
|
||||
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
|
||||
# `__init__` arguments map to one `DataLoader.__init__` argument
|
||||
|
@ -198,38 +205,62 @@ def _update_dataloader(
|
|||
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
|
||||
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
|
||||
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
|
||||
" This argument was automatically passed to your DataLoader by PyTorch Lightning."
|
||||
)
|
||||
raise MisconfigurationException(message) from e
|
||||
return dataloader
|
||||
|
||||
|
||||
def _get_dataloader_init_kwargs(
|
||||
def _get_dataloader_init_args_and_kwargs(
|
||||
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
|
||||
) -> Dict[str, Any]:
|
||||
) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
|
||||
|
||||
# get the dataloader instance attributes
|
||||
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
|
||||
# not part of `vars`
|
||||
attrs["multiprocessing_context"] = dataloader.multiprocessing_context
|
||||
was_wrapped = hasattr(dataloader, "__pl_dl_args")
|
||||
if was_wrapped:
|
||||
dl_args = dataloader.__pl_dl_args
|
||||
dl_kwargs = dataloader.__pl_dl_kwargs
|
||||
arg_names = dataloader.__pl_dl_arg_names
|
||||
original_dataset = dataloader.__dataset # we have this saved from _wrap_init
|
||||
else:
|
||||
# get the dataloader instance attributes
|
||||
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
|
||||
# We cannot be 100% sure the class sets dataset argument. Let's set it to None to be safe
|
||||
# and hope we can get it from the instance attributes
|
||||
original_dataset = None
|
||||
# not part of `vars`
|
||||
attrs["multiprocessing_context"] = dataloader.multiprocessing_context
|
||||
arg_names = ()
|
||||
|
||||
# get the dataloader instance `__init__` parameters
|
||||
params = dict(inspect.signature(dataloader.__init__).parameters)
|
||||
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
|
||||
if has_variadic_kwargs:
|
||||
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
|
||||
params.update(inspect.signature(DataLoader.__init__).parameters)
|
||||
del params["self"]
|
||||
|
||||
# keep only the params whose default is different to the current attr value
|
||||
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
|
||||
# add `dataset` as it might have been replaced with `*args`
|
||||
non_defaults.add("dataset")
|
||||
if was_wrapped:
|
||||
# if the dataloader was wrapped in a hook, only take arguments with default values
|
||||
# and assume user passes their kwargs correctly
|
||||
params.update(
|
||||
{k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty}
|
||||
)
|
||||
else:
|
||||
params.update(inspect.signature(DataLoader.__init__).parameters)
|
||||
params.pop("self", None)
|
||||
|
||||
# kwargs to re-construct the dataloader
|
||||
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
|
||||
if isinstance(dl_kwargs["dataset"], IterableDataset):
|
||||
if not was_wrapped:
|
||||
# keep only the params whose default is different to the current attr value
|
||||
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
|
||||
|
||||
# add `dataset` as it might have been replaced with `*args`
|
||||
non_defaults.add("dataset")
|
||||
# kwargs to re-construct the dataloader
|
||||
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
|
||||
dl_args = ()
|
||||
|
||||
dataset = dl_kwargs.get("dataset", original_dataset)
|
||||
if isinstance(dataset, IterableDataset):
|
||||
dl_kwargs["batch_sampler"] = None
|
||||
dl_kwargs["sampler"] = None
|
||||
else:
|
||||
|
@ -238,40 +269,43 @@ def _get_dataloader_init_kwargs(
|
|||
required_args = {
|
||||
p.name
|
||||
for p in params.values()
|
||||
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs
|
||||
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
||||
and p.default is p.empty
|
||||
and p.name not in dl_kwargs
|
||||
and p.name not in arg_names
|
||||
}
|
||||
# the dataloader has required args which we could not extract from the existing attributes
|
||||
if required_args:
|
||||
required_args = sorted(required_args)
|
||||
dataloader_cls_name = dataloader.__class__.__name__
|
||||
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
|
||||
raise MisconfigurationException(
|
||||
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
|
||||
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
|
||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
|
||||
f"The missing attributes are {required_args}. "
|
||||
f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
|
||||
"manually add the `DistributedSampler` as: "
|
||||
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
|
||||
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
|
||||
"`*_dataloader` hook of your module, we will do this for you."
|
||||
f" Otherwise, define {missing_args_message} inside your `__init__`."
|
||||
)
|
||||
|
||||
if not has_variadic_kwargs:
|
||||
# the dataloader signature does not allow keyword arguments that need to be passed
|
||||
missing_kwargs = dl_kwargs.keys() - params.keys()
|
||||
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
|
||||
if missing_kwargs:
|
||||
missing_kwargs = sorted(missing_kwargs)
|
||||
dataloader_cls_name = dataloader.__class__.__name__
|
||||
raise MisconfigurationException(
|
||||
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
|
||||
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
|
||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
|
||||
f"The missing arguments are {missing_kwargs}. "
|
||||
f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
|
||||
"manually add the `DistributedSampler` as: "
|
||||
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
|
||||
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
|
||||
"add the `__init__` arguments or allow passing `**kwargs`"
|
||||
)
|
||||
|
||||
if _FaultTolerantMode.detect_current_mode().is_automatic:
|
||||
dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs)
|
||||
dl_args, dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(
|
||||
was_wrapped, arg_names, dl_args, dl_kwargs
|
||||
)
|
||||
|
||||
return dl_kwargs
|
||||
return dl_args, dl_kwargs
|
||||
|
||||
|
||||
def _dataloader_init_kwargs_resolve_sampler(
|
||||
|
@ -321,30 +355,35 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
|
|||
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)
|
||||
|
||||
|
||||
def _wrap_init(init: Callable) -> Callable:
|
||||
"""Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
|
||||
:class:`~torch.utils.data.DataLoader`."""
|
||||
def _wrap_dataloader_init(init: Callable) -> Callable:
|
||||
"""Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation
|
||||
of custom subclasses."""
|
||||
|
||||
@functools.wraps(init)
|
||||
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
|
||||
# We need to inspect `init`, as inspecting `obj.__init__`
|
||||
# can lead to inspecting the wrong function with multiple inheritance
|
||||
params = inspect.signature(init).parameters
|
||||
|
||||
param_names = [
|
||||
param_names = tuple(
|
||||
param.name
|
||||
for param in params.values()
|
||||
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
|
||||
]
|
||||
)
|
||||
param_names = param_names[: len(args)]
|
||||
|
||||
if not hasattr(obj, "__pl_dl_args"):
|
||||
obj.__pl_dl_args = args
|
||||
obj.__pl_dl_kwargs = kwargs
|
||||
obj.__pl_dl_arg_names = param_names
|
||||
|
||||
# We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader)
|
||||
# so that we can be sure, that it will not get changed anymore.
|
||||
# That is why we are setting this in every `__init__`
|
||||
if "dataset" in param_names:
|
||||
setattr(obj, "__dataset", args[param_names.index("dataset")])
|
||||
elif "dataset" in kwargs:
|
||||
setattr(obj, "__dataset", kwargs["dataset"])
|
||||
|
||||
cls = type(obj)
|
||||
for arg_name, arg_value in chain(zip(param_names, args), kwargs.items()):
|
||||
if hasattr(cls, arg_name) and getattr(cls, arg_name).fset is None:
|
||||
# the class defines a read-only (no setter) property of this name. it's likely that the implementation
|
||||
# will set `self._arg_name = arg_value` in `__init__` which is the attribute returned by the `arg_name`
|
||||
# property so we are fine skipping in that case
|
||||
continue
|
||||
setattr(obj, arg_name, arg_value)
|
||||
init(obj, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
@ -368,33 +407,63 @@ def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
|
|||
def _replace_dataloader_init_method() -> Generator[None, None, None]:
|
||||
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
|
||||
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
|
||||
subclasses = _get_all_subclasses(DataLoader)
|
||||
for subclass in subclasses:
|
||||
subclass._old_init = subclass.__init__
|
||||
subclass.__init__ = _wrap_init(subclass.__init__)
|
||||
classes = _get_all_subclasses(DataLoader) | {DataLoader}
|
||||
wrapped = set()
|
||||
for cls in classes:
|
||||
if cls.__init__ not in wrapped:
|
||||
cls._old_init = cls.__init__
|
||||
cls.__init__ = _wrap_dataloader_init(cls.__init__)
|
||||
wrapped.add(cls.__init__)
|
||||
yield
|
||||
for subclass in subclasses:
|
||||
subclass.__init__ = subclass._old_init
|
||||
del subclass._old_init
|
||||
for cls in classes:
|
||||
if hasattr(cls, "_old_init"):
|
||||
cls.__init__ = cls._old_init
|
||||
del cls._old_init
|
||||
|
||||
|
||||
def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict:
|
||||
dataset = dl_kwargs["dataset"]
|
||||
def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset:
|
||||
if isinstance(dataset, IterableDataset):
|
||||
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
|
||||
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dataset)
|
||||
elif get_len(dataset) != float("inf"):
|
||||
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dataset)
|
||||
return CaptureIterableDataset(dataset=dataset)
|
||||
if get_len(dataset) != float("inf"):
|
||||
return CaptureMapDataset(dataset=dataset)
|
||||
raise RuntimeError("This shouldn't happen, please open an issue on Lightning Github repository.")
|
||||
|
||||
|
||||
def _apply_fault_tolerant_automatic_capture_dataset_wrapper(
|
||||
was_wrapped: bool, arg_names: Tuple[str, ...], dl_args: Tuple[Any, ...], dl_kwargs: Dict[str, Any]
|
||||
) -> Tuple[Tuple[str, ...], Dict[str, Any]]:
|
||||
if "dataset" in dl_kwargs:
|
||||
dl_kwargs["dataset"] = _wrap_with_capture_dataset(dl_kwargs["dataset"])
|
||||
elif "dataset" in arg_names:
|
||||
dataset_idx = arg_names.index("dataset")
|
||||
dataset = _wrap_with_capture_dataset(dl_args[dataset_idx])
|
||||
dl_args = dl_args[:dataset_idx] + (dataset,) + dl_args[dataset_idx + 1 :]
|
||||
else:
|
||||
raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.")
|
||||
return dl_kwargs
|
||||
if was_wrapped:
|
||||
avoid_message = (
|
||||
" To avoid this, either pass `DataLoader(dataset=your_dataset)` or the positional dataset argument"
|
||||
" `DataLoader(your_dataset, ...)`."
|
||||
)
|
||||
else:
|
||||
avoid_message = " To avoid this, define `self.dataset = dataset` inside your DataLoader's `__init__`."
|
||||
|
||||
raise MisconfigurationException(
|
||||
"You enabled automatic Fault Tolerant mode, but we were not able to replace your dataset"
|
||||
" with Fault Tolerant wrapper, because you have a custom DataLoader." + avoid_message
|
||||
)
|
||||
|
||||
return dl_args, dl_kwargs
|
||||
|
||||
|
||||
def _is_dataloader_shuffled(dataloader: object) -> bool:
|
||||
if hasattr(dataloader, "shuffle"):
|
||||
if hasattr(dataloader, "__pl_dl_kwargs"):
|
||||
# this attribute is not part of PyTorch's DataLoader, but could have been set by
|
||||
# our `_replace_dataloader_init_method` context manager
|
||||
return dataloader.shuffle
|
||||
if "shuffle" in dataloader.__pl_dl_kwargs:
|
||||
return dataloader.__pl_dl_kwargs["shuffle"]
|
||||
if "shuffle" in dataloader.__pl_dl_arg_names:
|
||||
return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")]
|
||||
if isinstance(dataloader.dataset, IterableDataset):
|
||||
# shuffling is useless with iterable datasets
|
||||
return False
|
||||
|
|
|
@ -201,7 +201,7 @@ def test_setup_dataloaders_raises_for_unknown_custom_args():
|
|||
with pytest.raises(
|
||||
MisconfigurationException,
|
||||
match=(
|
||||
r"Trying to inject `DistributedSampler` into the `CustomDataLoader` instance.*"
|
||||
r"Trying to inject custom `Sampler` into the `CustomDataLoader` instance.*"
|
||||
r"The missing attributes are \['new_arg'\]"
|
||||
),
|
||||
):
|
||||
|
|
|
@ -9,7 +9,7 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.data import (
|
||||
_get_dataloader_init_kwargs,
|
||||
_get_dataloader_init_args_and_kwargs,
|
||||
_replace_dataloader_init_method,
|
||||
_update_dataloader,
|
||||
extract_batch_size,
|
||||
|
@ -134,23 +134,28 @@ def test_has_len_all_rank():
|
|||
|
||||
|
||||
def test_update_dataloader_typerror_custom_exception():
|
||||
class BadImpl(DataLoader):
|
||||
class BadStandaloneGoodHookImpl(DataLoader):
|
||||
def __init__(self, foo, *args, **kwargs):
|
||||
self.foo = foo
|
||||
# positional conflict with `dataset`
|
||||
super().__init__(foo, *args, **kwargs)
|
||||
|
||||
dataloader = BadImpl([1, 2, 3])
|
||||
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
|
||||
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"):
|
||||
_update_dataloader(dataloader, dataloader.sampler)
|
||||
|
||||
class BadImpl2(DataLoader):
|
||||
with _replace_dataloader_init_method():
|
||||
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
|
||||
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
|
||||
assert isinstance(new_dataloader, BadStandaloneGoodHookImpl)
|
||||
|
||||
class BadImpl(DataLoader):
|
||||
def __init__(self, randomize, *args, **kwargs):
|
||||
self.randomize = randomize
|
||||
# keyword conflict with `shuffle`
|
||||
super().__init__(*args, shuffle=randomize, **kwargs)
|
||||
|
||||
dataloader = BadImpl2(False, [])
|
||||
dataloader = BadImpl(False, [])
|
||||
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"):
|
||||
_update_dataloader(dataloader, dataloader.sampler)
|
||||
|
||||
|
@ -165,69 +170,165 @@ def test_update_dataloader_typerror_custom_exception():
|
|||
assert isinstance(new_dataloader, GoodImpl)
|
||||
|
||||
|
||||
def test_replace_dataloader_init_method():
|
||||
"""Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and
|
||||
sets them as attributes."""
|
||||
class DataLoaderSubclass1(DataLoader):
|
||||
def __init__(self, attribute1, *args, **kwargs):
|
||||
self.at1 = attribute1
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
class DataLoaderSubclass1(DataLoader):
|
||||
def __init__(self, attribute1, *args, **kwargs):
|
||||
# intentionally not setting this attribute, calling super with different args
|
||||
# self.attribute1 = attribute1
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
class DataLoaderSubclass2(DataLoaderSubclass1):
|
||||
def __init__(self, attribute2, *args, **kwargs):
|
||||
# intentionally not setting this attribute, calling super with different args
|
||||
# self.attribute2 = attribute2
|
||||
super().__init__(attribute2 + "-2", *args, **kwargs)
|
||||
class DataLoaderSubclass2(DataLoaderSubclass1):
|
||||
def __init__(self, attribute2, *args, **kwargs):
|
||||
self.at2 = attribute2
|
||||
super().__init__(attribute2 + "-2", *args, **kwargs)
|
||||
|
||||
|
||||
class MyBaseDataLoader(DataLoader):
|
||||
pass
|
||||
|
||||
|
||||
class MyDataLoader(MyBaseDataLoader):
|
||||
def __init__(self, data: torch.Tensor, *args, **kwargs):
|
||||
self.data = data
|
||||
super().__init__(range(data.size(0)), *args, **kwargs)
|
||||
|
||||
|
||||
test3_data = torch.randn((10, 20))
|
||||
|
||||
|
||||
class PoptorchDataLoader(DataLoader):
|
||||
def __init__(self, options, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._options = options
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
return self._options
|
||||
|
||||
|
||||
class IncompleteDataLoader(DataLoader):
|
||||
def __init__(self, dataset, batch_size, **kwargs):
|
||||
batch_size = max(batch_size - 5, 0)
|
||||
super().__init__(dataset, batch_size=batch_size, **kwargs)
|
||||
|
||||
|
||||
class WeirdDataLoader1(DataLoader):
|
||||
def __init__(self, arg1, arg2, **kwargs):
|
||||
self.arg1 = arg1
|
||||
super().__init__(arg2, **kwargs)
|
||||
|
||||
|
||||
class WeirdDataLoader2(DataLoader):
|
||||
def __init__(self, data_part1, data_part2, **kwargs):
|
||||
data = list(data_part1) + list(data_part2)
|
||||
super().__init__(data, **kwargs)
|
||||
|
||||
|
||||
class NoneDataLoader(DataLoader):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ChangingDataLoader(DataLoader):
|
||||
def __init__(self, dataset, **kwargs):
|
||||
super().__init__(list(dataset) + list(range(5, 10)), **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["cls", "args", "kwargs", "arg_names", "dataset", "checked_values"],
|
||||
[
|
||||
pytest.param(
|
||||
DataLoaderSubclass1,
|
||||
("attribute1",),
|
||||
dict(dataset=range(4), batch_size=2),
|
||||
("attribute1",),
|
||||
range(4),
|
||||
dict(batch_size=2, at1="attribute1"),
|
||||
id="test1",
|
||||
),
|
||||
pytest.param(
|
||||
DataLoaderSubclass2,
|
||||
("attribute2",),
|
||||
dict(dataset=range(4), batch_size=2),
|
||||
("attribute2",),
|
||||
range(4),
|
||||
dict(batch_size=2, at1="attribute2-2", at2="attribute2"),
|
||||
id="test2",
|
||||
),
|
||||
pytest.param(
|
||||
MyDataLoader,
|
||||
(test3_data,),
|
||||
dict(batch_size=2),
|
||||
("data",),
|
||||
range(10),
|
||||
dict(batch_size=2, data=test3_data),
|
||||
id="test3",
|
||||
),
|
||||
pytest.param(PoptorchDataLoader, (123, [1]), dict(), ("options",), [1], dict(options=123), id="test4"),
|
||||
pytest.param(
|
||||
IncompleteDataLoader,
|
||||
(range(10),),
|
||||
dict(batch_size=10),
|
||||
("dataset",),
|
||||
range(10),
|
||||
dict(batch_size=5),
|
||||
id="test5",
|
||||
),
|
||||
pytest.param(
|
||||
WeirdDataLoader1,
|
||||
(10, range(10)),
|
||||
dict(batch_size=10),
|
||||
("arg1", "arg2"),
|
||||
range(10),
|
||||
dict(arg1=10, batch_size=10),
|
||||
id="test6",
|
||||
),
|
||||
pytest.param(
|
||||
WeirdDataLoader2,
|
||||
(range(10), range(10, 20)),
|
||||
dict(batch_size=10),
|
||||
("data_part1", "data_part2"),
|
||||
list(range(20)),
|
||||
dict(batch_size=10),
|
||||
id="test7",
|
||||
),
|
||||
pytest.param(NoneDataLoader, (None,), dict(), (), None, dict(), id="test8"),
|
||||
pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"),
|
||||
],
|
||||
)
|
||||
def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, checked_values):
|
||||
with _replace_dataloader_init_method():
|
||||
dataloader = DataLoaderSubclass1("attribute1", dataset=range(4), batch_size=2)
|
||||
dataloader = cls(*args, **kwargs)
|
||||
|
||||
assert dataloader.attribute1 == "attribute1"
|
||||
assert dataloader.__pl_dl_args == args
|
||||
assert dataloader.__pl_dl_kwargs == kwargs
|
||||
assert dataloader.__pl_dl_arg_names == arg_names
|
||||
assert dataloader.__dataset == dataset
|
||||
|
||||
with _replace_dataloader_init_method():
|
||||
dataloader = DataLoaderSubclass2("attribute2", dataset=range(4), batch_size=2)
|
||||
assert dataloader.dataset == dataset
|
||||
|
||||
assert dataloader.attribute1 == "attribute2-2"
|
||||
assert dataloader.attribute2 == "attribute2"
|
||||
for key, value in checked_values.items():
|
||||
dataloader_value = getattr(dataloader, key)
|
||||
if isinstance(dataloader_value, torch.Tensor):
|
||||
assert dataloader_value is value
|
||||
else:
|
||||
assert getattr(dataloader, key) == value
|
||||
|
||||
# Failing test case from issue 12564
|
||||
class MyBaseDataLoader(DataLoader):
|
||||
pass
|
||||
dataloader = _update_dataloader(dataloader, dataloader.sampler)
|
||||
|
||||
class MyDataLoader(MyBaseDataLoader):
|
||||
def __init__(self, data: torch.Tensor, *args, **kwargs):
|
||||
self.data = data
|
||||
super().__init__(range(data.size(0)), *args, **kwargs)
|
||||
assert isinstance(dataloader, cls)
|
||||
assert not hasattr(dataloader, "__pl_dl_kwargs")
|
||||
assert not hasattr(dataloader, "__pl_dl_arg_names")
|
||||
assert not hasattr(dataloader, "__pl_dl_args")
|
||||
assert not hasattr(dataloader, "__dataset")
|
||||
|
||||
data = torch.randn((10, 20))
|
||||
assert dataloader.dataset == dataset
|
||||
|
||||
with _replace_dataloader_init_method():
|
||||
dataloader = MyDataLoader(data, batch_size=2)
|
||||
|
||||
assert dataloader.data is data
|
||||
assert dataloader.dataset == range(10)
|
||||
|
||||
# `poptorch.DataLoader` uses this pattern, simulate it
|
||||
class PoptorchDataLoader(DataLoader):
|
||||
def __init__(self, options, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._options = options
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
return self._options
|
||||
|
||||
# †his read-only property pattern is fine
|
||||
dataloader = PoptorchDataLoader(123, [1])
|
||||
assert dataloader.options == 123
|
||||
|
||||
# still works with the init replacement
|
||||
with _replace_dataloader_init_method():
|
||||
dataloader = PoptorchDataLoader(123, [1])
|
||||
|
||||
assert dataloader.options == 123
|
||||
for key, value in checked_values.items():
|
||||
dataloader_value = getattr(dataloader, key)
|
||||
if isinstance(dataloader_value, torch.Tensor):
|
||||
assert dataloader_value is value
|
||||
else:
|
||||
assert getattr(dataloader, key) == value
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
|
||||
|
@ -235,7 +336,7 @@ def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
|
|||
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
|
||||
dataset = RandomIterableDataset(7, 100)
|
||||
dataloader = DataLoader(dataset, batch_size=32)
|
||||
dl_kwargs = _get_dataloader_init_kwargs(dataloader, dataloader.sampler, mode=mode)
|
||||
_, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler, mode=mode)
|
||||
assert dl_kwargs["sampler"] is None
|
||||
assert dl_kwargs["batch_sampler"] is None
|
||||
assert dl_kwargs["batch_size"] is dataloader.batch_size
|
||||
|
|
Loading…
Reference in New Issue