Allowed custom `BatchSampler`s when instantiated in `*_dataloader` hook (#13640)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
otaj 2022-07-27 17:32:50 +02:00 committed by GitHub
parent c58d351e01
commit 95f5f170f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 317 additions and 81 deletions

View File

@ -348,6 +348,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))
- Allowed custom `BatchSampler`s when instantiated in `*_dataloader` hook [#13640](https://github.com/PyTorchLightning/pytorch-lightning/pull/13640))
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))

View File

@ -22,7 +22,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
@ -35,7 +35,7 @@ from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_da
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
_replace_init_method,
_update_dataloader,
has_iterable_dataset,
)
@ -403,7 +403,9 @@ class LightningLite(ABC):
def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
with self._strategy.model_sharded_context(), _replace_dataloader_init_method():
with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method(
BatchSampler
):
return run_method(*args, **kwargs)
def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:

View File

@ -17,7 +17,7 @@ from dataclasses import dataclass, field
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
from weakref import proxy
from torch.utils.data import DataLoader, Sampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
@ -31,7 +31,7 @@ from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_au
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_is_dataloader_shuffled,
_replace_dataloader_init_method,
_replace_init_method,
_update_dataloader,
has_iterable_dataset,
has_len_all_ranks,
@ -424,7 +424,7 @@ class DataConnector:
"""
source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source")
with _replace_dataloader_init_method():
with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler):
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning
dataloader = source.dataloader()

View File

@ -16,15 +16,7 @@ from dataclasses import dataclass, field
from functools import partial, wraps
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
from torch.utils.data import (
BatchSampler,
Dataset,
DistributedSampler,
get_worker_info,
RandomSampler,
Sampler,
SequentialSampler,
)
from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataloader import (
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
@ -757,10 +749,6 @@ def _validate_map_dataset(dataloader: DataLoader) -> None:
if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS:
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")
batch_sampler = getattr(dataloader, "batch_sampler", None)
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
raise TypeError("Fault-tolerance supports only a `BatchSampler`.")
if type(sampler) is DistributedSampler and sampler.shuffle:
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
elif type(sampler) is RandomSampler:

View File

@ -14,6 +14,7 @@
import functools
import inspect
import os
from collections import OrderedDict
from contextlib import contextmanager
from dataclasses import fields
from functools import partial
@ -220,11 +221,11 @@ def _get_dataloader_init_args_and_kwargs(
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
was_wrapped = hasattr(dataloader, "__pl_dl_args")
was_wrapped = hasattr(dataloader, "__pl_saved_args")
if was_wrapped:
dl_args = dataloader.__pl_dl_args
dl_kwargs = dataloader.__pl_dl_kwargs
arg_names = dataloader.__pl_dl_arg_names
dl_args = dataloader.__pl_saved_args
dl_kwargs = dataloader.__pl_saved_kwargs
arg_names = dataloader.__pl_saved_arg_names
original_dataset = dataloader.__dataset # we have this saved from _wrap_init
else:
# get the dataloader instance attributes
@ -323,6 +324,9 @@ def _dataloader_init_kwargs_resolve_sampler(
If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a
`FastForwardSampler`.
If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated
automatically, since `poptorch.DataLoader` will try to increase the batch_size
"""
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
batch_sampler = getattr(dataloader, "batch_sampler")
@ -341,11 +345,59 @@ def _dataloader_init_kwargs_resolve_sampler(
"when running on multiple IPU devices."
)
elif type(batch_sampler) is not BatchSampler or is_predicting:
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
batch_sampler_cls = type(batch_sampler)
if hasattr(batch_sampler, "__pl_saved_args"):
args = batch_sampler.__pl_saved_args
kwargs = batch_sampler.__pl_saved_kwargs
default_kwargs = batch_sampler.__pl_saved_default_kwargs
arg_names = batch_sampler.__pl_saved_arg_names
if is_predicting:
success, args, kwargs = _replace_value_in_saved_args(
"drop_last", False, args, kwargs, default_kwargs, arg_names
)
if not success:
rank_zero_warn(
f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however "
f"it seems the class `{batch_sampler_cls.__qualname__}` does not support it. "
"Your predictions might be incomplete. To mitigate this, expose `drop_last` in "
"the `__init__` method of your custom class."
)
success, args, kwargs = _replace_value_in_saved_args(
"sampler", sampler, args, kwargs, default_kwargs, arg_names
)
if not success:
raise TypeError(
"Trying to inject a modified sampler into the batch sampler; however, it seems the class "
f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate "
"this, expose an argument `sampler` in the `__init__` method of your custom class."
)
batch_sampler = batch_sampler_cls(*args, **kwargs)
else:
try:
batch_sampler = batch_sampler_cls(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
except TypeError as e:
import re
match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e))
if not match:
# an unexpected `TypeError`, continue failure
raise
# There could either be too few or too many arguments. Customizing the message based on this doesn't
# make much sense since our MisconfigurationException is going to be raised from the original one.
raise MisconfigurationException(
"We tried to re-instantiate your custom batch sampler and failed. "
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
"your custom batch sampler inside `*_dataloader` hooks of your module."
) from e
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
@ -368,39 +420,73 @@ def _dataloader_init_kwargs_resolve_sampler(
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
def _replace_value_in_saved_args(
replace_key: str,
replace_value: Any,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
default_kwargs: Dict[str, Any],
arg_names: Tuple[str, ...],
) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]:
"""Tries to replace an argument value in a saved list of args and kwargs.
Returns a tuple indicating success of the operation and modified saved args and kwargs
"""
if replace_key in arg_names:
replace_index = arg_names.index(replace_key)
args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :]
return True, args, kwargs
elif replace_key in kwargs or replace_key in default_kwargs:
kwargs[replace_key] = replace_value
return True, args, kwargs
return False, args, kwargs
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)
def _wrap_dataloader_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation
of custom subclasses."""
def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable:
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
# We need to inspect `init`, as inspecting `obj.__init__`
# can lead to inspecting the wrong function with multiple inheritance
params = inspect.signature(init).parameters
param_names = tuple(
param.name
parameters_defaults = OrderedDict(
(param.name, param.default)
for param in params.values()
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
)
param_names = param_names[: len(args)]
if not hasattr(obj, "__pl_dl_args"):
obj.__pl_dl_args = args
obj.__pl_dl_kwargs = kwargs
obj.__pl_dl_arg_names = param_names
param_names = tuple(parameters_defaults)[: len(args)]
# We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader)
default_kwargs = {
name: value
for name, value in parameters_defaults.items()
if name not in kwargs and name not in param_names and value != inspect.Parameter.empty
}
if not hasattr(obj, "__pl_saved_args"):
obj.__pl_saved_args = args
obj.__pl_saved_kwargs = kwargs
obj.__pl_saved_arg_names = param_names
obj.__pl_saved_default_kwargs = default_kwargs
# We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
# so that we can be sure, that it will not get changed anymore.
# That is why we are setting this in every `__init__`
if "dataset" in param_names:
setattr(obj, "__dataset", args[param_names.index("dataset")])
elif "dataset" in kwargs:
setattr(obj, "__dataset", kwargs["dataset"])
if store_explicit_arg is not None:
if store_explicit_arg in param_names:
setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
elif store_explicit_arg in kwargs:
setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])
init(obj, *args, **kwargs)
@ -422,15 +508,17 @@ def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
@contextmanager
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
classes = _get_all_subclasses(DataLoader) | {DataLoader}
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
It patches the ``__init__`` method.
"""
classes = _get_all_subclasses(base_cls) | {base_cls}
wrapped = set()
for cls in classes:
if cls.__init__ not in wrapped:
cls._old_init = cls.__init__
cls.__init__ = _wrap_dataloader_init(cls.__init__)
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
wrapped.add(cls.__init__)
yield
for cls in classes:
@ -475,13 +563,13 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(
def _is_dataloader_shuffled(dataloader: object) -> bool:
if hasattr(dataloader, "__pl_dl_kwargs"):
if hasattr(dataloader, "__pl_saved_kwargs"):
# this attribute is not part of PyTorch's DataLoader, but could have been set by
# our `_replace_dataloader_init_method` context manager
if "shuffle" in dataloader.__pl_dl_kwargs:
return dataloader.__pl_dl_kwargs["shuffle"]
if "shuffle" in dataloader.__pl_dl_arg_names:
return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")]
# our `_replace_init_method` context manager
if "shuffle" in dataloader.__pl_saved_kwargs:
return dataloader.__pl_saved_kwargs["shuffle"]
if "shuffle" in dataloader.__pl_saved_arg_names:
return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
if isinstance(dataloader.dataset, IterableDataset):
# shuffling is useless with iterable datasets
return False

View File

@ -177,16 +177,17 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1
@mock.patch("pytorch_lightning.lite.lite._replace_dataloader_init_method")
@mock.patch("pytorch_lightning.lite.lite._replace_init_method")
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""
class Lite(LightningLite):
def run(self):
ctx_manager().__enter__.assert_called_once()
# One for BatchSampler, another for DataLoader
assert ctx_manager().__enter__.call_count == 2
Lite().run()
ctx_manager().__exit__.assert_called_once()
assert ctx_manager().__exit__.call_count == 2
def test_setup_dataloaders_raises_for_unknown_custom_args():

View File

@ -34,7 +34,6 @@ from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler, Se
from torch.utils.data._utils.worker import _generate_state, get_worker_info
from torch.utils.data.dataloader import DataLoader, default_collate
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.sampler import Sampler
import tests_pytorch.helpers.utils as tutils
from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
@ -1177,15 +1176,6 @@ def test_validate_fault_tolerant(tmpdir):
with pytest.raises(TypeError, match="RandomSampler"):
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)
class CustomBatchSampler(BatchSampler):
pass
sampler = Sampler(data())
batch_sampler = CustomBatchSampler(sampler, 2, False)
dl = DataLoader(data(), batch_sampler=batch_sampler)
with pytest.raises(TypeError, match="BatchSampler"):
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)
class CustomIterable(IterableDataset):
pass

