Add profiling to dataloader `next()` (#12124)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
7ce948edb6
commit
c5e1002fe4
|
@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for `Trainer(deterministic="warn")` to warn instead of fail when a non-deterministic operation is encountered ([#12588](https://github.com/PyTorchLightning/pytorch-lightning/pull/12588))
|
||||
|
||||
|
||||
- Added profiling to the loops' dataloader `__next__` calls ([#12124](https://github.com/PyTorchLightning/pytorch-lightning/pull/12124))
|
||||
|
||||
|
||||
- Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842))
|
||||
|
||||
|
||||
|
|
|
@ -88,6 +88,21 @@ class EvaluationEpochLoop(Loop):
|
|||
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
|
||||
data_fetcher.fetched += self.batch_progress.current.ready
|
||||
|
||||
stage = self.trainer.state.stage
|
||||
assert stage is not None
|
||||
stage = stage.dataloader_prefix
|
||||
self._profiler_fetch_action = (
|
||||
f"[{self.__class__.__name__}].{stage}_dataloader_idx_{kwargs.get('dataloader_idx', 0)}_next"
|
||||
)
|
||||
data_fetcher._start_profiler = self._on_before_fetch
|
||||
data_fetcher._stop_profiler = self._on_after_fetch
|
||||
|
||||
def _on_before_fetch(self) -> None:
|
||||
self.trainer.profiler.start(self._profiler_fetch_action)
|
||||
|
||||
def _on_after_fetch(self) -> None:
|
||||
self.trainer.profiler.stop(self._profiler_fetch_action)
|
||||
|
||||
def advance( # type: ignore[override]
|
||||
self,
|
||||
data_fetcher: AbstractDataFetcher,
|
||||
|
|
|
@ -89,7 +89,9 @@ class PredictionEpochLoop(Loop):
|
|||
num_dataloaders: the total number of dataloaders
|
||||
return_predictions: whether to return the obtained predictions
|
||||
"""
|
||||
batch_idx, batch = next(dataloader_iter)
|
||||
action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next"
|
||||
with self.trainer.profiler.profile(action_name):
|
||||
batch_idx, batch = next(dataloader_iter)
|
||||
self._seen_batch_indices = self._get_batch_indices(dataloader_idx)
|
||||
# we need to truncate the list of batch indices due to prefetching in the dataloader and Lightning
|
||||
self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)]
|
||||
|
|
|
@ -154,6 +154,15 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
|
||||
data_fetcher.fetched += self.batch_progress.current.ready
|
||||
|
||||
data_fetcher._start_profiler = self._on_before_fetch
|
||||
data_fetcher._stop_profiler = self._on_after_fetch
|
||||
|
||||
def _on_before_fetch(self) -> None:
|
||||
self.trainer.profiler.start(f"[{self.__class__.__name__}].train_dataloader_next")
|
||||
|
||||
def _on_after_fetch(self) -> None:
|
||||
self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next")
|
||||
|
||||
def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
|
||||
"""Runs a single training batch.
|
||||
|
||||
|
|
|
@ -35,6 +35,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
|
||||
def _profile_nothing() -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AbstractDataFetcher(ABC):
|
||||
|
||||
"""This base class should be used to implement a fault tolerant ``DataFetcher``. It is required to override the
|
||||
|
@ -76,6 +80,8 @@ class AbstractDataFetcher(ABC):
|
|||
self.dataloader_iter: Optional[Iterator] = None
|
||||
self.fetched: int = 0
|
||||
self.done: bool = False
|
||||
self._start_profiler = _profile_nothing
|
||||
self._stop_profiler = _profile_nothing
|
||||
|
||||
def setup(self, dataloader: Iterable, **kwargs: Any) -> None:
|
||||
self._add_capture_metadata_collate(dataloader)
|
||||
|
@ -225,8 +231,12 @@ class DataFetcher(AbstractDataFetcher):
|
|||
if batch_to_device is not None:
|
||||
self.batch_to_device = batch_to_device
|
||||
|
||||
def on_fetch_start(self) -> Any:
|
||||
self._start_profiler()
|
||||
|
||||
def on_fetch_end(self, batch: Any, start_output: Any) -> None:
|
||||
"""Hook to extend which handles the logic after fetching a batch."""
|
||||
self._stop_profiler()
|
||||
self.batches.append(batch)
|
||||
|
||||
def prefetching(self) -> None:
|
||||
|
@ -320,9 +330,12 @@ class InterBatchParallelDataFetcher(DataFetcher):
|
|||
|
||||
def on_fetch_start(self) -> "torch.cuda.Event":
|
||||
# create a cuda event used to record the async stream of data to device.
|
||||
return torch.cuda.Event()
|
||||
event = torch.cuda.Event()
|
||||
self._start_profiler()
|
||||
return event
|
||||
|
||||
def on_fetch_end(self, batch: Any, event: torch.cuda.Event) -> None:
|
||||
self._stop_profiler()
|
||||
self.batches.append(batch)
|
||||
event.record()
|
||||
self.events.append(event)
|
||||
|
@ -344,7 +357,9 @@ class StepFuncDataLoaderIter(Iterator):
|
|||
|
||||
def __next__(self) -> Any:
|
||||
try:
|
||||
self.data_fetcher._start_profiler()
|
||||
data = next(self.iterator)
|
||||
self.data_fetcher._stop_profiler()
|
||||
self.data_fetcher.fetched += 1
|
||||
return data
|
||||
except StopIteration as e:
|
||||
|
|
|
@ -21,6 +21,7 @@ import torch
|
|||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
from pytorch_lightning import Callback, LightningDataModule, Trainer
|
||||
from pytorch_lightning.profiler import SimpleProfiler
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
|
||||
|
@ -486,3 +487,88 @@ def test_transfer_hooks_with_unpacking(tmpdir):
|
|||
assert dm.count_called_on_before_batch_transfer == 4
|
||||
assert dm.count_called_transfer_batch_to_device == 4
|
||||
assert dm.count_called_on_after_batch_transfer == 4
|
||||
|
||||
|
||||
@RunIf(skip_windows=True) # TODO: all durations are 0 on Windows
|
||||
def test_fetching_is_profiled():
|
||||
"""Test that fetching is profiled."""
|
||||
|
||||
class MyModel(BoringModel):
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
return super().validation_step(batch, batch_idx)
|
||||
|
||||
def val_dataloader(self):
|
||||
return [super().val_dataloader(), super().val_dataloader()]
|
||||
|
||||
validation_epoch_end = None
|
||||
|
||||
model = MyModel()
|
||||
fast_dev_run = 2
|
||||
trainer = Trainer(
|
||||
fast_dev_run=fast_dev_run,
|
||||
profiler="simple",
|
||||
enable_model_summary=False,
|
||||
enable_checkpointing=False,
|
||||
enable_progress_bar=False,
|
||||
logger=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model)
|
||||
|
||||
profiler = trainer.profiler
|
||||
assert isinstance(profiler, SimpleProfiler)
|
||||
|
||||
# validation
|
||||
for i in range(2):
|
||||
key = f"[EvaluationEpochLoop].val_dataloader_idx_{i}_next"
|
||||
assert key in profiler.recorded_durations
|
||||
durations = profiler.recorded_durations[key]
|
||||
assert len(durations) == fast_dev_run
|
||||
assert all(d > 0 for d in durations)
|
||||
# training
|
||||
key = "[TrainingEpochLoop].train_dataloader_next"
|
||||
assert key in profiler.recorded_durations
|
||||
durations = profiler.recorded_durations[key]
|
||||
assert len(durations) == fast_dev_run
|
||||
assert all(d > 0 for d in durations)
|
||||
# test
|
||||
key = "[EvaluationEpochLoop].val_dataloader_idx_0_next"
|
||||
assert key in profiler.recorded_durations
|
||||
durations = profiler.recorded_durations[key]
|
||||
assert len(durations) == fast_dev_run
|
||||
assert all(d > 0 for d in durations)
|
||||
# predict
|
||||
key = "[PredictionEpochLoop].predict_dataloader_idx_0_next"
|
||||
assert key in profiler.recorded_durations
|
||||
durations = profiler.recorded_durations[key]
|
||||
assert len(durations) == fast_dev_run
|
||||
assert all(d > 0 for d in durations)
|
||||
|
||||
# now test profiling when the dataloader_iter is polled manually
|
||||
class MyModel(BoringModel):
|
||||
def training_step(self, dataloader_iter):
|
||||
_ = next(dataloader_iter)
|
||||
batch = next(dataloader_iter)
|
||||
return super().training_step(batch, 0)
|
||||
|
||||
model = MyModel()
|
||||
trainer = Trainer(
|
||||
fast_dev_run=1,
|
||||
profiler="simple",
|
||||
limit_val_batches=0,
|
||||
enable_model_summary=False,
|
||||
enable_checkpointing=False,
|
||||
enable_progress_bar=False,
|
||||
logger=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
profiler = trainer.profiler
|
||||
assert isinstance(profiler, SimpleProfiler)
|
||||
|
||||
key = "[TrainingEpochLoop].train_dataloader_next"
|
||||
assert key in profiler.recorded_durations
|
||||
durations = profiler.recorded_durations[key]
|
||||
assert len(durations) == 2 # 2 polls in training_step
|
||||
assert all(d > 0 for d in durations)
|
||||
|
|
Loading…
Reference in New Issue