import contextlib import os import random from unittest import mock from unittest.mock import Mock import lightning.fabric import numpy as np import pytest import torch from lightning.fabric.utilities.data import ( AttributeDict, _get_dataloader_init_args_and_kwargs, _replace_dunder_methods, _replace_value_in_saved_args, _set_sampler_epoch, _update_dataloader, _WrapAttrTag, has_iterable_dataset, has_len, suggested_max_num_workers, ) from lightning.fabric.utilities.exceptions import MisconfigurationException from lightning_utilities.test.warning import no_warning_call from torch import Tensor from torch.utils.data import BatchSampler, DataLoader, RandomSampler from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset def test_has_iterable_dataset(): assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1))) assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1))) class MockDatasetWithoutIterableDataset(RandomDataset): def __iter__(self): yield 1 return self assert not has_iterable_dataset(DataLoader(MockDatasetWithoutIterableDataset(1, 1))) def test_has_len(): assert has_len(DataLoader(RandomDataset(1, 1))) with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."): assert has_len(DataLoader(RandomDataset(0, 0))) assert not has_len(DataLoader(RandomIterableDataset(1, 1))) def test_replace_dunder_methods_multiple_loaders_without_init(): """In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__` method (the one we are wrapping), it can happen, that `hasattr(cls, "__old__init__")` is True because of parent class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error occured only sometimes because it depends on the order in which we are iterating over a set of classes we are patching. This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__` and are children of `DataLoader`. We are testing that a) context manager `_replace_dunder_method` exits cleanly, and b) the mechanism checking for presence of `__old__init__` works as expected. """ classes = [DataLoader] for i in range(100): classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {})) before = {cls: cls.__init__ for cls in classes} with _replace_dunder_methods(DataLoader, "dataset"): for cls in classes[1:]: # First one is `DataLoader` assert "__old__init__" not in cls.__dict__ assert hasattr(cls, "__old__init__") assert "__old__init__" in DataLoader.__dict__ assert hasattr(DataLoader, "__old__init__") for cls in classes: assert before[cls] == cls.__init__ class MyBaseDataLoader(DataLoader): pass class DataLoaderSubclass1(DataLoader): def __init__(self, attribute1, *args, **kwargs): self.at1 = attribute1 super().__init__(*args, **kwargs) class DataLoaderSubclass2(DataLoaderSubclass1): def __init__(self, attribute2, *args, **kwargs): self.at2 = attribute2 super().__init__(attribute2 + "-2", *args, **kwargs) class MyDataLoader(MyBaseDataLoader): def __init__(self, data: Tensor, *args, **kwargs): self.data = data super().__init__(range(data.size(0)), *args, **kwargs) test3_data = torch.randn((10, 20)) class PoptorchDataLoader(DataLoader): def __init__(self, options, *args, **kwargs): super().__init__(*args, **kwargs) self._options = options @property def options(self): return self._options class IncompleteDataLoader(DataLoader): def __init__(self, dataset, batch_size, **kwargs): batch_size = max(batch_size - 5, 0) super().__init__(dataset, batch_size=batch_size, **kwargs) class WeirdDataLoader1(DataLoader): def __init__(self, arg1, arg2, **kwargs): self.arg1 = arg1 super().__init__(arg2, **kwargs) class WeirdDataLoader2(DataLoader): def __init__(self, data_part1, data_part2, **kwargs): data = list(data_part1) + list(data_part2) super().__init__(data, **kwargs) class NoneDataLoader(DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class ChangingDataLoader(DataLoader): def __init__(self, dataset, **kwargs): super().__init__(list(dataset) + list(range(5, 10)), **kwargs) @pytest.mark.parametrize( ("cls", "args", "kwargs", "arg_names", "dataset", "checked_values"), [ pytest.param( DataLoaderSubclass1, ("attribute1",), {"dataset": range(4), "batch_size": 2}, ("attribute1",), range(4), {"batch_size": 2, "at1": "attribute1"}, id="test1", ), pytest.param( DataLoaderSubclass2, ("attribute2",), {"dataset": range(4), "batch_size": 2}, ("attribute2",), range(4), {"batch_size": 2, "at1": "attribute2-2", "at2": "attribute2"}, id="test2", ), pytest.param( MyDataLoader, (test3_data,), {"batch_size": 2}, ("data",), range(10), {"batch_size": 2, "data": test3_data}, id="test3", ), pytest.param(PoptorchDataLoader, (123, [1]), {}, ("options",), [1], {"options": 123}, id="test4"), pytest.param( IncompleteDataLoader, (range(10),), {"batch_size": 10}, ("dataset",), range(10), {"batch_size": 5}, id="test5", ), pytest.param( WeirdDataLoader1, (10, range(10)), {"batch_size": 10}, ("arg1", "arg2"), range(10), {"arg1": 10, "batch_size": 10}, id="test6", ), pytest.param( WeirdDataLoader2, (range(10), range(10, 20)), {"batch_size": 10}, ("data_part1", "data_part2"), list(range(20)), {"batch_size": 10}, id="test7", ), pytest.param(NoneDataLoader, (None,), {}, (), None, {}, id="test8"), pytest.param(ChangingDataLoader, (range(5),), {}, ("dataset",), list(range(10)), {}, id="test9"), ], ) def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset, checked_values): with _replace_dunder_methods(DataLoader, "dataset"): dataloader = cls(*args, **kwargs) assert dataloader.__pl_saved_args == args assert dataloader.__pl_saved_kwargs == kwargs assert dataloader.__pl_saved_arg_names == arg_names assert dataloader.__pl_saved_default_kwargs == {} assert dataloader.__dataset == dataset assert dataloader.dataset == dataset for key, value in checked_values.items(): dataloader_value = getattr(dataloader, key) if isinstance(dataloader_value, Tensor): assert dataloader_value is value else: assert dataloader_value == value dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(dataloader, cls) assert not hasattr(dataloader, "__pl_saved_kwargs") assert not hasattr(dataloader, "__pl_saved_arg_names") assert not hasattr(dataloader, "__pl_saved_args") assert not hasattr(dataloader, "__pl_saved_default_kwargs") assert not hasattr(dataloader, "__dataset") assert dataloader.dataset == dataset for key, value in checked_values.items(): dataloader_value = getattr(dataloader, key) if isinstance(dataloader_value, Tensor): assert dataloader_value is value else: assert dataloader_value == value def test_replace_dunder_methods_extra_kwargs(): class LoaderSubclass(DataLoader): def __init__(self, dataset, *args, batch_size=10, **kwargs): super().__init__(dataset, *args, batch_size=batch_size, **kwargs) with _replace_dunder_methods(DataLoader, "dataset"): dataloader = LoaderSubclass(range(10)) assert dataloader.__pl_saved_args == (range(10),) assert dataloader.__pl_saved_kwargs == {} assert dataloader.__pl_saved_arg_names == ("dataset",) assert dataloader.__pl_saved_default_kwargs == {"batch_size": 10} assert dataloader.__dataset == range(10) def test_replace_dunder_methods_attrs(): """This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` are correctly preserved even after reinstantiation. It also includes a custom `__setattr__` """ class Loader(DataLoader): def __setattr__(self, attr, val): if attr == "custom_arg": val = val + 2 super().__setattr__(attr, val) with _replace_dunder_methods(DataLoader, "dataset"): dataloader = Loader(range(10)) dataloader.custom_arg = 5 dataloader.my_arg = 10 dataloader.another_arg = 100 del dataloader.dataset with contextlib.suppress(AttributeError): del dataloader.abc_arg assert dataloader.__pl_saved_args == (range(10),) assert dataloader.__pl_saved_kwargs == {} assert dataloader.__pl_saved_arg_names == ("dataset",) assert dataloader.__dataset == range(10) assert dataloader.custom_arg == 7 assert dataloader.my_arg == 10 assert dataloader.another_arg == 100 assert not hasattr(dataloader, "dataset") assert dataloader.__pl_attrs_record == [ (("custom_arg", 5), _WrapAttrTag.SET), (("my_arg", 10), _WrapAttrTag.SET), (("another_arg", 100), _WrapAttrTag.SET), (("dataset",), _WrapAttrTag.DEL), ] dataloader = _update_dataloader(dataloader, dataloader.sampler) assert dataloader.custom_arg == 7 assert dataloader.my_arg == 10 assert dataloader.another_arg == 100 assert not hasattr(dataloader, "dataset") def test_replace_dunder_methods_restore_methods(): """This tests checks whether are all dunder methods restored to their original versions.""" class Init(DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class SetAttr(DataLoader): def __setattr__(self, *args): return super().__setattr__(*args) class DelAttr(DataLoader): def __delattr__(self, *args): return super().__delattr__(*args) class InitAndSetAttr(Init, SetAttr): pass class InitAndDelAttr(Init, DelAttr): pass class SetAttrAndDelAttr(SetAttr, DelAttr): pass class AllDunder(Init, SetAttr, DelAttr): pass before = {} for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder): before[cls] = {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__} with _replace_dunder_methods(DataLoader, "dataset"): pass for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder): assert before[cls] == {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__} @pytest.mark.parametrize( ( "args", "kwargs", "default_kwargs", "arg_names", "replace_key", "replace_value", "expected_status", "expected_args", "expected_kwargs", ), [ pytest.param((), {}, {}, [], "a", 1, False, (), {}, id="empty"), pytest.param((1,), {}, {}, ["a"], "a", 2, True, (2,), {}, id="simple1"), pytest.param((1, 2, 3), {}, {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"), pytest.param((1, 2, 3), {"a": 1}, {}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"), pytest.param( (1, 2, 3), {"a": 1}, {"e": 5}, ["b", "c", "d"], "e", 2, True, (1, 2, 3), {"a": 1, "e": 2}, id="default_kwargs", ), ], ) def test_replace_value_in_args( args, kwargs, default_kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs ): assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, default_kwargs, arg_names) == ( expected_status, expected_args, expected_kwargs, ) def test_update_dataloader_typerror_custom_exception(): class BadStandaloneGoodHookImpl(DataLoader): def __init__(self, foo, *args, **kwargs): self.foo = foo # positional conflict with `dataset` super().__init__(foo, *args, **kwargs) dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) with pytest.raises(MisconfigurationException, match="implementation has an error.*`dataset`"): _update_dataloader(dataloader, dataloader.sampler) with _replace_dunder_methods(DataLoader, "dataset"): dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) new_dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(new_dataloader, BadStandaloneGoodHookImpl) class BadImpl(DataLoader): def __init__(self, randomize, *args, **kwargs): self.randomize = randomize # keyword conflict with `shuffle` super().__init__(*args, shuffle=randomize, **kwargs) dataloader = BadImpl(False, []) with pytest.raises(MisconfigurationException, match="implementation has an error.*`shuffle`"): _update_dataloader(dataloader, dataloader.sampler) class GoodImpl(DataLoader): def __init__(self, randomize, *args, **kwargs): # fixed implementation, kwargs are filtered self.randomize = randomize or kwargs.pop("shuffle", False) super().__init__(*args, shuffle=randomize, **kwargs) dataloader = GoodImpl(False, []) new_dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(new_dataloader, GoodImpl) def test_custom_torch_batch_sampler(): """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to properly reinstantiate the class, is invoked properly. It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore not setting `__pl_saved_{args,arg_names,kwargs}` attributes. """ class MyBatchSampler(BatchSampler): # Custom Batch sampler with extra argument and default value def __init__(self, sampler, extra_arg, drop_last=True): self.extra_arg = extra_arg super().__init__(sampler, 10, drop_last) sampler = RandomSampler(range(10)) with _replace_dunder_methods(BatchSampler): # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler(sampler, "random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler) # assert that passed information got saved assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") assert dataloader.batch_sampler.__pl_saved_kwargs == {} assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg") assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True} # updating dataloader, what happens on access of the dataloaders. # This should not fail, and would fail before support for custom args. dataloader = _update_dataloader(dataloader, dataloader.sampler) # Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types batch_sampler = dataloader.batch_sampler assert isinstance(batch_sampler, MyBatchSampler) assert batch_sampler.extra_arg == "random_str" assert not hasattr(batch_sampler, "__pl_saved_kwargs") assert not hasattr(batch_sampler, "__pl_saved_arg_names") assert not hasattr(batch_sampler, "__pl_saved_args") assert not hasattr(batch_sampler, "__pl_saved_default_kwargs") def test_custom_torch_batch_sampler_doppelganger(): """Test we can reinstantiate a sampler that mimics PyTorch's BatchSampler even if it does not inherit from it. This is only possible if that sampler accepts the `batch_size` and `drop_last` arguments, and stores them as attributes. """ class BatchSamplerDoppelganger: """A batch sampler that mimics `torch.utils.data.BatchSampler` but does not inherit from it.""" def __init__(self, sampler, batch_size, drop_last): self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): while True: yield [0, 1, 2, 3] def __len__(self) -> int: return 4 batch_sampler = BatchSamplerDoppelganger(sampler=Mock(), batch_size=2, drop_last=True) dataloader = DataLoader(range(100), batch_sampler=batch_sampler) new_sampler = Mock() dataloader = _update_dataloader(dataloader, sampler=new_sampler) batch_sampler = dataloader.batch_sampler assert isinstance(batch_sampler, BatchSamplerDoppelganger) assert batch_sampler.sampler == new_sampler def test_custom_batch_sampler(): """Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`.""" class CustomBatchSampler: # not inheriting from `BatchSampler` def __iter__(self): while True: yield [0, 1, 2, 3] batch_sampler = CustomBatchSampler() dataloader = DataLoader(range(100), batch_sampler=batch_sampler) with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"): _ = _update_dataloader(dataloader, sampler=Mock()) def test_custom_batch_sampler_no_sampler(): """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler argument.""" class MyBatchSampler(BatchSampler): # Custom batch sampler, without sampler argument. def __init__(self, extra_arg): self.extra_arg = extra_arg super().__init__(RandomSampler(range(10)), 10, False) with _replace_dunder_methods(BatchSampler): # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler("random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler) # assert that passed information got saved assert dataloader.batch_sampler.__pl_saved_args == ("random_str",) assert dataloader.batch_sampler.__pl_saved_kwargs == {} assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",) assert dataloader.batch_sampler.__pl_saved_default_kwargs == {} # Assert that error is raised with pytest.raises(TypeError, match="sampler into the batch sampler"): _ = _update_dataloader(dataloader, dataloader.sampler) def test_dataloader_kwargs_replacement_with_iterable_dataset(): """Test that DataLoader kwargs are not replaced when using Iterable Dataset.""" dataset = RandomIterableDataset(7, 100) dataloader = DataLoader(dataset, batch_size=32) _, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler) assert dl_kwargs["sampler"] is None assert dl_kwargs["batch_sampler"] is None assert dl_kwargs["batch_size"] is dataloader.batch_size assert dl_kwargs["dataset"] is dataloader.dataset assert dl_kwargs["collate_fn"] is dataloader.collate_fn def test_dataloader_kwargs_replacement_with_array_default_comparison(): """Test that the comparison of attributes and default argument values works with arrays (truth value ambiguous). Regression test for issue #15408. """ dataset = RandomDataset(5, 100) class ArrayAttributeDataloader(DataLoader): def __init__(self, indices=None, **kwargs): super().__init__(dataset) self.indices = np.random.rand(2, 2) # an attribute we can't compare with == dataloader = ArrayAttributeDataloader(dataset) _, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler) assert dl_kwargs["indices"] is dataloader.indices def test_set_sampler_epoch(): # No samplers dataloader = Mock() dataloader.sampler = None dataloader.batch_sampler = None _set_sampler_epoch(dataloader, 55) # set_epoch not callable dataloader = Mock() dataloader.sampler.set_epoch = None dataloader.batch_sampler.set_epoch = None _set_sampler_epoch(dataloader, 55) # set_epoch callable dataloader = Mock() _set_sampler_epoch(dataloader, 55) dataloader.sampler.set_epoch.assert_called_once_with(55) dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55) @pytest.mark.parametrize( ("cpu_count", "local_world_size", "expected"), [ (0, 1, 1), (1, 1, 1), (2, 1, 2 - 1), (1, 2, 1), (2, 2, 1), (3, 2, 1), (4, 2, 2 - 1), (4, 3, 1), (4, 1, 4 - 1), ], ) @pytest.mark.parametrize( "affinity", [ False, pytest.param( True, marks=pytest.mark.skipif( not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores" ), ), ], ) @mock.patch("lightning.fabric.utilities.data.os.cpu_count") def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch): if affinity: monkeypatch.setattr(lightning.fabric.utilities.data.os, "sched_getaffinity", lambda _: list(range(cpu_count))) else: monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) cpu_count_mock.return_value = cpu_count assert suggested_max_num_workers(local_world_size) == expected @pytest.mark.parametrize("invalid", [-1, 0]) def test_suggested_max_num_workers_input_validation(invalid): with pytest.raises(ValueError, match="should be >= 1"): suggested_max_num_workers(invalid) @pytest.mark.parametrize("cpu_count", [1, 2, 3]) @pytest.mark.parametrize("local_world_size", [1, 2, 3]) def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size, cpu_count, monkeypatch): """Test that our suggestion for num workers doesn't trigger a warning in the DataLoader for too many workers.""" monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False) monkeypatch.delattr(torch.utils.data.dataloader.os, "sched_getaffinity", raising=False) monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count) monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count) # The dataloader runs a check in `DataLoader.check_worker_number_rationality` with pytest.warns(UserWarning, match="This DataLoader will create"): DataLoader(range(2), num_workers=(cpu_count + 1)) with no_warning_call(): DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size)) def test_state(): # init via dict inputs = {"key1": 1, "key2": "abc"} state = AttributeDict(inputs) for key, value in inputs.items(): assert getattr(state, key) == value # init via kwargs inputs = {"key1": 1, "key2": "abc"} state = AttributeDict(**inputs) for key, value in inputs.items(): assert getattr(state, key) == value # update via dict state = AttributeDict() state.update({"key1": 1}) assert state.key1 == 1 # update via setter state = AttributeDict({"key1": 1}) state.key1 = 123 assert state.key1 == 123 with pytest.raises(AttributeError, match="has no attribute 'key3'"): _ = state.key3 # delete attribute del state.key1 assert "key1" not in state with pytest.raises(KeyError): del state.key3