View File

@ -3,15 +3,17 @@ from dataclasses import dataclass
import pytest
import torch
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.data import (
_dataloader_init_kwargs_resolve_sampler,
_get_dataloader_init_args_and_kwargs,
_replace_dataloader_init_method,
_replace_init_method,
_replace_value_in_saved_args,
_update_dataloader,
extract_batch_size,
get_len,
@ -145,7 +147,7 @@ def test_update_dataloader_typerror_custom_exception():
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"):
_update_dataloader(dataloader, dataloader.sampler)
with _replace_dataloader_init_method():
with _replace_init_method(DataLoader, "dataset"):
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(new_dataloader, BadStandaloneGoodHookImpl)
@ -296,13 +298,14 @@ class ChangingDataLoader(DataLoader):
pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"),
],
)
def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, checked_values):
with _replace_dataloader_init_method():
def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, checked_values):
with _replace_init_method(DataLoader, "dataset"):
dataloader = cls(*args, **kwargs)
assert dataloader.__pl_dl_args == args
assert dataloader.__pl_dl_kwargs == kwargs
assert dataloader.__pl_dl_arg_names == arg_names
assert dataloader.__pl_saved_args == args
assert dataloader.__pl_saved_kwargs == kwargs
assert dataloader.__pl_saved_arg_names == arg_names
assert dataloader.__pl_saved_default_kwargs == {}
assert dataloader.__dataset == dataset
assert dataloader.dataset == dataset
@ -312,14 +315,15 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c
if isinstance(dataloader_value, torch.Tensor):
assert dataloader_value is value
else:
assert getattr(dataloader, key) == value
assert dataloader_value == value
dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(dataloader, cls)
assert not hasattr(dataloader, "__pl_dl_kwargs")
assert not hasattr(dataloader, "__pl_dl_arg_names")
assert not hasattr(dataloader, "__pl_dl_args")
assert not hasattr(dataloader, "__pl_saved_kwargs")
assert not hasattr(dataloader, "__pl_saved_arg_names")
assert not hasattr(dataloader, "__pl_saved_args")
assert not hasattr(dataloader, "__pl_saved_default_kwargs")
assert not hasattr(dataloader, "__dataset")
assert dataloader.dataset == dataset
@ -329,7 +333,168 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c
if isinstance(dataloader_value, torch.Tensor):
assert dataloader_value is value
else:
assert getattr(dataloader, key) == value
assert dataloader_value == value
def test_replace_init_method_extra_kwargs():
class LoaderSubclass(DataLoader):
def __init__(self, dataset, *args, batch_size=10, **kwargs):
super().__init__(dataset, *args, batch_size=batch_size, **kwargs)
with _replace_init_method(DataLoader, "dataset"):
dataloader = LoaderSubclass(range(10))
assert dataloader.__pl_saved_args == (range(10),)
assert dataloader.__pl_saved_kwargs == {}
assert dataloader.__pl_saved_arg_names == ("dataset",)
assert dataloader.__pl_saved_default_kwargs == {"batch_size": 10}
assert dataloader.__dataset == range(10)
@pytest.mark.parametrize("predicting", [True, False])
def test_custom_batch_sampler(predicting):
"""This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to
properly reinstantiate the class, is invoked properly.
It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore
not setting `__pl_saved_{args,arg_names,kwargs}` attributes.
"""
class MyBatchSampler(BatchSampler):
# Custom Batch sampler with extra argument and default value
def __init__(self, sampler, extra_arg, drop_last=True):
self.extra_arg = extra_arg
super().__init__(sampler, 10, drop_last)
sampler = RandomSampler(range(10))
with _replace_init_method(BatchSampler):
# instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks
batch_sampler = MyBatchSampler(sampler, "random_str")
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
# assert that passed information got saved
assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True}
# updating dataloader, what happens on access of the dataloaders.
# This should not fail, and would fail before support for custom args.
dataloader = _update_dataloader(
dataloader, dataloader.sampler, mode=RunningStage.PREDICTING if predicting else None
)
# Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types
batch_sampler = dataloader.batch_sampler
if predicting:
assert isinstance(batch_sampler, IndexBatchSamplerWrapper)
batch_sampler = batch_sampler._sampler
assert isinstance(batch_sampler, MyBatchSampler)
assert batch_sampler.drop_last == (not predicting)
assert batch_sampler.extra_arg == "random_str"
assert not hasattr(batch_sampler, "__pl_saved_kwargs")
assert not hasattr(batch_sampler, "__pl_saved_arg_names")
assert not hasattr(batch_sampler, "__pl_saved_args")
assert not hasattr(batch_sampler, "__pl_saved_default_kwargs")
def test_custom_batch_sampler_no_drop_last():
"""Tests whether appropriate warning is raised when the custom `BatchSampler` does not support `drop_last` and
we want to reset it."""
class MyBatchSampler(BatchSampler):
# Custom batch sampler with extra argument, but without `drop_last`
def __init__(self, sampler, extra_arg):
self.extra_arg = extra_arg
super().__init__(sampler, 10, False)
sampler = RandomSampler(range(10))
with _replace_init_method(BatchSampler):
# instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks
batch_sampler = MyBatchSampler(sampler, "random_str")
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
# assert that passed information got saved
assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}
# Assert that warning is raised
with pytest.warns(UserWarning, match="drop_last=False"):
dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)
def test_custom_batch_sampler_no_sampler():
"""Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler
argument."""
class MyBatchSampler(BatchSampler):
# Custom batch sampler, without sampler argument.
def __init__(self, extra_arg):
self.extra_arg = extra_arg
super().__init__(RandomSampler(range(10)), 10, False)
with _replace_init_method(BatchSampler):
# instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks
batch_sampler = MyBatchSampler("random_str")
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
# assert that passed information got saved
assert dataloader.batch_sampler.__pl_saved_args == ("random_str",)
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",)
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}
# Assert that error is raised
with pytest.raises(TypeError, match="sampler into the batch sampler"):
dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)
@pytest.mark.parametrize(
[
"args",
"kwargs",
"default_kwargs",
"arg_names",
"replace_key",
"replace_value",
"expected_status",
"expected_args",
"expected_kwargs",
],
[
pytest.param((), {}, {}, [], "a", 1, False, (), {}, id="empty"),
pytest.param((1,), {}, {}, ["a"], "a", 2, True, (2,), {}, id="simple1"),
pytest.param((1, 2, 3), {}, {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"),
pytest.param((1, 2, 3), {"a": 1}, {}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"),
pytest.param(
(1, 2, 3),
{"a": 1},
{"e": 5},
["b", "c", "d"],
"e",
2,
True,
(1, 2, 3),
{"a": 1, "e": 2},
id="default_kwargs",
),
],
)
def test_replace_value_in_args(
args, kwargs, default_kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs
):
assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, default_kwargs, arg_names) == (
expected_status,
expected_args,
expected_kwargs,
)
def test_dataloader_disallow_batch_sampler():