lightning/tests/trainer/test_supporters.py

294 lines
9.6 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import Sequence
from unittest import mock
import pytest
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import Sampler
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.supporters import (
_nested_calc_num_data,
CombinedDataset,
CombinedLoader,
CombinedLoaderIterator,
CycleIterator,
prefetch_iterator,
TensorRunningAccum,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
def test_tensor_running_accum_reset():
"""Test that reset would set all attributes to the initialization state"""
window_length = 10
accum = TensorRunningAccum(window_length=window_length)
assert accum.last() is None
assert accum.mean() is None
accum.append(torch.tensor(1.5))
assert accum.last() == torch.tensor(1.5)
assert accum.mean() == torch.tensor(1.5)
accum.reset()
assert accum.window_length == window_length
assert accum.memory is None
assert accum.current_idx == 0
assert accum.last_idx is None
assert not accum.rotated
def test_cycle_iterator():
"""Test the cycling function of `CycleIterator`"""
iterator = CycleIterator(range(100), 1000)
assert len(iterator) == 1000
for idx, item in enumerate(iterator):
assert item < 100
assert idx == len(iterator) - 1
def test_none_length_cycle_iterator():
"""Test the infinite cycling function of `CycleIterator`"""
iterator = CycleIterator(range(100))
assert iterator.__len__() == float("inf")
# test infinite loop
for idx, item in enumerate(iterator):
if idx == 1000:
break
assert item == 0
def test_prefetch_iterator():
"""Test the prefetch_iterator with PyTorch IterableDataset."""
class IterDataset(IterableDataset):
def __iter__(self):
yield 1
yield 2
yield 3
dataset = IterDataset()
iterator = prefetch_iterator(dataset)
assert list(iterator) == [(1, False), (2, False), (3, True)]
class EmptyIterDataset(IterableDataset):
def __iter__(self):
return iter([])
dataset = EmptyIterDataset()
iterator = prefetch_iterator(dataset)
assert list(iterator) == []
@pytest.mark.parametrize(
["dataset_1", "dataset_2"],
[
([list(range(10)), list(range(20))]),
([range(10), range(20)]),
([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]),
([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]),
],
)
def test_combined_dataset(dataset_1, dataset_2):
"""Verify the length of the CombinedDataset"""
datasets = [dataset_1, dataset_2]
combined_dataset = CombinedDataset(datasets)
assert combined_dataset.max_len == 20
assert combined_dataset.min_len == len(combined_dataset) == 10
def test_combined_dataset_length_mode_error():
dset = CombinedDataset([range(10)])
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
dset._calc_num_data([range(10)], "test")
def test_combined_loader_iterator_dict_min_size():
"""Test `CombinedLoaderIterator` given mapping loaders"""
loaders = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}
combined_iter = CombinedLoaderIterator(loaders)
for idx, item in enumerate(combined_iter):
assert isinstance(item, dict)
assert len(item) == 2
assert "a" in item and "b" in item
assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1
def test_combined_loader_init_mode_error():
"""Test the ValueError when constructing `CombinedLoader`"""
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
CombinedLoader([range(10)], "testtt")
def test_combined_loader_loader_type_error():
"""Test the ValueError when wrapping the loaders"""
with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"):
CombinedLoader(None, "max_size_cycle")
def test_combined_loader_calc_length_mode_error():
"""Test the ValueError when calculating the number of batches"""
with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"):
CombinedLoader._calc_num_batches(None)
def test_combined_loader_dict_min_size():
"""Test `CombinedLoader` of mode 'min_size' given mapping loaders"""
loaders = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}
combined_loader = CombinedLoader(loaders, "min_size")
assert len(combined_loader) == min(len(v) for v in loaders.values())
for idx, item in enumerate(combined_loader):
assert isinstance(item, dict)
assert len(item) == 2
assert "a" in item and "b" in item
assert idx == len(combined_loader) - 1
def test_combined_loader_dict_max_size_cycle():
"""Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders"""
loaders = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}
combined_loader = CombinedLoader(loaders, "max_size_cycle")
assert len(combined_loader) == max(len(v) for v in loaders.values())
for idx, item in enumerate(combined_loader):
assert isinstance(item, dict)
assert len(item) == 2
assert "a" in item and "b" in item
assert idx == len(combined_loader) - 1
def test_combined_loader_sequence_min_size():
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders"""
loaders = [
torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5),
]
combined_loader = CombinedLoader(loaders, "min_size")
assert len(combined_loader) == min(len(v) for v in loaders)
for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
assert len(item) == 2
assert idx == len(combined_loader) - 1
def test_combined_loader_sequence_max_size_cycle():
"""Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders"""
loaders = [
torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5),
]
combined_loader = CombinedLoader(loaders, "max_size_cycle")
assert len(combined_loader) == max(len(v) for v in loaders)
for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
assert len(item) == 2
assert idx == len(combined_loader) - 1
@pytest.mark.parametrize(
["input_data", "compute_func", "expected_length"],
[
([*range(10), list(range(1, 20))], min, 0),
([*range(10), list(range(1, 20))], max, 19),
([*range(10), {str(i): i for i in range(1, 20)}], min, 0),
([*range(10), {str(i): i for i in range(1, 20)}], max, 19),
({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, min, 0),
({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, max, 19),
({**{str(i): i for i in range(10)}, "nested": list(range(20))}, min, 0),
({**{str(i): i for i in range(10)}, "nested": list(range(20))}, max, 19),
],
)
def test_nested_calc_num_data(input_data, compute_func, expected_length):
calculated_length = _nested_calc_num_data(input_data, compute_func)
assert calculated_length == expected_length
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"})
@mock.patch("torch.cuda.device_count", return_value=2)
@mock.patch("torch.cuda.is_available", return_value=True)
def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir):
"""
This test makes sure distributed sampler has been properly injected in dataloaders
when using CombinedLoader
"""
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
dataloader = CombinedLoader(
{
"a": DataLoader(CustomDataset(range(10))),
"b": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))},
"e": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))],
}
)
trainer = Trainer(replace_sampler_ddp=True, accelerator="ddp", gpus=2)
dataloader = trainer.auto_add_sampler(dataloader, shuffle=True)
_count = 0
def _assert_distributed_sampler(v):
nonlocal _count
_count += 1
assert isinstance(v, DistributedSampler)
apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler)
assert _count == 5