2022-04-15 12:13:33 +00:00
|
|
|
from dataclasses import dataclass
|
2023-08-29 09:56:01 +00:00
|
|
|
from unittest.mock import Mock
|
2022-04-15 12:13:33 +00:00
|
|
|
|
2022-10-31 15:09:29 +00:00
|
|
|
import numpy as np
|
2021-08-10 06:39:00 +00:00
|
|
|
import pytest
|
2021-07-13 11:35:10 +00:00
|
|
|
import torch
|
2023-02-01 20:34:38 +00:00
|
|
|
from lightning.fabric.utilities.data import _replace_dunder_methods
|
2023-02-02 10:06:45 +00:00
|
|
|
from lightning.pytorch import Trainer
|
2023-03-09 00:04:42 +00:00
|
|
|
from lightning.pytorch.demos.boring_classes import RandomDataset, RandomIterableDataset
|
2023-02-22 16:03:08 +00:00
|
|
|
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
|
2023-02-02 10:06:45 +00:00
|
|
|
from lightning.pytorch.trainer.states import RunningStage
|
|
|
|
from lightning.pytorch.utilities.data import (
|
2022-06-21 23:53:24 +00:00
|
|
|
_get_dataloader_init_args_and_kwargs,
|
2021-11-24 21:51:11 +00:00
|
|
|
_update_dataloader,
|
2021-11-02 17:22:58 +00:00
|
|
|
extract_batch_size,
|
|
|
|
has_len_all_ranks,
|
|
|
|
warning_cache,
|
|
|
|
)
|
2023-02-02 10:06:45 +00:00
|
|
|
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
ruff: replace isort with ruff +TPU (#17684)
* ruff: replace isort with ruff
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing & imports
* lines in warning test
* docs
* fix enum import
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing
* import
* fix lines
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* type ClusterEnvironment
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
|
|
|
from lightning_utilities.test.warning import no_warning_call
|
|
|
|
from torch import Tensor
|
|
|
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
|
2021-07-13 11:35:10 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_extract_batch_size():
|
|
|
|
"""Tests the behavior of extracting the batch size."""
|
2021-11-01 19:50:30 +00:00
|
|
|
|
|
|
|
def _check_warning_not_raised(data, expected):
|
2021-11-19 16:48:26 +00:00
|
|
|
with no_warning_call(match="Trying to infer the `batch_size`"):
|
2021-11-01 19:50:30 +00:00
|
|
|
assert extract_batch_size(data) == expected
|
|
|
|
|
|
|
|
def _check_warning_raised(data, expected):
|
|
|
|
with pytest.warns(UserWarning, match=f"Trying to infer the `batch_size` .* we found is {expected}."):
|
|
|
|
assert extract_batch_size(batch) == expected
|
|
|
|
warning_cache.clear()
|
|
|
|
|
2021-11-22 16:55:19 +00:00
|
|
|
def _check_error_raised(data):
|
|
|
|
with pytest.raises(MisconfigurationException, match="We could not infer the batch_size"):
|
|
|
|
extract_batch_size(batch)
|
2021-07-13 11:35:10 +00:00
|
|
|
|
2022-04-15 12:13:33 +00:00
|
|
|
@dataclass
|
|
|
|
class CustomDataclass:
|
|
|
|
a: Tensor
|
|
|
|
b: Tensor
|
|
|
|
|
2021-11-22 16:55:19 +00:00
|
|
|
# Warning not raised
|
2021-07-13 11:35:10 +00:00
|
|
|
batch = torch.zeros(11, 10, 9, 8)
|
2021-11-01 19:50:30 +00:00
|
|
|
_check_warning_not_raised(batch, 11)
|
2021-07-13 11:35:10 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
batch = {"test": torch.zeros(11, 10)}
|
2021-11-01 19:50:30 +00:00
|
|
|
_check_warning_not_raised(batch, 11)
|
2021-07-13 11:35:10 +00:00
|
|
|
|
|
|
|
batch = [torch.zeros(11, 10)]
|
2021-11-01 19:50:30 +00:00
|
|
|
_check_warning_not_raised(batch, 11)
|
2021-07-13 11:35:10 +00:00
|
|
|
|
2022-04-15 12:13:33 +00:00
|
|
|
batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(11, 10))
|
|
|
|
_check_warning_not_raised(batch, 11)
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
batch = {"test": [{"test": [torch.zeros(11, 10)]}]}
|
2021-11-01 19:50:30 +00:00
|
|
|
_check_warning_not_raised(batch, 11)
|
|
|
|
|
2021-11-22 16:55:19 +00:00
|
|
|
# Warning raised
|
2021-11-19 16:48:26 +00:00
|
|
|
batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])}
|
|
|
|
_check_warning_raised(batch, 1)
|
|
|
|
|
2022-04-15 12:13:33 +00:00
|
|
|
batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(1))
|
|
|
|
_check_warning_raised(batch, 11)
|
|
|
|
|
2021-11-01 19:50:30 +00:00
|
|
|
batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]}
|
|
|
|
_check_warning_raised(batch, 11)
|
|
|
|
|
|
|
|
batch = {"test": [{"test": [torch.zeros(10, 10), torch.zeros(11, 10)]}]}
|
|
|
|
_check_warning_raised(batch, 10)
|
|
|
|
|
|
|
|
batch = [{"test": torch.zeros(10, 10), "test_1": torch.zeros(11, 10)}]
|
|
|
|
_check_warning_raised(batch, 10)
|
2021-08-10 06:39:00 +00:00
|
|
|
|
2021-11-22 16:55:19 +00:00
|
|
|
# Error raised
|
|
|
|
batch = "test string"
|
|
|
|
_check_error_raised(batch)
|
|
|
|
|
|
|
|
data = {"test": ["some text"] * 7}
|
|
|
|
_check_error_raised(data)
|
|
|
|
|
|
|
|
class CustomBatch:
|
|
|
|
def __init__(self):
|
|
|
|
self.x = torch.randn(7, 2)
|
|
|
|
|
|
|
|
data = CustomBatch()
|
|
|
|
_check_error_raised(data)
|
|
|
|
|
2021-08-10 06:39:00 +00:00
|
|
|
|
2021-11-02 17:22:58 +00:00
|
|
|
def test_has_len_all_rank():
|
|
|
|
trainer = Trainer(fast_dev_run=True)
|
|
|
|
|
2022-02-28 18:31:18 +00:00
|
|
|
with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero."):
|
2023-03-09 00:04:42 +00:00
|
|
|
assert has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy)
|
2021-11-02 17:22:58 +00:00
|
|
|
|
2023-03-09 00:04:42 +00:00
|
|
|
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy)
|
2021-11-24 14:58:51 +00:00
|
|
|
|
|
|
|
|
2021-11-24 21:51:11 +00:00
|
|
|
def test_update_dataloader_typerror_custom_exception():
|
2022-06-21 23:53:24 +00:00
|
|
|
class BadStandaloneGoodHookImpl(DataLoader):
|
2021-11-24 21:51:11 +00:00
|
|
|
def __init__(self, foo, *args, **kwargs):
|
|
|
|
self.foo = foo
|
|
|
|
# positional conflict with `dataset`
|
|
|
|
super().__init__(foo, *args, **kwargs)
|
|
|
|
|
2022-06-21 23:53:24 +00:00
|
|
|
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
|
2022-08-17 15:42:54 +00:00
|
|
|
with pytest.raises(MisconfigurationException, match="implementation has an error.*`dataset`"):
|
2021-11-24 21:51:11 +00:00
|
|
|
_update_dataloader(dataloader, dataloader.sampler)
|
|
|
|
|
2022-08-17 15:42:54 +00:00
|
|
|
with _replace_dunder_methods(DataLoader, "dataset"):
|
2022-06-21 23:53:24 +00:00
|
|
|
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
|
|
|
|
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
|
|
|
|
assert isinstance(new_dataloader, BadStandaloneGoodHookImpl)
|
|
|
|
|
|
|
|
class BadImpl(DataLoader):
|
2021-11-24 21:51:11 +00:00
|
|
|
def __init__(self, randomize, *args, **kwargs):
|
|
|
|
self.randomize = randomize
|
|
|
|
# keyword conflict with `shuffle`
|
|
|
|
super().__init__(*args, shuffle=randomize, **kwargs)
|
|
|
|
|
2022-06-21 23:53:24 +00:00
|
|
|
dataloader = BadImpl(False, [])
|
2022-08-17 15:42:54 +00:00
|
|
|
with pytest.raises(MisconfigurationException, match="implementation has an error.*`shuffle`"):
|
2021-11-24 21:51:11 +00:00
|
|
|
_update_dataloader(dataloader, dataloader.sampler)
|
|
|
|
|
|
|
|
class GoodImpl(DataLoader):
|
|
|
|
def __init__(self, randomize, *args, **kwargs):
|
|
|
|
# fixed implementation, kwargs are filtered
|
|
|
|
self.randomize = randomize or kwargs.pop("shuffle", False)
|
|
|
|
super().__init__(*args, shuffle=randomize, **kwargs)
|
|
|
|
|
|
|
|
dataloader = GoodImpl(False, [])
|
|
|
|
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
|
|
|
|
assert isinstance(new_dataloader, GoodImpl)
|
|
|
|
|
|
|
|
|
2022-07-27 15:32:50 +00:00
|
|
|
@pytest.mark.parametrize("predicting", [True, False])
|
2023-08-29 09:56:01 +00:00
|
|
|
def test_custom_torch_batch_sampler(predicting):
|
2023-08-09 14:44:20 +00:00
|
|
|
"""This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to properly
|
|
|
|
reinstantiate the class, is invoked properly.
|
2022-07-27 15:32:50 +00:00
|
|
|
|
|
|
|
It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore
|
|
|
|
not setting `__pl_saved_{args,arg_names,kwargs}` attributes.
|
|
|
|
"""
|
|
|
|
|
|
|
|
class MyBatchSampler(BatchSampler):
|
|
|
|
# Custom Batch sampler with extra argument and default value
|
|
|
|
def __init__(self, sampler, extra_arg, drop_last=True):
|
|
|
|
self.extra_arg = extra_arg
|
|
|
|
super().__init__(sampler, 10, drop_last)
|
|
|
|
|
|
|
|
sampler = RandomSampler(range(10))
|
2022-08-17 15:42:54 +00:00
|
|
|
with _replace_dunder_methods(BatchSampler):
|
|
|
|
# instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks
|
2022-07-27 15:32:50 +00:00
|
|
|
batch_sampler = MyBatchSampler(sampler, "random_str")
|
|
|
|
|
|
|
|
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
|
|
|
|
|
|
|
|
# assert that passed information got saved
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True}
|
|
|
|
|
|
|
|
# updating dataloader, what happens on access of the dataloaders.
|
|
|
|
# This should not fail, and would fail before support for custom args.
|
|
|
|
dataloader = _update_dataloader(
|
2023-08-29 09:56:01 +00:00
|
|
|
dataloader, dataloader.sampler, mode=(RunningStage.PREDICTING if predicting else None)
|
2022-07-27 15:32:50 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types
|
|
|
|
batch_sampler = dataloader.batch_sampler
|
|
|
|
|
|
|
|
if predicting:
|
2023-02-22 16:03:08 +00:00
|
|
|
assert isinstance(batch_sampler, _IndexBatchSamplerWrapper)
|
|
|
|
batch_sampler = batch_sampler._batch_sampler
|
2022-07-27 15:32:50 +00:00
|
|
|
|
|
|
|
assert isinstance(batch_sampler, MyBatchSampler)
|
|
|
|
assert batch_sampler.drop_last == (not predicting)
|
|
|
|
|
|
|
|
assert batch_sampler.extra_arg == "random_str"
|
|
|
|
assert not hasattr(batch_sampler, "__pl_saved_kwargs")
|
|
|
|
assert not hasattr(batch_sampler, "__pl_saved_arg_names")
|
|
|
|
assert not hasattr(batch_sampler, "__pl_saved_args")
|
|
|
|
assert not hasattr(batch_sampler, "__pl_saved_default_kwargs")
|
|
|
|
|
|
|
|
|
2023-08-29 09:56:01 +00:00
|
|
|
@pytest.mark.parametrize("predicting", [True, False])
|
|
|
|
def test_custom_torch_batch_sampler_doppelganger(predicting):
|
|
|
|
"""Test we can reinstantiate a sampler that mimics PyTorch's BatchSampler even if it does not inherit from it.
|
|
|
|
|
|
|
|
This is only possible if that sampler accepts the `batch_size` and `drop_last` arguments, and stores them
|
|
|
|
as attributes.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
class BatchSamplerDoppelganger:
|
|
|
|
"""A batch sampler that mimics `torch.utils.data.BatchSampler` but does not inherit from it."""
|
|
|
|
|
|
|
|
def __init__(self, sampler, batch_size, drop_last):
|
|
|
|
self.sampler = sampler
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.drop_last = drop_last
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
while True:
|
|
|
|
yield [0, 1, 2, 3]
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
return 4
|
|
|
|
|
|
|
|
batch_sampler = BatchSamplerDoppelganger(sampler=Mock(), batch_size=2, drop_last=True)
|
|
|
|
dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
|
|
|
|
new_sampler = Mock()
|
|
|
|
dataloader = _update_dataloader(
|
|
|
|
dataloader, sampler=new_sampler, mode=(RunningStage.PREDICTING if predicting else None)
|
|
|
|
)
|
|
|
|
|
|
|
|
batch_sampler = dataloader.batch_sampler
|
|
|
|
|
|
|
|
if predicting:
|
|
|
|
assert isinstance(batch_sampler, _IndexBatchSamplerWrapper)
|
|
|
|
batch_sampler = batch_sampler._batch_sampler
|
|
|
|
|
|
|
|
assert isinstance(batch_sampler, BatchSamplerDoppelganger)
|
|
|
|
assert batch_sampler.sampler == new_sampler
|
|
|
|
assert batch_sampler.drop_last == (not predicting)
|
|
|
|
|
|
|
|
|
|
|
|
def test_custom_batch_sampler():
|
|
|
|
"""Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`."""
|
|
|
|
|
|
|
|
class CustomBatchSampler: # not inheriting from `BatchSampler`
|
|
|
|
def __iter__(self):
|
|
|
|
while True:
|
|
|
|
yield [0, 1, 2, 3]
|
|
|
|
|
|
|
|
batch_sampler = CustomBatchSampler()
|
|
|
|
dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
|
|
|
|
with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
|
|
|
|
_ = _update_dataloader(dataloader, sampler=Mock())
|
|
|
|
|
|
|
|
|
2022-07-27 15:32:50 +00:00
|
|
|
def test_custom_batch_sampler_no_drop_last():
|
2023-08-09 14:44:20 +00:00
|
|
|
"""Tests whether appropriate warning is raised when the custom `BatchSampler` does not support `drop_last` and we
|
|
|
|
want to reset it."""
|
2022-07-27 15:32:50 +00:00
|
|
|
|
|
|
|
class MyBatchSampler(BatchSampler):
|
|
|
|
# Custom batch sampler with extra argument, but without `drop_last`
|
|
|
|
def __init__(self, sampler, extra_arg):
|
|
|
|
self.extra_arg = extra_arg
|
|
|
|
super().__init__(sampler, 10, False)
|
|
|
|
|
|
|
|
sampler = RandomSampler(range(10))
|
2022-08-17 15:42:54 +00:00
|
|
|
with _replace_dunder_methods(BatchSampler):
|
|
|
|
# instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks
|
2022-07-27 15:32:50 +00:00
|
|
|
batch_sampler = MyBatchSampler(sampler, "random_str")
|
|
|
|
|
|
|
|
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
|
|
|
|
|
|
|
|
# assert that passed information got saved
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}
|
|
|
|
|
|
|
|
# Assert that warning is raised
|
|
|
|
with pytest.warns(UserWarning, match="drop_last=False"):
|
2023-08-29 09:56:01 +00:00
|
|
|
_ = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)
|
2022-07-27 15:32:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_custom_batch_sampler_no_sampler():
|
2023-08-09 14:44:20 +00:00
|
|
|
"""Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler argument."""
|
2022-07-27 15:32:50 +00:00
|
|
|
|
|
|
|
class MyBatchSampler(BatchSampler):
|
|
|
|
# Custom batch sampler, without sampler argument.
|
|
|
|
def __init__(self, extra_arg):
|
|
|
|
self.extra_arg = extra_arg
|
|
|
|
super().__init__(RandomSampler(range(10)), 10, False)
|
|
|
|
|
2022-08-17 15:42:54 +00:00
|
|
|
with _replace_dunder_methods(BatchSampler):
|
|
|
|
# instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks
|
2022-07-27 15:32:50 +00:00
|
|
|
batch_sampler = MyBatchSampler("random_str")
|
|
|
|
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
|
|
|
|
|
|
|
|
# assert that passed information got saved
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_args == ("random_str",)
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",)
|
|
|
|
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}
|
|
|
|
|
|
|
|
# Assert that error is raised
|
|
|
|
with pytest.raises(TypeError, match="sampler into the batch sampler"):
|
2023-08-26 10:30:04 +00:00
|
|
|
_ = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)
|
2022-07-27 09:50:43 +00:00
|
|
|
|
|
|
|
|
2022-01-17 22:33:57 +00:00
|
|
|
@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
|
|
|
|
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
|
|
|
|
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
|
|
|
|
dataset = RandomIterableDataset(7, 100)
|
|
|
|
dataloader = DataLoader(dataset, batch_size=32)
|
2022-06-21 23:53:24 +00:00
|
|
|
_, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler, mode=mode)
|
2022-01-17 22:33:57 +00:00
|
|
|
assert dl_kwargs["sampler"] is None
|
|
|
|
assert dl_kwargs["batch_sampler"] is None
|
|
|
|
assert dl_kwargs["batch_size"] is dataloader.batch_size
|
|
|
|
assert dl_kwargs["dataset"] is dataloader.dataset
|
|
|
|
assert dl_kwargs["collate_fn"] is dataloader.collate_fn
|
2022-10-31 15:09:29 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_dataloader_kwargs_replacement_with_array_default_comparison():
|
2023-08-09 14:44:20 +00:00
|
|
|
"""Test that the comparison of attributes and default argument values works with arrays (truth value ambiguous).
|
2022-10-31 15:09:29 +00:00
|
|
|
|
|
|
|
Regression test for issue #15408.
|
2023-08-09 14:44:20 +00:00
|
|
|
|
2022-10-31 15:09:29 +00:00
|
|
|
"""
|
|
|
|
dataset = RandomDataset(5, 100)
|
|
|
|
|
|
|
|
class ArrayAttributeDataloader(DataLoader):
|
|
|
|
def __init__(self, indices=None, **kwargs):
|
|
|
|
super().__init__(dataset)
|
|
|
|
self.indices = np.random.rand(2, 2) # an attribute we can't compare with ==
|
|
|
|
|
|
|
|
dataloader = ArrayAttributeDataloader(dataset)
|
|
|
|
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
|
|
|
|
assert dl_kwargs["indices"] is dataloader.indices
|