# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable import pytest from torch.utils.data import BatchSampler, SequentialSampler from pytorch_lightning import seed_everything from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler from pytorch_lightning.utilities.data import has_len @pytest.mark.parametrize("shuffle", [False, True]) def test_unrepeated_distributed_sampler(shuffle, tmpdir): """Test each rank will receive a different number of elements.""" seed_everything(42) world_size = 4 samplers = [] dataset = range(103) for rank in range(world_size): samplers.append(UnrepeatedDistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=shuffle)) indices = [list(s) for s in samplers] assert len(indices[0]) == 26 assert len(indices[1]) == 26 assert len(indices[2]) == 26 assert len(indices[3]) == 25 assert indices[0][-1] == 18 if shuffle else 100 assert indices[1][-1] == 30 if shuffle else 101 assert indices[2][-1] == 29 if shuffle else 102 assert indices[3][-1] == 35 if shuffle else 99 def test_index_batch_sampler(tmpdir): """Test `IndexBatchSampler` properly extracts indices.""" dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler) assert batch_sampler.batch_size == index_batch_sampler.batch_size assert batch_sampler.drop_last == index_batch_sampler.drop_last assert batch_sampler.sampler is sampler assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices def test_index_batch_sampler_methods(): dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler) assert isinstance(index_batch_sampler, Iterable) assert has_len(index_batch_sampler)