3/n inter batch parallelism (#9052)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
thomas chaton 2021-08-24 19:45:54 +01:00 committed by GitHub
parent b9443a07b9
commit f959b13ab9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 271 additions and 661 deletions

View File

@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `on_train_epoch_end` from `Accelerator` ([#9035](https://github.com/PyTorchLightning/pytorch-lightning/pull/9035))
- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052))
### Fixed
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (

View File

@ -456,6 +456,15 @@ class LightningModule(
f" of {list(self._metric_attributes.values())}"
)
if (
self.trainer.training
and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True)
and batch_size is None
):
raise MisconfigurationException(
"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
)
results.log(
self._current_fx_name,
name,

View File

@ -17,4 +17,3 @@ from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.processors import IteratorBatchProcessor # noqa: F401

View File

@ -547,12 +547,15 @@ class TrainingBatchLoop(Loop):
the keyword arguments for the training step
"""
# enable not needing to add opt_idx to training_step
step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
step_kwargs = OrderedDict([("batch", batch)])
lightning_module = self.trainer.lightning_module
training_step_fx = getattr(lightning_module, "training_step")
if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
step_kwargs["batch_idx"] = batch_idx
if len(self.trainer.optimizers) > 1:
training_step_fx = getattr(lightning_module, "training_step")
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
if not lightning_module.automatic_optimization:

View File

@ -98,17 +98,14 @@ class EvaluationLoop(DataLoaderLoop):
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)
dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=self.current_dataloader_idx
)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
dataloader_iter = iter(dataloader)
dl_max_batches = self._max_batches[self.current_dataloader_idx]
dl_max_batches = self._max_batches[dataloader_idx]
dl_outputs = self.epoch_loop.run(
dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
)
dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders)
# store batch level output per dataloader
if self.should_track_batch_outputs_for_epoch_end:

View File

@ -19,6 +19,7 @@ from deprecate import void
from torch import Tensor
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.types import STEP_OUTPUT
@ -37,6 +38,7 @@ class EvaluationEpochLoop(Loop):
self._num_dataloaders: Optional[int] = None
self.outputs: List[STEP_OUTPUT] = []
self.batch_progress = Progress()
self.dataloader_iter: Optional[Iterator] = None
@property
def done(self) -> bool:
@ -66,10 +68,12 @@ class EvaluationEpochLoop(Loop):
dl_max_batches: maximum number of batches the dataloader can produce
num_dataloaders: the total number of dataloaders
"""
void(dataloader_iter, dataloader_idx)
void(dataloader_idx)
self._dl_max_batches = dl_max_batches
self._num_dataloaders = num_dataloaders
self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready)
def advance(
self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
) -> None:
@ -84,9 +88,9 @@ class EvaluationEpochLoop(Loop):
Raises:
StopIteration: If the current batch is None
"""
void(dl_max_batches, num_dataloaders)
void(dataloader_iter, dl_max_batches, num_dataloaders)
batch_idx, (batch, _) = next(dataloader_iter)
batch_idx, (batch, _) = next(self.dataloader_iter)
if batch is None:
raise StopIteration

View File

@ -11,24 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Union
import torch
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.processors import IteratorBatchProcessor
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT
# TODO: currently, the batch processor is only a loop when tbptt is enabled.
# As we introduce more specialized batch processors, we may want to choose a
# more suitable abstraction for them.
BATCH_LOOP_TYPE = Optional[Tuple[TrainingBatchLoop, IteratorBatchProcessor]]
class TrainingEpochLoop(loops.Loop):
"""
@ -50,7 +45,7 @@ class TrainingEpochLoop(loops.Loop):
self.batch_progress = Progress()
self.scheduler_progress = SchedulerProgress()
self.batch_loop: BATCH_LOOP_TYPE = None
self.batch_loop: Optional[TrainingBatchLoop] = None
self.val_loop: Optional["loops.EvaluationLoop"] = None
self._results = ResultCollection(training=True)
@ -81,7 +76,7 @@ class TrainingEpochLoop(loops.Loop):
def connect(
self,
batch_loop: BATCH_LOOP_TYPE = None,
batch_loop: TrainingBatchLoop = None,
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
@ -102,14 +97,16 @@ class TrainingEpochLoop(loops.Loop):
self.scheduler_progress.current.reset()
self.batch_loop.optim_progress.reset_on_epoch()
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")
self.trainer.fit_loop.epoch_progress.increment_started()
def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_idx + 1)
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Runs a single training batch.
Args:
@ -118,33 +115,18 @@ class TrainingEpochLoop(loops.Loop):
Raises:
StopIteration: When the epoch is canceled by the user returning -1
"""
if isinstance(self.batch_loop, IteratorBatchProcessor):
# By contract, when taking `dataloader_iter` as an argument,
# `training_step` is responsible for reporting `is_last` in the
# result dict, which is used to determine the stop condition for
# the epoch. So as long as `advance` is invoked, it's correct to
# assume that there are more batches to be processed.
self.batch_progress.increment_ready()
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(dataloader_iter)
self.batch_progress.increment_processed()
is_last = batch_output.is_last
else:
_, (batch, is_last) = next(dataloader_iter)
batch_idx, (batch, is_last) = next(self.dataloader_iter)
if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)
if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
self.batch_progress.increment_ready()
self.batch_progress.increment_ready()
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, self.batch_idx)
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)
self.batch_progress.increment_processed()
self.batch_progress.increment_processed()
self.is_last_batch = is_last
@ -162,8 +144,7 @@ class TrainingEpochLoop(loops.Loop):
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)
# hook
if not isinstance(self.batch_loop, IteratorBatchProcessor):
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_batch_end")
self.trainer.logger_connector.on_batch_end()

View File

@ -1,15 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.processors.iterator_batch_processor import IteratorBatchProcessor # noqa: F401

View File

@ -1,174 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import OrderedDict
from typing import Any, Dict, Iterator, List, Optional, Tuple
import torch
import pytorch_lightning as pl
from pytorch_lightning.loops.utilities import (
_check_training_step_output,
_process_training_step_output,
check_finite_loss,
)
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
log = logging.getLogger(__name__)
class IteratorBatchProcessor:
"""
The processor for performing a training iteration when ``training_step`` needs access to the
dataloader. It is selected when the signature of ``training_step`` contains ``dataloader_iter``:
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
The ``training_step`` is allowed to fetch multiple batches during one training iteration. The
framework provides minimum amount of automation with regards to model optimization. The
flexibility allows for ease of experimentation with inter-batch parallelism techniques.
This processor doesn't support ``automatic_optimization`` and ``tbptt``. An error will be thrown
if the ``LightningModule`` or the ``Trainer`` is configured to use these features.
The ``training_step`` is responsible for reporting whether it has reached the last batch by
including an ``is_last`` field in the result dict. Failing to do so will result in an error.
The ``training_step`` should only optimize the model with one batch for the sake of API and
reporting consistency (TODO: consider removing this limitation).
Args:
trainer: a reference to the trainer
model: a reference to the lightning module (for config validation purposes only)
"""
def __init__(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
if is_overridden("on_train_batch_start", model):
raise MisconfigurationException(
"The model hook `on_train_batch_start` is not compatible with "
"taking a `dataloader_iter` argument in your `training_step`."
)
if is_overridden("on_train_batch_end", model):
raise MisconfigurationException(
"The model hook `on_train_batch_end` is not compatible with "
"taking a `dataloader_iter` argument in your `training_step`."
)
if is_overridden("tbptt_split_batch", model):
raise MisconfigurationException(
"The model hook `tbptt_split_batch` is not compatible with "
"taking a `dataloader_iter` argument in your `training_step`."
)
if trainer.accumulate_grad_batches != 1:
raise MisconfigurationException(
"`accumulate_grad_batches` can only be 1 when your "
"`training_step` takes `dataloader_iter` as an argument."
)
self.trainer = trainer
# The following field is not used by the processor since it doesn't support automatic
# optimization and tbptt. Initializing them regardless since they are currently expected by
# `FitLoop` or `TrainingEpochLoop`.
# TODO: come up with an abstraction for "batch processors" so they can be better decoupled
# with parent loops.
self.accumulated_loss: Optional[torch.Tensor] = None
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=1)
self.optim_progress = OptimizationProgress()
self.split_idx: int = 0
self._skip_backward = False
def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
"""
Returns the number of active optimizers.
"""
return len(self.trainer.optimizers)
def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, torch.optim.Optimizer]]:
"""
Returns the currently active optimizers.
Returns:
A list of tuples (opt_idx, optimizer) of currently active optimizers.
"""
return list(enumerate(self.trainer.optimizers))
def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]:
"""
Args:
dataloader_iter: the iterator over the dataloader producing the new batch
"""
batch_idx, (dataloader_iter, is_last) = next(dataloader_iter)
self.trainer.logger_connector.on_batch_start()
response = self.trainer.call_hook("on_batch_start")
if response == -1:
return AttributeDict(signal=-1)
self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()
# give the PL module a result for logging
model = self.trainer.lightning_module
# manually capture logged metrics
model._current_fx_name = "training_step"
step_kwargs = self._build_kwargs(dataloader_iter, batch_idx)
with self.trainer.profiler.profile("model_forward"):
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
_check_training_step_output(self.trainer.lightning_module, training_step_output)
training_step_output, _ = _process_training_step_output(self.trainer, training_step_output)
if self.trainer.terminate_on_nan:
check_finite_loss(self.trainer.lightning_module, training_step_output.minimize)
batch_outputs = [[] for _ in range(len(self.trainer.optimizers))]
if training_step_output:
batch_outputs[0].append(training_step_output)
return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last)
def teardown(self) -> None:
"""
No-op. Only defined to comply with FitLoop's expectation.
"""
pass
# FIXME: To be deleted in next PR.
def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Any]:
"""Builds the keyword arguments for training_step
Args:
dataloader_iter: The dataloader to pass
batch_idx: the index of the current batch
Returns:
An ordered dict with the keyword arguments for the training step
"""
# enable not needing to add opt_idx to training_step
step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter)])
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
if is_param_in_hook_signature(training_step_fx, "batch_idx"):
step_kwargs["batch_idx"] = batch_idx
return step_kwargs

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Mapping, Optional, Tuple
from typing import Iterator, Mapping, Optional, Tuple
import torch
@ -20,6 +20,7 @@ import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.types import STEP_OUTPUT
@ -102,3 +103,11 @@ def _process_training_step_output(
if trainer.move_metrics_to_cpu:
results.cpu()
return results, hiddens
def _prepare_dataloader_iter(dataloader_iter: Iterator, batch_idx: int) -> Iterator:
"""Attach the dataloader"""
if not isinstance(dataloader_iter, DataLoaderIterDataFetcher):
dataloader_iter = enumerate(dataloader_iter, batch_idx)
# restore iteration
return dataloader_iter

View File

@ -16,6 +16,7 @@ from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
class ConfigValidator:
@ -34,6 +35,7 @@ class ConfigValidator:
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, "val")
self.__verify_manual_optimization_support(model)
self.__check_training_step_requires_dataloader_iter(model)
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self.__verify_eval_loop_configuration(model, "val")
elif self.trainer.state.fn == TrainerFn.TESTING:
@ -128,3 +130,26 @@ class ConfigValidator:
f" Remove `Trainer(accumulate_grad_batches={self.trainer.accumulate_grad_batches})`"
" or switch to automatic optimization."
)
def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningModule"):
"""Check if the current `training_step` is requesting `dataloader_iter`."""
training_step_fx = getattr(model, "training_step")
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
if is_overridden("on_train_batch_start", model):
raise MisconfigurationException(
"The model hook `on_train_batch_start` is not compatible with "
"taking a `dataloader_iter` argument in your `training_step`."
)
if is_overridden("on_train_batch_end", model):
raise MisconfigurationException(
"The model hook `on_train_batch_end` is not compatible with "
"taking a `dataloader_iter` argument in your `training_step`."
)
if model.truncated_bptt_steps > 0:
raise MisconfigurationException(
"The model taking a `dataloader_iter` argument in your `training_step` "
"is incompatible with `truncated_bptt_steps > 0`."
)

View File

@ -91,21 +91,18 @@ class DataConnector:
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False
def _check_training_step_requires_dataloader_iter(self) -> bool:
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True)
return contains_dataloader_iter
def _select_data_fetcher(self) -> AbstractDataFetcher:
if self.trainer.sanity_checking:
return DataFetcher()
if self.trainer.training and self._check_training_step_requires_dataloader_iter():
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
if self.trainer.training and is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher()
elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
# note: this is an experimental feature
if not self.trainer.training_type_plugin.on_gpu:
@ -124,9 +121,7 @@ class DataConnector:
profiler=self.trainer.profiler,
)
setattr(self, f"{stage}_data_fetcher", data_fetcher)
if isinstance(data_fetcher, DataLoaderIterDataFetcher):
return data_fetcher
return enumerate(data_fetcher)
return data_fetcher
def prepare_data(self) -> None:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
@ -250,6 +245,16 @@ class DataConnector:
if isinstance(loader, _PatchDataLoader):
loader.unpatch(model)
def teardown(self) -> None:
if self.train_data_fetcher:
self.train_data_fetcher.teardown()
if self.validate_data_fetcher:
self.validate_data_fetcher.teardown()
if self.test_data_fetcher:
self.test_data_fetcher.teardown()
if self.sanity_check_data_fetcher:
self.sanity_check_data_fetcher.teardown()
class _PatchDataLoader:
r"""

View File

@ -199,7 +199,11 @@ class LoggerConnector:
"""
def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
self.trainer._results.extract_batch_size(split_batch)
# when the user request `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
self.trainer._results.extract_batch_size(split_batch)
self._batch_idx = batch_idx
self._split_idx = split_idx

View File

@ -28,7 +28,7 @@ from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loops import IteratorBatchProcessor, TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop
from pytorch_lightning.loops.fit_loop import FitLoop
@ -77,12 +77,10 @@ from pytorch_lightning.utilities import (
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
log = logging.getLogger(__name__)
@ -920,18 +918,6 @@ class Trainer(
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}")
self.checkpoint_connector.restore_model_weights(self._ckpt_path)
def _maybe_switch_to_iterator_batch_processor(self, model: "pl.LightningModule") -> None:
training_step_fx = getattr(model, "training_step")
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
log.warning(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior may subject to change."
)
batch_loop = IteratorBatchProcessor(self, model)
self.fit_loop.epoch_loop.connect(batch_loop)
# FIXME: Move this logic to data_connector after removing `IteratorBatchProcessor`
self.data_connector.data_fetcher = DataLoaderIterDataFetcher()
def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# clean hparams
if hasattr(model, "hparams"):
@ -939,9 +925,6 @@ class Trainer(
self.config_validator.verify_loop_configurations(model)
if self.training:
self._maybe_switch_to_iterator_batch_processor(model)
# attach model log function to callback
self.callback_connector.attach_model_logging_functions(model)
@ -1077,6 +1060,7 @@ class Trainer(
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.accelerator.teardown()
self.data_connector.teardown()
self._active_loop.teardown()
self.logger_connector.teardown()

View File

@ -390,7 +390,6 @@ class StepFuncDataLoaderIter:
def __next__(self) -> Any:
try:
data = next(self.iterator)
# FIXME: Link this to `batch_idx`.
self.data_fetcher.fetched += 1
return data
except StopIteration:

View File

@ -12,15 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable
from typing import Callable, Optional
def is_param_in_hook_signature(hook_fx: Callable, param: str, explicit: bool = False) -> bool:
def is_param_in_hook_signature(
hook_fx: Callable, param: str, explicit: bool = False, min_args: Optional[int] = None
) -> bool:
"""
Args:
hook_fx: the hook callable
param: the name of the parameter to check
explicit: whether the parameter has to be explicitly declared
min_args: whether the `signature` as at least `min_args` parameters
"""
hook_params = list(inspect.signature(hook_fx).parameters)
return param in hook_params or (not explicit and "args" in hook_params)
return (
param in hook_params
or (not explicit and "args" in hook_params)
or (isinstance(min_args, int) and len(hook_params) >= min_args)
)

View File

@ -17,7 +17,7 @@ from copy import deepcopy
import torch
import pytorch_lightning as pl
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.helpers import BoringModel
@ -27,8 +27,6 @@ def test_finetuning_with_resume_from_checkpoint(tmpdir):
This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test
"""
seed_everything(4)
checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1)
class ExtendedBoringModel(BoringModel):
@ -75,9 +73,6 @@ def test_finetuning_with_resume_from_checkpoint(tmpdir):
results.append(deepcopy(trainer.callback_metrics))
best_model_paths.append(trainer.checkpoint_callback.best_model_path)
for idx in range(len(results) - 1):
assert results[idx]["val_loss"] > results[idx + 1]["val_loss"]
for idx, best_model_path in enumerate(best_model_paths):
if idx == 0:
assert best_model_path.endswith(f"epoch=0{idx}.ckpt")

