import pytest import torch from torch.utils.data.dataloader import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.utilities.data import ( extract_batch_size, get_len, has_iterable_dataset, has_len, has_len_all_ranks, warning_cache, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset def test_extract_batch_size(): """Tests the behavior of extracting the batch size.""" def _check_warning_not_raised(data, expected): with pytest.warns(None) as record: assert extract_batch_size(data) == expected assert len(record) == 0 def _check_warning_raised(data, expected): with pytest.warns(UserWarning, match=f"Trying to infer the `batch_size` .* we found is {expected}."): assert extract_batch_size(batch) == expected warning_cache.clear() batch = "test string" _check_warning_not_raised(batch, 11) batch = torch.zeros(11, 10, 9, 8) _check_warning_not_raised(batch, 11) batch = {"test": torch.zeros(11, 10)} _check_warning_not_raised(batch, 11) batch = [torch.zeros(11, 10)] _check_warning_not_raised(batch, 11) batch = {"test": [{"test": [torch.zeros(11, 10)]}]} _check_warning_not_raised(batch, 11) batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]} _check_warning_raised(batch, 11) batch = {"test": [{"test": [torch.zeros(10, 10), torch.zeros(11, 10)]}]} _check_warning_raised(batch, 10) batch = [{"test": torch.zeros(10, 10), "test_1": torch.zeros(11, 10)}] _check_warning_raised(batch, 10) def test_has_iterable_dataset(): assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1))) assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1))) class MockDatasetWithoutIterableDataset(RandomDataset): def __iter__(self): yield 1 return self assert not has_iterable_dataset(DataLoader(MockDatasetWithoutIterableDataset(1, 1))) def test_has_len(): assert has_len(DataLoader(RandomDataset(1, 1))) with pytest.raises(ValueError, match="`Dataloader` returned 0 length."): assert has_len(DataLoader(RandomDataset(0, 0))) assert not has_len(DataLoader(RandomIterableDataset(1, 1))) def test_get_len(): assert get_len(DataLoader(RandomDataset(1, 1))) == 1 value = get_len(DataLoader(RandomIterableDataset(1, 1))) assert isinstance(value, float) assert value == float("inf") def test_has_len_all_rank(): trainer = Trainer(fast_dev_run=True) model = BoringModel() with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."): assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model) assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)