347 lines
12 KiB
Python
347 lines
12 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from collections import Sequence
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
import torch
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
from torch.utils.data.dataset import Dataset, IterableDataset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.data.sampler import Sampler, SequentialSampler
|
|
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.trainer.supporters import (
|
|
_nested_calc_num_data,
|
|
CombinedDataset,
|
|
CombinedLoader,
|
|
CombinedLoaderIterator,
|
|
CycleIterator,
|
|
TensorRunningAccum,
|
|
)
|
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
|
|
|
|
def test_tensor_running_accum_reset():
|
|
"""Test that reset would set all attributes to the initialization state."""
|
|
|
|
window_length = 10
|
|
|
|
accum = TensorRunningAccum(window_length=window_length)
|
|
assert accum.last() is None
|
|
assert accum.mean() is None
|
|
|
|
accum.append(torch.tensor(1.5))
|
|
assert accum.last() == torch.tensor(1.5)
|
|
assert accum.mean() == torch.tensor(1.5)
|
|
|
|
accum.reset()
|
|
assert accum.window_length == window_length
|
|
assert accum.memory is None
|
|
assert accum.current_idx == 0
|
|
assert accum.last_idx is None
|
|
assert not accum.rotated
|
|
|
|
|
|
def test_cycle_iterator():
|
|
"""Test the cycling function of `CycleIterator`"""
|
|
|
|
iterator = CycleIterator(range(100), 1000)
|
|
assert len(iterator) == 1000
|
|
for idx, item in enumerate(iterator):
|
|
assert item < 100
|
|
|
|
assert idx == len(iterator) - 1
|
|
|
|
|
|
def test_none_length_cycle_iterator():
|
|
"""Test the infinite cycling function of `CycleIterator`"""
|
|
iterator = CycleIterator(range(100))
|
|
assert iterator.__len__() == float("inf")
|
|
|
|
# test infinite loop
|
|
for idx, item in enumerate(iterator):
|
|
if idx == 1000:
|
|
break
|
|
assert item == 0
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["dataset_1", "dataset_2"],
|
|
[
|
|
([list(range(10)), list(range(20))]),
|
|
([range(10), range(20)]),
|
|
([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]),
|
|
([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]),
|
|
],
|
|
)
|
|
def test_combined_dataset(dataset_1, dataset_2):
|
|
"""Verify the length of the CombinedDataset."""
|
|
datasets = [dataset_1, dataset_2]
|
|
combined_dataset = CombinedDataset(datasets)
|
|
|
|
assert combined_dataset.max_len == 20
|
|
assert combined_dataset.min_len == len(combined_dataset) == 10
|
|
|
|
|
|
def test_combined_dataset_length_mode_error():
|
|
dset = CombinedDataset([range(10)])
|
|
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
|
|
dset._calc_num_data([range(10)], "test")
|
|
|
|
|
|
def test_combined_loader_iterator_dict_min_size():
|
|
"""Test `CombinedLoaderIterator` given mapping loaders."""
|
|
loaders = {
|
|
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
|
|
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
|
|
}
|
|
|
|
combined_iter = CombinedLoaderIterator(loaders)
|
|
|
|
for idx, item in enumerate(combined_iter):
|
|
assert isinstance(item, dict)
|
|
assert len(item) == 2
|
|
assert "a" in item and "b" in item
|
|
|
|
assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1
|
|
|
|
|
|
def test_combined_loader_init_mode_error():
|
|
"""Test the ValueError when constructing `CombinedLoader`"""
|
|
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
|
|
CombinedLoader([range(10)], "testtt")
|
|
|
|
|
|
def test_combined_loader_loader_type_error():
|
|
"""Test the ValueError when wrapping the loaders."""
|
|
with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"):
|
|
CombinedLoader(None, "max_size_cycle")
|
|
|
|
|
|
def test_combined_loader_calc_length_mode_error():
|
|
"""Test the ValueError when calculating the number of batches."""
|
|
with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"):
|
|
CombinedLoader._calc_num_batches(None)
|
|
|
|
|
|
def test_combined_loader_dict_min_size():
|
|
"""Test `CombinedLoader` of mode 'min_size' given mapping loaders."""
|
|
loaders = {
|
|
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
|
|
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
|
|
}
|
|
|
|
combined_loader = CombinedLoader(loaders, "min_size")
|
|
|
|
assert len(combined_loader) == min(len(v) for v in loaders.values())
|
|
|
|
for idx, item in enumerate(combined_loader):
|
|
assert isinstance(item, dict)
|
|
assert len(item) == 2
|
|
assert "a" in item and "b" in item
|
|
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
|
|
def test_combined_loader_dict_max_size_cycle():
|
|
"""Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders."""
|
|
loaders = {
|
|
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
|
|
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
|
|
}
|
|
|
|
combined_loader = CombinedLoader(loaders, "max_size_cycle")
|
|
|
|
assert len(combined_loader) == max(len(v) for v in loaders.values())
|
|
|
|
for idx, item in enumerate(combined_loader):
|
|
assert isinstance(item, dict)
|
|
assert len(item) == 2
|
|
assert "a" in item and "b" in item
|
|
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
|
|
def test_combined_loader_sequence_min_size():
|
|
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders."""
|
|
loaders = [
|
|
torch.utils.data.DataLoader(range(10), batch_size=4),
|
|
torch.utils.data.DataLoader(range(20), batch_size=5),
|
|
]
|
|
|
|
combined_loader = CombinedLoader(loaders, "min_size")
|
|
|
|
assert len(combined_loader) == min(len(v) for v in loaders)
|
|
|
|
for idx, item in enumerate(combined_loader):
|
|
assert isinstance(item, Sequence)
|
|
assert len(item) == 2
|
|
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
|
|
class TestIterableDataset(IterableDataset):
|
|
def __init__(self, size: int = 10):
|
|
self.size = size
|
|
|
|
def __iter__(self):
|
|
self.sampler = SequentialSampler(range(self.size))
|
|
self.sampler_iter = iter(self.sampler)
|
|
return self
|
|
|
|
def __next__(self):
|
|
return next(self.sampler_iter)
|
|
|
|
|
|
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"])
|
|
@pytest.mark.parametrize("use_multiple_dataloaders", [False, True])
|
|
def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders):
|
|
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders."""
|
|
if use_multiple_dataloaders:
|
|
loaders = [
|
|
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
|
|
torch.utils.data.DataLoader(TestIterableDataset(20), batch_size=2),
|
|
]
|
|
else:
|
|
loaders = [
|
|
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
|
|
]
|
|
|
|
combined_loader = CombinedLoader(loaders, mode)
|
|
|
|
has_break = False
|
|
|
|
for idx, item in enumerate(combined_loader):
|
|
assert isinstance(item, Sequence)
|
|
assert len(item) == 2 if use_multiple_dataloaders else 1
|
|
if not use_multiple_dataloaders and idx == 4:
|
|
has_break = True
|
|
break
|
|
|
|
if mode == "max_size_cycle":
|
|
assert combined_loader.loaders[0].state.done == (not has_break)
|
|
expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5
|
|
assert (expected - 1) == idx, (mode, use_multiple_dataloaders)
|
|
|
|
|
|
@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
|
|
def test_combined_loader_sequence_with_map_and_iterable(lengths):
|
|
class MyIterableDataset(IterableDataset):
|
|
def __init__(self, size: int = 10):
|
|
self.size = size
|
|
|
|
def __iter__(self):
|
|
self.sampler = SequentialSampler(range(self.size))
|
|
self.iter_sampler = iter(self.sampler)
|
|
return self
|
|
|
|
def __next__(self):
|
|
return next(self.iter_sampler)
|
|
|
|
class MyMapDataset(Dataset):
|
|
def __init__(self, size: int = 10):
|
|
self.size = size
|
|
|
|
def __getitem__(self, index):
|
|
return index
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
x, y = lengths
|
|
loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))]
|
|
dataloader = CombinedLoader(loaders, mode="max_size_cycle")
|
|
counter = 0
|
|
for _ in dataloader:
|
|
counter += 1
|
|
assert counter == max(x, y)
|
|
|
|
|
|
def test_combined_loader_sequence_max_size_cycle():
|
|
"""Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders."""
|
|
loaders = [
|
|
torch.utils.data.DataLoader(range(10), batch_size=4),
|
|
torch.utils.data.DataLoader(range(20), batch_size=5),
|
|
]
|
|
|
|
combined_loader = CombinedLoader(loaders, "max_size_cycle")
|
|
|
|
assert len(combined_loader) == max(len(v) for v in loaders)
|
|
|
|
for idx, item in enumerate(combined_loader):
|
|
assert isinstance(item, Sequence)
|
|
assert len(item) == 2
|
|
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["input_data", "compute_func", "expected_length"],
|
|
[
|
|
([*range(10), list(range(1, 20))], min, 0),
|
|
([*range(10), list(range(1, 20))], max, 19),
|
|
([*range(10), {str(i): i for i in range(1, 20)}], min, 0),
|
|
([*range(10), {str(i): i for i in range(1, 20)}], max, 19),
|
|
({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, min, 0),
|
|
({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, max, 19),
|
|
({**{str(i): i for i in range(10)}, "nested": list(range(20))}, min, 0),
|
|
({**{str(i): i for i in range(10)}, "nested": list(range(20))}, max, 19),
|
|
],
|
|
)
|
|
def test_nested_calc_num_data(input_data, compute_func, expected_length):
|
|
calculated_length = _nested_calc_num_data(input_data, compute_func)
|
|
|
|
assert calculated_length == expected_length
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"})
|
|
@mock.patch("torch.cuda.device_count", return_value=2)
|
|
@mock.patch("torch.cuda.is_available", return_value=True)
|
|
def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir):
|
|
"""This test makes sure distributed sampler has been properly injected in dataloaders when using
|
|
CombinedLoader."""
|
|
|
|
class CustomDataset(Dataset):
|
|
def __init__(self, data):
|
|
self.data = data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
return self.data[index]
|
|
|
|
dataloader = CombinedLoader(
|
|
{
|
|
"a": DataLoader(CustomDataset(range(10))),
|
|
"b": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))},
|
|
"e": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))],
|
|
}
|
|
)
|
|
|
|
trainer = Trainer(replace_sampler_ddp=True, accelerator="ddp", gpus=2)
|
|
dataloader = trainer.auto_add_sampler(dataloader, shuffle=True)
|
|
_count = 0
|
|
|
|
def _assert_distributed_sampler(v):
|
|
nonlocal _count
|
|
_count += 1
|
|
assert isinstance(v, DistributedSampler)
|
|
|
|
apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler)
|
|
assert _count == 5
|