lightning/tests/utilities/test_data.py

95 lines
3.0 KiB
Python
Raw Normal View History

import pytest
import torch
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.data import (
extract_batch_size,
get_len,
has_iterable_dataset,
has_len,
has_len_all_ranks,
warning_cache,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
def test_extract_batch_size():
"""Tests the behavior of extracting the batch size."""
def _check_warning_not_raised(data, expected):
with pytest.warns(None) as record:
assert extract_batch_size(data) == expected
assert len(record) == 0
def _check_warning_raised(data, expected):
with pytest.warns(UserWarning, match=f"Trying to infer the `batch_size` .* we found is {expected}."):
assert extract_batch_size(batch) == expected
warning_cache.clear()
batch = "test string"
_check_warning_not_raised(batch, 11)
batch = torch.zeros(11, 10, 9, 8)
_check_warning_not_raised(batch, 11)
batch = {"test": torch.zeros(11, 10)}
_check_warning_not_raised(batch, 11)
batch = [torch.zeros(11, 10)]
_check_warning_not_raised(batch, 11)
batch = {"test": [{"test": [torch.zeros(11, 10)]}]}
_check_warning_not_raised(batch, 11)
batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]}
_check_warning_raised(batch, 11)
batch = {"test": [{"test": [torch.zeros(10, 10), torch.zeros(11, 10)]}]}
_check_warning_raised(batch, 10)
batch = [{"test": torch.zeros(10, 10), "test_1": torch.zeros(11, 10)}]
_check_warning_raised(batch, 10)
def test_has_iterable_dataset():
assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1)))
assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1)))
class MockDatasetWithoutIterableDataset(RandomDataset):
def __iter__(self):
yield 1
return self
assert not has_iterable_dataset(DataLoader(MockDatasetWithoutIterableDataset(1, 1)))
def test_has_len():
assert has_len(DataLoader(RandomDataset(1, 1)))
with pytest.raises(ValueError, match="`Dataloader` returned 0 length."):
assert has_len(DataLoader(RandomDataset(0, 0)))
assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
def test_get_len():
assert get_len(DataLoader(RandomDataset(1, 1))) == 1
value = get_len(DataLoader(RandomIterableDataset(1, 1)))
assert isinstance(value, float)
assert value == float("inf")
def test_has_len_all_rank():
trainer = Trainer(fast_dev_run=True)
model = BoringModel()
with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."):
assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model)
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)