lightning/tests/tests_pytorch/utilities/test_data.py

302 lines
12 KiB
Python

from dataclasses import dataclass
import numpy as np
import pytest
import torch
from torch import Tensor
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from lightning_lite.utilities.data import _replace_dunder_methods
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.data import (
_dataloader_init_kwargs_resolve_sampler,
_get_dataloader_init_args_and_kwargs,
_update_dataloader,
extract_batch_size,
get_len,
has_len_all_ranks,
warning_cache,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.utils import no_warning_call
def test_extract_batch_size():
"""Tests the behavior of extracting the batch size."""
def _check_warning_not_raised(data, expected):
with no_warning_call(match="Trying to infer the `batch_size`"):
assert extract_batch_size(data) == expected
def _check_warning_raised(data, expected):
with pytest.warns(UserWarning, match=f"Trying to infer the `batch_size` .* we found is {expected}."):
assert extract_batch_size(batch) == expected
warning_cache.clear()
def _check_error_raised(data):
with pytest.raises(MisconfigurationException, match="We could not infer the batch_size"):
extract_batch_size(batch)
@dataclass
class CustomDataclass:
a: Tensor
b: Tensor
# Warning not raised
batch = torch.zeros(11, 10, 9, 8)
_check_warning_not_raised(batch, 11)
batch = {"test": torch.zeros(11, 10)}
_check_warning_not_raised(batch, 11)
batch = [torch.zeros(11, 10)]
_check_warning_not_raised(batch, 11)
batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(11, 10))
_check_warning_not_raised(batch, 11)
batch = {"test": [{"test": [torch.zeros(11, 10)]}]}
_check_warning_not_raised(batch, 11)
# Warning raised
batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])}
_check_warning_raised(batch, 1)
batch = CustomDataclass(torch.zeros(11, 10), torch.zeros(1))
_check_warning_raised(batch, 11)
batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]}
_check_warning_raised(batch, 11)
batch = {"test": [{"test": [torch.zeros(10, 10), torch.zeros(11, 10)]}]}
_check_warning_raised(batch, 10)
batch = [{"test": torch.zeros(10, 10), "test_1": torch.zeros(11, 10)}]
_check_warning_raised(batch, 10)
# Error raised
batch = "test string"
_check_error_raised(batch)
data = {"test": ["some text"] * 7}
_check_error_raised(data)
class CustomBatch:
def __init__(self):
self.x = torch.randn(7, 2)
data = CustomBatch()
_check_error_raised(data)
def test_get_len():
assert get_len(DataLoader(RandomDataset(1, 1))) == 1
value = get_len(DataLoader(RandomIterableDataset(1, 1)))
assert isinstance(value, float)
assert value == float("inf")
def test_has_len_all_rank():
trainer = Trainer(fast_dev_run=True)
model = BoringModel()
with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero."):
assert has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.strategy, model)
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.strategy, model)
def test_update_dataloader_typerror_custom_exception():
class BadStandaloneGoodHookImpl(DataLoader):
def __init__(self, foo, *args, **kwargs):
self.foo = foo
# positional conflict with `dataset`
super().__init__(foo, *args, **kwargs)
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
with pytest.raises(MisconfigurationException, match="implementation has an error.*`dataset`"):
_update_dataloader(dataloader, dataloader.sampler)
with _replace_dunder_methods(DataLoader, "dataset"):
dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(new_dataloader, BadStandaloneGoodHookImpl)
class BadImpl(DataLoader):
def __init__(self, randomize, *args, **kwargs):
self.randomize = randomize
# keyword conflict with `shuffle`
super().__init__(*args, shuffle=randomize, **kwargs)
dataloader = BadImpl(False, [])
with pytest.raises(MisconfigurationException, match="implementation has an error.*`shuffle`"):
_update_dataloader(dataloader, dataloader.sampler)
class GoodImpl(DataLoader):
def __init__(self, randomize, *args, **kwargs):
# fixed implementation, kwargs are filtered
self.randomize = randomize or kwargs.pop("shuffle", False)
super().__init__(*args, shuffle=randomize, **kwargs)
dataloader = GoodImpl(False, [])
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(new_dataloader, GoodImpl)
@pytest.mark.parametrize("predicting", [True, False])
def test_custom_batch_sampler(predicting):
"""This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to
properly reinstantiate the class, is invoked properly.
It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore
not setting `__pl_saved_{args,arg_names,kwargs}` attributes.
"""
class MyBatchSampler(BatchSampler):
# Custom Batch sampler with extra argument and default value
def __init__(self, sampler, extra_arg, drop_last=True):
self.extra_arg = extra_arg
super().__init__(sampler, 10, drop_last)
sampler = RandomSampler(range(10))
with _replace_dunder_methods(BatchSampler):
# instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks
batch_sampler = MyBatchSampler(sampler, "random_str")
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
# assert that passed information got saved
assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True}
# updating dataloader, what happens on access of the dataloaders.
# This should not fail, and would fail before support for custom args.
dataloader = _update_dataloader(
dataloader, dataloader.sampler, mode=RunningStage.PREDICTING if predicting else None
)
# Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types
batch_sampler = dataloader.batch_sampler
if predicting:
assert isinstance(batch_sampler, IndexBatchSamplerWrapper)
batch_sampler = batch_sampler._sampler
assert isinstance(batch_sampler, MyBatchSampler)
assert batch_sampler.drop_last == (not predicting)
assert batch_sampler.extra_arg == "random_str"
assert not hasattr(batch_sampler, "__pl_saved_kwargs")
assert not hasattr(batch_sampler, "__pl_saved_arg_names")
assert not hasattr(batch_sampler, "__pl_saved_args")
assert not hasattr(batch_sampler, "__pl_saved_default_kwargs")
def test_custom_batch_sampler_no_drop_last():
"""Tests whether appropriate warning is raised when the custom `BatchSampler` does not support `drop_last` and
we want to reset it."""
class MyBatchSampler(BatchSampler):
# Custom batch sampler with extra argument, but without `drop_last`
def __init__(self, sampler, extra_arg):
self.extra_arg = extra_arg
super().__init__(sampler, 10, False)
sampler = RandomSampler(range(10))
with _replace_dunder_methods(BatchSampler):
# instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks
batch_sampler = MyBatchSampler(sampler, "random_str")
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
# assert that passed information got saved
assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}
# Assert that warning is raised
with pytest.warns(UserWarning, match="drop_last=False"):
dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)
def test_custom_batch_sampler_no_sampler():
"""Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler
argument."""
class MyBatchSampler(BatchSampler):
# Custom batch sampler, without sampler argument.
def __init__(self, extra_arg):
self.extra_arg = extra_arg
super().__init__(RandomSampler(range(10)), 10, False)
with _replace_dunder_methods(BatchSampler):
# instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks
batch_sampler = MyBatchSampler("random_str")
dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
# assert that passed information got saved
assert dataloader.batch_sampler.__pl_saved_args == ("random_str",)
assert dataloader.batch_sampler.__pl_saved_kwargs == {}
assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",)
assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}
# Assert that error is raised
with pytest.raises(TypeError, match="sampler into the batch sampler"):
dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING)
def test_dataloader_disallow_batch_sampler():
dataset = RandomDataset(5, 100)
dataloader = DataLoader(dataset, batch_size=10)
# This should not raise
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)
dataset = RandomDataset(5, 100)
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size=10, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
# this should raise - using batch sampler, that was not automatically instantiated by DataLoader
with pytest.raises(MisconfigurationException, match="when running on multiple IPU devices"):
_dataloader_init_kwargs_resolve_sampler(dataloader, dataloader.sampler, disallow_batch_sampler=True)
@pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING])
def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
dataset = RandomIterableDataset(7, 100)
dataloader = DataLoader(dataset, batch_size=32)
_, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler, mode=mode)
assert dl_kwargs["sampler"] is None
assert dl_kwargs["batch_sampler"] is None
assert dl_kwargs["batch_size"] is dataloader.batch_size
assert dl_kwargs["dataset"] is dataloader.dataset
assert dl_kwargs["collate_fn"] is dataloader.collate_fn
def test_dataloader_kwargs_replacement_with_array_default_comparison():
"""Test that the comparison of attributes and default argument values works with arrays (truth value
ambiguous).
Regression test for issue #15408.
"""
dataset = RandomDataset(5, 100)
class ArrayAttributeDataloader(DataLoader):
def __init__(self, indices=None, **kwargs):
super().__init__(dataset)
self.indices = np.random.rand(2, 2) # an attribute we can't compare with ==
dataloader = ArrayAttributeDataloader(dataset)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
assert dl_kwargs["indices"] is dataloader.indices