Add profiling to dataloader `next()` (#12124)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Akash Kwatra 2022-05-06 02:48:12 -07:00 committed by GitHub
parent 7ce948edb6
commit c5e1002fe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 132 additions and 2 deletions

View File

@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))
- Added profiling to the loops' dataloader `__next__` calls ([#12124](https://github.com/PyTorchLightning/pytorch-lightning/pull/12124))
- Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842))

View File

@ -88,6 +88,21 @@ class EvaluationEpochLoop(Loop):
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready
stage = self.trainer.state.stage
assert stage is not None
stage = stage.dataloader_prefix
self._profiler_fetch_action = (
f"[{self.__class__.__name__}].{stage}_dataloader_idx_{kwargs.get('dataloader_idx', 0)}_next"
)
data_fetcher._start_profiler = self._on_before_fetch
data_fetcher._stop_profiler = self._on_after_fetch
def _on_before_fetch(self) -> None:
self.trainer.profiler.start(self._profiler_fetch_action)
def _on_after_fetch(self) -> None:
self.trainer.profiler.stop(self._profiler_fetch_action)
def advance( # type: ignore[override]
self,
data_fetcher: AbstractDataFetcher,

View File

@ -89,7 +89,9 @@ class PredictionEpochLoop(Loop):
num_dataloaders: the total number of dataloaders
return_predictions: whether to return the obtained predictions
"""
batch_idx, batch = next(dataloader_iter)
action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next"
with self.trainer.profiler.profile(action_name):
batch_idx, batch = next(dataloader_iter)
self._seen_batch_indices = self._get_batch_indices(dataloader_idx)
# we need to truncate the list of batch indices due to prefetching in the dataloader and Lightning
self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)]

View File

@ -154,6 +154,15 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready
data_fetcher._start_profiler = self._on_before_fetch
data_fetcher._stop_profiler = self._on_after_fetch
def _on_before_fetch(self) -> None:
self.trainer.profiler.start(f"[{self.__class__.__name__}].train_dataloader_next")
def _on_after_fetch(self) -> None:
self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next")
def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
"""Runs a single training batch.

View File

@ -35,6 +35,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
def _profile_nothing() -> None:
pass
class AbstractDataFetcher(ABC):
"""This base class should be used to implement a fault tolerant ``DataFetcher``. It is required to override the
@ -76,6 +80,8 @@ class AbstractDataFetcher(ABC):
self.dataloader_iter: Optional[Iterator] = None
self.fetched: int = 0
self.done: bool = False
self._start_profiler = _profile_nothing
self._stop_profiler = _profile_nothing
def setup(self, dataloader: Iterable, **kwargs: Any) -> None:
self._add_capture_metadata_collate(dataloader)
@ -225,8 +231,12 @@ class DataFetcher(AbstractDataFetcher):
if batch_to_device is not None:
self.batch_to_device = batch_to_device
def on_fetch_start(self) -> Any:
self._start_profiler()
def on_fetch_end(self, batch: Any, start_output: Any) -> None:
"""Hook to extend which handles the logic after fetching a batch."""
self._stop_profiler()
self.batches.append(batch)
def prefetching(self) -> None:
@ -320,9 +330,12 @@ class InterBatchParallelDataFetcher(DataFetcher):
def on_fetch_start(self) -> "torch.cuda.Event":
# create a cuda event used to record the async stream of data to device.
return torch.cuda.Event()
event = torch.cuda.Event()
self._start_profiler()
return event
def on_fetch_end(self, batch: Any, event: torch.cuda.Event) -> None:
self._stop_profiler()
self.batches.append(batch)
event.record()
self.events.append(event)
@ -344,7 +357,9 @@ class StepFuncDataLoaderIter(Iterator):
def __next__(self) -> Any:
try:
self.data_fetcher._start_profiler()
data = next(self.iterator)
self.data_fetcher._stop_profiler()
self.data_fetcher.fetched += 1
return data
except StopIteration as e:

View File

@ -21,6 +21,7 @@ import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset
from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.profiler import SimpleProfiler
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
@ -486,3 +487,88 @@ def test_transfer_hooks_with_unpacking(tmpdir):
assert dm.count_called_on_before_batch_transfer == 4
assert dm.count_called_transfer_batch_to_device == 4
assert dm.count_called_on_after_batch_transfer == 4
@RunIf(skip_windows=True) # TODO: all durations are 0 on Windows
def test_fetching_is_profiled():
"""Test that fetching is profiled."""
class MyModel(BoringModel):
def validation_step(self, batch, batch_idx, dataloader_idx=0):
return super().validation_step(batch, batch_idx)
def val_dataloader(self):
return [super().val_dataloader(), super().val_dataloader()]
validation_epoch_end = None
model = MyModel()
fast_dev_run = 2
trainer = Trainer(
fast_dev_run=fast_dev_run,
profiler="simple",
enable_model_summary=False,
enable_checkpointing=False,
enable_progress_bar=False,
logger=False,
)
trainer.fit(model)
trainer.test(model)
trainer.predict(model)
profiler = trainer.profiler
assert isinstance(profiler, SimpleProfiler)
# validation
for i in range(2):
key = f"[EvaluationEpochLoop].val_dataloader_idx_{i}_next"
assert key in profiler.recorded_durations
durations = profiler.recorded_durations[key]
assert len(durations) == fast_dev_run
assert all(d > 0 for d in durations)
# training
key = "[TrainingEpochLoop].train_dataloader_next"
assert key in profiler.recorded_durations
durations = profiler.recorded_durations[key]
assert len(durations) == fast_dev_run
assert all(d > 0 for d in durations)
# test
key = "[EvaluationEpochLoop].val_dataloader_idx_0_next"
assert key in profiler.recorded_durations
durations = profiler.recorded_durations[key]
assert len(durations) == fast_dev_run
assert all(d > 0 for d in durations)
# predict
key = "[PredictionEpochLoop].predict_dataloader_idx_0_next"
assert key in profiler.recorded_durations
durations = profiler.recorded_durations[key]
assert len(durations) == fast_dev_run
assert all(d > 0 for d in durations)
# now test profiling when the dataloader_iter is polled manually
class MyModel(BoringModel):
def training_step(self, dataloader_iter):
_ = next(dataloader_iter)
batch = next(dataloader_iter)
return super().training_step(batch, 0)
model = MyModel()
trainer = Trainer(
fast_dev_run=1,
profiler="simple",
limit_val_batches=0,
enable_model_summary=False,
enable_checkpointing=False,
enable_progress_bar=False,
logger=False,
)
trainer.fit(model)
profiler = trainer.profiler
assert isinstance(profiler, SimpleProfiler)
key = "[TrainingEpochLoop].train_dataloader_next"
assert key in profiler.recorded_durations
durations = profiler.recorded_durations[key]
assert len(durations) == 2 # 2 polls in training_step
assert all(d > 0 for d in durations)