# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Sequence from unittest import mock import pytest import torch from torch.utils.data import DataLoader, TensorDataset from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler from pytorch_lightning import Trainer from pytorch_lightning.trainer.supporters import ( _nested_calc_num_data, CombinedDataset, CombinedLoader, CombinedLoaderIterator, CycleIterator, TensorRunningAccum, ) from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import CaptureMapDataset, FastForwardSampler from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import RandomDataset def test_tensor_running_accum_reset(): """Test that reset would set all attributes to the initialization state.""" window_length = 10 accum = TensorRunningAccum(window_length=window_length) assert accum.last() is None assert accum.mean() is None accum.append(torch.tensor(1.5)) assert accum.last() == torch.tensor(1.5) assert accum.mean() == torch.tensor(1.5) accum.reset() assert accum.window_length == window_length assert accum.memory is None assert accum.current_idx == 0 assert accum.last_idx is None assert not accum.rotated def test_cycle_iterator(): """Test the cycling function of `CycleIterator`""" iterator = CycleIterator(range(100), 1000) assert len(iterator) == 1000 for idx, item in enumerate(iterator): assert item < 100 assert idx == len(iterator) - 1 def test_none_length_cycle_iterator(): """Test the infinite cycling function of `CycleIterator`""" iterator = CycleIterator(range(100)) assert iterator.__len__() == float("inf") # test infinite loop for idx, item in enumerate(iterator): if idx == 1000: break assert item == 0 @pytest.mark.parametrize( ["dataset_1", "dataset_2"], [ ([list(range(10)), list(range(20))]), ([range(10), range(20)]), ([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]), ([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]), ], ) def test_combined_dataset(dataset_1, dataset_2): """Verify the length of the CombinedDataset.""" datasets = [dataset_1, dataset_2] combined_dataset = CombinedDataset(datasets) assert combined_dataset.max_len == 20 assert combined_dataset.min_len == len(combined_dataset) == 10 def test_combined_dataset_length_mode_error(): dset = CombinedDataset([range(10)]) with pytest.raises(MisconfigurationException, match="Invalid Mode"): dset._calc_num_data([range(10)], "test") def test_combined_loader_iterator_dict_min_size(): """Test `CombinedLoaderIterator` given mapping loaders.""" loaders = { "a": torch.utils.data.DataLoader(range(10), batch_size=4), "b": torch.utils.data.DataLoader(range(20), batch_size=5), } combined_iter = CombinedLoaderIterator(loaders) for idx, item in enumerate(combined_iter): assert isinstance(item, dict) assert len(item) == 2 assert "a" in item and "b" in item assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1 def test_combined_loader_init_mode_error(): """Test the ValueError when constructing `CombinedLoader`""" with pytest.raises(MisconfigurationException, match="Invalid Mode"): CombinedLoader([range(10)], "testtt") def test_combined_loader_loader_type_error(): """Test the ValueError when wrapping the loaders.""" with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"): CombinedLoader(None, "max_size_cycle") def test_combined_loader_calc_length_mode_error(): """Test the ValueError when calculating the number of batches.""" with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"): CombinedLoader._calc_num_batches(None) def test_combined_loader_dict_min_size(): """Test `CombinedLoader` of mode 'min_size' given mapping loaders.""" loaders = { "a": torch.utils.data.DataLoader(range(10), batch_size=4), "b": torch.utils.data.DataLoader(range(20), batch_size=5), } combined_loader = CombinedLoader(loaders, "min_size") assert len(combined_loader) == min(len(v) for v in loaders.values()) for idx, item in enumerate(combined_loader): assert isinstance(item, dict) assert len(item) == 2 assert "a" in item and "b" in item assert idx == len(combined_loader) - 1 def test_combined_loader_dict_max_size_cycle(): """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders.""" loaders = { "a": torch.utils.data.DataLoader(range(10), batch_size=4), "b": torch.utils.data.DataLoader(range(20), batch_size=5), } combined_loader = CombinedLoader(loaders, "max_size_cycle") assert len(combined_loader) == max(len(v) for v in loaders.values()) for idx, item in enumerate(combined_loader): assert isinstance(item, dict) assert len(item) == 2 assert "a" in item and "b" in item assert idx == len(combined_loader) - 1 def test_combined_loader_sequence_min_size(): """Test `CombinedLoader` of mode 'min_size' given sequence loaders.""" loaders = [ torch.utils.data.DataLoader(range(10), batch_size=4), torch.utils.data.DataLoader(range(20), batch_size=5), ] combined_loader = CombinedLoader(loaders, "min_size") assert len(combined_loader) == min(len(v) for v in loaders) for idx, item in enumerate(combined_loader): assert isinstance(item, Sequence) assert len(item) == 2 assert idx == len(combined_loader) - 1 class TestIterableDataset(IterableDataset): def __init__(self, size: int = 10): self.size = size def __iter__(self): self.sampler = SequentialSampler(range(self.size)) self.sampler_iter = iter(self.sampler) return self def __next__(self): return next(self.sampler_iter) @pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"]) @pytest.mark.parametrize("use_multiple_dataloaders", [False, True]) def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders): """Test `CombinedLoader` of mode 'min_size' given sequence loaders.""" if use_multiple_dataloaders: loaders = [ torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2), torch.utils.data.DataLoader(TestIterableDataset(20), batch_size=2), ] else: loaders = [ torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2), ] combined_loader = CombinedLoader(loaders, mode) has_break = False for idx, item in enumerate(combined_loader): assert isinstance(item, Sequence) assert len(item) == 2 if use_multiple_dataloaders else 1 if not use_multiple_dataloaders and idx == 4: has_break = True break if mode == "max_size_cycle": assert combined_loader.loaders[0].state.done == (not has_break) expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5 assert (expected - 1) == idx, (mode, use_multiple_dataloaders) @pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]]) def test_combined_loader_sequence_with_map_and_iterable(lengths): class MyIterableDataset(IterableDataset): def __init__(self, size: int = 10): self.size = size def __iter__(self): self.sampler = SequentialSampler(range(self.size)) self.iter_sampler = iter(self.sampler) return self def __next__(self): return next(self.iter_sampler) class MyMapDataset(Dataset): def __init__(self, size: int = 10): self.size = size def __getitem__(self, index): return index def __len__(self): return self.size x, y = lengths loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))] dataloader = CombinedLoader(loaders, mode="max_size_cycle") counter = 0 for _ in dataloader: counter += 1 assert counter == max(x, y) def test_combined_loader_sequence_max_size_cycle(): """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders.""" loaders = [ torch.utils.data.DataLoader(range(10), batch_size=4), torch.utils.data.DataLoader(range(20), batch_size=5), ] combined_loader = CombinedLoader(loaders, "max_size_cycle") assert len(combined_loader) == max(len(v) for v in loaders) for idx, item in enumerate(combined_loader): assert isinstance(item, Sequence) assert len(item) == 2 assert idx == len(combined_loader) - 1 @pytest.mark.parametrize( ["input_data", "compute_func", "expected_length"], [ ([*range(10), list(range(1, 20))], min, 0), ([*range(10), list(range(1, 20))], max, 19), ([*range(10), {str(i): i for i in range(1, 20)}], min, 0), ([*range(10), {str(i): i for i in range(1, 20)}], max, 19), ({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, min, 0), ({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, max, 19), ({**{str(i): i for i in range(10)}, "nested": list(range(20))}, min, 0), ({**{str(i): i for i in range(10)}, "nested": list(range(20))}, max, 19), ], ) def test_nested_calc_num_data(input_data, compute_func, expected_length): calculated_length = _nested_calc_num_data(input_data, compute_func) assert calculated_length == expected_length @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) @mock.patch("torch.cuda.device_count", return_value=2) @mock.patch("torch.cuda.is_available", return_value=True) @pytest.mark.parametrize("use_fault_tolerant", [False, True]) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) def test_combined_data_loader_validation_test( cuda_available_mock, device_count_mock, use_fault_tolerant, replace_sampler_ddp, tmpdir ): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader.""" class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] class CustomSampler(RandomSampler): def __init__(self, data_source, name) -> None: super().__init__(data_source) self.name = name dataset = CustomDataset(range(10)) dataloader = CombinedLoader( { "a": DataLoader(CustomDataset(range(10))), "b": DataLoader(dataset, sampler=CustomSampler(dataset, "custom_sampler")), "c": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))}, "d": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))], } ) with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(use_fault_tolerant))}): trainer = Trainer(replace_sampler_ddp=replace_sampler_ddp, strategy="ddp", gpus=2) dataloader = trainer.prepare_dataloader(dataloader, shuffle=True) _count = 0 _has_fastforward_sampler = False def _assert_distributed_sampler(v): nonlocal _count nonlocal _has_fastforward_sampler _count += 1 if use_fault_tolerant: _has_fastforward_sampler = True assert isinstance(v, FastForwardSampler) v = v._sampler if replace_sampler_ddp: assert isinstance(v, DistributedSampler) else: assert isinstance(v, (SequentialSampler, CustomSampler)) apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler) assert _count == 6 assert _has_fastforward_sampler == use_fault_tolerant def _assert_dataset(loader): d = loader.dataset if use_fault_tolerant: assert isinstance(d, CaptureMapDataset) else: assert isinstance(d, CustomDataset) apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) def test_combined_data_loader_with_max_size_cycle_and_ddp(replace_sampler_ddp, tmpdir): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader with ddp and `max_size_cycle` mode.""" trainer = Trainer(strategy="ddp", accelerator="auto", devices=2, replace_sampler_ddp=replace_sampler_ddp) dataloader = CombinedLoader( {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, ) dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == 4 if replace_sampler_ddp else 8 for a_length in [6, 8, 10]: dataloader = CombinedLoader( { "a": DataLoader(range(a_length), batch_size=1), "b": DataLoader(range(8), batch_size=1), }, mode="max_size_cycle", ) length = max(a_length, 8) assert len(dataloader) == length dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) assert len(dataloader) == length // 2 if replace_sampler_ddp else length if replace_sampler_ddp: last_batch = list(dataloader)[-1] if a_length == 6: assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])} elif a_length == 8: assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])} elif a_length == 10: assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])} class InfiniteDataset(IterableDataset): def __iter__(self): while True: yield 1 dataloader = CombinedLoader( { "a": DataLoader(InfiniteDataset(), batch_size=1), "b": DataLoader(range(8), batch_size=1), }, mode="max_size_cycle", ) assert get_len(dataloader) == float("inf") assert len(dataloader.loaders["b"].loader) == 8 dataloader = trainer.prepare_dataloader(dataloader, shuffle=False) assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8 assert get_len(dataloader) == float("inf")