View File

@ -1,190 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from statistics import mean
from typing import Iterator
import torch
from torch.utils.data import DataLoader, IterableDataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.types import STEP_OUTPUT
from tests.helpers.runif import RunIf
def count_cycles_per_ms() -> float:
"""
Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py
"""
def measure() -> float:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
return cycles_per_ms
# Get 10 values and remove the 2 max and 2 min and return the avg.
# This is to avoid system disturbance that skew the results, e.g.
# the very first cuda call likely does a bunch of init, which takes
# much longer than subsequent calls.
#
# Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
# and seems to return stable values. Therefore, we enable caching
# using lru_cache decorator above.
num = 10
vals = []
for _ in range(num):
vals.append(measure())
vals = sorted(vals)
return mean(vals[2 : num - 2])
_CYCLES_PER_MS = int(count_cycles_per_ms()) if torch.cuda.is_available() else 0
_BATCH_SIZE = 128
_EMB_SZ = 100
_EMB_DIM = 64
class RandomSparseDataset(IterableDataset):
def __init__(self, emb_dim: int, batch_size: int, count: int) -> None:
self.emb_dim = emb_dim
self.batch_size = batch_size
self.count = count
def __iter__(self):
for _ in range(self.count):
yield torch.randint(self.emb_dim, [self.batch_size])
class ToyDLRMModel(LightningModule):
"""
A toy model for mimicking the communication overhead of sharded embedding
modules in DLRM models.
DLRM models can be trained in a DDP-like fashion, where each trainer
receives different batches (embedding indices in this example). Since the
embeddings are sharded across trainers, the lookup process involves (1)
routing the indices to the trainer that possesses the corresponding
embeddings (2) performing local lookup (3) routing the embedding lookup
result back.
The toy model doesn't actually performs index/result routing. It simply
uses torch.cuda._sleep() to mimic the cost of the communication op (i.e.
a2a).
"""
def __init__(self):
super().__init__()
self.automatic_optimization = False
self.local_embedding = torch.nn.Embedding(_EMB_SZ, _EMB_DIM)
def _route_indices(self, batch: torch.Tensor, non_blocking=False):
"""
This can be parallelized across different batches since it's model
weight independent.
Why not run this in dataloader/datamodule?
- The routing logic depends on how model is sharded
- Putting this in data preprocessor changes the semantic of the model
"""
torch.cuda._sleep(_CYCLES_PER_MS * 1_000)
if not non_blocking:
torch.cuda.synchronize()
return batch
def _route_result(self, result: torch.Tensor, non_blocking=False):
torch.cuda._sleep(_CYCLES_PER_MS * 1_000)
if not non_blocking:
torch.cuda.synchronize()
return result
def forward(self, indices: torch.Tensor):
local_indices = self._route_indices(indices)
result = self.local_embedding(local_indices)
return self._route_result(result)
def training_step(self, batch: torch.Tensor, batch_idx: int) -> STEP_OUTPUT:
return self.forward(batch)
def configure_optimizers(self):
return torch.optim.SGD(self.local_embedding.parameters(), lr=0.1)
def train_dataloader(self):
return DataLoader(RandomSparseDataset(_EMB_DIM, _BATCH_SIZE, 5))
class AsyncToyDLRMModel(ToyDLRMModel):
def __init__(self):
super().__init__()
self.comm_stream = torch.cuda.Stream()
self.batch_i = None
self.batch_i_ready = torch.cuda.Event()
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
if self.batch_i is None:
self.batch_i = next(dataloader_iter)
with torch.cuda.stream(self.comm_stream):
self._route_indices(self.batch_i, non_blocking=True)
self.batch_i_ready.record()
# Invariant: the routing for batch[i] has been kicked off
is_last = False
batch_ip1 = None
batch_ip1_ready = torch.cuda.Event()
try:
batch_ip1 = next(dataloader_iter)
with torch.cuda.stream(self.comm_stream):
self._route_indices(batch_ip1, non_blocking=True)
batch_ip1_ready.record()
except StopIteration:
is_last = True
self.batch_i_ready.wait()
result = self.local_embedding(self.batch_i)
self._route_result(result)
self.batch_i = batch_ip1
self.batch_i_ready = batch_ip1_ready
return {"is_last": is_last}
@RunIf(min_gpus=1)
def test_inter_batch_parallelism(tmpdir):
"""
Verify the speedup of a simple inter-batch parallelization use case enabled
by exposing `dataloader_iter` to `training_step`.
"""
begin_time = time.time()
m = AsyncToyDLRMModel()
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
trainer.fit(m)
async_duration = time.time() - begin_time
begin_time = time.time()
m = ToyDLRMModel()
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
trainer.fit(m)
sync_duration = time.time() - begin_time
# We expect 2x speedup. However, we only assert that the async
# training_step is faster in order to avoid flaky tests
assert async_duration < sync_duration, "Expect `AsyncToyDLRMModel` to train faster than `ToyDLRMModel`."

