1199 lines
44 KiB
Python
1199 lines
44 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
import os
|
|
import random
|
|
import random as python_random
|
|
from collections import defaultdict
|
|
from collections.abc import Iterable
|
|
from contextlib import suppress
|
|
from copy import deepcopy
|
|
from typing import List, Optional
|
|
from unittest import mock
|
|
from unittest.mock import ANY
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler, SequentialSampler
|
|
from torch.utils.data._utils.worker import get_worker_info
|
|
from torch.utils.data.dataloader import DataLoader, default_collate
|
|
from torch.utils.data.dataset import Dataset, IterableDataset
|
|
|
|
import tests.helpers.utils as tutils
|
|
from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
|
|
from pytorch_lightning.utilities.auto_restart import (
|
|
_add_capture_metadata_collate,
|
|
_dataloader_load_state_dict,
|
|
_dataloader_to_state_dict,
|
|
CaptureIterableDataset,
|
|
CaptureMapDataset,
|
|
FastForwardSampler,
|
|
MergedIteratorState,
|
|
)
|
|
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
|
|
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
|
|
from pytorch_lightning.utilities.fetching import DataFetcher
|
|
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
|
from tests.helpers.boring_model import BoringModel, RandomDataset
|
|
from tests.helpers.runif import RunIf
|
|
|
|
|
|
# Credit to PyTorch Team.
|
|
# Taken from:
|
|
# https://github.com/pytorch/pytorch/blob/3b977a0d2834d300c0301a0c6af98c8e939019ce/torch/utils/data/_utils/worker.py#L151
|
|
# Not available until torch 1.9.0
|
|
def _generate_state(base_seed, worker_id):
|
|
INIT_A = 0x43B0D7E5
|
|
MULT_A = 0x931E8875
|
|
INIT_B = 0x8B51F9DD
|
|
MULT_B = 0x58F38DED
|
|
MIX_MULT_L = 0xCA01F9DD
|
|
MIX_MULT_R = 0x4973F715
|
|
XSHIFT = 4 * 8 // 2
|
|
MASK32 = 0xFFFFFFFF
|
|
|
|
entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
|
|
pool = [0] * 4
|
|
|
|
hash_const_A = INIT_A
|
|
|
|
def hash(value):
|
|
nonlocal hash_const_A
|
|
value = (value ^ hash_const_A) & MASK32
|
|
hash_const_A = (hash_const_A * MULT_A) & MASK32
|
|
value = (value * hash_const_A) & MASK32
|
|
value = (value ^ (value >> XSHIFT)) & MASK32
|
|
return value
|
|
|
|
def mix(x, y):
|
|
result_x = (MIX_MULT_L * x) & MASK32
|
|
result_y = (MIX_MULT_R * y) & MASK32
|
|
result = (result_x - result_y) & MASK32
|
|
result = (result ^ (result >> XSHIFT)) & MASK32
|
|
return result
|
|
|
|
# Add in the entropy to the pool.
|
|
for i in range(len(pool)):
|
|
pool[i] = hash(entropy[i])
|
|
|
|
# Mix all bits together so late bits can affect earlier bits.
|
|
for i_src in range(len(pool)):
|
|
for i_dst in range(len(pool)):
|
|
if i_src != i_dst:
|
|
pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
|
|
|
|
hash_const_B = INIT_B
|
|
state = []
|
|
for i_dst in range(4):
|
|
data_val = pool[i_dst]
|
|
data_val = (data_val ^ hash_const_B) & MASK32
|
|
hash_const_B = (hash_const_B * MULT_B) & MASK32
|
|
data_val = (data_val * hash_const_B) & MASK32
|
|
data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
|
|
state.append(data_val)
|
|
return state
|
|
|
|
|
|
def test_fast_forward_getattr():
|
|
dataset = range(15)
|
|
sampler = SequentialSampler(dataset)
|
|
batch_sampler = BatchSampler(sampler, 3, False)
|
|
index_batch_sampler = FastForwardSampler(batch_sampler)
|
|
|
|
assert index_batch_sampler.batch_size == 3
|
|
assert index_batch_sampler.sampler == sampler
|
|
|
|
|
|
def test_fast_forward_on_batch_sampler():
|
|
"""This test ensures ``FastForwardSampler`` applied to ``BatchSampler`` correctly retrived the right next batch
|
|
on restart."""
|
|
dataset = range(15)
|
|
sampler = SequentialSampler(dataset)
|
|
batch_sampler = BatchSampler(sampler, 3, False)
|
|
index_batch_sampler = FastForwardSampler(batch_sampler)
|
|
|
|
assert isinstance(index_batch_sampler, Iterable)
|
|
|
|
index_batch_sampler_iter = iter(index_batch_sampler)
|
|
|
|
assert next(index_batch_sampler_iter) == [0, 1, 2]
|
|
assert next(index_batch_sampler_iter) == [3, 4, 5]
|
|
|
|
state_dict = index_batch_sampler.state_dict(2)
|
|
|
|
index_batch_sampler = FastForwardSampler(batch_sampler)
|
|
index_batch_sampler.load_state_dict(state_dict)
|
|
|
|
index_batch_sampler_iter = iter(index_batch_sampler)
|
|
assert next(index_batch_sampler_iter) == [6, 7, 8]
|
|
|
|
|
|
def test_fast_forward_on_sequential_sampler():
|
|
"""This test ensures ``FastForwardSampler`` applied to ``SequentialSampler`` correctly retrived the right next
|
|
batch on restart."""
|
|
dataset = range(15)
|
|
sequential_sampler = SequentialSampler(dataset)
|
|
sampler = FastForwardSampler(sequential_sampler)
|
|
sampler.setup(3)
|
|
batch_sampler = BatchSampler(sampler, 3, False)
|
|
|
|
batch_sampler_iter = iter(batch_sampler)
|
|
|
|
assert next(batch_sampler_iter) == [0, 1, 2]
|
|
assert next(batch_sampler_iter) == [3, 4, 5]
|
|
|
|
state_dict = sampler.state_dict(2)
|
|
assert state_dict[0]["current_iteration"] == 6
|
|
|
|
sampler.load_state_dict(state_dict)
|
|
|
|
batch_sampler_iter = iter(batch_sampler)
|
|
assert next(batch_sampler_iter) == [6, 7, 8]
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.is_available(), reason="todo (tchaton) Need more investigation")
|
|
def test_fast_forward_on_random_sampler():
|
|
"""This test ensures ``FastForwardSampler`` applied to ``RandomSampler`` correctly retrived the right next
|
|
batch on restart."""
|
|
seed = 42
|
|
seed_everything(42)
|
|
|
|
dataset = range(15)
|
|
generator = torch.Generator().manual_seed(seed)
|
|
values = list(RandomSampler(dataset, generator=generator))
|
|
|
|
generator = torch.Generator().manual_seed(seed)
|
|
random_sampler = RandomSampler(dataset, generator=generator)
|
|
sampler = FastForwardSampler(random_sampler)
|
|
sampler.setup(3)
|
|
batch_sampler = BatchSampler(sampler, 3, False)
|
|
|
|
batch_sampler_iter = iter(batch_sampler)
|
|
|
|
assert next(batch_sampler_iter) == values[:3]
|
|
assert next(batch_sampler_iter) == values[3:6]
|
|
assert next(batch_sampler_iter) == values[6:9]
|
|
|
|
state_dict = sampler.state_dict(3)
|
|
assert state_dict[0]["current_iteration"] == 9
|
|
state_dict[0]["current_iteration"] = 6
|
|
|
|
seed_everything(42)
|
|
generator = torch.Generator().manual_seed(seed)
|
|
random_sampler = RandomSampler(dataset, generator=generator)
|
|
sampler = FastForwardSampler(random_sampler)
|
|
sampler.setup(3)
|
|
batch_sampler = BatchSampler(sampler, 3, False)
|
|
sampler.load_state_dict(state_dict)
|
|
|
|
batch_sampler_iter = iter(batch_sampler)
|
|
assert next(batch_sampler_iter) == values[6:9]
|
|
has_raised = False
|
|
try:
|
|
for _ in range(5):
|
|
next(batch_sampler_iter)
|
|
except StopIteration:
|
|
has_raised = True
|
|
assert sampler._current_iteration == 0
|
|
sampler.load_state_dict(sampler.state_dict(0))
|
|
assert has_raised
|
|
|
|
|
|
class RangeIterableDataset(IterableDataset):
|
|
def __init__(self, data, num_workers: int, batch_size: int, state_dict=None, attr_name: str = "iter_sampler"):
|
|
self.data = list(data)
|
|
self.batch_size = batch_size
|
|
self.num_workers = num_workers
|
|
self.state_dict = state_dict
|
|
self.attr_name = attr_name
|
|
|
|
def __iter__(self):
|
|
worker_info = get_worker_info()
|
|
if worker_info and self.num_workers == 2:
|
|
id = worker_info.id
|
|
num_samples = len(self.data)
|
|
if id == 0:
|
|
self.data = list(self.data)[: num_samples // 2]
|
|
else:
|
|
self.data = list(self.data)[num_samples // 2 :]
|
|
self.user_sampler = RandomSampler(self.data)
|
|
else:
|
|
self.user_sampler = RandomSampler(self.data)
|
|
|
|
setattr(self, self.attr_name, iter(self.user_sampler))
|
|
return self
|
|
|
|
def __next__(self):
|
|
iter_sampler = getattr(self, self.attr_name)
|
|
return self.data[next(iter_sampler)]
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 30 sec and should be skipped in Azure CI")
|
|
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
|
def test_fast_forward_sampler_over_iterable_dataset(num_workers):
|
|
"""This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being used to capture
|
|
workers states."""
|
|
batch_size = 3
|
|
initial_seed = seed_everything(42)
|
|
generator = torch.Generator()
|
|
generator.manual_seed(initial_seed)
|
|
dataset = RangeIterableDataset(range(20), num_workers, batch_size, True)
|
|
dataset = CaptureIterableDataset(dataset)
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator)
|
|
_add_capture_metadata_collate(dataloader)
|
|
|
|
iter_dataloader = iter(dataloader)
|
|
batches = []
|
|
for _ in range(5):
|
|
batches.append(next(iter_dataloader))
|
|
|
|
# restarting on batch_1 and getting 3 extra batches
|
|
|
|
state_dict = {"iter_sampler": {}}
|
|
for batch in batches[:2]:
|
|
batch, _state_dict = batch["data"], batch[AutoRestartBatchKeys.PL_RESTART_META]
|
|
for k, v in _state_dict.items():
|
|
state_dict[k].update(v)
|
|
|
|
assert len(state_dict["iter_sampler"]) == (num_workers if num_workers > 1 else 1)
|
|
|
|
initial_seed = seed_everything(42)
|
|
generator.manual_seed(initial_seed)
|
|
dataset = RangeIterableDataset(range(20), num_workers, batch_size, state_dict=state_dict)
|
|
dataset = CaptureIterableDataset(dataset)
|
|
dataset.load_state_dict(state_dict)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator)
|
|
_add_capture_metadata_collate(dataloader)
|
|
|
|
iter_dataloader = iter(dataloader)
|
|
batches_restart = []
|
|
for _ in range(3):
|
|
batches_restart.append(next(iter_dataloader))
|
|
|
|
assert torch.equal(batches_restart[0]["data"], batches[2]["data"])
|
|
assert torch.equal(batches_restart[1]["data"], batches[3]["data"])
|
|
assert torch.equal(batches_restart[2]["data"], batches[4]["data"])
|
|
|
|
|
|
def _setup_ddp(rank, worldsize):
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
|
|
# initialize the process group
|
|
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
|
|
|
|
|
def _test_fast_forward_sampler_with_distributed_sampler(rank, worldsize):
|
|
_setup_ddp(rank, worldsize)
|
|
|
|
initial_seed = seed_everything(42)
|
|
|
|
generator = torch.Generator()
|
|
generator.manual_seed(initial_seed)
|
|
|
|
num_workers = 2
|
|
batch_size = 4
|
|
|
|
dataset = range(30)
|
|
sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed))
|
|
sampler.setup(batch_size)
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler
|
|
)
|
|
|
|
iter_dataloader = iter(dataloader)
|
|
|
|
num_yielded = 0
|
|
batches = []
|
|
while True:
|
|
try:
|
|
batches.append(next(iter_dataloader))
|
|
num_yielded += 1
|
|
except StopIteration:
|
|
break
|
|
|
|
expected = torch.tensor([17, 27, 24]) if rank == 0 else torch.tensor([19, 5, 3])
|
|
assert torch.equal(batches[-1], expected)
|
|
|
|
assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16
|
|
|
|
reload_state_dict = sampler.state_dict(num_yielded - 1)
|
|
assert reload_state_dict[0]["current_iteration"] == 12
|
|
|
|
sampler = FastForwardSampler(DistributedSampler(dataset, num_replicas=worldsize, rank=rank, seed=initial_seed))
|
|
sampler.setup(batch_size)
|
|
sampler.load_state_dict(reload_state_dict)
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=batch_size, num_workers=num_workers, generator=generator, sampler=sampler
|
|
)
|
|
|
|
iter_dataloader = iter(dataloader)
|
|
|
|
batches = []
|
|
while True:
|
|
try:
|
|
batches.append(next(iter_dataloader))
|
|
except StopIteration:
|
|
break
|
|
|
|
assert torch.equal(batches[-1], expected)
|
|
assert sampler.state_dict(num_yielded)[0]["current_iteration"] == 16
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 25 sec and should be skipped in Azure CI")
|
|
@RunIf(skip_windows=True)
|
|
def test_fast_forward_sampler_with_distributed_sampler():
|
|
"""Make sure result logging works with DDP."""
|
|
tutils.set_random_main_port()
|
|
worldsize = 2
|
|
mp.spawn(_test_fast_forward_sampler_with_distributed_sampler, args=(worldsize,), nprocs=worldsize)
|
|
|
|
|
|
class MetaLearningDataset(IterableDataset):
|
|
def __init__(
|
|
self,
|
|
dataset: Dataset,
|
|
batch_size: int,
|
|
drop_last: bool,
|
|
task_num_classes: int = 5,
|
|
num_workers: Optional[int] = None,
|
|
global_rank: Optional[int] = None,
|
|
world_size: Optional[int] = None,
|
|
initial_seed: Optional[int] = None,
|
|
shuffle: bool = True,
|
|
debugging: bool = False,
|
|
):
|
|
self.dataset = dataset
|
|
self.batch_size = batch_size
|
|
self.drop_last = drop_last
|
|
self.num_workers = num_workers or 1
|
|
self.global_rank = global_rank
|
|
self.world_size = world_size
|
|
self.task_num_classes = task_num_classes
|
|
self.labels = labels = getattr(dataset, "labels")
|
|
self.initial_seed = initial_seed
|
|
self.generator: Optional[torch.Generator] = None
|
|
self.current_task_iteration = 0
|
|
self.shuffle = shuffle
|
|
self.debugging = debugging
|
|
|
|
if labels is None:
|
|
raise MisconfigurationException(f"Provided {self.dataset} should have an attribute labels.")
|
|
|
|
if len(labels) != len(dataset):
|
|
raise MisconfigurationException("Found provided ``labels`` don't match the dataset length.")
|
|
|
|
if (isinstance(global_rank, int) and world_size is None) or (
|
|
isinstance(world_size, int) and global_rank is None
|
|
):
|
|
raise MisconfigurationException("Both ``world_size`` and ``global_rank`` should be provided !")
|
|
|
|
self.unique_labels = np.unique(self.labels)
|
|
|
|
@property
|
|
def worker_id(self) -> int:
|
|
worker_info = get_worker_info()
|
|
return worker_info.id if worker_info else 0
|
|
|
|
@property
|
|
def is_distributed(self) -> bool:
|
|
return self.world_size is not None and self.world_size > 1
|
|
|
|
def set_seed(self, shared: bool = False):
|
|
initial_seed = self.initial_seed + self.current_task_iteration
|
|
if shared:
|
|
seed = initial_seed
|
|
np_seed = _generate_state(initial_seed, 0)
|
|
else:
|
|
seed = initial_seed + self.worker_id + self.global_rank + self.current_task_iteration
|
|
np_seed = _generate_state(initial_seed, self.worker_id + self.global_rank)
|
|
|
|
random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
np.random.seed(np_seed)
|
|
|
|
def sample_task_indices(self):
|
|
self.set_seed(shared=True)
|
|
self.selected_indexes = np.random.choice(self.unique_labels, self.task_num_classes, replace=False)
|
|
self.selected_indexes.sort()
|
|
|
|
# subset of indices from the entire dataset where the labels are actually among the
|
|
# task_num_classes selected_indexes
|
|
|
|
self.task_indices = np.arange(len(self.dataset))[np.in1d(self.labels, self.selected_indexes)]
|
|
self.task_length = len(self.task_indices)
|
|
self.set_seed(shared=False)
|
|
|
|
@property
|
|
def worker_rank(self) -> int:
|
|
worker_id = self.worker_id
|
|
is_global_zero = self.global_rank == 0
|
|
return self.global_rank + worker_id + int(not is_global_zero)
|
|
|
|
def create_sampler(self):
|
|
data = range(self.task_length)
|
|
if self.world_size == 1 and self.num_workers in (0, 1):
|
|
if self.shuffle:
|
|
self.sampler = RandomSampler(data, generator=self.generator)
|
|
else:
|
|
self.sampler = SequentialSampler(data)
|
|
else:
|
|
num_workers = 1 if self.num_workers in (None, 0) else self.num_workers
|
|
num_replicas = num_workers * self.world_size
|
|
current_seed = self.initial_seed + self.current_task_iteration
|
|
self.sampler = DistributedSampler(
|
|
data, num_replicas=num_replicas, rank=self.worker_rank, shuffle=self.shuffle, seed=current_seed
|
|
)
|
|
|
|
def __iter__(self):
|
|
if self.generator is None:
|
|
self.generator = torch.Generator().manual_seed(self.initial_seed)
|
|
self.sample_task_indices()
|
|
self.create_sampler()
|
|
self.batch_sampler = BatchSampler(self.sampler, batch_size=self.batch_size, drop_last=self.drop_last)
|
|
self.iter_sampler = iter(self.batch_sampler)
|
|
self.is_first_batch = True
|
|
self.current_task_iteration += 1
|
|
return self
|
|
|
|
def increment_iteration(self):
|
|
self.current_task_iteration += 1
|
|
|
|
def __next__(self):
|
|
# this is optional, but useful to accumulate gradient over the entire task.
|
|
is_first_batch = self.is_first_batch if self.debugging else (self.is_first_batch and self.worker_id == 0)
|
|
if is_first_batch:
|
|
self.is_first_batch = False
|
|
return {"task_length": len(self.batch_sampler), "selected_indexes": self.selected_indexes}
|
|
|
|
random_indices = next(self.iter_sampler)
|
|
task_indices = [self.task_indices[idx] for idx in random_indices]
|
|
return default_collate([self.dataset[idx] for idx in task_indices])
|
|
|
|
|
|
class ClassificationDataset(Dataset):
|
|
def __init__(self, inputs, labels):
|
|
self.inputs = inputs
|
|
self.labels = labels
|
|
assert len(self.inputs) == len(self.labels)
|
|
|
|
def __getitem__(self, index):
|
|
return (self.inputs[index], self.labels[index])
|
|
|
|
def __len__(self):
|
|
return len(self.inputs)
|
|
|
|
|
|
def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(rank, worldsize):
|
|
if worldsize > 1:
|
|
_setup_ddp(rank, worldsize)
|
|
|
|
def all_gather(tensor, world_size):
|
|
tensor_list = [torch.zeros_like(tensor, dtype=torch.int64) for _ in range(world_size)]
|
|
torch.distributed.all_gather(tensor_list, tensor)
|
|
return tensor_list
|
|
|
|
initial_seed = seed_everything(42)
|
|
|
|
generator = torch.Generator()
|
|
generator.manual_seed(initial_seed)
|
|
|
|
num_workers = 2
|
|
batch_size = 4
|
|
dataset_length = 60
|
|
num_classes = 10
|
|
|
|
labels = np.random.randint(0, num_classes, dataset_length)
|
|
|
|
dataset = ClassificationDataset(range(dataset_length), labels)
|
|
dataset = MetaLearningDataset(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
drop_last=True,
|
|
num_workers=num_workers,
|
|
global_rank=rank,
|
|
world_size=worldsize,
|
|
initial_seed=initial_seed,
|
|
debugging=True,
|
|
shuffle=True,
|
|
)
|
|
dataset = CaptureIterableDataset(dataset)
|
|
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator)
|
|
_add_capture_metadata_collate(dataloader)
|
|
|
|
epoch_results = []
|
|
for _ in range(2):
|
|
iter_dataloader = iter(dataloader)
|
|
batches = []
|
|
while True:
|
|
try:
|
|
batches.append(next(iter_dataloader))
|
|
except StopIteration:
|
|
break
|
|
epoch_results.append(batches)
|
|
dataloader.dataset.dataset.current_task_iteration += 1
|
|
|
|
assert len(epoch_results) == 2
|
|
|
|
assert len(epoch_results[0]) == math.ceil((dataset_length / (num_workers * worldsize)) / batch_size) + 2
|
|
|
|
if worldsize == 1:
|
|
assert epoch_results[0][0]["data"]["task_length"] == epoch_results[0][1]["data"]["task_length"]
|
|
assert torch.equal(
|
|
epoch_results[0][0]["data"]["selected_indexes"], epoch_results[0][1]["data"]["selected_indexes"]
|
|
)
|
|
assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 0
|
|
assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 1
|
|
assert not torch.equal(epoch_results[0][2]["data"][0], epoch_results[0][3]["data"][0])
|
|
else:
|
|
first_task_metadata = all_gather(epoch_results[0][0]["data"]["task_length"], worldsize)
|
|
second_task_metadata = all_gather(epoch_results[0][1]["data"]["task_length"], worldsize)
|
|
assert torch.equal(first_task_metadata[0], first_task_metadata[1])
|
|
assert torch.equal(second_task_metadata[0], second_task_metadata[1])
|
|
assert torch.equal(first_task_metadata[0], second_task_metadata[1])
|
|
|
|
first_batch_list = all_gather(epoch_results[0][2]["data"][0], worldsize)
|
|
assert not torch.equal(first_batch_list[0], first_batch_list[1])
|
|
second_batch_list = all_gather(epoch_results[0][3]["data"][0], worldsize)
|
|
assert not torch.equal(second_batch_list[0], second_batch_list[1])
|
|
|
|
# restarting on epoch 0 / real batch 2
|
|
state_dict = {"iter_sampler": {}}
|
|
for batch in epoch_results[0][2:4]:
|
|
batch, _state_dict = batch["data"], batch[AutoRestartBatchKeys.PL_RESTART_META]
|
|
for k, v in _state_dict.items():
|
|
state_dict[k].update(v)
|
|
|
|
dataset = ClassificationDataset(range(dataset_length), labels)
|
|
dataset = MetaLearningDataset(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
drop_last=True,
|
|
num_workers=num_workers,
|
|
global_rank=rank,
|
|
world_size=worldsize,
|
|
initial_seed=initial_seed,
|
|
debugging=True,
|
|
shuffle=True,
|
|
)
|
|
|
|
dataset = CaptureIterableDataset(dataset)
|
|
dataset.load_state_dict(state_dict)
|
|
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator)
|
|
_add_capture_metadata_collate(dataloader)
|
|
|
|
epoch_results_restart = []
|
|
for _ in range(2):
|
|
iter_dataloader = iter(dataloader)
|
|
batches = []
|
|
while True:
|
|
try:
|
|
batches.append(next(iter_dataloader))
|
|
except StopIteration:
|
|
break
|
|
epoch_results_restart.append(batches)
|
|
dataloader.dataset.dataset.increment_iteration()
|
|
dataloader.dataset.reset_on_epoch()
|
|
|
|
assert len(epoch_results_restart[0]) + 2 == len(epoch_results[0])
|
|
epoch_tensors = [e["data"][0] for e in epoch_results[0][4:]]
|
|
epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[0][2:]]
|
|
|
|
for t, tr in zip(epoch_tensors, epoch_tensors_restart):
|
|
assert torch.equal(t, tr)
|
|
|
|
epoch_tensors = [e["data"][0] for e in epoch_results[1][2:]]
|
|
epoch_tensors_restart = [e["data"][0] for e in epoch_results_restart[1][2:]]
|
|
|
|
for t, tr in zip(epoch_tensors, epoch_tensors_restart):
|
|
assert torch.equal(t, tr)
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 45 sec and should be skipped in Azure CI")
|
|
def test_fast_forward_sampler_iterative_dataset():
|
|
_test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(0, 1)
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 55 sec and should be skipped in Azure CI")
|
|
@RunIf(skip_windows=True)
|
|
def test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset():
|
|
"""Make sure result logging works with DDP."""
|
|
tutils.set_random_main_port()
|
|
worldsize = 2
|
|
mp.spawn(
|
|
_test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset, args=(worldsize,), nprocs=worldsize
|
|
)
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
|
@RunIf(max_torch="1.7")
|
|
def test_fault_tolerant_not_supported():
|
|
assert not _fault_tolerant_training()
|
|
|
|
|
|
def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", wrap: bool = True):
|
|
dataset = RangeIterableDataset(range(50), num_workers=num_workers, batch_size=batch_size, attr_name=attr_name)
|
|
if wrap:
|
|
dataset = CaptureIterableDataset(dataset)
|
|
return dataset
|
|
|
|
|
|
def test_dataloader_to_state_dict_and_reload():
|
|
"""
|
|
Note: Those utilities are used only with DataLoader wrapping a ``mapping`` based dataset.
|
|
"""
|
|
|
|
def create_dataloader():
|
|
dataset = range(50)
|
|
batch_size = 8
|
|
sampler = FastForwardSampler(SequentialSampler(dataset))
|
|
sampler.setup(batch_size)
|
|
|
|
return DataLoader(dataset, sampler=sampler, batch_size=batch_size)
|
|
|
|
dataloader = create_dataloader()
|
|
iter_dataloader = iter(dataloader)
|
|
_ = next(iter_dataloader)
|
|
_ = next(iter_dataloader)
|
|
|
|
state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
|
|
assert state_dict == {
|
|
"num_workers": 0,
|
|
"previous_worker": None,
|
|
0: {"current_iteration": 16},
|
|
}
|
|
|
|
dataloader = create_dataloader()
|
|
dataloader = _dataloader_load_state_dict(dataloader, state_dict)
|
|
iter_dataloader = iter(dataloader)
|
|
_ = next(iter_dataloader)
|
|
|
|
state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
|
|
assert state_dict == {
|
|
"num_workers": 0,
|
|
"previous_worker": None,
|
|
0: {"current_iteration": 24},
|
|
}
|
|
|
|
|
|
@RunIf(min_torch="1.7.0")
|
|
@pytest.mark.parametrize("use_fault_tolerant", ["0", "1"])
|
|
def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir):
|
|
"""This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled."""
|
|
|
|
class CustomBatchSampler(BatchSampler):
|
|
pass
|
|
|
|
dataset = range(50)
|
|
|
|
class TestModel(BoringModel):
|
|
def train_dataloader(self):
|
|
return {
|
|
"a": [
|
|
DataLoader(create_iterable_dataset(3, 1, wrap=False), num_workers=0, batch_size=3),
|
|
DataLoader(dataset, batch_size=8),
|
|
DataLoader(
|
|
dataset,
|
|
batch_sampler=CustomBatchSampler(SequentialSampler(dataset), batch_size=8, drop_last=False),
|
|
),
|
|
],
|
|
"b": DataLoader(
|
|
create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler", wrap=False),
|
|
num_workers=0,
|
|
batch_size=2,
|
|
),
|
|
}
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
assert batch == {
|
|
"a": [ANY, ANY, ANY],
|
|
"b": ANY,
|
|
}
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
assert isinstance(batch, torch.Tensor)
|
|
|
|
validation_epoch_end = None
|
|
|
|
class Check(Callback):
|
|
def on_train_batch_start(self, trainer, *_) -> None:
|
|
loaders = trainer.train_dataloader.loaders
|
|
if use_fault_tolerant == "1":
|
|
assert isinstance(loaders["a"][0].loader.dataset, CaptureIterableDataset)
|
|
assert isinstance(loaders["a"][1].loader.sampler, FastForwardSampler)
|
|
assert isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset)
|
|
assert isinstance(loaders["a"][2].loader.batch_sampler, FastForwardSampler)
|
|
assert isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset)
|
|
assert isinstance(loaders["b"].loader.dataset, CaptureIterableDataset)
|
|
else:
|
|
assert isinstance(loaders["a"][0].loader.dataset, RangeIterableDataset)
|
|
assert isinstance(loaders["a"][1].loader.sampler, SequentialSampler)
|
|
assert not isinstance(loaders["a"][1].loader.dataset, CaptureMapDataset)
|
|
assert isinstance(loaders["a"][2].loader.batch_sampler, CustomBatchSampler)
|
|
assert not isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset)
|
|
assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset)
|
|
|
|
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}):
|
|
model = TestModel()
|
|
model.training_epoch_end = None
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check())
|
|
trainer.fit(model)
|
|
|
|
|
|
class SequentialGetItemDataset(Dataset):
|
|
def __init__(self, length, *_):
|
|
self.len = length
|
|
|
|
def __getitem__(self, index):
|
|
return torch.tensor([index]).float()
|
|
|
|
def __len__(self):
|
|
return self.len
|
|
|
|
|
|
class RandomGetItemDataset(Dataset):
|
|
"""A dataset with random elements generated using global rng from torch, numpy and python."""
|
|
|
|
def __init__(self, length, size):
|
|
self.size = size
|
|
self.len = length
|
|
|
|
def __getitem__(self, index):
|
|
t = torch.rand(self.size)
|
|
n = torch.from_numpy(np.random.rand(self.size))
|
|
p = torch.tensor([python_random.random() for _ in range(self.size)])
|
|
sample = (index + (t + n + p) / 10).float()
|
|
return sample
|
|
|
|
def __len__(self):
|
|
return self.len
|
|
|
|
|
|
# TODO: test with `RandomGeneratorGetItemDataset`
|
|
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
|
@RunIf(min_torch="1.7.0")
|
|
@pytest.mark.parametrize(
|
|
"dataset_class",
|
|
[
|
|
SequentialGetItemDataset,
|
|
RandomGetItemDataset,
|
|
# RandomGeneratorGetItemDataset,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("num_workers", [0])
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 3])
|
|
def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size):
|
|
"""Test that the sequence of batches coming from a random number generator continues with the correct sequence
|
|
after reloading the state."""
|
|
|
|
def create_dataset_sampler():
|
|
dset = CaptureMapDataset(dataset_class(16, 8))
|
|
random_sampler = RandomSampler(dset, generator=torch.Generator())
|
|
return dset, random_sampler
|
|
|
|
def create_dataloader_sampler(dset, sampler):
|
|
sampler = FastForwardSampler(sampler)
|
|
sampler.setup(batch_size)
|
|
dl = DataLoader(dset, num_workers=num_workers, sampler=sampler, batch_size=batch_size)
|
|
_add_capture_metadata_collate(dl)
|
|
return dl, sampler
|
|
|
|
def fetch(fetcher, prefetch_iter, num_batches_fetched):
|
|
batch, _ = next(prefetch_iter)
|
|
|
|
state: List[MergedIteratorState] = fetcher.state
|
|
assert len(state) == 1
|
|
assert isinstance(state[0], MergedIteratorState)
|
|
|
|
assert len(fetcher.dataloader_iter.cache_states) == 1
|
|
if num_workers == 0:
|
|
assert state[0].state[0].num_batches_fetched == num_batches_fetched
|
|
return state
|
|
|
|
dataset, random_sampler = create_dataset_sampler()
|
|
dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler)
|
|
|
|
fetcher = DataFetcher()
|
|
fetcher.setup(dataloader)
|
|
prefetch_iter = iter(fetcher)
|
|
|
|
# fetch 4 batches
|
|
fetch(fetcher, prefetch_iter, 1)
|
|
fetch(fetcher, prefetch_iter, 2)
|
|
fetch(fetcher, prefetch_iter, 3)
|
|
|
|
# (A) capture the state after fetching 4 batches
|
|
state = fetch(fetcher, prefetch_iter, 4)
|
|
state = deepcopy(state[0])
|
|
|
|
# (B) simulate 2 additional batches
|
|
batch05, _ = next(prefetch_iter)
|
|
batch06, _ = next(prefetch_iter)
|
|
|
|
# start reloading
|
|
dataset, random_sampler = create_dataset_sampler()
|
|
dataloader, ff_sampler = create_dataloader_sampler(dataset, random_sampler)
|
|
|
|
# load the state dict saved at (A)
|
|
ff_sampler.load_state_dict(state.sampler_states)
|
|
dataset.load_state_dict(state.dataset_states, latest_worker_id=state.latest_worker_id, num_workers=num_workers)
|
|
|
|
prefetcher = DataFetcher()
|
|
prefetcher.setup(dataloader)
|
|
prefetch_iter = iter(prefetcher)
|
|
|
|
# fetch 2 random batches, these should match exactly the batches seen at (B)
|
|
batch05_restart, _ = next(prefetch_iter)
|
|
batch06_restart, _ = next(prefetch_iter)
|
|
|
|
assert torch.equal(batch05, batch05_restart)
|
|
assert torch.equal(batch06, batch06_restart)
|
|
|
|
|
|
class CustomException(Exception):
|
|
pass
|
|
|
|
|
|
class SequentialIterableDataset(IterableDataset):
|
|
def __init__(self, length, *_):
|
|
self.len = length
|
|
self.sampler = SequentialSampler(range(self.len))
|
|
|
|
def __iter__(self):
|
|
self.sampler_iter = iter(self.sampler)
|
|
return self
|
|
|
|
def __next__(self):
|
|
indices = next(self.sampler_iter)
|
|
return torch.tensor([indices]).float()
|
|
|
|
|
|
class SequentialDictIterableDataset(SequentialIterableDataset):
|
|
def __next__(self):
|
|
indices = next(self.sampler_iter)
|
|
return {"data": torch.tensor([indices]).float()}
|
|
|
|
|
|
class TestModel(LightningModule):
|
|
def __init__(self, fail_on_step: int = -1):
|
|
super().__init__()
|
|
self.layer = torch.nn.Linear(1, 2)
|
|
self.seen_batches = []
|
|
self.fail_on_step = fail_on_step
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
if self.global_step == self.fail_on_step:
|
|
raise CustomException()
|
|
batch = batch["data"] if isinstance(batch, dict) else batch
|
|
self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch)
|
|
loss = sum(self.layer(b).sum() for b in batch)
|
|
return loss
|
|
|
|
def configure_optimizers(self):
|
|
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
|
|
|
|
|
def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_path=None):
|
|
seed_everything(1)
|
|
train_dataloader = [
|
|
DataLoader(dataset_class(3, 1), batch_size=1, num_workers=0) for dataset_class in dataset_classes
|
|
]
|
|
train_dataloader = train_dataloader[0] if len(train_dataloader) == 1 else train_dataloader
|
|
model = TestModel(fail_on_step=fail_on_step)
|
|
trainer = Trainer(**trainer_kwargs)
|
|
with suppress(CustomException):
|
|
trainer.fit(model, train_dataloader=train_dataloader, ckpt_path=ckpt_path)
|
|
return model.seen_batches, model.parameters()
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
|
@RunIf(min_torch="1.7.0")
|
|
@pytest.mark.parametrize(
|
|
"dataset_classes",
|
|
[
|
|
# single training dataset
|
|
[RandomGetItemDataset],
|
|
[SequentialIterableDataset],
|
|
[SequentialDictIterableDataset],
|
|
# multiple training datasets (combinded dataloader)
|
|
[SequentialGetItemDataset, SequentialIterableDataset],
|
|
[SequentialIterableDataset, SequentialIterableDataset],
|
|
# [RandomGetItemDataset, RandomGetItemDataset], # TODO: support in the future
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"])
|
|
def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, multiple_trainloader_mode):
|
|
"""Test that the Trainer can resume from a failed run in the case of several types of datasets."""
|
|
trainer_kwargs = dict(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=3,
|
|
enable_progress_bar=False,
|
|
enable_model_summary=False,
|
|
multiple_trainloader_mode=multiple_trainloader_mode,
|
|
)
|
|
|
|
all_batches, weights0 = _run_training(trainer_kwargs, dataset_classes)
|
|
all_batches = torch.stack(all_batches)
|
|
assert len(all_batches) == 9
|
|
|
|
# Simulate 1st failure
|
|
complete_batches, _ = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4)
|
|
assert len(complete_batches) == 4
|
|
|
|
checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
|
|
assert os.path.exists(checkpoint_path)
|
|
|
|
# Resume after failure
|
|
resumed_batches, weights1 = _run_training(
|
|
trainer_kwargs, dataset_classes, fail_on_step=-1, ckpt_path=checkpoint_path
|
|
)
|
|
assert len(resumed_batches) == 5
|
|
|
|
# the resumed batches should match the batches of the successful training
|
|
all_batches_resumed = torch.stack(complete_batches + resumed_batches)
|
|
assert len(all_batches_resumed) == 9
|
|
assert torch.equal(all_batches, all_batches_resumed)
|
|
|
|
# the final weights of a resumed training should equal the weights of an uninterrupted training
|
|
for w0, w1 in zip(weights0, weights1):
|
|
assert w0 is not w1
|
|
assert torch.allclose(w0, w1)
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
|
@RunIf(min_torch="1.7.0")
|
|
@pytest.mark.parametrize(
|
|
["train_datasets", "val_datasets"],
|
|
[
|
|
([RandomGetItemDataset], [RandomGetItemDataset]),
|
|
([RandomGetItemDataset], [RandomGetItemDataset, RandomGetItemDataset]),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"val_check_interval",
|
|
[
|
|
pytest.param(
|
|
0.5,
|
|
marks=pytest.mark.xfail(
|
|
reason=(
|
|
"TODO: the `train_dataloader` random state overrides the validation state when restarting training"
|
|
)
|
|
),
|
|
),
|
|
1.0,
|
|
],
|
|
)
|
|
def test_auto_restart_within_validation_loop(train_datasets, val_datasets, val_check_interval, tmpdir):
|
|
n_val_dataloaders = len(val_datasets)
|
|
stop_dataloader = n_val_dataloaders - 1
|
|
stop_batch = 1
|
|
|
|
class ValidationLoopTestModel(LightningModule):
|
|
def __init__(self, should_fail):
|
|
super().__init__()
|
|
self.layer = torch.nn.Linear(1, 2)
|
|
self.should_fail = should_fail
|
|
self.training_batches = []
|
|
self.validation_batches = defaultdict(list)
|
|
|
|
def step(self, batch):
|
|
return sum(self.layer(b).sum() for b in batch)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
self.training_batches.append(batch)
|
|
return self.step(batch)
|
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
|
if self.should_fail and stop_dataloader == dataloader_idx and batch_idx == stop_batch:
|
|
raise CustomException
|
|
self.validation_batches[dataloader_idx].append(batch)
|
|
return self.step(batch)
|
|
|
|
def configure_optimizers(self):
|
|
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
|
|
|
def train_dataloader(self):
|
|
return [DataLoader(cls(4, 1)) for cls in train_datasets]
|
|
|
|
def val_dataloader(self):
|
|
return [DataLoader(cls(4, 1)) for cls in val_datasets]
|
|
|
|
def run(should_fail, resume):
|
|
if not resume:
|
|
seed_everything(42)
|
|
|
|
model = ValidationLoopTestModel(should_fail)
|
|
|
|
ckpt_path = str(tmpdir / ".pl_auto_save.ckpt") if resume else None
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
val_check_interval=val_check_interval,
|
|
num_sanity_val_steps=0,
|
|
)
|
|
if should_fail:
|
|
with pytest.raises(CustomException):
|
|
trainer.fit(model, ckpt_path=ckpt_path)
|
|
else:
|
|
trainer.fit(model, ckpt_path=ckpt_path)
|
|
|
|
return model.training_batches, model.validation_batches
|
|
|
|
total_train_batches, total_val_batches = run(should_fail=False, resume=False)
|
|
pre_fail_train_batches, pre_fail_val_batches = run(should_fail=True, resume=False)
|
|
post_fail_train_batches, post_fail_val_batches = run(should_fail=False, resume=True)
|
|
|
|
torch.testing.assert_allclose(total_train_batches, pre_fail_train_batches + post_fail_train_batches)
|
|
for k in total_val_batches:
|
|
torch.testing.assert_allclose(total_val_batches[k], pre_fail_val_batches[k] + post_fail_val_batches[k])
|
|
|
|
|
|
class TestAutoRestartModelUnderSignal(BoringModel):
|
|
def __init__(self, should_signal: bool, failure_on_step: bool, failure_on_training: bool, on_last_batch: bool):
|
|
super().__init__()
|
|
self.should_signal = should_signal
|
|
self.failure_on_step = failure_on_step
|
|
self.failure_on_training = failure_on_training
|
|
self.on_last_batch = on_last_batch
|
|
self.seen_train_batches = []
|
|
|
|
def _signal(self):
|
|
if self.should_signal:
|
|
# simulate `os.kill(os.getpid(), signal.SIGUSR1)`
|
|
self.trainer._terminate_gracefully = True
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
self.seen_train_batches.append(batch)
|
|
should_signal = self.trainer.fit_loop.epoch_loop._is_training_done if self.on_last_batch else batch_idx == 2
|
|
if self.failure_on_step and self.failure_on_training and should_signal:
|
|
self._signal()
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
should_signal = (
|
|
self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.is_last_batch
|
|
if self.on_last_batch
|
|
else batch_idx == 2
|
|
)
|
|
if self.failure_on_step and not self.failure_on_training and should_signal:
|
|
self._signal()
|
|
return super().validation_step(batch, batch_idx)
|
|
|
|
def training_epoch_end(self, outputs) -> None:
|
|
if not self.failure_on_step and self.failure_on_training:
|
|
self._signal()
|
|
|
|
def validation_epoch_end(self, outputs) -> None:
|
|
if not self.failure_on_step and not self.failure_on_training:
|
|
self._signal()
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(RandomDataset(32, 4))
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(RandomDataset(32, 4))
|
|
|
|
|
|
def _fit_model(
|
|
tmpdir, should_signal, val_check_interval, failure_on_step, failure_on_training, on_last_batch, status=None
|
|
):
|
|
seed_everything(42)
|
|
model = TestAutoRestartModelUnderSignal(should_signal, failure_on_step, failure_on_training, on_last_batch)
|
|
|
|
trainer_kwargs = dict(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
limit_train_batches=4,
|
|
limit_val_batches=4,
|
|
val_check_interval=val_check_interval,
|
|
num_sanity_val_steps=0,
|
|
)
|
|
|
|
trainer = Trainer(**trainer_kwargs)
|
|
if should_signal:
|
|
with pytest.raises(ExitGracefullyException, match=status):
|
|
trainer.fit(model)
|
|
else:
|
|
trainer.fit(model)
|
|
assert trainer._terminate_gracefully == should_signal
|
|
|
|
return model
|
|
|
|
|
|
@pytest.mark.parametrize("on_last_batch", [False, True])
|
|
@pytest.mark.parametrize("val_check_interval", [0.5, 1.0])
|
|
@pytest.mark.parametrize("failure_on_training", [False, True])
|
|
@pytest.mark.parametrize("failure_on_step", [False, True])
|
|
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
|
@RunIf(min_torch="1.7.0", skip_windows=True)
|
|
def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on_training, failure_on_step, tmpdir):
|
|
"""This test asserts that if a signal is being sent during the training / validation phase, the model should
|
|
restart in a reproducible way."""
|
|
|
|
model_total = _fit_model(tmpdir, False, val_check_interval, failure_on_step, failure_on_training, on_last_batch)
|
|
|
|
if failure_on_step:
|
|
if on_last_batch:
|
|
if failure_on_training:
|
|
# Breaking on first validation batch.
|
|
# This is done to capture the random state of the validation dataloader.
|
|
status = "EvaluationEpochLoop:advance"
|
|
else:
|
|
# when breaking on last batch of validation, we should exist on `run_end` val_check_interval == 1.0
|
|
status = (
|
|
"TrainingEpochLoop:on_run_end" if val_check_interval == 1.0 else "TrainingEpochLoop:on_advance_end"
|
|
)
|
|
else:
|
|
status = "TrainingEpochLoop:on_advance_end" if failure_on_training else "EvaluationEpochLoop:advance"
|
|
else:
|
|
if val_check_interval == 1.0:
|
|
status = "TrainingEpochLoop:on_run_end"
|
|
else:
|
|
# `training_epoch_end` happens after `validation_epoch_end` since Lightning v1.4
|
|
status = "TrainingEpochLoop:on_run_end" if failure_on_training else "TrainingEpochLoop:on_advance_end"
|
|
|
|
model_signaled = _fit_model(
|
|
tmpdir, True, val_check_interval, failure_on_step, failure_on_training, on_last_batch, status=status
|
|
)
|
|
checkpoint_path = str(tmpdir / ".pl_auto_save.ckpt")
|
|
assert os.path.exists(checkpoint_path)
|
|
model_restarted = _fit_model(tmpdir, False, val_check_interval, failure_on_step, failure_on_training, on_last_batch)
|
|
|
|
# check the batches
|
|
actual = torch.cat(model_signaled.seen_train_batches + model_restarted.seen_train_batches)
|
|
expected = torch.cat(model_total.seen_train_batches)
|
|
assert torch.equal(actual, expected)
|
|
|
|
# FIXME: why `on_last_batch` doesn't work ?
|
|
if failure_on_step and failure_on_training and not on_last_batch:
|
|
assert not torch.equal(model_total.layer.weight, model_signaled.layer.weight)
|
|
assert torch.equal(model_restarted.layer.weight, model_total.layer.weight)
|
|
|
|
checkpoint = torch.load(checkpoint_path)["loops"]["fit_loop"]
|
|
p = checkpoint["epoch_loop.batch_progress"]
|
|
if p["is_last_batch"] and p["current"]["completed"] == 4:
|
|
assert "dataloader_state_dict" not in checkpoint["epoch_loop.state_dict"]
|
|
else:
|
|
assert "dataloader_state_dict" in checkpoint["epoch_loop.state_dict"]
|
|
|
|
state_dict = checkpoint["epoch_loop.val_loop.epoch_loop.state_dict"]
|
|
p = checkpoint["epoch_loop.val_loop.epoch_loop.batch_progress"]
|
|
if (p["is_last_batch"] and p["current"]["completed"] == 4) or p["current"]["ready"] == 0:
|
|
assert "dataloader_state_dict" not in state_dict
|
|
else:
|
|
assert "dataloader_state_dict" in state_dict
|