2021-08-10 06:39:00 +00:00
|
|
|
import pytest
|
2021-07-13 11:35:10 +00:00
|
|
|
import torch
|
2021-08-10 06:39:00 +00:00
|
|
|
from torch.utils.data.dataloader import DataLoader
|
2021-07-13 11:35:10 +00:00
|
|
|
|
2021-11-02 17:22:58 +00:00
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from pytorch_lightning.utilities.data import (
|
|
|
|
extract_batch_size,
|
|
|
|
get_len,
|
|
|
|
has_iterable_dataset,
|
|
|
|
has_len,
|
|
|
|
has_len_all_ranks,
|
|
|
|
warning_cache,
|
|
|
|
)
|
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2021-11-19 16:48:26 +00:00
|
|
|
from tests.deprecated_api import no_warning_call
|
2021-11-02 17:22:58 +00:00
|
|
|
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
|
2021-07-13 11:35:10 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_extract_batch_size():
|
|
|
|
"""Tests the behavior of extracting the batch size."""
|
2021-11-01 19:50:30 +00:00
|
|
|
|
|
|
|
def _check_warning_not_raised(data, expected):
|
2021-11-19 16:48:26 +00:00
|
|
|
with no_warning_call(match="Trying to infer the `batch_size`"):
|
2021-11-01 19:50:30 +00:00
|
|
|
assert extract_batch_size(data) == expected
|
|
|
|
|
|
|
|
def _check_warning_raised(data, expected):
|
|
|
|
with pytest.warns(UserWarning, match=f"Trying to infer the `batch_size` .* we found is {expected}."):
|
|
|
|
assert extract_batch_size(batch) == expected
|
|
|
|
warning_cache.clear()
|
|
|
|
|
2021-11-22 16:55:19 +00:00
|
|
|
def _check_error_raised(data):
|
|
|
|
with pytest.raises(MisconfigurationException, match="We could not infer the batch_size"):
|
|
|
|
extract_batch_size(batch)
|
2021-07-13 11:35:10 +00:00
|
|
|
|
2021-11-22 16:55:19 +00:00
|
|
|
# Warning not raised
|
2021-07-13 11:35:10 +00:00
|
|
|
batch = torch.zeros(11, 10, 9, 8)
|
2021-11-01 19:50:30 +00:00
|
|
|
_check_warning_not_raised(batch, 11)
|
2021-07-13 11:35:10 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
batch = {"test": torch.zeros(11, 10)}
|
2021-11-01 19:50:30 +00:00
|
|
|
_check_warning_not_raised(batch, 11)
|
2021-07-13 11:35:10 +00:00
|
|
|
|
|
|
|
batch = [torch.zeros(11, 10)]
|
2021-11-01 19:50:30 +00:00
|
|
|
_check_warning_not_raised(batch, 11)
|
2021-07-13 11:35:10 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
batch = {"test": [{"test": [torch.zeros(11, 10)]}]}
|
2021-11-01 19:50:30 +00:00
|
|
|
_check_warning_not_raised(batch, 11)
|
|
|
|
|
2021-11-22 16:55:19 +00:00
|
|
|
# Warning raised
|
2021-11-19 16:48:26 +00:00
|
|
|
batch = {"a": [torch.tensor(1), torch.tensor(2)], "b": torch.tensor([1, 2, 3, 4])}
|
|
|
|
_check_warning_raised(batch, 1)
|
|
|
|
|
2021-11-01 19:50:30 +00:00
|
|
|
batch = {"test": [{"test": [torch.zeros(11, 10), torch.zeros(10, 10)]}]}
|
|
|
|
_check_warning_raised(batch, 11)
|
|
|
|
|
|
|
|
batch = {"test": [{"test": [torch.zeros(10, 10), torch.zeros(11, 10)]}]}
|
|
|
|
_check_warning_raised(batch, 10)
|
|
|
|
|
|
|
|
batch = [{"test": torch.zeros(10, 10), "test_1": torch.zeros(11, 10)}]
|
|
|
|
_check_warning_raised(batch, 10)
|
2021-08-10 06:39:00 +00:00
|
|
|
|
2021-11-22 16:55:19 +00:00
|
|
|
# Error raised
|
|
|
|
batch = "test string"
|
|
|
|
_check_error_raised(batch)
|
|
|
|
|
|
|
|
data = {"test": ["some text"] * 7}
|
|
|
|
_check_error_raised(data)
|
|
|
|
|
|
|
|
class CustomBatch:
|
|
|
|
def __init__(self):
|
|
|
|
self.x = torch.randn(7, 2)
|
|
|
|
|
|
|
|
data = CustomBatch()
|
|
|
|
_check_error_raised(data)
|
|
|
|
|
2021-08-10 06:39:00 +00:00
|
|
|
|
|
|
|
def test_has_iterable_dataset():
|
|
|
|
assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1)))
|
|
|
|
|
|
|
|
assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1)))
|
|
|
|
|
|
|
|
class MockDatasetWithoutIterableDataset(RandomDataset):
|
|
|
|
def __iter__(self):
|
|
|
|
yield 1
|
|
|
|
return self
|
|
|
|
|
|
|
|
assert not has_iterable_dataset(DataLoader(MockDatasetWithoutIterableDataset(1, 1)))
|
|
|
|
|
|
|
|
|
|
|
|
def test_has_len():
|
|
|
|
assert has_len(DataLoader(RandomDataset(1, 1)))
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="`Dataloader` returned 0 length."):
|
|
|
|
assert has_len(DataLoader(RandomDataset(0, 0)))
|
|
|
|
|
|
|
|
assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_len():
|
|
|
|
assert get_len(DataLoader(RandomDataset(1, 1))) == 1
|
|
|
|
|
|
|
|
value = get_len(DataLoader(RandomIterableDataset(1, 1)))
|
|
|
|
|
|
|
|
assert isinstance(value, float)
|
|
|
|
assert value == float("inf")
|
2021-11-02 17:22:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_has_len_all_rank():
|
|
|
|
trainer = Trainer(fast_dev_run=True)
|
|
|
|
model = BoringModel()
|
|
|
|
|
|
|
|
with pytest.raises(MisconfigurationException, match="Total length of `Dataloader` across ranks is zero."):
|
|
|
|
assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model)
|
|
|
|
|
|
|
|
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)
|