View File

@ -1,183 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterator
import pytest
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT
from tests.helpers import BoringModel, RandomDataset
_BATCH_SIZE = 32
_DATASET_LEN = 64
class DummyWaitable:
def __init__(self, val: Any) -> None:
self.val = val
def wait(self) -> Any:
return self.val
class AsyncBoringModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.automatic_optimization = False
self.batch_i_handle = None
self.num_batches_processed = 0
def _async_op(self, batch: Any) -> DummyWaitable:
return DummyWaitable(val=batch)
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
if self.batch_i_handle is None:
batch_i_raw = next(dataloader_iter)
self.batch_i_handle = self._async_op(batch_i_raw)
# Invariant: _async_op for batch[i] has been initiated
batch_ip1_handle = None
is_last = False
try:
batch_ip1_raw = next(dataloader_iter)
batch_ip1_handle = self._async_op(batch_ip1_raw)
except StopIteration:
is_last = True
batch_i = self.batch_i_handle.wait()
pred = self.layer(batch_i)
loss = self.loss(batch_i, pred)
loss.backward()
self.optimizers().step()
self.optimizers().zero_grad()
self.batch_i_handle = batch_ip1_handle
self.num_batches_processed += 1
return {"loss": loss, "is_last": is_last}
def train_dataloader(self):
return DataLoader(RandomDataset(_BATCH_SIZE, _DATASET_LEN))
def test_training_step_with_dataloader_access(tmpdir) -> None:
"""
A baseline functional test for `training_step` with dataloader access.
"""
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = AsyncBoringModel()
trainer.fit(m)
assert m.num_batches_processed == _DATASET_LEN, f"Expect all {_DATASET_LEN} batches to be processed."
def test_stop_iteration(tmpdir) -> None:
"""
Verify that when `StopIteration` is raised within `training_step`, `fit()`
terminiates as expected.
"""
EXPECT_NUM_BATCHES_PROCESSED = 2
class TestModel(AsyncBoringModel):
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
output = super().training_step(dataloader_iter)
if self.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED:
raise StopIteration()
return output
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = TestModel()
trainer.fit(m)
assert (
m.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED
), "Expect {EXPECT_NUM_BATCHES_PROCESSED} batches to be processed."
def test_on_train_batch_start_overridden(tmpdir) -> None:
"""
Verify that a `MisconfigurationException` is raised when
`on_train_batch_start` is overridden on the `LightningModule`.
"""
class InvalidModel(AsyncBoringModel):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
pass
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = InvalidModel()
with pytest.raises(MisconfigurationException):
trainer.fit(m)
def test_on_train_batch_end_overridden(tmpdir) -> None:
"""
Verify that a `MisconfigurationException` is raised when
`on_train_batch_end` is overridden on the `LightningModule`.
"""
class InvalidModel(AsyncBoringModel):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
pass
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = InvalidModel()
with pytest.raises(MisconfigurationException):
trainer.fit(m)
def test_tbptt_split_batch_overridden(tmpdir) -> None:
"""
Verify that a `MisconfigurationException` is raised when
`tbptt_split_batch` is overridden on the `LightningModule`.
"""
class InvalidModel(AsyncBoringModel):
def tbptt_split_batch(self, batch, split_size):
pass
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = InvalidModel()
with pytest.raises(MisconfigurationException):
trainer.fit(m)
def test_accumulate_grad_batches(tmpdir) -> None:
"""
Verify that a `MisconfigurationException` is raised when
`accumulate_grad_batches` is not set to 1.
"""
trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir)
m = AsyncBoringModel()
with pytest.raises(MisconfigurationException):
trainer.fit(m)
def test_is_last_not_set(tmpdir) -> None:
"""
Verify that a `MisconfigurationException` is raised when `training_step`
doesn't include "is_last" in the result dict.
"""
class InvalidModel(AsyncBoringModel):
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
output = super().training_step(dataloader_iter)
del output["is_last"]
return output
trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir)
m = InvalidModel()
with pytest.raises(MisconfigurationException):
trainer.fit(m)

