737 lines
27 KiB
Python
737 lines
27 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
import os
|
|
import random
|
|
from collections.abc import Iterable
|
|
from typing import Optional
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler, SequentialSampler
|
|
from torch.utils.data._utils.worker import get_worker_info
|
|
from torch.utils.data.dataloader import DataLoader, default_collate
|
|
from torch.utils.data.dataset import Dataset, IterableDataset
|
|
|
|
import tests.helpers.utils as tutils
|
|
from pytorch_lightning import Callback, seed_everything, Trainer
|
|
from pytorch_lightning.utilities.auto_restart import (
|
|
_add_capture_metadata_collate,
|
|
_dataloader_load_state_dict,
|
|
_dataloader_to_state_dict,
|
|
CaptureIterableDataset,
|
|
FastForwardSampler,
|
|
)
|
|
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
|
from tests.helpers.boring_model import BoringModel
|
|
from tests.helpers.runif import RunIf
|
|
|
|
|
|
# Credit to PyTorch Team.
|
|
# Taken from:
|
|
# https://github.com/pytorch/pytorch/blob/3b977a0d2834d300c0301a0c6af98c8e939019ce/torch/utils/data/_utils/worker.py#L151
|
|
# Not available until torch 1.9.0
|
|
def _generate_state(base_seed, worker_id):
|
|
INIT_A = 0x43B0D7E5
|
|
MULT_A = 0x931E8875
|
|
INIT_B = 0x8B51F9DD
|
|
MULT_B = 0x58F38DED
|
|
MIX_MULT_L = 0xCA01F9DD
|
|
MIX_MULT_R = 0x4973F715
|
|
XSHIFT = 4 * 8 // 2
|
|
MASK32 = 0xFFFFFFFF
|
|
|
|
entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
|
|
pool = [0] * 4
|
|
|
|
hash_const_A = INIT_A
|
|
|
|
def hash(value):
|
|
nonlocal hash_const_A
|
|
value = (value ^ hash_const_A) & MASK32
|
|
hash_const_A = (hash_const_A * MULT_A) & MASK32
|
|
value = (value * hash_const_A) & MASK32
|
|
value = (value ^ (value >> XSHIFT)) & MASK32
|
|
return value
|
|
|
|
def mix(x, y):
|
|
result_x = (MIX_MULT_L * x) & MASK32
|
|
result_y = (MIX_MULT_R * y) & MASK32
|
|
result = (result_x - result_y) & MASK32
|
|
result = (result ^ (result >> XSHIFT)) & MASK32
|
|
return result
|
|
|
|
# Add in the entropy to the pool.
|
|
for i in range(len(pool)):
|
|
pool[i] = hash(entropy[i])
|
|
|
|
# Mix all bits together so late bits can affect earlier bits.
|
|
for i_src in range(len(pool)):
|
|
for i_dst in range(len(pool)):
|
|
if i_src != i_dst:
|
|
pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
|
|
|
|
hash_const_B = INIT_B
|
|
state = []
|
|
for i_dst in range(4):
|
|
data_val = pool[i_dst]
|
|
data_val = (data_val ^ hash_const_B) & MASK32
|
|
hash_const_B = (hash_const_B * MULT_B) & MASK32
|
|
data_val = (data_val * hash_const_B) & MASK32
|
|
data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
|
|
state.append(data_val)
|
|
return state
|
|
|
|
|
|
def test_fast_forward_getattr():
|
|
dataset = range(15)
|
|
sampler = SequentialSampler(dataset)
|
|
batch_sampler = BatchSampler(sampler, 3, False)
|
|
index_batch_sampler = FastForwardSampler(batch_sampler)
|
|
|
|
assert index_batch_sampler.batch_size == 3
|
|
assert index_batch_sampler.sampler == sampler
|
|
|
|
|
|
def test_fast_forward_on_batch_sampler():
|
|
"""
|
|
This test ensures ``FastForwardSampler`` applied to ``BatchSampler`` correctly retrived
|
|
the right next batch on restart.
|
|
"""
|
|
dataset = range(15)
|
|
sampler = SequentialSampler(dataset)
|
|
batch_sampler = BatchSampler(sampler, 3, False)
|
|
index_batch_sampler = FastForwardSampler(batch_sampler)
|
|
|
|
assert isinstance(index_batch_sampler, Iterable)
|
|
|
|
index_batch_sampler_iter = iter(index_batch_sampler)
|
|
|
|
assert next(index_batch_sampler_iter) == [0, 1, 2]
|
|
assert next(index_batch_sampler_iter) == [3, 4, 5]
|
|
|
|
state_dict = index_batch_sampler.state_dict(2)
|
|
|
|
index_batch_sampler = FastForwardSampler(batch_sampler)
|
|
index_batch_sampler.load_state_dict(state_dict)
|
|
|
|
index_batch_sampler_iter = iter(index_batch_sampler)
|
|
assert next(index_batch_sampler_iter) == [6, 7, 8]
|
|
|
|
|
|
def test_fast_forward_on_sequential_sampler():
|
|
"""
|
|
This test ensures ``FastForwardSampler`` applied to ``SequentialSampler`` correctly retrived
|
|
the right next batch on restart.
|
|
"""
|
|
dataset = range(15)
|
|
sequential_sampler = SequentialSampler(dataset)
|
|
sampler = FastForwardSampler(sequential_sampler)
|
|
sampler.setup(3)
|
|
batch_sampler = BatchSampler(sampler, 3, False)
|
|
|
|
batch_sampler_iter = iter(batch_sampler)
|
|
|
|
assert next(batch_sampler_iter) == [0, 1, 2]
|
|
assert next(batch_sampler_iter) == [3, 4, 5]
|
|
|
|
state_dict = sampler.state_dict(2)
|
|
assert state_dict[0]["current_iteration"] == 6
|
|
|
|
sampler.load_state_dict(state_dict)
|
|
|
|
batch_sampler_iter = iter(batch_sampler)
|
|
assert next(batch_sampler_iter) == [6, 7, 8]
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.is_available(), reason="todo (tchaton) Need more investigation")
|
|
def test_fast_forward_on_random_sampler():
|
|
"""
|
|
This test ensures ``FastForwardSampler`` applied to ``RandomSampler`` correctly retrived
|
|
the right next batch on restart.
|
|
"""
|
|
seed = 42
|
|
seed_everything(42)
|
|
|
|
dataset = range(15)
|
|
generator = torch.Generator().manual_seed(seed)
|
|
values = list(RandomSampler(dataset, generator=generator))
|
|
|
|
generator = torch.Generator().manual_seed(seed)
|
|
random_sampler = RandomSampler(dataset, generator=generator)
|
|
sampler = FastForwardSampler(random_sampler)
|
|
sampler.setup(3)
|
|
batch_sampler = BatchSampler(sampler, 3, False)
|
|
|
|
batch_sampler_iter = iter(batch_sampler)
|
|
|
|
assert next(batch_sampler_iter) == values[:3]
|
|
assert next(batch_sampler_iter) == values[3:6]
|
|
assert next(batch_sampler_iter) == values[6:9]
|
|
|
|
state_dict = sampler.state_dict(3)
|
|
assert state_dict[0]["current_iteration"] == 9
|
|
state_dict[0]["current_iteration"] = 6
|
|
|
|
seed_everything(42)
|
|
generator = torch.Generator().manual_seed(seed)
|
|
random_sampler = RandomSampler(dataset, generator=generator)
|
|
sampler = FastForwardSampler(random_sampler)
|
|
sampler.setup(3)
|
|
batch_sampler = BatchSampler(sampler, 3, False)
|
|
sampler.load_state_dict(state_dict)
|
|
|
|
batch_sampler_iter = iter(batch_sampler)
|
|
assert next(batch_sampler_iter) == values[6:9]
|
|
has_raised = False
|
|
try:
|
|
for _ in range(5):
|
|
next(batch_sampler_iter)
|
|
except StopIteration:
|
|
has_raised = True
|
|
assert sampler._current_iteration == 0
|
|
sampler.load_state_dict(sampler.state_dict(0))
|
|
assert has_raised
|
|
|
|
|
|
class RangeIterableDataset(IterableDataset):
|
|
def __init__(self, data, num_workers: int, batch_size: int, state_dict=None, attr_name: str = "iter_sampler"):
|
|
self.data = list(data)
|
|
self.batch_size = batch_size
|
|
self.num_workers = num_workers
|
|
self.state_dict = state_dict
|
|
self.attr_name = attr_name
|
|
|
|
def __iter__(self):
|
|
worker_info = get_worker_info()
|
|
if worker_info and self.num_workers == 2:
|
|
id = worker_info.id
|
|
num_samples = len(self.data)
|
|
if id == 0:
|
|
self.data = list(self.data)[: num_samples // 2]
|
|
else:
|
|
self.data = list(self.data)[num_samples // 2 :]
|
|
self.user_sampler = RandomSampler(self.data)
|
|
else:
|
|
self.user_sampler = RandomSampler(self.data)
|
|
|
|
setattr(self, self.attr_name, iter(self.user_sampler))
|
|
return self
|
|
|
|
def __next__(self):
|
|
iter_sampler = getattr(self, self.attr_name)
|
|
return self.data[next(iter_sampler)]
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 30 sec and should be skipped in Azure CI")
|
|
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
|
def test_fast_forward_sampler_over_iterative_dataset(num_workers):
|
|
"""
|
|
This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being
|
|
used to capture workers states.
|
|
"""
|
|
batch_size = 3
|
|
initial_seed = seed_everything(42)
|
|
generator = torch.Generator()
|
|
generator.manual_seed(initial_seed)
|
|
dataset = RangeIterableDataset(range(20), num_workers, batch_size, True)
|
|
dataset = CaptureIterableDataset(dataset)
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator)
|
|
_add_capture_metadata_collate(dataloader)
|
|
|
|
iter_dataloader = iter(dataloader)
|
|
batches = []
|
|
for _ in range(5):
|
|
batches.append(next(iter_dataloader))
|
|
|
|
# restarting on batch_1 and getting 3 extra batches
|
|
|
|
state_dict = {"iter_sampler": {}}
|
|
for batch in batches[:2]:
|
|
batch, _state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch)
|
|
for k, v in _state_dict[0].items():
|
|
state_dict[k].update(v)
|
|
|
|
assert len(state_dict["iter_sampler"]) == (num_workers if num_workers > 1 else 1)
|
|
|
|
initial_seed = seed_everything(42)
|
|
generator.manual_seed(initial_seed)
|
|
dataset = RangeIterableDataset(range(20), num_workers, batch_size, state_dict=state_dict)
|
|
dataset = CaptureIterableDataset(dataset)
|
|
dataset.load_state_dict(state_dict)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator)
|
|
_add_capture_metadata_collate(dataloader)
|
|
|
|
iter_dataloader = iter(dataloader)
|
|
batches_restart = []
|
|
for _ in range(3):
|
|
batches_restart.append(next(iter_dataloader))
|
|
|
|
assert torch.equal(batches_restart[0]["data"], batches[2]["data"])
|
|
assert torch.equal(batches_restart[1]["data"], batches[3]["data"])
|
|
assert torch.equal(batches_restart[2]["data"], batches[4]["data"])
|
|
|
|
|
|
def _setup_ddp(rank, worldsize):
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
|
|
# initialize the process group
|
|
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
|
|
|
|
|
def _test_fast_forward_sampler_with_distributed_sampler(rank, worldsize):
|
|
_setup_ddp(rank, worldsize)
|
|
|
|
initial_seed = seed_everything(42)
|
|
|
|
generator = torch.Generator()
|
|
generator.manual_seed(initial_seed)
|
|
|
|
num_workers = 2
|
|
batch_size = 4
|
|
|
|
dataset = range(30)
|
|
sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed))
|
|
sampler.setup(batch_size)
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler
|
|
)
|
|
|
|
iter_dataloader = iter(dataloader)
|
|
|
|
num_yielded = 0
|
|
batches = []
|
|
while True:
|
|
try:
|
|
batches.append(next(iter_dataloader))
|
|
num_yielded += 1
|
|
except StopIteration:
|
|
break
|
|
|
|
expected = torch.tensor([17, 27, 24]) if rank == 0 else torch.tensor([19, 5, 3])
|
|
assert torch.equal(batches[-1], expected)
|
|
|
|
assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16
|
|
|
|
reload_state_dict = sampler.state_dict(num_yielded - 1)
|
|
assert reload_state_dict[0]["current_iteration"] == 12
|
|
|
|
sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed))
|
|
sampler.setup(batch_size)
|
|
sampler.load_state_dict(reload_state_dict)
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler
|
|
)
|
|
|
|
iter_dataloader = iter(dataloader)
|
|
|
|
batches = []
|
|
while True:
|
|
try:
|
|
batches.append(next(iter_dataloader))
|
|
except StopIteration:
|
|
break
|
|
|
|
assert torch.equal(batches[-1], expected)
|
|
assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 25 sec and should be skipped in Azure CI")
|
|
@RunIf(skip_windows=True)
|
|
def test_fast_forward_sampler_with_distributed_sampler():
|
|
"""Make sure result logging works with DDP"""
|
|
tutils.set_random_master_port()
|
|
worldsize = 2
|
|
mp.spawn(_test_fast_forward_sampler_with_distributed_sampler, args=(worldsize,), nprocs=worldsize)
|
|
|
|
|
|
class MetaLearningDataset(IterableDataset):
|
|
def __init__(
|
|
self,
|
|
dataset: Dataset,
|
|
batch_size: int,
|
|
drop_last: bool,
|
|
task_num_classes: int = 5,
|
|
num_workers: Optional[int] = None,
|
|
global_rank: Optional[int] = None,
|
|
world_size: Optional[int] = None,
|
|
initial_seed: Optional[int] = None,
|
|
shuffle: bool = True,
|
|
debugging: bool = False,
|
|
):
|
|
self.dataset = dataset
|
|
self.batch_size = batch_size
|
|
self.drop_last = drop_last
|
|
self.num_workers = num_workers or 1
|
|
self.global_rank = global_rank
|
|
self.world_size = world_size
|
|
self.task_num_classes = task_num_classes
|
|
self.labels = labels = getattr(dataset, "labels")
|
|
self.initial_seed = initial_seed
|
|
self.generator: Optional[torch.Generator] = None
|
|
self.current_task_iteration = 0
|
|
self.shuffle = shuffle
|
|
self.debugging = debugging
|
|
|
|
if labels is None:
|
|
raise MisconfigurationException(f"Provided {self.dataset} should have an attribute labels.")
|
|
|
|
if len(labels) != len(dataset):
|
|
raise MisconfigurationException("Found provided ``labels`` don't match the dataset length.")
|
|
|
|
if (isinstance(global_rank, int) and world_size is None) or (
|
|
isinstance(world_size, int) and global_rank is None
|
|
):
|
|
raise MisconfigurationException("Both ``world_size`` and ``global_rank`` should be provided !")
|
|
|
|
self.unique_labels = np.unique(self.labels)
|
|
|
|
@property
|
|
def worker_id(self) -> int:
|
|
worker_info = get_worker_info()
|
|
return worker_info.id if worker_info else 0
|
|
|
|
@property
|
|
def is_distributed(self) -> bool:
|
|
return self.world_size is not None and self.world_size > 1
|
|
|
|
def set_seed(self, shared: bool = False):
|
|
initial_seed = self.initial_seed + self.current_task_iteration
|
|
if shared:
|
|
seed = initial_seed
|
|
np_seed = _generate_state(initial_seed, 0)
|
|
else:
|
|
seed = initial_seed + self.worker_id + self.global_rank + self.current_task_iteration
|
|
np_seed = _generate_state(initial_seed, self.worker_id + self.global_rank)
|
|
|
|
random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
np.random.seed(np_seed)
|
|
|
|
def sample_task_indices(self):
|
|
self.set_seed(shared=True)
|
|
self.selected_indexes = np.random.choice(self.unique_labels, self.task_num_classes, replace=False)
|
|
self.selected_indexes.sort()
|
|
|
|
# subset of indices from the entire dataset where the labels are actually among the
|
|
# task_num_classes selected_indexes
|
|
|
|
self.task_indices = np.arange(len(self.dataset))[np.in1d(self.labels, self.selected_indexes)]
|
|
self.task_length = len(self.task_indices)
|
|
self.set_seed(shared=False)
|
|
|
|
@property
|
|
def worker_rank(self) -> int:
|
|
worker_id = self.worker_id
|
|
is_global_zero = self.global_rank == 0
|
|
return self.global_rank + worker_id + int(not is_global_zero)
|
|
|
|
def create_sampler(self):
|
|
data = range(self.task_length)
|
|
if self.world_size == 1 and self.num_workers in (0, 1):
|
|
if self.shuffle:
|
|
self.sampler = RandomSampler(data, generator=self.generator)
|
|
else:
|
|
self.sampler = SequentialSampler(data)
|
|
else:
|
|
num_workers = 1 if self.num_workers in (None, 0) else self.num_workers
|
|
num_replicas = num_workers * self.world_size
|
|
current_seed = self.initial_seed + self.current_task_iteration
|
|
self.sampler = DistributedSampler(
|
|
data, num_replicas=num_replicas, rank=self.worker_rank, shuffle=self.shuffle, seed=current_seed
|
|
)
|
|
|
|
def __iter__(self):
|
|
if self.generator is None:
|
|
self.generator = torch.Generator().manual_seed(self.initial_seed)
|
|
self.sample_task_indices()
|
|
self.create_sampler()
|
|
self.batch_sampler = BatchSampler(self.sampler, batch_size=self.batch_size, drop_last=self.drop_last)
|
|
self.iter_sampler = iter(self.batch_sampler)
|
|
self.is_first_batch = True
|
|
self.current_task_iteration += 1
|
|
return self
|
|
|
|
def increment_iteration(self):
|
|
self.current_task_iteration += 1
|
|
|
|
def __next__(self):
|
|
# this is optional, but useful to accumulate gradient over the entire task.
|
|
is_first_batch = self.is_first_batch if self.debugging else (self.is_first_batch and self.worker_id == 0)
|
|
if is_first_batch:
|
|
self.is_first_batch = False
|
|
return {"task_length": len(self.batch_sampler), "selected_indexes": self.selected_indexes}
|
|
|
|
random_indices = next(self.iter_sampler)
|
|
task_indices = [self.task_indices[idx] for idx in random_indices]
|
|
return default_collate([self.dataset[idx] for idx in task_indices])
|
|
|
|
|
|
class ClassificationDataset(Dataset):
|
|
def __init__(self, inputs, labels):
|
|
self.inputs = inputs
|
|
self.labels = labels
|
|
assert len(self.inputs) == len(self.labels)
|
|
|
|
def __getitem__(self, index):
|
|
return (self.inputs[index], self.labels[index])
|
|
|
|
def __len__(self):
|
|
return len(self.inputs)
|
|
|
|
|
|
def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(rank, worldsize):
|
|
if worldsize > 1:
|
|
_setup_ddp(rank, worldsize)
|
|
|
|
def all_gather(tensor, world_size):
|
|
tensor_list = [torch.zeros_like(tensor, dtype=torch.int64) for _ in range(world_size)]
|
|
torch.distributed.all_gather(tensor_list, tensor)
|
|
return tensor_list
|
|
|
|
initial_seed = seed_everything(42)
|
|
|
|
generator = torch.Generator()
|
|
generator.manual_seed(initial_seed)
|
|
|
|
num_workers = 2
|
|
batch_size = 4
|
|
dataset_length = 60
|
|
num_classes = 10
|
|
|
|
labels = np.random.randint(0, num_classes, dataset_length)
|
|
|
|
dataset = ClassificationDataset(range(dataset_length), labels)
|
|
dataset = MetaLearningDataset(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
drop_last=True,
|
|
num_workers=num_workers,
|
|
global_rank=rank,
|
|
world_size=worldsize,
|
|
initial_seed=initial_seed,
|
|
debugging=True,
|
|
shuffle=True,
|
|
)
|
|
dataset = CaptureIterableDataset(dataset)
|
|
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator)
|
|
_add_capture_metadata_collate(dataloader)
|
|
|
|
epoch_results = []
|
|
for _ in range(2):
|
|
iter_dataloader = iter(dataloader)
|
|
batches = []
|
|
while True:
|
|
try:
|
|
batches.append(next(iter_dataloader))
|
|
except StopIteration:
|
|
break
|
|
epoch_results.append(batches)
|
|
dataloader.dataset.dataset.current_task_iteration += 1
|
|
|
|
assert len(epoch_results) == 2
|
|
|
|
assert len(epoch_results[0]) == math.ceil((dataset_length / (num_workers * worldsize)) / batch_size) + 2
|
|
|
|
if worldsize == 1:
|
|
assert epoch_results[0][0]["data"]["task_length"] == epoch_results[0][1]["data"]["task_length"]
|
|
assert torch.equal(
|
|
epoch_results[0][0]["data"]["selected_indexes"], epoch_results[0][1]["data"]["selected_indexes"]
|
|
)
|
|
assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 0
|
|
assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 1
|
|
assert not torch.equal(epoch_results[0][2]["data"][0], epoch_results[0][3]["data"][0])
|
|
else:
|
|
first_task_metadata = all_gather(epoch_results[0][0]["data"]["task_length"], worldsize)
|
|
second_task_metadata = all_gather(epoch_results[0][1]["data"]["task_length"], worldsize)
|
|
assert torch.equal(first_task_metadata[0], first_task_metadata[1])
|
|
assert torch.equal(second_task_metadata[0], second_task_metadata[1])
|
|
assert torch.equal(first_task_metadata[0], second_task_metadata[1])
|
|
|
|
first_batch_list = all_gather(epoch_results[0][2]["data"][0], worldsize)
|
|
assert not torch.equal(first_batch_list[0], first_batch_list[1])
|
|
second_batch_list = all_gather(epoch_results[0][3]["data"][0], worldsize)
|
|
assert not torch.equal(second_batch_list[0], second_batch_list[1])
|
|
|
|
# restarting on epoch 0 / real batch 2
|
|
state_dict = {"iter_sampler": {}}
|
|
for batch in epoch_results[0][2:4]:
|
|
batch, _state_dict = CaptureIterableDataset.extract_samplers_state_dict_from_batch(batch)
|
|
for k, v in _state_dict[0].items():
|
|
state_dict[k].update(v)
|
|
|
|
dataset = ClassificationDataset(range(dataset_length), labels)
|
|
dataset = MetaLearningDataset(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
drop_last=True,
|
|
num_workers=num_workers,
|
|
global_rank=rank,
|
|
world_size=worldsize,
|
|
initial_seed=initial_seed,
|
|
debugging=True,
|
|
shuffle=True,
|
|
)
|
|
|
|
dataset = CaptureIterableDataset(dataset)
|
|
dataset.load_state_dict(state_dict)
|
|
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator)
|
|
_add_capture_metadata_collate(dataloader)
|
|
|
|
epoch_results_restart = []
|
|
for _ in range(2):
|
|
iter_dataloader = iter(dataloader)
|
|
batches = []
|
|
while True:
|
|
try:
|
|
batches.append(next(iter_dataloader))
|
|
except StopIteration:
|
|
break
|
|
epoch_results_restart.append(batches)
|
|
dataloader.dataset.dataset.increment_iteration()
|
|
dataloader.dataset.reset_on_epoch()
|
|
|
|
assert len(epoch_results_restart[0]) + 2 == len(epoch_results[0])
|
|
epoch_tensors = [e["data"][0] for e in epoch_results[0][4:]]
|
|
epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[0][2:]]
|
|
|
|
for t, tr in zip(epoch_tensors, epoch_tensors_restart):
|
|
assert torch.equal(t, tr)
|
|
|
|
epoch_tensors = [e["data"][0] for e in epoch_results[1][2:]]
|
|
epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[1][2:]]
|
|
|
|
for t, tr in zip(epoch_tensors, epoch_tensors_restart):
|
|
assert torch.equal(t, tr)
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 45 sec and should be skipped in Azure CI")
|
|
def test_fast_forward_sampler_iterative_dataset():
|
|
_test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(0, 1)
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 55 sec and should be skipped in Azure CI")
|
|
@RunIf(skip_windows=True)
|
|
def test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset():
|
|
"""Make sure result logging works with DDP"""
|
|
tutils.set_random_master_port()
|
|
worldsize = 2
|
|
mp.spawn(
|
|
_test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset, args=(worldsize,), nprocs=worldsize
|
|
)
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
|
@RunIf(max_torch="1.7")
|
|
def test_fault_tolerant_not_supported():
|
|
assert not _fault_tolerant_training()
|
|
|
|
|
|
def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", wrap: bool = True):
|
|
dataset = RangeIterableDataset(range(50), num_workers=num_workers, batch_size=batch_size, attr_name=attr_name)
|
|
if wrap:
|
|
dataset = CaptureIterableDataset(dataset)
|
|
return dataset
|
|
|
|
|
|
def test_dataloader_to_state_dict_and_reload():
|
|
"""
|
|
Note: Those utilities are used only with DataLoader wrapping a ``mapping`` based dataset.
|
|
"""
|
|
|
|
def create_dataloader():
|
|
dataset = range(50)
|
|
batch_size = 8
|
|
sampler = FastForwardSampler(SequentialSampler(dataset))
|
|
sampler.setup(batch_size)
|
|
|
|
return DataLoader(dataset, sampler=sampler, batch_size=batch_size)
|
|
|
|
dataloader = create_dataloader()
|
|
iter_dataloader = iter(dataloader)
|
|
_ = next(iter_dataloader)
|
|
_ = next(iter_dataloader)
|
|
|
|
state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
|
|
assert state_dict[0]["current_iteration"] == 16
|
|
|
|
dataloader = create_dataloader()
|
|
dataloader = _dataloader_load_state_dict(dataloader, state_dict)
|
|
iter_dataloader = iter(dataloader)
|
|
_ = next(iter_dataloader)
|
|
|
|
state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
|
|
assert state_dict[0]["current_iteration"] == 24
|
|
|
|
|
|
@RunIf(min_torch="1.7.0")
|
|
@pytest.mark.parametrize("use_fault_tolerant", ["0", "1"])
|
|
def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir):
|
|
"""
|
|
this test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.
|
|
"""
|
|
|
|
class CustomBatchSampler(BatchSampler):
|
|
pass
|
|
|
|
dataset = range(50)
|
|
|
|
class TestModel(BoringModel):
|
|
def train_dataloader(self):
|
|
return {
|
|
"a": [
|
|
DataLoader(create_iterable_dataset(3, 1, wrap=False), num_workers=0, batch_size=3),
|
|
DataLoader(dataset, batch_size=8),
|
|
DataLoader(
|
|
dataset,
|
|
batch_sampler=CustomBatchSampler(SequentialSampler(dataset), batch_size=8, drop_last=False),
|
|
),
|
|
],
|
|
"b": DataLoader(
|
|
create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler", wrap=False),
|
|
num_workers=0,
|
|
batch_size=2,
|
|
),
|
|
}
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
pass
|
|
|
|
class Check(Callback):
|
|
def on_train_batch_start(self, trainer, *_) -> None:
|
|
loaders = trainer.train_dataloader.loaders
|
|
if use_fault_tolerant == "1":
|
|
assert isinstance(loaders["a"][0].loader.dataset, CaptureIterableDataset)
|
|
assert isinstance(loaders["a"][1].loader.sampler, FastForwardSampler)
|
|
assert isinstance(loaders["a"][2].loader.batch_sampler, FastForwardSampler)
|
|
assert isinstance(loaders["b"].loader.dataset, CaptureIterableDataset)
|
|
else:
|
|
assert isinstance(loaders["a"][0].loader.dataset, RangeIterableDataset)
|
|
assert isinstance(loaders["a"][1].loader.sampler, SequentialSampler)
|
|
assert isinstance(loaders["a"][2].loader.batch_sampler, CustomBatchSampler)
|
|
assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset)
|
|
|
|
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}):
|
|
model = TestModel()
|
|
model.training_epoch_end = None
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check())
|
|
trainer.fit(model)
|