From c5e1002fe4bc5f876d3f187999020ad8307e6051 Mon Sep 17 00:00:00 2001 From: Akash Kwatra Date: Fri, 6 May 2022 02:48:12 -0700 Subject: [PATCH] Add profiling to dataloader `next()` (#12124) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 3 + .../loops/epoch/evaluation_epoch_loop.py | 15 ++++ .../loops/epoch/prediction_epoch_loop.py | 4 +- .../loops/epoch/training_epoch_loop.py | 9 ++ pytorch_lightning/utilities/fetching.py | 17 +++- tests/utilities/test_fetching.py | 86 +++++++++++++++++++ 6 files changed, 132 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19f07010d4..ef9bcb4e02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588)) +- Added profiling to the loops' dataloader `__next__` calls ([#12124](https://github.com/PyTorchLightning/pytorch-lightning/pull/12124)) + + - Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842)) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 8c631bf23f..f6e49fa310 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -88,6 +88,21 @@ class EvaluationEpochLoop(Loop): # add the previous `fetched` value to properly track `is_last_batch` with no prefetching data_fetcher.fetched += self.batch_progress.current.ready + stage = self.trainer.state.stage + assert stage is not None + stage = stage.dataloader_prefix + self._profiler_fetch_action = ( + f"[{self.__class__.__name__}].{stage}_dataloader_idx_{kwargs.get('dataloader_idx', 0)}_next" + ) + data_fetcher._start_profiler = self._on_before_fetch + data_fetcher._stop_profiler = self._on_after_fetch + + def _on_before_fetch(self) -> None: + self.trainer.profiler.start(self._profiler_fetch_action) + + def _on_after_fetch(self) -> None: + self.trainer.profiler.stop(self._profiler_fetch_action) + def advance( # type: ignore[override] self, data_fetcher: AbstractDataFetcher, diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index ca3555cd79..8833aeddd2 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -89,7 +89,9 @@ class PredictionEpochLoop(Loop): num_dataloaders: the total number of dataloaders return_predictions: whether to return the obtained predictions """ - batch_idx, batch = next(dataloader_iter) + action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next" + with self.trainer.profiler.profile(action_name): + batch_idx, batch = next(dataloader_iter) self._seen_batch_indices = self._get_batch_indices(dataloader_idx) # we need to truncate the list of batch indices due to prefetching in the dataloader and Lightning self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)] diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index e059446bd4..4b15dafa10 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -154,6 +154,15 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): # add the previous `fetched` value to properly track `is_last_batch` with no prefetching data_fetcher.fetched += self.batch_progress.current.ready + data_fetcher._start_profiler = self._on_before_fetch + data_fetcher._stop_profiler = self._on_after_fetch + + def _on_before_fetch(self) -> None: + self.trainer.profiler.start(f"[{self.__class__.__name__}].train_dataloader_next") + + def _on_after_fetch(self) -> None: + self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next") + def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override] """Runs a single training batch. diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index ea039fcb23..ff7e6080ba 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -35,6 +35,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training +def _profile_nothing() -> None: + pass + + class AbstractDataFetcher(ABC): """This base class should be used to implement a fault tolerant ``DataFetcher``. It is required to override the @@ -76,6 +80,8 @@ class AbstractDataFetcher(ABC): self.dataloader_iter: Optional[Iterator] = None self.fetched: int = 0 self.done: bool = False + self._start_profiler = _profile_nothing + self._stop_profiler = _profile_nothing def setup(self, dataloader: Iterable, **kwargs: Any) -> None: self._add_capture_metadata_collate(dataloader) @@ -225,8 +231,12 @@ class DataFetcher(AbstractDataFetcher): if batch_to_device is not None: self.batch_to_device = batch_to_device + def on_fetch_start(self) -> Any: + self._start_profiler() + def on_fetch_end(self, batch: Any, start_output: Any) -> None: """Hook to extend which handles the logic after fetching a batch.""" + self._stop_profiler() self.batches.append(batch) def prefetching(self) -> None: @@ -320,9 +330,12 @@ class InterBatchParallelDataFetcher(DataFetcher): def on_fetch_start(self) -> "torch.cuda.Event": # create a cuda event used to record the async stream of data to device. - return torch.cuda.Event() + event = torch.cuda.Event() + self._start_profiler() + return event def on_fetch_end(self, batch: Any, event: torch.cuda.Event) -> None: + self._stop_profiler() self.batches.append(batch) event.record() self.events.append(event) @@ -344,7 +357,9 @@ class StepFuncDataLoaderIter(Iterator): def __next__(self) -> Any: try: + self.data_fetcher._start_profiler() data = next(self.iterator) + self.data_fetcher._stop_profiler() self.data_fetcher.fetched += 1 return data except StopIteration as e: diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 843462ad22..b536a61036 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -21,6 +21,7 @@ import torch from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning import Callback, LightningDataModule, Trainer +from pytorch_lightning.profiler import SimpleProfiler from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher @@ -486,3 +487,88 @@ def test_transfer_hooks_with_unpacking(tmpdir): assert dm.count_called_on_before_batch_transfer == 4 assert dm.count_called_transfer_batch_to_device == 4 assert dm.count_called_on_after_batch_transfer == 4 + + +@RunIf(skip_windows=True) # TODO: all durations are 0 on Windows +def test_fetching_is_profiled(): + """Test that fetching is profiled.""" + + class MyModel(BoringModel): + def validation_step(self, batch, batch_idx, dataloader_idx=0): + return super().validation_step(batch, batch_idx) + + def val_dataloader(self): + return [super().val_dataloader(), super().val_dataloader()] + + validation_epoch_end = None + + model = MyModel() + fast_dev_run = 2 + trainer = Trainer( + fast_dev_run=fast_dev_run, + profiler="simple", + enable_model_summary=False, + enable_checkpointing=False, + enable_progress_bar=False, + logger=False, + ) + trainer.fit(model) + trainer.test(model) + trainer.predict(model) + + profiler = trainer.profiler + assert isinstance(profiler, SimpleProfiler) + + # validation + for i in range(2): + key = f"[EvaluationEpochLoop].val_dataloader_idx_{i}_next" + assert key in profiler.recorded_durations + durations = profiler.recorded_durations[key] + assert len(durations) == fast_dev_run + assert all(d > 0 for d in durations) + # training + key = "[TrainingEpochLoop].train_dataloader_next" + assert key in profiler.recorded_durations + durations = profiler.recorded_durations[key] + assert len(durations) == fast_dev_run + assert all(d > 0 for d in durations) + # test + key = "[EvaluationEpochLoop].val_dataloader_idx_0_next" + assert key in profiler.recorded_durations + durations = profiler.recorded_durations[key] + assert len(durations) == fast_dev_run + assert all(d > 0 for d in durations) + # predict + key = "[PredictionEpochLoop].predict_dataloader_idx_0_next" + assert key in profiler.recorded_durations + durations = profiler.recorded_durations[key] + assert len(durations) == fast_dev_run + assert all(d > 0 for d in durations) + + # now test profiling when the dataloader_iter is polled manually + class MyModel(BoringModel): + def training_step(self, dataloader_iter): + _ = next(dataloader_iter) + batch = next(dataloader_iter) + return super().training_step(batch, 0) + + model = MyModel() + trainer = Trainer( + fast_dev_run=1, + profiler="simple", + limit_val_batches=0, + enable_model_summary=False, + enable_checkpointing=False, + enable_progress_bar=False, + logger=False, + ) + trainer.fit(model) + + profiler = trainer.profiler + assert isinstance(profiler, SimpleProfiler) + + key = "[TrainingEpochLoop].train_dataloader_next" + assert key in profiler.recorded_durations + durations = profiler.recorded_durations[key] + assert len(durations) == 2 # 2 polls in training_step + assert all(d > 0 for d in durations)