# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import os import random from collections.abc import Iterable from typing import Optional from unittest import mock import numpy as np import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler, SequentialSampler from torch.utils.data._utils.worker import get_worker_info from torch.utils.data.dataloader import DataLoader, default_collate from torch.utils.data.dataset import Dataset, IterableDataset import tests.helpers.utils as tutils from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, CaptureIterableDataset, FastForwardSampler, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf # Credit to PyTorch Team. # Taken from: # https://github.com/pytorch/pytorch/blob/3b977a0d2834d300c0301a0c6af98c8e939019ce/torch/utils/data/_utils/worker.py#L151 # Not available until torch 1.9.0 def _generate_state(base_seed, worker_id): INIT_A = 0x43B0D7E5 MULT_A = 0x931E8875 INIT_B = 0x8B51F9DD MULT_B = 0x58F38DED MIX_MULT_L = 0xCA01F9DD MIX_MULT_R = 0x4973F715 XSHIFT = 4 * 8 // 2 MASK32 = 0xFFFFFFFF entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] pool = [0] * 4 hash_const_A = INIT_A def hash(value): nonlocal hash_const_A value = (value ^ hash_const_A) & MASK32 hash_const_A = (hash_const_A * MULT_A) & MASK32 value = (value * hash_const_A) & MASK32 value = (value ^ (value >> XSHIFT)) & MASK32 return value def mix(x, y): result_x = (MIX_MULT_L * x) & MASK32 result_y = (MIX_MULT_R * y) & MASK32 result = (result_x - result_y) & MASK32 result = (result ^ (result >> XSHIFT)) & MASK32 return result # Add in the entropy to the pool. for i in range(len(pool)): pool[i] = hash(entropy[i]) # Mix all bits together so late bits can affect earlier bits. for i_src in range(len(pool)): for i_dst in range(len(pool)): if i_src != i_dst: pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) hash_const_B = INIT_B state = [] for i_dst in range(4): data_val = pool[i_dst] data_val = (data_val ^ hash_const_B) & MASK32 hash_const_B = (hash_const_B * MULT_B) & MASK32 data_val = (data_val * hash_const_B) & MASK32 data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 state.append(data_val) return state def test_fast_forward_getattr(): dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) index_batch_sampler = FastForwardSampler(batch_sampler) assert index_batch_sampler.batch_size == 3 assert index_batch_sampler.sampler == sampler def test_fast_forward_on_batch_sampler(): """ This test ensures ``FastForwardSampler`` applied to ``BatchSampler`` correctly retrived the right next batch on restart. """ dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) index_batch_sampler = FastForwardSampler(batch_sampler) assert isinstance(index_batch_sampler, Iterable) index_batch_sampler_iter = iter(index_batch_sampler) assert next(index_batch_sampler_iter) == [0, 1, 2] assert next(index_batch_sampler_iter) == [3, 4, 5] state_dict = index_batch_sampler.state_dict(2) index_batch_sampler = FastForwardSampler(batch_sampler) index_batch_sampler.load_state_dict(state_dict) index_batch_sampler_iter = iter(index_batch_sampler) assert next(index_batch_sampler_iter) == [6, 7, 8] def test_fast_forward_on_sequential_sampler(): """ This test ensures ``FastForwardSampler`` applied to ``SequentialSampler`` correctly retrived the right next batch on restart. """ dataset = range(15) sequential_sampler = SequentialSampler(dataset) sampler = FastForwardSampler(sequential_sampler) sampler.setup(3) batch_sampler = BatchSampler(sampler, 3, False) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == [0, 1, 2] assert next(batch_sampler_iter) == [3, 4, 5] state_dict = sampler.state_dict(2) assert state_dict[0]["current_iteration"] == 6 sampler.load_state_dict(state_dict) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == [6, 7, 8] @pytest.mark.skipif(torch.cuda.is_available(), reason="todo (tchaton) Need more investigation") def test_fast_forward_on_random_sampler(): """ This test ensures ``FastForwardSampler`` applied to ``RandomSampler`` correctly retrived the right next batch on restart. """ seed = 42 seed_everything(42) dataset = range(15) generator = torch.Generator().manual_seed(seed) values = list(RandomSampler(dataset, generator=generator)) generator = torch.Generator().manual_seed(seed) random_sampler = RandomSampler(dataset, generator=generator) sampler = FastForwardSampler(random_sampler) sampler.setup(3) batch_sampler = BatchSampler(sampler, 3, False) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == values[:3] assert next(batch_sampler_iter) == values[3:6] assert next(batch_sampler_iter) == values[6:9] state_dict = sampler.state_dict(3) assert state_dict[0]["current_iteration"] == 9 state_dict[0]["current_iteration"] = 6 seed_everything(42) generator = torch.Generator().manual_seed(seed) random_sampler = RandomSampler(dataset, generator=generator) sampler = FastForwardSampler(random_sampler) sampler.setup(3) batch_sampler = BatchSampler(sampler, 3, False) sampler.load_state_dict(state_dict) batch_sampler_iter = iter(batch_sampler) assert next(batch_sampler_iter) == values[6:9] has_raised = False try: for _ in range(5): next(batch_sampler_iter) except StopIteration: has_raised = True assert sampler._current_iteration == 0 sampler.load_state_dict(sampler.state_dict(0)) assert has_raised class RangeIterableDataset(IterableDataset): def __init__(self, data, num_workers: int, batch_size: int, state_dict=None, attr_name: str = "iter_sampler"): self.data = list(data) self.batch_size = batch_size self.num_workers = num_workers self.state_dict = state_dict self.attr_name = attr_name def __iter__(self): worker_info = get_worker_info() if worker_info and self.num_workers == 2: id = worker_info.id num_samples = len(self.data) if id == 0: self.data = list(self.data)[: num_samples // 2] else: self.data = list(self.data)[num_samples // 2 :] self.user_sampler = RandomSampler(self.data) else: self.user_sampler = RandomSampler(self.data) setattr(self, self.attr_name, iter(self.user_sampler)) return self def __next__(self): iter_sampler = getattr(self, self.attr_name) return self.data[next(iter_sampler)] @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 30 sec and should be skipped in Azure CI") @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_fast_forward_sampler_over_iterative_dataset(num_workers): """ This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being used to capture workers states. """ batch_size = 3 initial_seed = seed_everything(42) generator = torch.Generator() generator.manual_seed(initial_seed) dataset = RangeIterableDataset(range(20), num_workers, batch_size, True) dataset = CaptureIterableDataset(dataset) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) _add_capture_metadata_collate(dataloader) iter_dataloader = iter(dataloader) batches = [] for _ in range(5): batches.append(next(iter_dataloader)) # restarting on batch_1 and getting 3 extra batches state_dict = {"iter_sampler": {}} for batch in batches[:2]: batch, _state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch) for k, v in _state_dict[0].items(): state_dict[k].update(v) assert len(state_dict["iter_sampler"]) == (num_workers if num_workers > 1 else 1) initial_seed = seed_everything(42) generator.manual_seed(initial_seed) dataset = RangeIterableDataset(range(20), num_workers, batch_size, state_dict=state_dict) dataset = CaptureIterableDataset(dataset) dataset.load_state_dict(state_dict) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) _add_capture_metadata_collate(dataloader) iter_dataloader = iter(dataloader) batches_restart = [] for _ in range(3): batches_restart.append(next(iter_dataloader)) assert torch.equal(batches_restart[0]["data"], batches[2]["data"]) assert torch.equal(batches_restart[1]["data"], batches[3]["data"]) assert torch.equal(batches_restart[2]["data"], batches[4]["data"]) def _setup_ddp(rank, worldsize): os.environ["MASTER_ADDR"] = "localhost" # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=worldsize) def _test_fast_forward_sampler_with_distributed_sampler(rank, worldsize): _setup_ddp(rank, worldsize) initial_seed = seed_everything(42) generator = torch.Generator() generator.manual_seed(initial_seed) num_workers = 2 batch_size = 4 dataset = range(30) sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed)) sampler.setup(batch_size) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler ) iter_dataloader = iter(dataloader) num_yielded = 0 batches = [] while True: try: batches.append(next(iter_dataloader)) num_yielded += 1 except StopIteration: break expected = torch.tensor([17, 27, 24]) if rank == 0 else torch.tensor([19, 5, 3]) assert torch.equal(batches[-1], expected) assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16 reload_state_dict = sampler.state_dict(num_yielded - 1) assert reload_state_dict[0]["current_iteration"] == 12 sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed)) sampler.setup(batch_size) sampler.load_state_dict(reload_state_dict) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler ) iter_dataloader = iter(dataloader) batches = [] while True: try: batches.append(next(iter_dataloader)) except StopIteration: break assert torch.equal(batches[-1], expected) assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16 @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 25 sec and should be skipped in Azure CI") @RunIf(skip_windows=True) def test_fast_forward_sampler_with_distributed_sampler(): """Make sure result logging works with DDP""" tutils.set_random_master_port() worldsize = 2 mp.spawn(_test_fast_forward_sampler_with_distributed_sampler, args=(worldsize,), nprocs=worldsize) class MetaLearningDataset(IterableDataset): def __init__( self, dataset: Dataset, batch_size: int, drop_last: bool, task_num_classes: int = 5, num_workers: Optional[int] = None, global_rank: Optional[int] = None, world_size: Optional[int] = None, initial_seed: Optional[int] = None, shuffle: bool = True, debugging: bool = False, ): self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last self.num_workers = num_workers or 1 self.global_rank = global_rank self.world_size = world_size self.task_num_classes = task_num_classes self.labels = labels = getattr(dataset, "labels") self.initial_seed = initial_seed self.generator: Optional[torch.Generator] = None self.current_task_iteration = 0 self.shuffle = shuffle self.debugging = debugging if labels is None: raise MisconfigurationException(f"Provided {self.dataset} should have an attribute labels.") if len(labels) != len(dataset): raise MisconfigurationException("Found provided ``labels`` don't match the dataset length.") if (isinstance(global_rank, int) and world_size is None) or ( isinstance(world_size, int) and global_rank is None ): raise MisconfigurationException("Both ``world_size`` and ``global_rank`` should be provided !") self.unique_labels = np.unique(self.labels) @property def worker_id(self) -> int: worker_info = get_worker_info() return worker_info.id if worker_info else 0 @property def is_distributed(self) -> bool: return self.world_size is not None and self.world_size > 1 def set_seed(self, shared: bool = False): initial_seed = self.initial_seed + self.current_task_iteration if shared: seed = initial_seed np_seed = _generate_state(initial_seed, 0) else: seed = initial_seed + self.worker_id + self.global_rank + self.current_task_iteration np_seed = _generate_state(initial_seed, self.worker_id + self.global_rank) random.seed(seed) torch.manual_seed(seed) np.random.seed(np_seed) def sample_task_indices(self): self.set_seed(shared=True) self.selected_indexes = np.random.choice(self.unique_labels, self.task_num_classes, replace=False) self.selected_indexes.sort() # subset of indices from the entire dataset where the labels are actually among the # task_num_classes selected_indexes self.task_indices = np.arange(len(self.dataset))[np.in1d(self.labels, self.selected_indexes)] self.task_length = len(self.task_indices) self.set_seed(shared=False) @property def worker_rank(self) -> int: worker_id = self.worker_id is_global_zero = self.global_rank == 0 return self.global_rank + worker_id + int(not is_global_zero) def create_sampler(self): data = range(self.task_length) if self.world_size == 1 and self.num_workers in (0, 1): if self.shuffle: self.sampler = RandomSampler(data, generator=self.generator) else: self.sampler = SequentialSampler(data) else: num_workers = 1 if self.num_workers in (None, 0) else self.num_workers num_replicas = num_workers * self.world_size current_seed = self.initial_seed + self.current_task_iteration self.sampler = DistributedSampler( data, num_replicas=num_replicas, rank=self.worker_rank, shuffle=self.shuffle, seed=current_seed ) def __iter__(self): if self.generator is None: self.generator = torch.Generator().manual_seed(self.initial_seed) self.sample_task_indices() self.create_sampler() self.batch_sampler = BatchSampler(self.sampler, batch_size=self.batch_size, drop_last=self.drop_last) self.iter_sampler = iter(self.batch_sampler) self.is_first_batch = True self.current_task_iteration += 1 return self def increment_iteration(self): self.current_task_iteration += 1 def __next__(self): # this is optional, but useful to accumulate gradient over the entire task. is_first_batch = self.is_first_batch if self.debugging else (self.is_first_batch and self.worker_id == 0) if is_first_batch: self.is_first_batch = False return {"task_length": len(self.batch_sampler), "selected_indexes": self.selected_indexes} random_indices = next(self.iter_sampler) task_indices = [self.task_indices[idx] for idx in random_indices] return default_collate([self.dataset[idx] for idx in task_indices]) class ClassificationDataset(Dataset): def __init__(self, inputs, labels): self.inputs = inputs self.labels = labels assert len(self.inputs) == len(self.labels) def __getitem__(self, index): return (self.inputs[index], self.labels[index]) def __len__(self): return len(self.inputs) def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(rank, worldsize): if worldsize > 1: _setup_ddp(rank, worldsize) def all_gather(tensor, world_size): tensor_list = [torch.zeros_like(tensor, dtype=torch.int64) for _ in range(world_size)] torch.distributed.all_gather(tensor_list, tensor) return tensor_list initial_seed = seed_everything(42) generator = torch.Generator() generator.manual_seed(initial_seed) num_workers = 2 batch_size = 4 dataset_length = 60 num_classes = 10 labels = np.random.randint(0, num_classes, dataset_length) dataset = ClassificationDataset(range(dataset_length), labels) dataset = MetaLearningDataset( dataset, batch_size=batch_size, drop_last=True, num_workers=num_workers, global_rank=rank, world_size=worldsize, initial_seed=initial_seed, debugging=True, shuffle=True, ) dataset = CaptureIterableDataset(dataset) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator) _add_capture_metadata_collate(dataloader) epoch_results = [] for _ in range(2): iter_dataloader = iter(dataloader) batches = [] while True: try: batches.append(next(iter_dataloader)) except StopIteration: break epoch_results.append(batches) dataloader.dataset.dataset.current_task_iteration += 1 assert len(epoch_results) == 2 assert len(epoch_results[0]) == math.ceil((dataset_length / (num_workers * worldsize)) / batch_size) + 2 if worldsize == 1: assert epoch_results[0][0]["data"]["task_length"] == epoch_results[0][1]["data"]["task_length"] assert torch.equal( epoch_results[0][0]["data"]["selected_indexes"], epoch_results[0][1]["data"]["selected_indexes"] ) assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 0 assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 1 assert not torch.equal(epoch_results[0][2]["data"][0], epoch_results[0][3]["data"][0]) else: first_task_metadata = all_gather(epoch_results[0][0]["data"]["task_length"], worldsize) second_task_metadata = all_gather(epoch_results[0][1]["data"]["task_length"], worldsize) assert torch.equal(first_task_metadata[0], first_task_metadata[1]) assert torch.equal(second_task_metadata[0], second_task_metadata[1]) assert torch.equal(first_task_metadata[0], second_task_metadata[1]) first_batch_list = all_gather(epoch_results[0][2]["data"][0], worldsize) assert not torch.equal(first_batch_list[0], first_batch_list[1]) second_batch_list = all_gather(epoch_results[0][3]["data"][0], worldsize) assert not torch.equal(second_batch_list[0], second_batch_list[1]) # restarting on epoch 0 / real batch 2 state_dict = {"iter_sampler": {}} for batch in epoch_results[0][2:4]: batch, _state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch) for k, v in _state_dict[0].items(): state_dict[k].update(v) dataset = ClassificationDataset(range(dataset_length), labels) dataset = MetaLearningDataset( dataset, batch_size=batch_size, drop_last=True, num_workers=num_workers, global_rank=rank, world_size=worldsize, initial_seed=initial_seed, debugging=True, shuffle=True, ) dataset = CaptureIterableDataset(dataset) dataset.load_state_dict(state_dict) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator) _add_capture_metadata_collate(dataloader) epoch_results_restart = [] for _ in range(2): iter_dataloader = iter(dataloader) batches = [] while True: try: batches.append(next(iter_dataloader)) except StopIteration: break epoch_results_restart.append(batches) dataloader.dataset.dataset.increment_iteration() dataloader.dataset.reset_on_epoch() assert len(epoch_results_restart[0]) + 2 == len(epoch_results[0]) epoch_tensors = [e["data"][0] for e in epoch_results[0][4:]] epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[0][2:]] for t, tr in zip(epoch_tensors, epoch_tensors_restart): assert torch.equal(t, tr) epoch_tensors = [e["data"][0] for e in epoch_results[1][2:]] epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[1][2:]] for t, tr in zip(epoch_tensors, epoch_tensors_restart): assert torch.equal(t, tr) @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 45 sec and should be skipped in Azure CI") def test_fast_forward_sampler_iterative_dataset(): _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(0, 1) @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 55 sec and should be skipped in Azure CI") @RunIf(skip_windows=True) def test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(): """Make sure result logging works with DDP""" tutils.set_random_master_port() worldsize = 2 mp.spawn( _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset, args=(worldsize,), nprocs=worldsize ) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @RunIf(max_torch="1.7") def test_fault_tolerant_not_supported(): assert not _fault_tolerant_training() def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", wrap: bool = True): dataset = RangeIterableDataset(range(50), num_workers=num_workers, batch_size=batch_size, attr_name=attr_name) if wrap: dataset = CaptureIterableDataset(dataset) return dataset def test_dataloader_to_state_dict_and_reload(): """ Note: Those utilities are used only with DataLoader wrapping a ``mapping`` based dataset. """ def create_dataloader(): dataset = range(50) batch_size = 8 sampler = FastForwardSampler(SequentialSampler(dataset)) sampler.setup(batch_size) return DataLoader(dataset, sampler=sampler, batch_size=batch_size) dataloader = create_dataloader() iter_dataloader = iter(dataloader) _ = next(iter_dataloader) _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) assert state_dict[0]["current_iteration"] == 16 dataloader = create_dataloader() dataloader = _dataloader_load_state_dict(dataloader, state_dict) iter_dataloader = iter(dataloader) _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) assert state_dict[0]["current_iteration"] == 24 @RunIf(min_torch="1.7.0") @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir): """ this test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled. """ class CustomBatchSampler(BatchSampler): pass dataset = range(50) class TestModel(BoringModel): def train_dataloader(self): return { "a": [ DataLoader(create_iterable_dataset(3, 1, wrap=False), num_workers=0, batch_size=3), DataLoader(dataset, batch_size=8), DataLoader( dataset, batch_sampler=CustomBatchSampler(SequentialSampler(dataset), batch_size=8, drop_last=False), ), ], "b": DataLoader( create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler", wrap=False), num_workers=0, batch_size=2, ), } def training_step(self, batch, batch_idx): pass class Check(Callback): def on_train_batch_start(self, trainer, *_) -> None: loaders = trainer.train_dataloader.loaders if use_fault_tolerant == "1": assert isinstance(loaders["a"][0].loader.dataset, CaptureIterableDataset) assert isinstance(loaders["a"][1].loader.sampler, FastForwardSampler) assert isinstance(loaders["a"][2].loader.batch_sampler, FastForwardSampler) assert isinstance(loaders["b"].loader.dataset, CaptureIterableDataset) else: assert isinstance(loaders["a"][0].loader.dataset, RangeIterableDataset) assert isinstance(loaders["a"][1].loader.sampler, SequentialSampler) assert isinstance(loaders["a"][2].loader.batch_sampler, CustomBatchSampler) assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset) with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}): model = TestModel() model.training_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check()) trainer.fit(model)