# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from time import time from typing import Any, Iterator from unittest import mock import pytest import torch from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from pytorch_lightning.profilers import SimpleProfiler from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher from pytorch_lightning.utilities.types import STEP_OUTPUT from tests_pytorch.helpers.runif import RunIf class IterDataset(IterableDataset): def __iter__(self): yield 1 yield 2 yield 3 class SizedDataset(Dataset): def __len__(self): return 3 def __getitem__(self, idx): return idx + 1 @pytest.mark.parametrize("use_combined_loader", [False, True]) @pytest.mark.parametrize("dataset_cls", [IterDataset, SizedDataset]) @pytest.mark.parametrize("prefetch_batches", list(range(5))) def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches): fetcher = DataFetcher(prefetch_batches=prefetch_batches) assert fetcher.prefetch_batches == prefetch_batches if use_combined_loader: loader = CombinedLoader([DataLoader(dataset_cls()), DataLoader(dataset_cls())]) else: loader = DataLoader(dataset_cls()) fetcher.setup(loader) def generate(): generated = [(fetcher.fetched, data, fetcher.done) for data in fetcher] assert fetcher.fetched == 3 assert fetcher.done return generated # we can only know the last batch with sized iterables or when we prefetch is_last_batch = [False, False, prefetch_batches > 0 or dataset_cls is SizedDataset] fetched = list(range(prefetch_batches + 1, 4)) fetched += [3] * (3 - len(fetched)) batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3] expected = list(zip(fetched, batches, is_last_batch)) assert len(expected) == 3 assert generate() == expected # validate reset works properly. assert generate() == expected assert fetcher.fetched == 3 @pytest.mark.parametrize("use_combined_loader", [False, True]) def test_profiler_closing(use_combined_loader): """Tests if the profiler terminates upon raising a StopIteration on an iterable dataset.""" class TestDataset(IterableDataset): def __init__(self): self.list = list(range(1)) def __iter__(self): return iter(self.list) fetcher = DataFetcher() if use_combined_loader: loader = CombinedLoader([DataLoader(TestDataset()), DataLoader(TestDataset())]) else: loader = DataLoader(TestDataset()) fetcher.setup(loader) profiler = SimpleProfiler() fetcher._start_profiler = lambda: profiler.start("test") fetcher._stop_profiler = lambda: profiler.stop("test") iter(fetcher) # on epoch 0 start next(fetcher) # raises StopIteration exception assert not bool(profiler.current_actions) class EmptyIterDataset(IterableDataset): def __iter__(self): return iter([]) class EmptySizedDataset(Dataset): def __len__(self): return 0 @pytest.mark.parametrize("dataset_cls", [EmptyIterDataset, EmptySizedDataset]) @pytest.mark.parametrize("prefetch_batches", list(range(2))) def test_empty_prefetch_iterator(dataset_cls, prefetch_batches): loader = DataLoader(dataset_cls()) fetcher = DataFetcher(prefetch_batches=prefetch_batches) fetcher.setup(loader) assert not fetcher.done assert not list(fetcher) assert fetcher.done def test_misconfiguration_error(): fetcher = DataFetcher() loader = DataLoader(range(10)) fetcher.setup(loader) assert fetcher.loaders == loader with pytest.raises( MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context." ): fetcher.loader_iters iter(fetcher) assert fetcher.loader_iters def get_cycles_per_ms() -> float: """Get 10 values and remove the 2 max and 2 min and return the avg. This is to avoid system disturbance that skew the results, e.g. the very first cuda call likely does a bunch of init, which takes much longer than subsequent calls. """ def measure() -> float: """Measure and return approximate number of cycles per millisecond for `torch.cuda._sleep` Copied from: https://github.com/pytorch/pytorch/blob/v1.9.0/test/test_cuda.py#L81. """ start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() torch.cuda._sleep(1000000) end.record() end.synchronize() cycles_per_ms = 1000000 / start.elapsed_time(end) return cycles_per_ms num = 10 vals = [] for _ in range(num): vals.append(measure()) vals = sorted(vals) stats = vals[2 : num - 2] return sum(stats) / len(stats) BATCH_SIZE = 32 DATASET_LEN = 64 EMB_SZ = 100 EMB_DIM = 64 class RandomIndicesDataset(Dataset): def __getitem__(self, index): return torch.randint(EMB_DIM, [BATCH_SIZE]) def __len__(self): return 16 class RecommenderModel(BoringModel): def __init__(self): super().__init__() self.layer = None self.local_embedding = torch.nn.Embedding(EMB_SZ, EMB_DIM) self.CYCLES_PER_MS = int(get_cycles_per_ms()) def forward(self, indices: torch.Tensor): result = self.local_embedding(indices) return result def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: # emulate heavy routine torch.cuda._sleep(self.CYCLES_PER_MS * 50) return batch def training_step_end(self, training_step_outputs): # emulate heavy routine torch.cuda._sleep(self.CYCLES_PER_MS * 50) return training_step_outputs def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=0.1) def train_dataloader(self): return DataLoader(RandomIndicesDataset(), batch_size=4) def val_dataloader(self): return DataLoader(RandomIndicesDataset(), batch_size=4) def test_dataloader(self): return DataLoader(RandomIndicesDataset(), batch_size=4) @pytest.mark.flaky(reruns=3) @pytest.mark.parametrize("accelerator", [pytest.param("cuda", marks=RunIf(min_cuda_gpus=1))]) def test_trainer_num_prefetch_batches(tmpdir, accelerator): model = RecommenderModel() class AssertFetcher(Callback): def __init__(self, check_inter_batch): self._check_inter_batch = check_inter_batch def on_train_epoch_end(self, trainer, lightning_module): fetcher = trainer.fit_loop._data_fetcher assert isinstance(fetcher, InterBatchParallelDataFetcher if self._check_inter_batch else DataFetcher) assert fetcher.prefetch_batches == int(self._check_inter_batch) trainer_kwargs = dict( default_root_dir=tmpdir, max_epochs=1, accelerator=accelerator, devices=1, limit_train_batches=4, limit_val_batches=0, num_sanity_val_steps=0, enable_progress_bar=0, ) trainer = Trainer(**trainer_kwargs, callbacks=AssertFetcher(check_inter_batch=True)) with mock.patch.dict(os.environ, {"PL_INTER_BATCH_PARALLELISM": "1"}): t0 = time() trainer.fit(model) t1 = time() inter_batch_duration = t1 - t0 global_step = trainer.global_step torch.cuda.synchronize() trainer = Trainer(**trainer_kwargs, callbacks=AssertFetcher(check_inter_batch=False)) t2 = time() trainer.fit(model) t3 = time() regular_duration = t3 - t2 assert global_step == trainer.global_step == 4 ratio = regular_duration / inter_batch_duration assert ratio > 1.1, (regular_duration, inter_batch_duration, ratio) @pytest.mark.parametrize("automatic_optimization", [False, True]) def test_fetching_dataloader_iter_opt(automatic_optimization, tmpdir): class TestModel(BoringModel): def __init__(self, *args, automatic_optimization: bool = False, **kwargs): super().__init__(*args, **kwargs) self.automatic_optimization = automatic_optimization self.count = 0 self.batches = [] def training_step(self, dataloader_iter, batch_idx): assert self.count == batch_idx assert isinstance(self.trainer.fit_loop._data_fetcher, DataLoaderIterDataFetcher) # fetch 2 batches self.batches.append(next(dataloader_iter)) self.batches.append(next(dataloader_iter)) batch = self.batches.pop(0) assert isinstance(batch, torch.Tensor) or batch is None self.count += 2 if self.automatic_optimization: loss = super().training_step(batch, 0) with pytest.raises(MisconfigurationException, match="dataloader_iter"): self.log("train_loss", loss["loss"]) self.log("train_loss", loss["loss"], batch_size=1) else: opt = self.optimizers() output = self(batch) loss = self.loss(batch, output) opt.zero_grad() loss.backward() opt.step() def training_epoch_end(self, *_): assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33 assert self.trainer.fit_loop._data_fetcher.fetched == 64 assert self.count == 64 model = TestModel(automatic_optimization=automatic_optimization) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) trainer.fit(model) @pytest.mark.parametrize("fn", ("validate", "test")) def test_fetching_dataloader_iter_running_stages(fn, tmpdir): class TestModel(BoringModel): def fetch(self, data_fetcher, dataloader_iter, batch_idx): assert isinstance(data_fetcher, DataLoaderIterDataFetcher) assert data_fetcher.fetched == batch_idx batch = next(dataloader_iter) assert data_fetcher.fetched == batch_idx + 1 return batch def validation_step(self, dataloader_iter, batch_idx): batch = self.fetch(self.trainer.validate_loop._data_fetcher, dataloader_iter, batch_idx) return super().validation_step(batch, batch_idx) def test_step(self, dataloader_iter, batch_idx): batch = self.fetch(self.trainer.test_loop._data_fetcher, dataloader_iter, batch_idx) return super().test_step(batch, batch_idx) model = TestModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) if fn == "validate": trainer.validate(model) elif fn == "test": trainer.test(model) class DummyWaitable: def __init__(self, val: Any) -> None: self.val = val def wait(self) -> Any: return self.val class AsyncBoringModel(BoringModel): def __init__(self) -> None: super().__init__() self.automatic_optimization = False self.batch_i_handle = None self.num_batches_processed = 0 def _async_op(self, batch: Any) -> DummyWaitable: return DummyWaitable(val=batch) def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: if self.batch_i_handle is None: batch_i_raw = next(dataloader_iter) self.batch_i_handle = self._async_op(batch_i_raw) # Invariant: _async_op for batch[i] has been initiated batch_ip1_handle = None is_last = False try: batch_ip1_raw = next(dataloader_iter) batch_ip1_handle = self._async_op(batch_ip1_raw) except StopIteration: is_last = True batch_i = self.batch_i_handle.wait() pred = self.layer(batch_i) loss = self.loss(batch_i, pred) loss.backward() self.optimizers().step() self.optimizers().zero_grad() self.batch_i_handle = batch_ip1_handle self.num_batches_processed += 1 return {"loss": loss, "is_last": is_last} def train_dataloader(self): return DataLoader(RandomDataset(BATCH_SIZE, DATASET_LEN)) def test_training_step_with_dataloader_access(tmpdir) -> None: """A baseline functional test for `training_step` with dataloader access.""" trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) m = AsyncBoringModel() trainer.fit(m) assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed." @pytest.mark.parametrize("trigger_stop_iteration", [False, True]) def test_stop_iteration(trigger_stop_iteration, tmpdir): """Verify that StopIteration properly terminates the training when this is triggered from the current `dataloader_iter`""" EXPECT_NUM_BATCHES_PROCESSED = 2 class TestModel(AsyncBoringModel): def __init__(self, trigger_stop_iteration) -> None: super().__init__() self.trigger_stop_iteration = trigger_stop_iteration def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: output = super().training_step(dataloader_iter) batch_idx = self.trainer.fit_loop.epoch_loop.batch_idx if self.trigger_stop_iteration and batch_idx == EXPECT_NUM_BATCHES_PROCESSED: raise StopIteration return output def train_dataloader(self): if self.trigger_stop_iteration: return DataLoader(RandomDataset(BATCH_SIZE, 2 * EXPECT_NUM_BATCHES_PROCESSED)) return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED)) trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) m = TestModel(trigger_stop_iteration) trainer.fit(m) expected = EXPECT_NUM_BATCHES_PROCESSED if trigger_stop_iteration: expected *= 2 assert m.num_batches_processed == expected def test_on_train_batch_start_overridden(tmpdir) -> None: """Verify that a `MisconfigurationException` is raised when `on_train_batch_start` is overridden on the `LightningModule`.""" class InvalidModel(AsyncBoringModel): def on_train_batch_start(self, batch, batch_idx): pass trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) m = InvalidModel() with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_start` is not compatible with"): trainer.fit(m) def test_on_train_batch_end_overridden(tmpdir) -> None: """Verify that a `MisconfigurationException` is raised when `on_train_batch_end` is overridden on the `LightningModule`.""" class InvalidModel(AsyncBoringModel): def on_train_batch_end(self, outputs, batch, batch_idx): pass trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) m = InvalidModel() with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_end` is not compatible with"): trainer.fit(m) def test_tbptt_split_batch_overridden(tmpdir) -> None: """Verify that a `MisconfigurationException` is raised when `tbptt_split_batch` is overridden on the `LightningModule`.""" class InvalidModel(AsyncBoringModel): def __init__(self) -> None: super().__init__() self.truncated_bptt_steps = 2 trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) m = InvalidModel() with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."): trainer.fit(m) def test_transfer_hooks_with_unpacking(tmpdir): """This test asserts the `transfer_batch` hooks are called only once per batch.""" class RandomDictDataset(RandomDataset): def __getitem__(self, index): return {"x": self.data[index], "y_true": torch.ones((2,)), "other": torch.ones((1,))} class BoringDataModule(LightningDataModule): count_called_on_before_batch_transfer = 0 count_called_transfer_batch_to_device = 0 count_called_on_after_batch_transfer = 0 def train_dataloader(self): return DataLoader(RandomDictDataset(32, 2)) def val_dataloader(self): return DataLoader(RandomDictDataset(32, 2)) def on_before_batch_transfer(self, batch, dataloader_idx: int): self.count_called_on_before_batch_transfer += 1 return batch["x"], batch["y_true"] def transfer_batch_to_device(self, *args, **kwargs): self.count_called_transfer_batch_to_device += 1 return super().transfer_batch_to_device(*args, **kwargs) def on_after_batch_transfer(self, batch, dataloader_idx: int): self.count_called_on_after_batch_transfer += 1 return super().on_after_batch_transfer(batch, dataloader_idx) class TestModel(BoringModel): def training_step(self, batch, batch_idx): x, _ = batch return super().training_step(x, batch_idx) def validation_step(self, batch, batch_idx): x, _ = batch return super().validation_step(x, batch_idx) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, num_sanity_val_steps=0) dm = BoringDataModule() trainer.fit(TestModel(), datamodule=dm) assert dm.count_called_on_before_batch_transfer == 4 assert dm.count_called_transfer_batch_to_device == 4 assert dm.count_called_on_after_batch_transfer == 4 @RunIf(skip_windows=True) # TODO: all durations are 0 on Windows def test_fetching_is_profiled(): """Test that fetching is profiled.""" class MyModel(BoringModel): def validation_step(self, batch, batch_idx, dataloader_idx=0): return super().validation_step(batch, batch_idx) def val_dataloader(self): return [super().val_dataloader(), super().val_dataloader()] validation_epoch_end = None model = MyModel() fast_dev_run = 2 trainer = Trainer( fast_dev_run=fast_dev_run, profiler="simple", enable_model_summary=False, enable_checkpointing=False, enable_progress_bar=False, logger=False, ) trainer.fit(model) trainer.test(model) trainer.predict(model) profiler = trainer.profiler assert isinstance(profiler, SimpleProfiler) # validation for i in range(2): key = f"[EvaluationEpochLoop].val_dataloader_idx_{i}_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] assert len(durations) == fast_dev_run assert all(d > 0 for d in durations) # training key = "[TrainingEpochLoop].train_dataloader_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] assert len(durations) == fast_dev_run assert all(d > 0 for d in durations) # test key = "[EvaluationEpochLoop].val_dataloader_idx_0_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] assert len(durations) == fast_dev_run assert all(d > 0 for d in durations) # predict key = "[PredictionEpochLoop].predict_dataloader_idx_0_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] assert len(durations) == fast_dev_run assert all(d > 0 for d in durations) # now test profiling when the dataloader_iter is polled manually class MyModel(BoringModel): def training_step(self, dataloader_iter): _ = next(dataloader_iter) batch = next(dataloader_iter) return super().training_step(batch, 0) model = MyModel() trainer = Trainer( fast_dev_run=1, profiler="simple", limit_val_batches=0, enable_model_summary=False, enable_checkpointing=False, enable_progress_bar=False, logger=False, ) trainer.fit(model) profiler = trainer.profiler assert isinstance(profiler, SimpleProfiler) key = "[TrainingEpochLoop].train_dataloader_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] assert len(durations) == 2 # 2 polls in training_step assert all(d > 0 for d in durations)