lightning/tests/tests_fabric/utilities/test_data.py

677 lines
24 KiB
Python

import contextlib
import os
import random
from unittest import mock
from unittest.mock import Mock
import lightning.fabric
import numpy as np
import pytest
import torch
from lightning.fabric.utilities.data import (
AttributeDict,
_get_dataloader_init_args_and_kwargs,
_replace_dunder_methods,
_replace_value_in_saved_args,
_set_sampler_epoch,
_update_dataloader,
_WrapAttrTag,
has_iterable_dataset,
has_len,
suggested_max_num_workers,
)
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning_utilities.test.warning import no_warning_call
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset
def test_has_iterable_dataset():
assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1)))
assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1)))
class MockDatasetWithoutIterableDataset(RandomDataset):
def __iter__(self):
yield 1
return self
assert not has_iterable_dataset(DataLoader(MockDatasetWithoutIterableDataset(1, 1)))
def test_has_len():
assert has_len(DataLoader(RandomDataset(1, 1)))
with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."):
assert has_len(DataLoader(RandomDataset(0, 0)))
assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
def test_replace_dunder_methods_multiple_loaders_without_init():
"""In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__`
method (the one we are wrapping), it can happen, that `hasattr(cls, "__old__init__")` is True because of parent
class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error occured
only sometimes because it depends on the order in which we are iterating over a set of classes we are patching.
This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__`
and are children of `DataLoader`. We are testing that a) context manager `_replace_dunder_method` exits cleanly, and
b) the mechanism checking for presence of `__old__init__` works as expected.
"""
classes = [DataLoader]
for i in range(100):
classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {}))
before = {cls: cls.__init__ for cls in classes}
with _replace_dunder_methods(DataLoader, "dataset"):
for cls in classes[1:]: # First one is `DataLoader`
assert "__old__init__" not in cls.__dict__
assert hasattr(cls, "__old__init__")
assert "__old__init__" in DataLoader.__dict__
assert hasattr(DataLoader, "__old__init__")
for cls in classes:
assert before[cls] == cls.__init__
class MyBaseDataLoader(DataLoader):
pass
class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
self.at1 = attribute1
super().__init__(*args, **kwargs)
class DataLoaderSubclass2(DataLoaderSubclass1):
def __init__(self, attribute2, *args, **kwargs):
self.at2 = attribute2
super().__init__(attribute2 + "-2", *args, **kwargs)
class MyDataLoader(MyBaseDataLoader):
def __init__(self, data: Tensor, *args, **kwargs):
self.data = data
super().__init__(range(data.size(0)), *args, **kwargs)
test3_data = torch.randn((10, 20))
class PoptorchDataLoader(DataLoader):
def __init__(self, options, *args, **kwargs):
super().__init__(*args, **kwargs)
self._options = options
@property
def options(self):
return self._options
class IncompleteDataLoader(DataLoader):
def __init__(self, dataset, batch_size, **kwargs):
batch_size = max(batch_size - 5, 0)
super().__init__(dataset, batch_size=batch_size, **kwargs)
class WeirdDataLoader1(DataLoader):
def __init__(self, arg1, arg2, **kwargs):
self.arg1 = arg1
super().__init__(arg2, **kwargs)
class WeirdDataLoader2(DataLoader):
def __init__(self, data_part1, data_part2, **kwargs):
data = list(data_part1) + list(data_part2)
super().__init__(data, **kwargs)
class NoneDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class ChangingDataLoader(DataLoader):
def __init__(self, dataset, **kwargs):
super().__init__(list(dataset) + list(range(5, 10)), **kwargs)
@pytest.mark.parametrize(
("cls", "args", "kwargs", "arg_names", "dataset", "checked_values"),
[
pytest.param(
DataLoaderSubclass1,
("attribute1",),
{"dataset": range(4), "batch_size": 2},
("attribute1",),
range(4),
{"batch_size": 2, "at1": "attribute1"},
id="test1",
),
pytest.param(
DataLoaderSubclass2,
("attribute2",),
{"dataset": range(4), "batch_size": 2},
("attribute2",),
range(4),
{"batch_size": 2, "at1": "attribute2-2", "at2": "attribute2"},
id="test2",
),
pytest.param(
MyDataLoader,
(test3_data,),
{"batch_size": 2},
("data",),
range(10),
{"batch_size": 2, "data": test3_data},
id="test3",
),
pytest.param(PoptorchDataLoader, (123, [1]), {}, ("options",), [1], {"options": 123}, id="test4"),
pytest.param(
IncompleteDataLoader,
(range(10),),
{"batch_size": 10},
("dataset",),
range(10),
{"batch_size": 5},
id="test5",
),
pytest.param(
WeirdDataLoader1,
(10, range(10)),
{"batch_size": 10},
("arg1", "arg2"),
range(10),
{"arg1": 10, "batch_size": 10},
id="test6",
),
pytest.param(
WeirdDataLoader2,
(range(10), range(10, 20)),
{"batch_size": 10},
("data_part1", "data_part2"),
list(range(20)),
{"batch_size": 10},
id="test7",
),
pytest.param(NoneDataLoader, (None,), {}, (), None, {}, id="test8"),
pytest.param(ChangingDataLoader, (range(5),), {}, ("dataset",), list(range(10)), {}, id="test9"),
],
)
def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset, checked_values):
with _replace_dunder_methods(DataLoader, "dataset"):
dataloader = cls(*args, **kwargs)
assert dataloader.__pl_saved_args == args
assert dataloader.__pl_saved_kwargs == kwargs
assert dataloader.__pl_saved_arg_names == arg_names
assert dataloader.__pl_saved_default_kwargs == {}
assert dataloader.__dataset == dataset
assert dataloader.dataset == dataset
for key, value in checked_values.items():
dataloader_value = getattr(dataloader, key)
if isinstance(dataloader_value, Tensor):
assert dataloader_value is value
else:
assert dataloader_value == value
dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(dataloader, cls)
assert not hasattr(dataloader, "__pl_saved_kwargs")
assert not hasattr(dataloader, "__pl_saved_arg_names")
assert not hasattr(dataloader, "__pl_saved_args")
assert not hasattr(dataloader, "__pl_saved_default_kwargs")
assert not hasattr(dataloader, "__dataset")
assert dataloader.dataset == dataset
for key, value in checked_values.items():
dataloader_value = getattr(dataloader, key)
if isinstance(dataloader_value, Tensor):
assert dataloader_value is value
else:
assert dataloader_value == value
def test_replace_dunder_methods_extra_kwargs():
class LoaderSubclass(DataLoader):
def __init__(self, dataset, *args, batch_size=10, **kwargs):
super().__init__(dataset, *args, batch_size=batch_size, **kwargs)
with _replace_dunder_methods(DataLoader, "dataset"):
dataloader = LoaderSubclass(range(10))
assert dataloader.__pl_saved_args == (range(10),)
assert dataloader.__pl_saved_kwargs == {}
assert dataloader.__pl_saved_arg_names == ("dataset",)
assert dataloader.__pl_saved_default_kwargs == {"batch_size": 10}
assert dataloader.__dataset == range(10)
def test_replace_dunder_methods_attrs():
"""This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` are
correctly preserved even after reinstantiation.
It also includes a custom `__setattr__`
"""
class Loader(DataLoader):
def __setattr__(self, attr, val):
if attr == "custom_arg":
val = val + 2
super().__setattr__(attr, val)
with _replace_dunder_methods(DataLoader, "dataset"):
dataloader = Loader(range(10))
dataloader.custom_arg = 5
dataloader.my_arg = 10
dataloader.another_arg = 100
del dataloader.dataset
with contextlib.suppress(AttributeError):
del dataloader.abc_arg
assert dataloader.__pl_saved_args == (range(10),)
assert dataloader.__pl_saved_kwargs == {}
assert dataloader.__pl_saved_arg_names == ("dataset",)
assert dataloader.__dataset == range(10)
assert dataloader.custom_arg == 7
assert dataloader.my_arg == 10
assert dataloader.another_arg == 100
assert not hasattr(dataloader, "dataset")
assert dataloader.__pl_attrs_record == [
(("custom_arg", 5), _WrapAttrTag.SET),
(("my_arg", 10), _WrapAttrTag.SET),
(("another_arg", 100), _WrapAttrTag.SET),
(("dataset",), _WrapAttrTag.DEL),
]
dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert dataloader.custom_arg == 7
assert dataloader.my_arg == 10
assert dataloader.another_arg == 100
assert not hasattr(dataloader, "dataset")
def test_replace_dunder_methods_restore_methods():
"""This tests checks whether are all dunder methods restored to their original versions."""
class Init(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class SetAttr(DataLoader):
def __setattr__(self, *args):
return super().__setattr__(*args)
class DelAttr(DataLoader):
def __delattr__(self, *args):
return super().__delattr__(*args)
class InitAndSetAttr(Init, SetAttr):
pass
class InitAndDelAttr(Init, DelAttr):
pass
class SetAttrAndDelAttr(SetAttr, DelAttr):
pass
class AllDunder(Init, SetAttr, DelAttr):
pass
before = {}
for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder):
before[cls] = {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__}
with _replace_dunder_methods(DataLoader, "dataset"):
pass
for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder):
assert before[cls] == {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__}
@pytest.mark.parametrize(
(
"args",
"kwargs",
"default_kwargs",
"arg_names",
"replace_key",
"replace_value",
"expected_status",
"expected_args",
"expected_kwargs",
),
[
pytest.param((), {}, {}, [], "a", 1, False, (), {}, id="empty"),
pytest.param((1,), {}, {}, ["a"], "a", 2, True, (2,), {}, id="simple1"),
pytest.param((1, 2, 3), {}, {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"),
pytest.param((1, 2, 3), {"a": 1}, {}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"),
pytest.param(
(1, 2, 3),
{"a": 1},
{"e": 5},
["b", "c", "d"],
"e",
2,
True,
(1, 2, 3),
{"a": 1, "e": 2},
id="default_kwargs",
),
],
)
def test_replace_value_in_args(
args, kwargs, default_kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs
):
assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, default_kwargs, arg_names) == (
expected_status,
expected_args,
expected_kwargs,
)
def test_update_dataloader_typerror_custom_exception():
class BadStandaloneGoodHookImpl(DataLoader):
def __init__(self, foo, *args, **kwargs):
self.foo = foo
# positional conflict with `dataset`
super().__init__(foo, *args, **kwargs)
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
with pytest.raises(MisconfigurationException, match="implementation has an error.*`dataset`"):
_update_dataloader(dataloader, dataloader.sampler)
with _replace_dunder_methods(DataLoader, "dataset"):
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(new_dataloader, BadStandaloneGoodHookImpl)
class BadImpl(DataLoader):
def __init__(self, randomize, *args, **kwargs):
self.randomize = randomize
# keyword conflict with `shuffle`
super().__init__(*args, shuffle=randomize, **kwargs)
dataloader = BadImpl(False, [])
with pytest.raises(MisconfigurationException, match="implementation has an error.*`shuffle`"):
_update_dataloader(dataloader, dataloader.sampler)
class GoodImpl(DataLoader):
def __init__(self, randomize, *args, **kwargs):
# fixed implementation, kwargs are filtered
self.randomize = randomize or kwargs.pop("shuffle", False)
super().__init__(*args, shuffle=randomize, **kwargs)
dataloader = GoodImpl(False, [])
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(new_dataloader, GoodImpl)
def test_custom_torch_batch_sampler():
"""This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to properly
reinstantiate the class, is invoked properly.
It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore
not setting `__pl_saved_{args,arg_names,kwargs}` attributes.
"""
class MyBatchSampler(BatchSampler):
# Custom Batch sampler with extra argument and default value
def __init__(self, sampler, extra_arg, drop_last=True):
self.extra_arg = extra_arg
super().__init__(sampler, 10, drop_last)
sampler = RandomSampler(range(10))
with _replace_dunder_methods(BatchSampler):
# instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks
batch_sampler = MyBatchSampler(sampler, "random_str")
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
# assert that passed information got saved
assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True}
# updating dataloader, what happens on access of the dataloaders.
# This should not fail, and would fail before support for custom args.
dataloader = _update_dataloader(dataloader, dataloader.sampler)
# Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types
batch_sampler = dataloader.batch_sampler
assert isinstance(batch_sampler, MyBatchSampler)
assert batch_sampler.extra_arg == "random_str"
assert not hasattr(batch_sampler, "__pl_saved_kwargs")
assert not hasattr(batch_sampler, "__pl_saved_arg_names")
assert not hasattr(batch_sampler, "__pl_saved_args")
assert not hasattr(batch_sampler, "__pl_saved_default_kwargs")
def test_custom_torch_batch_sampler_doppelganger():
"""Test we can reinstantiate a sampler that mimics PyTorch's BatchSampler even if it does not inherit from it.
This is only possible if that sampler accepts the `batch_size` and `drop_last` arguments, and stores them
as attributes.
"""
class BatchSamplerDoppelganger:
"""A batch sampler that mimics `torch.utils.data.BatchSampler` but does not inherit from it."""
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
while True:
yield [0, 1, 2, 3]
def __len__(self) -> int:
return 4
batch_sampler = BatchSamplerDoppelganger(sampler=Mock(), batch_size=2, drop_last=True)
dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
new_sampler = Mock()
dataloader = _update_dataloader(dataloader, sampler=new_sampler)
batch_sampler = dataloader.batch_sampler
assert isinstance(batch_sampler, BatchSamplerDoppelganger)
assert batch_sampler.sampler == new_sampler
def test_custom_batch_sampler():
"""Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`."""
class CustomBatchSampler: # not inheriting from `BatchSampler`
def __iter__(self):
while True:
yield [0, 1, 2, 3]
batch_sampler = CustomBatchSampler()
dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
_ = _update_dataloader(dataloader, sampler=Mock())
def test_custom_batch_sampler_no_sampler():
"""Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler argument."""
class MyBatchSampler(BatchSampler):
# Custom batch sampler, without sampler argument.
def __init__(self, extra_arg):
self.extra_arg = extra_arg
super().__init__(RandomSampler(range(10)), 10, False)
with _replace_dunder_methods(BatchSampler):
# instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks
batch_sampler = MyBatchSampler("random_str")
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
# assert that passed information got saved
assert dataloader.batch_sampler.__pl_saved_args == ("random_str",)
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",)
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}
# Assert that error is raised
with pytest.raises(TypeError, match="sampler into the batch sampler"):
_ = _update_dataloader(dataloader, dataloader.sampler)
def test_dataloader_kwargs_replacement_with_iterable_dataset():
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
dataset = RandomIterableDataset(7, 100)
dataloader = DataLoader(dataset, batch_size=32)
_, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
assert dl_kwargs["sampler"] is None
assert dl_kwargs["batch_sampler"] is None
assert dl_kwargs["batch_size"] is dataloader.batch_size
assert dl_kwargs["dataset"] is dataloader.dataset
assert dl_kwargs["collate_fn"] is dataloader.collate_fn
def test_dataloader_kwargs_replacement_with_array_default_comparison():
"""Test that the comparison of attributes and default argument values works with arrays (truth value ambiguous).
Regression test for issue #15408.
"""
dataset = RandomDataset(5, 100)
class ArrayAttributeDataloader(DataLoader):
def __init__(self, indices=None, **kwargs):
super().__init__(dataset)
self.indices = np.random.rand(2, 2) # an attribute we can't compare with ==
dataloader = ArrayAttributeDataloader(dataset)
_, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
assert dl_kwargs["indices"] is dataloader.indices
def test_set_sampler_epoch():
# No samplers
dataloader = Mock()
dataloader.sampler = None
dataloader.batch_sampler = None
_set_sampler_epoch(dataloader, 55)
# set_epoch not callable
dataloader = Mock()
dataloader.sampler.set_epoch = None
dataloader.batch_sampler.set_epoch = None
_set_sampler_epoch(dataloader, 55)
# set_epoch callable
dataloader = Mock()
_set_sampler_epoch(dataloader, 55)
dataloader.sampler.set_epoch.assert_called_once_with(55)
dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55)
@pytest.mark.parametrize(
("cpu_count", "local_world_size", "expected"),
[
(0, 1, 1),
(1, 1, 1),
(2, 1, 2 - 1),
(1, 2, 1),
(2, 2, 1),
(3, 2, 1),
(4, 2, 2 - 1),
(4, 3, 1),
(4, 1, 4 - 1),
],
)
@pytest.mark.parametrize(
"affinity",
[
False,
pytest.param(
True,
marks=pytest.mark.skipif(
not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores"
),
),
],
)
@mock.patch("lightning.fabric.utilities.data.os.cpu_count")
def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch):
if affinity:
monkeypatch.setattr(lightning.fabric.utilities.data.os, "sched_getaffinity", lambda _: list(range(cpu_count)))
else:
monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
cpu_count_mock.return_value = cpu_count
assert suggested_max_num_workers(local_world_size) == expected
@pytest.mark.parametrize("invalid", [-1, 0])
def test_suggested_max_num_workers_input_validation(invalid):
with pytest.raises(ValueError, match="should be >= 1"):
suggested_max_num_workers(invalid)
@pytest.mark.parametrize("cpu_count", [1, 2, 3])
@pytest.mark.parametrize("local_world_size", [1, 2, 3])
def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size, cpu_count, monkeypatch):
"""Test that our suggestion for num workers doesn't trigger a warning in the DataLoader for too many workers."""
monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
monkeypatch.delattr(torch.utils.data.dataloader.os, "sched_getaffinity", raising=False)
monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count)
monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count)
# The dataloader runs a check in `DataLoader.check_worker_number_rationality`
with pytest.warns(UserWarning, match="This DataLoader will create"):
DataLoader(range(2), num_workers=(cpu_count + 1))
with no_warning_call():
DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size))
def test_state():
# init via dict
inputs = {"key1": 1, "key2": "abc"}
state = AttributeDict(inputs)
for key, value in inputs.items():
assert getattr(state, key) == value
# init via kwargs
inputs = {"key1": 1, "key2": "abc"}
state = AttributeDict(**inputs)
for key, value in inputs.items():
assert getattr(state, key) == value
# update via dict
state = AttributeDict()
state.update({"key1": 1})
assert state.key1 == 1
# update via setter
state = AttributeDict({"key1": 1})
state.key1 = 123
assert state.key1 == 123
with pytest.raises(AttributeError, match="has no attribute 'key3'"):
_ = state.key3
# delete attribute
del state.key1
assert "key1" not in state
with pytest.raises(KeyError):
del state.key3