View File

@ -419,7 +419,7 @@ def test_multiple_optimizers_step(tmpdir):
called = False
def on_after_backward(self):
def on_before_optimizer_step(self, *args):
self.called = True
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):

View File

@ -13,7 +13,7 @@
# limitations under the License.
import os
from time import time
from typing import Any
from typing import Any, Iterator
from unittest import mock
import pytest
@ -25,7 +25,8 @@ from pytorch_lightning import Trainer
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
from tests.helpers.boring_model import BoringModel
from pytorch_lightning.utilities.types import STEP_OUTPUT
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@ -125,7 +126,8 @@ def get_cycles_per_ms() -> float:
return sum(stats) / len(stats)
BATCH_SIZE = 128
BATCH_SIZE = 32
DATASET_LEN = 64
EMB_SZ = 100
EMB_DIM = 64
@ -176,6 +178,7 @@ class RecommenderModel(BoringModel):
def test_trainer_num_prefetch_batches(tmpdir):
model = RecommenderModel()
trainer_kwargs = dict(
default_root_dir=tmpdir,
max_epochs=1,
@ -190,8 +193,8 @@ def test_trainer_num_prefetch_batches(tmpdir):
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
t1 = time()
global_step = trainer.global_step
assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
global_step = trainer.global_step
torch.cuda.synchronize()
@ -199,9 +202,9 @@ def test_trainer_num_prefetch_batches(tmpdir):
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
t3 = time()
assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
assert global_step == trainer.global_step == 4
assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
ratio = (t3 - t2) / (t1 - t0)
assert ratio > 1.1, ratio
@ -218,7 +221,7 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir):
def training_step(self, dataloader_iter, batch_idx):
assert self.count == batch_idx
assert isinstance(self.trainer.data_connector.data_fetcher, DataLoaderIterDataFetcher)
assert isinstance(self.trainer.data_connector.train_data_fetcher, DataLoaderIterDataFetcher)
# fetch 2 batches
self.batches.append(next(dataloader_iter))
self.batches.append(next(dataloader_iter))
@ -227,7 +230,10 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir):
assert isinstance(batch, torch.Tensor) or batch is None
self.count += 2
if self.automatic_optimization:
return super().training_step(batch, 0)
loss = super().training_step(batch, 0)
with pytest.raises(MisconfigurationException, match="dataloader_iter"):
self.log("train_loss", loss["loss"])
self.log("train_loss", loss["loss"], batch_size=1)
else:
opt = self.optimizers()
output = self(batch)
@ -236,10 +242,152 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir):
loss.backward()
opt.step()
training_epoch_end = None
def training_epoch_end(self, *_):
assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33
assert self.trainer.data_connector.train_data_fetcher.fetched == 64
assert self.count == 64
model = TestModel(automatic_optimization=automatic_optimization)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
trainer.data_connector.data_fetcher = DataLoaderIterDataFetcher()
trainer.fit(model)
assert model.count == 64
class DummyWaitable:
def __init__(self, val: Any) -> None:
self.val = val
def wait(self) -> Any:
return self.val
class AsyncBoringModel(BoringModel):
def __init__(self) -> None:
super().__init__()
self.automatic_optimization = False
self.batch_i_handle = None
self.num_batches_processed = 0
def _async_op(self, batch: Any) -> DummyWaitable:
return DummyWaitable(val=batch)
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
if self.batch_i_handle is None:
batch_i_raw = next(dataloader_iter)
self.batch_i_handle = self._async_op(batch_i_raw)
# Invariant: _async_op for batch[i] has been initiated
batch_ip1_handle = None
is_last = False
try:
batch_ip1_raw = next(dataloader_iter)
batch_ip1_handle = self._async_op(batch_ip1_raw)
except StopIteration:
is_last = True
batch_i = self.batch_i_handle.wait()
pred = self.layer(batch_i)
loss = self.loss(batch_i, pred)
loss.backward()
self.optimizers().step()
self.optimizers().zero_grad()
self.batch_i_handle = batch_ip1_handle
self.num_batches_processed += 1
return {"loss": loss, "is_last": is_last}
def train_dataloader(self):
return DataLoader(RandomDataset(BATCH_SIZE, DATASET_LEN))
def test_training_step_with_dataloader_access(tmpdir) -> None:
"""
A baseline functional test for `training_step` with dataloader access.
"""
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = AsyncBoringModel()
trainer.fit(m)
assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed."
@pytest.mark.parametrize("trigger_stop_iteration", [False, True])
def test_stop_iteration(trigger_stop_iteration, tmpdir):
"""
Verify that StopIteration properly terminates the training when this is trigged
from the current `dataloader_iter`
"""
EXPECT_NUM_BATCHES_PROCESSED = 2
class TestModel(AsyncBoringModel):
def __init__(self, trigger_stop_iteration) -> None:
super().__init__()
self.trigger_stop_iteration = trigger_stop_iteration
def training_step(self, dataloader_iter: Iterator, *args) -> STEP_OUTPUT:
output = super().training_step(dataloader_iter)
if self.trigger_stop_iteration and args[0] == EXPECT_NUM_BATCHES_PROCESSED:
raise StopIteration
return output
def train_dataloader(self):
if self.trigger_stop_iteration:
return DataLoader(RandomDataset(BATCH_SIZE, 2 * EXPECT_NUM_BATCHES_PROCESSED))
return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED))
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = TestModel(trigger_stop_iteration)
trainer.fit(m)
expected = EXPECT_NUM_BATCHES_PROCESSED
if trigger_stop_iteration:
expected *= 2
assert m.num_batches_processed == expected
def test_on_train_batch_start_overridden(tmpdir) -> None:
"""
Verify that a `MisconfigurationException` is raised when
`on_train_batch_start` is overridden on the `LightningModule`.
"""
class InvalidModel(AsyncBoringModel):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
pass
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = InvalidModel()
with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_start` is not compatible with"):
trainer.fit(m)
def test_on_train_batch_end_overridden(tmpdir) -> None:
"""
Verify that a `MisconfigurationException` is raised when
`on_train_batch_end` is overridden on the `LightningModule`.
"""
class InvalidModel(AsyncBoringModel):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
pass
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = InvalidModel()
with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_end` is not compatible with"):
trainer.fit(m)
def test_tbptt_split_batch_overridden(tmpdir) -> None:
"""
Verify that a `MisconfigurationException` is raised when
`tbptt_split_batch` is overridden on the `LightningModule`.
"""
class InvalidModel(AsyncBoringModel):
def __init__(self) -> None:
super().__init__()
self.truncated_bptt_steps = 2
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = InvalidModel()
with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."):
trainer.fit(m)