Allowed custom `BatchSampler`s when instantiated in `*_dataloader` hook (#13640)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
c58d351e01
commit
95f5f170f5
|
@ -348,6 +348,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))
|
||||
|
||||
- Allowed custom `BatchSampler`s when instantiated in `*_dataloader` hook [#13640](https://github.com/PyTorchLightning/pytorch-lightning/pull/13640))
|
||||
|
||||
|
||||
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler
|
||||
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
|
@ -35,7 +35,7 @@ from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_da
|
|||
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
|
||||
from pytorch_lightning.utilities.data import (
|
||||
_auto_add_worker_init_fn,
|
||||
_replace_dataloader_init_method,
|
||||
_replace_init_method,
|
||||
_update_dataloader,
|
||||
has_iterable_dataset,
|
||||
)
|
||||
|
@ -403,7 +403,9 @@ class LightningLite(ABC):
|
|||
|
||||
def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
self._strategy.setup_environment()
|
||||
with self._strategy.model_sharded_context(), _replace_dataloader_init_method():
|
||||
with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method(
|
||||
BatchSampler
|
||||
):
|
||||
return run_method(*args, **kwargs)
|
||||
|
||||
def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
|
||||
|
|
|
@ -17,7 +17,7 @@ from dataclasses import dataclass, field
|
|||
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
|
||||
from weakref import proxy
|
||||
|
||||
from torch.utils.data import DataLoader, Sampler, SequentialSampler
|
||||
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -31,7 +31,7 @@ from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_au
|
|||
from pytorch_lightning.utilities.data import (
|
||||
_auto_add_worker_init_fn,
|
||||
_is_dataloader_shuffled,
|
||||
_replace_dataloader_init_method,
|
||||
_replace_init_method,
|
||||
_update_dataloader,
|
||||
has_iterable_dataset,
|
||||
has_len_all_ranks,
|
||||
|
@ -424,7 +424,7 @@ class DataConnector:
|
|||
"""
|
||||
source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source")
|
||||
|
||||
with _replace_dataloader_init_method():
|
||||
with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler):
|
||||
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
|
||||
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning
|
||||
dataloader = source.dataloader()
|
||||
|
|
|
@ -16,15 +16,7 @@ from dataclasses import dataclass, field
|
|||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
Dataset,
|
||||
DistributedSampler,
|
||||
get_worker_info,
|
||||
RandomSampler,
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler
|
||||
from torch.utils.data.dataloader import (
|
||||
_BaseDataLoaderIter,
|
||||
_MultiProcessingDataLoaderIter,
|
||||
|
@ -757,10 +749,6 @@ def _validate_map_dataset(dataloader: DataLoader) -> None:
|
|||
if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS:
|
||||
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")
|
||||
|
||||
batch_sampler = getattr(dataloader, "batch_sampler", None)
|
||||
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
|
||||
raise TypeError("Fault-tolerance supports only a `BatchSampler`.")
|
||||
|
||||
if type(sampler) is DistributedSampler and sampler.shuffle:
|
||||
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
|
||||
elif type(sampler) is RandomSampler:
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
import functools
|
||||
import inspect
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import fields
|
||||
from functools import partial
|
||||
|
@ -220,11 +221,11 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
if not isinstance(dataloader, DataLoader):
|
||||
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
|
||||
|
||||
was_wrapped = hasattr(dataloader, "__pl_dl_args")
|
||||
was_wrapped = hasattr(dataloader, "__pl_saved_args")
|
||||
if was_wrapped:
|
||||
dl_args = dataloader.__pl_dl_args
|
||||
dl_kwargs = dataloader.__pl_dl_kwargs
|
||||
arg_names = dataloader.__pl_dl_arg_names
|
||||
dl_args = dataloader.__pl_saved_args
|
||||
dl_kwargs = dataloader.__pl_saved_kwargs
|
||||
arg_names = dataloader.__pl_saved_arg_names
|
||||
original_dataset = dataloader.__dataset # we have this saved from _wrap_init
|
||||
else:
|
||||
# get the dataloader instance attributes
|
||||
|
@ -323,6 +324,9 @@ def _dataloader_init_kwargs_resolve_sampler(
|
|||
If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
|
||||
Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a
|
||||
`FastForwardSampler`.
|
||||
|
||||
If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated
|
||||
automatically, since `poptorch.DataLoader` will try to increase the batch_size
|
||||
"""
|
||||
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
|
||||
batch_sampler = getattr(dataloader, "batch_sampler")
|
||||
|
@ -341,11 +345,59 @@ def _dataloader_init_kwargs_resolve_sampler(
|
|||
"when running on multiple IPU devices."
|
||||
)
|
||||
elif type(batch_sampler) is not BatchSampler or is_predicting:
|
||||
batch_sampler = type(batch_sampler)(
|
||||
sampler,
|
||||
batch_size=batch_sampler.batch_size,
|
||||
drop_last=(False if is_predicting else batch_sampler.drop_last),
|
||||
)
|
||||
batch_sampler_cls = type(batch_sampler)
|
||||
if hasattr(batch_sampler, "__pl_saved_args"):
|
||||
args = batch_sampler.__pl_saved_args
|
||||
kwargs = batch_sampler.__pl_saved_kwargs
|
||||
default_kwargs = batch_sampler.__pl_saved_default_kwargs
|
||||
arg_names = batch_sampler.__pl_saved_arg_names
|
||||
|
||||
if is_predicting:
|
||||
success, args, kwargs = _replace_value_in_saved_args(
|
||||
"drop_last", False, args, kwargs, default_kwargs, arg_names
|
||||
)
|
||||
if not success:
|
||||
rank_zero_warn(
|
||||
f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however "
|
||||
f"it seems the class `{batch_sampler_cls.__qualname__}` does not support it. "
|
||||
"Your predictions might be incomplete. To mitigate this, expose `drop_last` in "
|
||||
"the `__init__` method of your custom class."
|
||||
)
|
||||
|
||||
success, args, kwargs = _replace_value_in_saved_args(
|
||||
"sampler", sampler, args, kwargs, default_kwargs, arg_names
|
||||
)
|
||||
if not success:
|
||||
raise TypeError(
|
||||
"Trying to inject a modified sampler into the batch sampler; however, it seems the class "
|
||||
f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate "
|
||||
"this, expose an argument `sampler` in the `__init__` method of your custom class."
|
||||
)
|
||||
|
||||
batch_sampler = batch_sampler_cls(*args, **kwargs)
|
||||
else:
|
||||
try:
|
||||
batch_sampler = batch_sampler_cls(
|
||||
sampler,
|
||||
batch_size=batch_sampler.batch_size,
|
||||
drop_last=(False if is_predicting else batch_sampler.drop_last),
|
||||
)
|
||||
except TypeError as e:
|
||||
import re
|
||||
|
||||
match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e))
|
||||
if not match:
|
||||
# an unexpected `TypeError`, continue failure
|
||||
raise
|
||||
|
||||
# There could either be too few or too many arguments. Customizing the message based on this doesn't
|
||||
# make much sense since our MisconfigurationException is going to be raised from the original one.
|
||||
raise MisconfigurationException(
|
||||
"We tried to re-instantiate your custom batch sampler and failed. "
|
||||
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
|
||||
"your custom batch sampler inside `*_dataloader` hooks of your module."
|
||||
) from e
|
||||
|
||||
if is_predicting:
|
||||
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
|
||||
|
||||
|
@ -368,39 +420,73 @@ def _dataloader_init_kwargs_resolve_sampler(
|
|||
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
|
||||
|
||||
|
||||
def _replace_value_in_saved_args(
|
||||
replace_key: str,
|
||||
replace_value: Any,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
default_kwargs: Dict[str, Any],
|
||||
arg_names: Tuple[str, ...],
|
||||
) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]:
|
||||
"""Tries to replace an argument value in a saved list of args and kwargs.
|
||||
|
||||
Returns a tuple indicating success of the operation and modified saved args and kwargs
|
||||
"""
|
||||
|
||||
if replace_key in arg_names:
|
||||
replace_index = arg_names.index(replace_key)
|
||||
args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :]
|
||||
return True, args, kwargs
|
||||
elif replace_key in kwargs or replace_key in default_kwargs:
|
||||
kwargs[replace_key] = replace_value
|
||||
return True, args, kwargs
|
||||
|
||||
return False, args, kwargs
|
||||
|
||||
|
||||
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
|
||||
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
|
||||
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)
|
||||
|
||||
|
||||
def _wrap_dataloader_init(init: Callable) -> Callable:
|
||||
"""Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation
|
||||
of custom subclasses."""
|
||||
def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable:
|
||||
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
|
||||
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
|
||||
|
||||
@functools.wraps(init)
|
||||
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
|
||||
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
|
||||
# We need to inspect `init`, as inspecting `obj.__init__`
|
||||
# can lead to inspecting the wrong function with multiple inheritance
|
||||
params = inspect.signature(init).parameters
|
||||
param_names = tuple(
|
||||
param.name
|
||||
|
||||
parameters_defaults = OrderedDict(
|
||||
(param.name, param.default)
|
||||
for param in params.values()
|
||||
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
|
||||
)
|
||||
param_names = param_names[: len(args)]
|
||||
|
||||
if not hasattr(obj, "__pl_dl_args"):
|
||||
obj.__pl_dl_args = args
|
||||
obj.__pl_dl_kwargs = kwargs
|
||||
obj.__pl_dl_arg_names = param_names
|
||||
param_names = tuple(parameters_defaults)[: len(args)]
|
||||
|
||||
# We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader)
|
||||
default_kwargs = {
|
||||
name: value
|
||||
for name, value in parameters_defaults.items()
|
||||
if name not in kwargs and name not in param_names and value != inspect.Parameter.empty
|
||||
}
|
||||
|
||||
if not hasattr(obj, "__pl_saved_args"):
|
||||
obj.__pl_saved_args = args
|
||||
obj.__pl_saved_kwargs = kwargs
|
||||
obj.__pl_saved_arg_names = param_names
|
||||
obj.__pl_saved_default_kwargs = default_kwargs
|
||||
|
||||
# We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
|
||||
# so that we can be sure, that it will not get changed anymore.
|
||||
# That is why we are setting this in every `__init__`
|
||||
if "dataset" in param_names:
|
||||
setattr(obj, "__dataset", args[param_names.index("dataset")])
|
||||
elif "dataset" in kwargs:
|
||||
setattr(obj, "__dataset", kwargs["dataset"])
|
||||
if store_explicit_arg is not None:
|
||||
if store_explicit_arg in param_names:
|
||||
setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
|
||||
elif store_explicit_arg in kwargs:
|
||||
setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])
|
||||
|
||||
init(obj, *args, **kwargs)
|
||||
|
||||
|
@ -422,15 +508,17 @@ def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
|
|||
|
||||
|
||||
@contextmanager
|
||||
def _replace_dataloader_init_method() -> Generator[None, None, None]:
|
||||
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
|
||||
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
|
||||
classes = _get_all_subclasses(DataLoader) | {DataLoader}
|
||||
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
|
||||
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
|
||||
|
||||
It patches the ``__init__`` method.
|
||||
"""
|
||||
classes = _get_all_subclasses(base_cls) | {base_cls}
|
||||
wrapped = set()
|
||||
for cls in classes:
|
||||
if cls.__init__ not in wrapped:
|
||||
cls._old_init = cls.__init__
|
||||
cls.__init__ = _wrap_dataloader_init(cls.__init__)
|
||||
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
|
||||
wrapped.add(cls.__init__)
|
||||
yield
|
||||
for cls in classes:
|
||||
|
@ -475,13 +563,13 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(
|
|||
|
||||
|
||||
def _is_dataloader_shuffled(dataloader: object) -> bool:
|
||||
if hasattr(dataloader, "__pl_dl_kwargs"):
|
||||
if hasattr(dataloader, "__pl_saved_kwargs"):
|
||||
# this attribute is not part of PyTorch's DataLoader, but could have been set by
|
||||
# our `_replace_dataloader_init_method` context manager
|
||||
if "shuffle" in dataloader.__pl_dl_kwargs:
|
||||
return dataloader.__pl_dl_kwargs["shuffle"]
|
||||
if "shuffle" in dataloader.__pl_dl_arg_names:
|
||||
return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")]
|
||||
# our `_replace_init_method` context manager
|
||||
if "shuffle" in dataloader.__pl_saved_kwargs:
|
||||
return dataloader.__pl_saved_kwargs["shuffle"]
|
||||
if "shuffle" in dataloader.__pl_saved_arg_names:
|
||||
return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
|
||||
if isinstance(dataloader.dataset, IterableDataset):
|
||||
# shuffling is useless with iterable datasets
|
||||
return False
|
||||
|
|
|
@ -177,16 +177,17 @@ def test_setup_dataloaders_return_type():
|
|||
assert lite_dataloader1.dataset is dataset1
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.lite.lite._replace_dataloader_init_method")
|
||||
@mock.patch("pytorch_lightning.lite.lite._replace_init_method")
|
||||
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
|
||||
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""
|
||||
|
||||
class Lite(LightningLite):
|
||||
def run(self):
|
||||
ctx_manager().__enter__.assert_called_once()
|
||||
# One for BatchSampler, another for DataLoader
|
||||
assert ctx_manager().__enter__.call_count == 2
|
||||
|
||||
Lite().run()
|
||||
ctx_manager().__exit__.assert_called_once()
|
||||
assert ctx_manager().__exit__.call_count == 2
|
||||
|
||||
|
||||
def test_setup_dataloaders_raises_for_unknown_custom_args():
|
||||
|
|
|
@ -34,7 +34,6 @@ from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler, Se
|
|||
from torch.utils.data._utils.worker import _generate_state, get_worker_info
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
from torch.utils.data.dataset import Dataset, IterableDataset
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
import tests_pytorch.helpers.utils as tutils
|
||||
from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
|
||||
|
@ -1177,15 +1176,6 @@ def test_validate_fault_tolerant(tmpdir):
|
|||
with pytest.raises(TypeError, match="RandomSampler"):
|
||||
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)
|
||||
|
||||
class CustomBatchSampler(BatchSampler):
|
||||
pass
|
||||
|
||||
sampler = Sampler(data())
|
||||
batch_sampler = CustomBatchSampler(sampler, 2, False)
|
||||
dl = DataLoader(data(), batch_sampler=batch_sampler)
|
||||
with pytest.raises(TypeError, match="BatchSampler"):
|
||||
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)
|
||||
|
||||
class CustomIterable(IterableDataset):
|
||||
pass
|
||||
|
||||
|
|
|
@ -3,15 +3,17 @@ from dataclasses import dataclass
|
|||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import BatchSampler, DataLoader, SequentialSampler
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.data import (
|
||||
_dataloader_init_kwargs_resolve_sampler,
|
||||
_get_dataloader_init_args_and_kwargs,
|
||||
_replace_dataloader_init_method,
|
||||
_replace_init_method,
|
||||
_replace_value_in_saved_args,
|
||||
_update_dataloader,
|
||||
extract_batch_size,
|
||||
get_len,
|
||||
|
@ -145,7 +147,7 @@ def test_update_dataloader_typerror_custom_exception():
|
|||
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"):
|
||||
_update_dataloader(dataloader, dataloader.sampler)
|
||||
|
||||
with _replace_dataloader_init_method():
|
||||
with _replace_init_method(DataLoader, "dataset"):
|
||||
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
|
||||
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
|
||||
assert isinstance(new_dataloader, BadStandaloneGoodHookImpl)
|
||||
|
@ -296,13 +298,14 @@ class ChangingDataLoader(DataLoader):
|
|||
pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"),
|
||||
],
|
||||
)
|
||||
def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, checked_values):
|
||||
with _replace_dataloader_init_method():
|
||||
def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, checked_values):
|
||||
with _replace_init_method(DataLoader, "dataset"):
|
||||
dataloader = cls(*args, **kwargs)
|
||||
|
||||
assert dataloader.__pl_dl_args == args
|
||||
assert dataloader.__pl_dl_kwargs == kwargs
|
||||
assert dataloader.__pl_dl_arg_names == arg_names
|
||||
assert dataloader.__pl_saved_args == args
|
||||
assert dataloader.__pl_saved_kwargs == kwargs
|
||||
assert dataloader.__pl_saved_arg_names == arg_names
|
||||
assert dataloader.__pl_saved_default_kwargs == {}
|
||||
assert dataloader.__dataset == dataset
|
||||
|
||||
assert dataloader.dataset == dataset
|
||||
|
@ -312,14 +315,15 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c
|
|||
if isinstance(dataloader_value, torch.Tensor):
|
||||
assert dataloader_value is value
|
||||
else:
|
||||
assert getattr(dataloader, key) == value
|
||||
assert dataloader_value == value
|
||||
|
||||
dataloader = _update_dataloader(dataloader, dataloader.sampler)
|
||||
|
||||
assert isinstance(dataloader, cls)
|
||||
assert not hasattr(dataloader, "__pl_dl_kwargs")
|
||||
assert not hasattr(dataloader, "__pl_dl_arg_names")
|
||||
assert not hasattr(dataloader, "__pl_dl_args")
|
||||
assert not hasattr(dataloader, "__pl_saved_kwargs")
|
||||
assert not hasattr(dataloader, "__pl_saved_arg_names")
|
||||
assert not hasattr(dataloader, "__pl_saved_args")
|
||||
assert not hasattr(dataloader, "__pl_saved_default_kwargs")
|
||||
assert not hasattr(dataloader, "__dataset")
|
||||
|
||||
assert dataloader.dataset == dataset
|
||||
|
@ -329,7 +333,168 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c
|
|||
if isinstance(dataloader_value, torch.Tensor):
|
||||
assert dataloader_value is value
|
||||
else:
|
||||
assert getattr(dataloader, key) == value
|
||||
assert dataloader_value == value
|
||||
|
||||
|
||||
def test_replace_init_method_extra_kwargs():
|
||||
class LoaderSubclass(DataLoader):
|
||||
def __init__(self, dataset, *args, batch_size=10, **kwargs):
|
||||
super().__init__(dataset, *args, batch_size=batch_size, **kwargs)
|
||||
|
||||
with _replace_init_method(DataLoader, "dataset"):
|
||||
dataloader = LoaderSubclass(range(10))
|
||||
|
||||
assert dataloader.__pl_saved_args == (range(10),)
|
||||
assert dataloader.__pl_saved_kwargs == {}
|
||||
assert dataloader.__pl_saved_arg_names == ("dataset",)
|
||||
assert dataloader.__pl_saved_default_kwargs == {"batch_size": 10}
|
||||
assert dataloader.__dataset == range(10)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("predicting", [True, False])
|
||||
def test_custom_batch_sampler(predicting):
|
||||
"""This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to
|
||||
properly reinstantiate the class, is invoked properly.
|
||||
|
||||
It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore
|
||||
not setting `__pl_saved_{args,arg_names,kwargs}` attributes.
|
||||
"""
|
||||
|
||||
class MyBatchSampler(BatchSampler):
|
||||
# Custom Batch sampler with extra argument and default value
|
||||
def __init__(self, sampler, extra_arg, drop_last=True):
|
||||
self.extra_arg = extra_arg
|
||||
super().__init__(sampler, 10, drop_last)
|
||||
|
||||
sampler = RandomSampler(range(10))
|
||||
with _replace_init_method(BatchSampler):
|
||||
# instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks
|
||||
batch_sampler = MyBatchSampler(sampler, "random_str")
|
||||
|
||||
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
|
||||
|
||||
# assert that passed information got saved
|
||||
assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")
|
||||
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
|
||||
assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")
|
||||
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True}
|
||||
|
||||
# updating dataloader, what happens on access of the dataloaders.
|
||||
# This should not fail, and would fail before support for custom args.
|
||||
dataloader = _update_dataloader(
|
||||
dataloader, dataloader.sampler, mode=RunningStage.PREDICTING if predicting else None
|
||||
)
|
||||
|
||||
# Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types
|
||||
batch_sampler = dataloader.batch_sampler
|
||||
|
||||
if predicting:
|
||||
assert isinstance(batch_sampler, IndexBatchSamplerWrapper)
|
||||
batch_sampler = batch_sampler._sampler
|
||||
|
||||
assert isinstance(batch_sampler, MyBatchSampler)
|
||||
assert batch_sampler.drop_last == (not predicting)
|
||||
|
||||
assert batch_sampler.extra_arg == "random_str"
|
||||
assert not hasattr(batch_sampler, "__pl_saved_kwargs")
|
||||
assert not hasattr(batch_sampler, "__pl_saved_arg_names")
|
||||
assert not hasattr(batch_sampler, "__pl_saved_args")
|
||||
assert not hasattr(batch_sampler, "__pl_saved_default_kwargs")
|
||||
|
||||
|
||||
def test_custom_batch_sampler_no_drop_last():
|
||||
"""Tests whether appropriate warning is raised when the custom `BatchSampler` does not support `drop_last` and
|
||||
we want to reset it."""
|
||||
|
||||
class MyBatchSampler(BatchSampler):
|
||||
# Custom batch sampler with extra argument, but without `drop_last`
|
||||
def __init__(self, sampler, extra_arg):
|
||||
self.extra_arg = extra_arg
|
||||
super().__init__(sampler, 10, False)
|
||||
|
||||
sampler = RandomSampler(range(10))
|
||||
with _replace_init_method(BatchSampler):
|
||||
# instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks
|
||||
batch_sampler = MyBatchSampler(sampler, "random_str")
|
||||
|
||||
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
|
||||
|
||||
# assert that passed information got saved
|
||||
assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")
|
||||
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
|
||||
assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")
|
||||
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}
|
||||
|
||||
# Assert that warning is raised
|
||||
with pytest.warns(UserWarning, match="drop_last=False"):
|
||||
dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)
|
||||
|
||||
|
||||
def test_custom_batch_sampler_no_sampler():
|
||||
"""Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler
|
||||
argument."""
|
||||
|
||||
class MyBatchSampler(BatchSampler):
|
||||
# Custom batch sampler, without sampler argument.
|
||||
def __init__(self, extra_arg):
|
||||
self.extra_arg = extra_arg
|
||||
super().__init__(RandomSampler(range(10)), 10, False)
|
||||
|
||||
with _replace_init_method(BatchSampler):
|
||||
# instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks
|
||||
batch_sampler = MyBatchSampler("random_str")
|
||||
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
|
||||
|
||||
# assert that passed information got saved
|
||||
assert dataloader.batch_sampler.__pl_saved_args == ("random_str",)
|
||||
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
|
||||
assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",)
|
||||
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}
|
||||
|
||||
# Assert that error is raised
|
||||
with pytest.raises(TypeError, match="sampler into the batch sampler"):
|
||||
dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
[
|
||||
"args",
|
||||
"kwargs",
|
||||
"default_kwargs",
|
||||
"arg_names",
|
||||
"replace_key",
|
||||
"replace_value",
|
||||
"expected_status",
|
||||
"expected_args",
|
||||
"expected_kwargs",
|
||||
],
|
||||
[
|
||||
pytest.param((), {}, {}, [], "a", 1, False, (), {}, id="empty"),
|
||||
pytest.param((1,), {}, {}, ["a"], "a", 2, True, (2,), {}, id="simple1"),
|
||||
pytest.param((1, 2, 3), {}, {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"),
|
||||
pytest.param((1, 2, 3), {"a": 1}, {}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"),
|
||||
pytest.param(
|
||||
(1, 2, 3),
|
||||
{"a": 1},
|
||||
{"e": 5},
|
||||
["b", "c", "d"],
|
||||
"e",
|
||||
2,
|
||||
True,
|
||||
(1, 2, 3),
|
||||
{"a": 1, "e": 2},
|
||||
id="default_kwargs",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_replace_value_in_args(
|
||||
args, kwargs, default_kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs
|
||||
):
|
||||
assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, default_kwargs, arg_names) == (
|
||||
expected_status,
|
||||
expected_args,
|
||||
expected_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_dataloader_disallow_batch_sampler():
|
||||
|
|
Loading…
Reference in New Issue