3/n inter batch parallelism (#9052)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
b9443a07b9
commit
f959b13ab9
|
@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed `on_train_epoch_end` from `Accelerator` ([#9035](https://github.com/PyTorchLightning/pytorch-lightning/pull/9035))
|
||||
|
||||
|
||||
- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052))
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
|
||||
|
|
|
@ -456,6 +456,15 @@ class LightningModule(
|
|||
f" of {list(self._metric_attributes.values())}"
|
||||
)
|
||||
|
||||
if (
|
||||
self.trainer.training
|
||||
and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True)
|
||||
and batch_size is None
|
||||
):
|
||||
raise MisconfigurationException(
|
||||
"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
|
||||
)
|
||||
|
||||
results.log(
|
||||
self._current_fx_name,
|
||||
name,
|
||||
|
|
|
@ -17,4 +17,3 @@ from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
|
|||
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
|
||||
from pytorch_lightning.loops.processors import IteratorBatchProcessor # noqa: F401
|
||||
|
|
|
@ -547,12 +547,15 @@ class TrainingBatchLoop(Loop):
|
|||
the keyword arguments for the training step
|
||||
"""
|
||||
# enable not needing to add opt_idx to training_step
|
||||
step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
|
||||
step_kwargs = OrderedDict([("batch", batch)])
|
||||
|
||||
lightning_module = self.trainer.lightning_module
|
||||
training_step_fx = getattr(lightning_module, "training_step")
|
||||
|
||||
if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
|
||||
step_kwargs["batch_idx"] = batch_idx
|
||||
|
||||
if len(self.trainer.optimizers) > 1:
|
||||
training_step_fx = getattr(lightning_module, "training_step")
|
||||
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
|
||||
if has_opt_idx_in_train_step:
|
||||
if not lightning_module.automatic_optimization:
|
||||
|
|
|
@ -98,17 +98,14 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
"""Performs evaluation on one single dataloader"""
|
||||
void(*args, **kwargs)
|
||||
|
||||
dataloader_idx: int = self.current_dataloader_idx
|
||||
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
|
||||
dataloader = self.trainer.data_connector.get_profiled_dataloader(
|
||||
dataloader, dataloader_idx=self.current_dataloader_idx
|
||||
)
|
||||
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
|
||||
dataloader_iter = iter(dataloader)
|
||||
|
||||
dl_max_batches = self._max_batches[self.current_dataloader_idx]
|
||||
dl_max_batches = self._max_batches[dataloader_idx]
|
||||
|
||||
dl_outputs = self.epoch_loop.run(
|
||||
dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders
|
||||
)
|
||||
dl_outputs = self.epoch_loop.run(dataloader_iter, dataloader_idx, dl_max_batches, self.num_dataloaders)
|
||||
|
||||
# store batch level output per dataloader
|
||||
if self.should_track_batch_outputs_for_epoch_end:
|
||||
|
|
|
@ -19,6 +19,7 @@ from deprecate import void
|
|||
from torch import Tensor
|
||||
|
||||
from pytorch_lightning.loops.base import Loop
|
||||
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
|
||||
from pytorch_lightning.trainer.progress import Progress
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
@ -37,6 +38,7 @@ class EvaluationEpochLoop(Loop):
|
|||
self._num_dataloaders: Optional[int] = None
|
||||
self.outputs: List[STEP_OUTPUT] = []
|
||||
self.batch_progress = Progress()
|
||||
self.dataloader_iter: Optional[Iterator] = None
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
|
@ -66,10 +68,12 @@ class EvaluationEpochLoop(Loop):
|
|||
dl_max_batches: maximum number of batches the dataloader can produce
|
||||
num_dataloaders: the total number of dataloaders
|
||||
"""
|
||||
void(dataloader_iter, dataloader_idx)
|
||||
void(dataloader_idx)
|
||||
self._dl_max_batches = dl_max_batches
|
||||
self._num_dataloaders = num_dataloaders
|
||||
|
||||
self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_progress.current.ready)
|
||||
|
||||
def advance(
|
||||
self, dataloader_iter: Iterator, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
|
||||
) -> None:
|
||||
|
@ -84,9 +88,9 @@ class EvaluationEpochLoop(Loop):
|
|||
Raises:
|
||||
StopIteration: If the current batch is None
|
||||
"""
|
||||
void(dl_max_batches, num_dataloaders)
|
||||
void(dataloader_iter, dl_max_batches, num_dataloaders)
|
||||
|
||||
batch_idx, (batch, _) = next(dataloader_iter)
|
||||
batch_idx, (batch, _) = next(self.dataloader_iter)
|
||||
|
||||
if batch is None:
|
||||
raise StopIteration
|
||||
|
|
|
@ -11,24 +11,19 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import loops # import as loops to avoid circular imports
|
||||
from pytorch_lightning.loops.batch import TrainingBatchLoop
|
||||
from pytorch_lightning.loops.processors import IteratorBatchProcessor
|
||||
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
# TODO: currently, the batch processor is only a loop when tbptt is enabled.
|
||||
# As we introduce more specialized batch processors, we may want to choose a
|
||||
# more suitable abstraction for them.
|
||||
BATCH_LOOP_TYPE = Optional[Tuple[TrainingBatchLoop, IteratorBatchProcessor]]
|
||||
|
||||
|
||||
class TrainingEpochLoop(loops.Loop):
|
||||
"""
|
||||
|
@ -50,7 +45,7 @@ class TrainingEpochLoop(loops.Loop):
|
|||
self.batch_progress = Progress()
|
||||
self.scheduler_progress = SchedulerProgress()
|
||||
|
||||
self.batch_loop: BATCH_LOOP_TYPE = None
|
||||
self.batch_loop: Optional[TrainingBatchLoop] = None
|
||||
self.val_loop: Optional["loops.EvaluationLoop"] = None
|
||||
|
||||
self._results = ResultCollection(training=True)
|
||||
|
@ -81,7 +76,7 @@ class TrainingEpochLoop(loops.Loop):
|
|||
|
||||
def connect(
|
||||
self,
|
||||
batch_loop: BATCH_LOOP_TYPE = None,
|
||||
batch_loop: TrainingBatchLoop = None,
|
||||
val_loop: Optional["loops.EvaluationLoop"] = None,
|
||||
) -> None:
|
||||
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
|
||||
|
@ -102,14 +97,16 @@ class TrainingEpochLoop(loops.Loop):
|
|||
self.scheduler_progress.current.reset()
|
||||
self.batch_loop.optim_progress.reset_on_epoch()
|
||||
|
||||
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
|
||||
# hook
|
||||
self.trainer.logger_connector.on_epoch_start()
|
||||
self.trainer.call_hook("on_epoch_start")
|
||||
self.trainer.call_hook("on_train_epoch_start")
|
||||
self.trainer.fit_loop.epoch_progress.increment_started()
|
||||
|
||||
def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
|
||||
self.dataloader_iter = _prepare_dataloader_iter(dataloader_iter, self.batch_idx + 1)
|
||||
|
||||
def advance(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Runs a single training batch.
|
||||
|
||||
Args:
|
||||
|
@ -118,33 +115,18 @@ class TrainingEpochLoop(loops.Loop):
|
|||
Raises:
|
||||
StopIteration: When the epoch is canceled by the user returning -1
|
||||
"""
|
||||
if isinstance(self.batch_loop, IteratorBatchProcessor):
|
||||
# By contract, when taking `dataloader_iter` as an argument,
|
||||
# `training_step` is responsible for reporting `is_last` in the
|
||||
# result dict, which is used to determine the stop condition for
|
||||
# the epoch. So as long as `advance` is invoked, it's correct to
|
||||
# assume that there are more batches to be processed.
|
||||
self.batch_progress.increment_ready()
|
||||
with self.trainer.profiler.profile("run_training_batch"):
|
||||
batch_output = self.batch_loop.run(dataloader_iter)
|
||||
self.batch_progress.increment_processed()
|
||||
is_last = batch_output.is_last
|
||||
else:
|
||||
_, (batch, is_last) = next(dataloader_iter)
|
||||
batch_idx, (batch, is_last) = next(self.dataloader_iter)
|
||||
|
||||
if not self.trainer.data_connector.train_data_fetcher.store_on_device:
|
||||
with self.trainer.profiler.profile("training_batch_to_device"):
|
||||
batch = self.trainer.accelerator.batch_to_device(batch)
|
||||
if not self.trainer.data_connector.train_data_fetcher.store_on_device:
|
||||
with self.trainer.profiler.profile("training_batch_to_device"):
|
||||
batch = self.trainer.accelerator.batch_to_device(batch)
|
||||
|
||||
# ------------------------------------
|
||||
# TRAINING_STEP + TRAINING_STEP_END
|
||||
# ------------------------------------
|
||||
self.batch_progress.increment_ready()
|
||||
self.batch_progress.increment_ready()
|
||||
|
||||
with self.trainer.profiler.profile("run_training_batch"):
|
||||
batch_output = self.batch_loop.run(batch, self.batch_idx)
|
||||
with self.trainer.profiler.profile("run_training_batch"):
|
||||
batch_output = self.batch_loop.run(batch, batch_idx)
|
||||
|
||||
self.batch_progress.increment_processed()
|
||||
self.batch_progress.increment_processed()
|
||||
|
||||
self.is_last_batch = is_last
|
||||
|
||||
|
@ -162,8 +144,7 @@ class TrainingEpochLoop(loops.Loop):
|
|||
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)
|
||||
|
||||
# hook
|
||||
if not isinstance(self.batch_loop, IteratorBatchProcessor):
|
||||
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
|
||||
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
|
||||
self.trainer.call_hook("on_batch_end")
|
||||
self.trainer.logger_connector.on_batch_end()
|
||||
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.processors.iterator_batch_processor import IteratorBatchProcessor # noqa: F401
|
|
@ -1,174 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loops.utilities import (
|
||||
_check_training_step_output,
|
||||
_process_training_step_output,
|
||||
check_finite_loss,
|
||||
)
|
||||
from pytorch_lightning.trainer.progress import OptimizationProgress
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.utilities import AttributeDict
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IteratorBatchProcessor:
|
||||
"""
|
||||
The processor for performing a training iteration when ``training_step`` needs access to the
|
||||
dataloader. It is selected when the signature of ``training_step`` contains ``dataloader_iter``:
|
||||
|
||||
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
|
||||
|
||||
The ``training_step`` is allowed to fetch multiple batches during one training iteration. The
|
||||
framework provides minimum amount of automation with regards to model optimization. The
|
||||
flexibility allows for ease of experimentation with inter-batch parallelism techniques.
|
||||
|
||||
This processor doesn't support ``automatic_optimization`` and ``tbptt``. An error will be thrown
|
||||
if the ``LightningModule`` or the ``Trainer`` is configured to use these features.
|
||||
|
||||
The ``training_step`` is responsible for reporting whether it has reached the last batch by
|
||||
including an ``is_last`` field in the result dict. Failing to do so will result in an error.
|
||||
|
||||
The ``training_step`` should only optimize the model with one batch for the sake of API and
|
||||
reporting consistency (TODO: consider removing this limitation).
|
||||
|
||||
Args:
|
||||
trainer: a reference to the trainer
|
||||
model: a reference to the lightning module (for config validation purposes only)
|
||||
"""
|
||||
|
||||
def __init__(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
|
||||
if is_overridden("on_train_batch_start", model):
|
||||
raise MisconfigurationException(
|
||||
"The model hook `on_train_batch_start` is not compatible with "
|
||||
"taking a `dataloader_iter` argument in your `training_step`."
|
||||
)
|
||||
if is_overridden("on_train_batch_end", model):
|
||||
raise MisconfigurationException(
|
||||
"The model hook `on_train_batch_end` is not compatible with "
|
||||
"taking a `dataloader_iter` argument in your `training_step`."
|
||||
)
|
||||
if is_overridden("tbptt_split_batch", model):
|
||||
raise MisconfigurationException(
|
||||
"The model hook `tbptt_split_batch` is not compatible with "
|
||||
"taking a `dataloader_iter` argument in your `training_step`."
|
||||
)
|
||||
if trainer.accumulate_grad_batches != 1:
|
||||
raise MisconfigurationException(
|
||||
"`accumulate_grad_batches` can only be 1 when your "
|
||||
"`training_step` takes `dataloader_iter` as an argument."
|
||||
)
|
||||
|
||||
self.trainer = trainer
|
||||
|
||||
# The following field is not used by the processor since it doesn't support automatic
|
||||
# optimization and tbptt. Initializing them regardless since they are currently expected by
|
||||
# `FitLoop` or `TrainingEpochLoop`.
|
||||
# TODO: come up with an abstraction for "batch processors" so they can be better decoupled
|
||||
# with parent loops.
|
||||
self.accumulated_loss: Optional[torch.Tensor] = None
|
||||
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=1)
|
||||
self.optim_progress = OptimizationProgress()
|
||||
self.split_idx: int = 0
|
||||
self._skip_backward = False
|
||||
|
||||
def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
|
||||
"""
|
||||
Returns the number of active optimizers.
|
||||
"""
|
||||
return len(self.trainer.optimizers)
|
||||
|
||||
def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, torch.optim.Optimizer]]:
|
||||
"""
|
||||
Returns the currently active optimizers.
|
||||
|
||||
Returns:
|
||||
A list of tuples (opt_idx, optimizer) of currently active optimizers.
|
||||
"""
|
||||
return list(enumerate(self.trainer.optimizers))
|
||||
|
||||
def run(self, dataloader_iter: Iterator) -> Optional[AttributeDict]:
|
||||
"""
|
||||
Args:
|
||||
dataloader_iter: the iterator over the dataloader producing the new batch
|
||||
"""
|
||||
batch_idx, (dataloader_iter, is_last) = next(dataloader_iter)
|
||||
|
||||
self.trainer.logger_connector.on_batch_start()
|
||||
response = self.trainer.call_hook("on_batch_start")
|
||||
if response == -1:
|
||||
return AttributeDict(signal=-1)
|
||||
|
||||
self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()
|
||||
|
||||
# give the PL module a result for logging
|
||||
model = self.trainer.lightning_module
|
||||
# manually capture logged metrics
|
||||
model._current_fx_name = "training_step"
|
||||
step_kwargs = self._build_kwargs(dataloader_iter, batch_idx)
|
||||
|
||||
with self.trainer.profiler.profile("model_forward"):
|
||||
with self.trainer.profiler.profile("training_step"):
|
||||
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
|
||||
self.trainer.accelerator.post_training_step()
|
||||
|
||||
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
|
||||
_check_training_step_output(self.trainer.lightning_module, training_step_output)
|
||||
|
||||
training_step_output, _ = _process_training_step_output(self.trainer, training_step_output)
|
||||
|
||||
if self.trainer.terminate_on_nan:
|
||||
check_finite_loss(self.trainer.lightning_module, training_step_output.minimize)
|
||||
|
||||
batch_outputs = [[] for _ in range(len(self.trainer.optimizers))]
|
||||
if training_step_output:
|
||||
batch_outputs[0].append(training_step_output)
|
||||
|
||||
return AttributeDict(signal=0, training_step_output=batch_outputs, is_last=is_last)
|
||||
|
||||
def teardown(self) -> None:
|
||||
"""
|
||||
No-op. Only defined to comply with FitLoop's expectation.
|
||||
"""
|
||||
pass
|
||||
|
||||
# FIXME: To be deleted in next PR.
|
||||
def _build_kwargs(self, dataloader_iter: Iterator, batch_idx: int) -> Dict[str, Any]:
|
||||
"""Builds the keyword arguments for training_step
|
||||
|
||||
Args:
|
||||
dataloader_iter: The dataloader to pass
|
||||
batch_idx: the index of the current batch
|
||||
|
||||
Returns:
|
||||
An ordered dict with the keyword arguments for the training step
|
||||
"""
|
||||
# enable not needing to add opt_idx to training_step
|
||||
step_kwargs = OrderedDict([("dataloader_iter", dataloader_iter)])
|
||||
|
||||
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
|
||||
if is_param_in_hook_signature(training_step_fx, "batch_idx"):
|
||||
step_kwargs["batch_idx"] = batch_idx
|
||||
|
||||
return step_kwargs
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Mapping, Optional, Tuple
|
||||
from typing import Iterator, Mapping, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -20,6 +20,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher
|
||||
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
@ -102,3 +103,11 @@ def _process_training_step_output(
|
|||
if trainer.move_metrics_to_cpu:
|
||||
results.cpu()
|
||||
return results, hiddens
|
||||
|
||||
|
||||
def _prepare_dataloader_iter(dataloader_iter: Iterator, batch_idx: int) -> Iterator:
|
||||
"""Attach the dataloader"""
|
||||
if not isinstance(dataloader_iter, DataLoaderIterDataFetcher):
|
||||
dataloader_iter = enumerate(dataloader_iter, batch_idx)
|
||||
# restore iteration
|
||||
return dataloader_iter
|
||||
|
|
|
@ -16,6 +16,7 @@ from pytorch_lightning.trainer.states import TrainerFn
|
|||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
|
||||
|
||||
class ConfigValidator:
|
||||
|
@ -34,6 +35,7 @@ class ConfigValidator:
|
|||
self.__verify_train_loop_configuration(model)
|
||||
self.__verify_eval_loop_configuration(model, "val")
|
||||
self.__verify_manual_optimization_support(model)
|
||||
self.__check_training_step_requires_dataloader_iter(model)
|
||||
elif self.trainer.state.fn == TrainerFn.VALIDATING:
|
||||
self.__verify_eval_loop_configuration(model, "val")
|
||||
elif self.trainer.state.fn == TrainerFn.TESTING:
|
||||
|
@ -128,3 +130,26 @@ class ConfigValidator:
|
|||
f" Remove `Trainer(accumulate_grad_batches={self.trainer.accumulate_grad_batches})`"
|
||||
" or switch to automatic optimization."
|
||||
)
|
||||
|
||||
def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningModule"):
|
||||
"""Check if the current `training_step` is requesting `dataloader_iter`."""
|
||||
training_step_fx = getattr(model, "training_step")
|
||||
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
|
||||
|
||||
if is_overridden("on_train_batch_start", model):
|
||||
raise MisconfigurationException(
|
||||
"The model hook `on_train_batch_start` is not compatible with "
|
||||
"taking a `dataloader_iter` argument in your `training_step`."
|
||||
)
|
||||
|
||||
if is_overridden("on_train_batch_end", model):
|
||||
raise MisconfigurationException(
|
||||
"The model hook `on_train_batch_end` is not compatible with "
|
||||
"taking a `dataloader_iter` argument in your `training_step`."
|
||||
)
|
||||
|
||||
if model.truncated_bptt_steps > 0:
|
||||
raise MisconfigurationException(
|
||||
"The model taking a `dataloader_iter` argument in your `training_step` "
|
||||
"is incompatible with `truncated_bptt_steps > 0`."
|
||||
)
|
||||
|
|
|
@ -91,21 +91,18 @@ class DataConnector:
|
|||
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
|
||||
self.trainer._is_data_prepared = False
|
||||
|
||||
def _check_training_step_requires_dataloader_iter(self) -> bool:
|
||||
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
|
||||
contains_dataloader_iter = is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True)
|
||||
return contains_dataloader_iter
|
||||
|
||||
def _select_data_fetcher(self) -> AbstractDataFetcher:
|
||||
if self.trainer.sanity_checking:
|
||||
return DataFetcher()
|
||||
|
||||
if self.trainer.training and self._check_training_step_requires_dataloader_iter():
|
||||
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
|
||||
if self.trainer.training and is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
|
||||
rank_zero_warn(
|
||||
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
|
||||
"this signature is experimental and the behavior is subject to change."
|
||||
)
|
||||
return DataLoaderIterDataFetcher()
|
||||
|
||||
elif self.trainer.training and os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
|
||||
# note: this is an experimental feature
|
||||
if not self.trainer.training_type_plugin.on_gpu:
|
||||
|
@ -124,9 +121,7 @@ class DataConnector:
|
|||
profiler=self.trainer.profiler,
|
||||
)
|
||||
setattr(self, f"{stage}_data_fetcher", data_fetcher)
|
||||
if isinstance(data_fetcher, DataLoaderIterDataFetcher):
|
||||
return data_fetcher
|
||||
return enumerate(data_fetcher)
|
||||
return data_fetcher
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||
|
@ -250,6 +245,16 @@ class DataConnector:
|
|||
if isinstance(loader, _PatchDataLoader):
|
||||
loader.unpatch(model)
|
||||
|
||||
def teardown(self) -> None:
|
||||
if self.train_data_fetcher:
|
||||
self.train_data_fetcher.teardown()
|
||||
if self.validate_data_fetcher:
|
||||
self.validate_data_fetcher.teardown()
|
||||
if self.test_data_fetcher:
|
||||
self.test_data_fetcher.teardown()
|
||||
if self.sanity_check_data_fetcher:
|
||||
self.sanity_check_data_fetcher.teardown()
|
||||
|
||||
|
||||
class _PatchDataLoader:
|
||||
r"""
|
||||
|
|
|
@ -199,7 +199,11 @@ class LoggerConnector:
|
|||
"""
|
||||
|
||||
def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
|
||||
self.trainer._results.extract_batch_size(split_batch)
|
||||
# when the user request `dataloader_iter`, we can't track the batch_size
|
||||
# and this is left to user responsibility.
|
||||
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
|
||||
self.trainer._results.extract_batch_size(split_batch)
|
||||
|
||||
self._batch_idx = batch_idx
|
||||
self._split_idx = split_idx
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
|
|||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.loops import IteratorBatchProcessor, TrainingBatchLoop, TrainingEpochLoop
|
||||
from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop
|
||||
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
|
||||
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop
|
||||
|
@ -77,12 +77,10 @@ from pytorch_lightning.utilities import (
|
|||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import DataLoaderIterDataFetcher
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -920,18 +918,6 @@ class Trainer(
|
|||
rank_zero_info(f"Loading model weights from checkpoint at {self._ckpt_path}")
|
||||
self.checkpoint_connector.restore_model_weights(self._ckpt_path)
|
||||
|
||||
def _maybe_switch_to_iterator_batch_processor(self, model: "pl.LightningModule") -> None:
|
||||
training_step_fx = getattr(model, "training_step")
|
||||
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
|
||||
log.warning(
|
||||
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
|
||||
"this signature is experimental and the behavior may subject to change."
|
||||
)
|
||||
batch_loop = IteratorBatchProcessor(self, model)
|
||||
self.fit_loop.epoch_loop.connect(batch_loop)
|
||||
# FIXME: Move this logic to data_connector after removing `IteratorBatchProcessor`
|
||||
self.data_connector.data_fetcher = DataLoaderIterDataFetcher()
|
||||
|
||||
def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
|
||||
# clean hparams
|
||||
if hasattr(model, "hparams"):
|
||||
|
@ -939,9 +925,6 @@ class Trainer(
|
|||
|
||||
self.config_validator.verify_loop_configurations(model)
|
||||
|
||||
if self.training:
|
||||
self._maybe_switch_to_iterator_batch_processor(model)
|
||||
|
||||
# attach model log function to callback
|
||||
self.callback_connector.attach_model_logging_functions(model)
|
||||
|
||||
|
@ -1077,6 +1060,7 @@ class Trainer(
|
|||
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
|
||||
# which need to happen before.
|
||||
self.accelerator.teardown()
|
||||
self.data_connector.teardown()
|
||||
self._active_loop.teardown()
|
||||
self.logger_connector.teardown()
|
||||
|
||||
|
|
|
@ -390,7 +390,6 @@ class StepFuncDataLoaderIter:
|
|||
def __next__(self) -> Any:
|
||||
try:
|
||||
data = next(self.iterator)
|
||||
# FIXME: Link this to `batch_idx`.
|
||||
self.data_fetcher.fetched += 1
|
||||
return data
|
||||
except StopIteration:
|
||||
|
|
|
@ -12,15 +12,22 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
def is_param_in_hook_signature(hook_fx: Callable, param: str, explicit: bool = False) -> bool:
|
||||
def is_param_in_hook_signature(
|
||||
hook_fx: Callable, param: str, explicit: bool = False, min_args: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Args:
|
||||
hook_fx: the hook callable
|
||||
param: the name of the parameter to check
|
||||
explicit: whether the parameter has to be explicitly declared
|
||||
min_args: whether the `signature` as at least `min_args` parameters
|
||||
"""
|
||||
hook_params = list(inspect.signature(hook_fx).parameters)
|
||||
return param in hook_params or (not explicit and "args" in hook_params)
|
||||
return (
|
||||
param in hook_params
|
||||
or (not explicit and "args" in hook_params)
|
||||
or (isinstance(min_args, int) and len(hook_params) >= min_args)
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ from copy import deepcopy
|
|||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import seed_everything, Trainer
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from tests.helpers import BoringModel
|
||||
|
||||
|
@ -27,8 +27,6 @@ def test_finetuning_with_resume_from_checkpoint(tmpdir):
|
|||
This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test
|
||||
"""
|
||||
|
||||
seed_everything(4)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1)
|
||||
|
||||
class ExtendedBoringModel(BoringModel):
|
||||
|
@ -75,9 +73,6 @@ def test_finetuning_with_resume_from_checkpoint(tmpdir):
|
|||
results.append(deepcopy(trainer.callback_metrics))
|
||||
best_model_paths.append(trainer.checkpoint_callback.best_model_path)
|
||||
|
||||
for idx in range(len(results) - 1):
|
||||
assert results[idx]["val_loss"] > results[idx + 1]["val_loss"]
|
||||
|
||||
for idx, best_model_path in enumerate(best_model_paths):
|
||||
if idx == 0:
|
||||
assert best_model_path.endswith(f"epoch=0{idx}.ckpt")
|
||||
|
|
|
@ -1,190 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from statistics import mean
|
||||
from typing import Iterator
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
def count_cycles_per_ms() -> float:
|
||||
"""
|
||||
Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
|
||||
|
||||
Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py
|
||||
"""
|
||||
|
||||
def measure() -> float:
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start.record()
|
||||
torch.cuda._sleep(1000000)
|
||||
end.record()
|
||||
end.synchronize()
|
||||
cycles_per_ms = 1000000 / start.elapsed_time(end)
|
||||
return cycles_per_ms
|
||||
|
||||
# Get 10 values and remove the 2 max and 2 min and return the avg.
|
||||
# This is to avoid system disturbance that skew the results, e.g.
|
||||
# the very first cuda call likely does a bunch of init, which takes
|
||||
# much longer than subsequent calls.
|
||||
#
|
||||
# Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
|
||||
# and seems to return stable values. Therefore, we enable caching
|
||||
# using lru_cache decorator above.
|
||||
num = 10
|
||||
vals = []
|
||||
for _ in range(num):
|
||||
vals.append(measure())
|
||||
vals = sorted(vals)
|
||||
return mean(vals[2 : num - 2])
|
||||
|
||||
|
||||
_CYCLES_PER_MS = int(count_cycles_per_ms()) if torch.cuda.is_available() else 0
|
||||
_BATCH_SIZE = 128
|
||||
_EMB_SZ = 100
|
||||
_EMB_DIM = 64
|
||||
|
||||
|
||||
class RandomSparseDataset(IterableDataset):
|
||||
def __init__(self, emb_dim: int, batch_size: int, count: int) -> None:
|
||||
self.emb_dim = emb_dim
|
||||
self.batch_size = batch_size
|
||||
self.count = count
|
||||
|
||||
def __iter__(self):
|
||||
for _ in range(self.count):
|
||||
yield torch.randint(self.emb_dim, [self.batch_size])
|
||||
|
||||
|
||||
class ToyDLRMModel(LightningModule):
|
||||
"""
|
||||
A toy model for mimicking the communication overhead of sharded embedding
|
||||
modules in DLRM models.
|
||||
|
||||
DLRM models can be trained in a DDP-like fashion, where each trainer
|
||||
receives different batches (embedding indices in this example). Since the
|
||||
embeddings are sharded across trainers, the lookup process involves (1)
|
||||
routing the indices to the trainer that possesses the corresponding
|
||||
embeddings (2) performing local lookup (3) routing the embedding lookup
|
||||
result back.
|
||||
|
||||
The toy model doesn't actually performs index/result routing. It simply
|
||||
uses torch.cuda._sleep() to mimic the cost of the communication op (i.e.
|
||||
a2a).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.automatic_optimization = False
|
||||
self.local_embedding = torch.nn.Embedding(_EMB_SZ, _EMB_DIM)
|
||||
|
||||
def _route_indices(self, batch: torch.Tensor, non_blocking=False):
|
||||
"""
|
||||
This can be parallelized across different batches since it's model
|
||||
weight independent.
|
||||
|
||||
Why not run this in dataloader/datamodule?
|
||||
- The routing logic depends on how model is sharded
|
||||
- Putting this in data preprocessor changes the semantic of the model
|
||||
"""
|
||||
torch.cuda._sleep(_CYCLES_PER_MS * 1_000)
|
||||
if not non_blocking:
|
||||
torch.cuda.synchronize()
|
||||
return batch
|
||||
|
||||
def _route_result(self, result: torch.Tensor, non_blocking=False):
|
||||
torch.cuda._sleep(_CYCLES_PER_MS * 1_000)
|
||||
if not non_blocking:
|
||||
torch.cuda.synchronize()
|
||||
return result
|
||||
|
||||
def forward(self, indices: torch.Tensor):
|
||||
local_indices = self._route_indices(indices)
|
||||
result = self.local_embedding(local_indices)
|
||||
return self._route_result(result)
|
||||
|
||||
def training_step(self, batch: torch.Tensor, batch_idx: int) -> STEP_OUTPUT:
|
||||
return self.forward(batch)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.SGD(self.local_embedding.parameters(), lr=0.1)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(RandomSparseDataset(_EMB_DIM, _BATCH_SIZE, 5))
|
||||
|
||||
|
||||
class AsyncToyDLRMModel(ToyDLRMModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.comm_stream = torch.cuda.Stream()
|
||||
self.batch_i = None
|
||||
self.batch_i_ready = torch.cuda.Event()
|
||||
|
||||
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
|
||||
if self.batch_i is None:
|
||||
self.batch_i = next(dataloader_iter)
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
self._route_indices(self.batch_i, non_blocking=True)
|
||||
self.batch_i_ready.record()
|
||||
|
||||
# Invariant: the routing for batch[i] has been kicked off
|
||||
is_last = False
|
||||
batch_ip1 = None
|
||||
batch_ip1_ready = torch.cuda.Event()
|
||||
try:
|
||||
batch_ip1 = next(dataloader_iter)
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
self._route_indices(batch_ip1, non_blocking=True)
|
||||
batch_ip1_ready.record()
|
||||
except StopIteration:
|
||||
is_last = True
|
||||
|
||||
self.batch_i_ready.wait()
|
||||
|
||||
result = self.local_embedding(self.batch_i)
|
||||
self._route_result(result)
|
||||
|
||||
self.batch_i = batch_ip1
|
||||
self.batch_i_ready = batch_ip1_ready
|
||||
|
||||
return {"is_last": is_last}
|
||||
|
||||
|
||||
@RunIf(min_gpus=1)
|
||||
def test_inter_batch_parallelism(tmpdir):
|
||||
"""
|
||||
Verify the speedup of a simple inter-batch parallelization use case enabled
|
||||
by exposing `dataloader_iter` to `training_step`.
|
||||
"""
|
||||
begin_time = time.time()
|
||||
m = AsyncToyDLRMModel()
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
trainer.fit(m)
|
||||
async_duration = time.time() - begin_time
|
||||
|
||||
begin_time = time.time()
|
||||
m = ToyDLRMModel()
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
trainer.fit(m)
|
||||
sync_duration = time.time() - begin_time
|
||||
|
||||
# We expect 2x speedup. However, we only assert that the async
|
||||
# training_step is faster in order to avoid flaky tests
|
||||
assert async_duration < sync_duration, "Expect `AsyncToyDLRMModel` to train faster than `ToyDLRMModel`."
|
|
@ -1,183 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Iterator
|
||||
|
||||
import pytest
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from tests.helpers import BoringModel, RandomDataset
|
||||
|
||||
_BATCH_SIZE = 32
|
||||
_DATASET_LEN = 64
|
||||
|
||||
|
||||
class DummyWaitable:
|
||||
def __init__(self, val: Any) -> None:
|
||||
self.val = val
|
||||
|
||||
def wait(self) -> Any:
|
||||
return self.val
|
||||
|
||||
|
||||
class AsyncBoringModel(BoringModel):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.automatic_optimization = False
|
||||
self.batch_i_handle = None
|
||||
self.num_batches_processed = 0
|
||||
|
||||
def _async_op(self, batch: Any) -> DummyWaitable:
|
||||
return DummyWaitable(val=batch)
|
||||
|
||||
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
|
||||
if self.batch_i_handle is None:
|
||||
batch_i_raw = next(dataloader_iter)
|
||||
self.batch_i_handle = self._async_op(batch_i_raw)
|
||||
|
||||
# Invariant: _async_op for batch[i] has been initiated
|
||||
batch_ip1_handle = None
|
||||
is_last = False
|
||||
try:
|
||||
batch_ip1_raw = next(dataloader_iter)
|
||||
batch_ip1_handle = self._async_op(batch_ip1_raw)
|
||||
except StopIteration:
|
||||
is_last = True
|
||||
|
||||
batch_i = self.batch_i_handle.wait()
|
||||
|
||||
pred = self.layer(batch_i)
|
||||
loss = self.loss(batch_i, pred)
|
||||
loss.backward()
|
||||
self.optimizers().step()
|
||||
self.optimizers().zero_grad()
|
||||
|
||||
self.batch_i_handle = batch_ip1_handle
|
||||
self.num_batches_processed += 1
|
||||
|
||||
return {"loss": loss, "is_last": is_last}
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(RandomDataset(_BATCH_SIZE, _DATASET_LEN))
|
||||
|
||||
|
||||
def test_training_step_with_dataloader_access(tmpdir) -> None:
|
||||
"""
|
||||
A baseline functional test for `training_step` with dataloader access.
|
||||
"""
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = AsyncBoringModel()
|
||||
trainer.fit(m)
|
||||
assert m.num_batches_processed == _DATASET_LEN, f"Expect all {_DATASET_LEN} batches to be processed."
|
||||
|
||||
|
||||
def test_stop_iteration(tmpdir) -> None:
|
||||
"""
|
||||
Verify that when `StopIteration` is raised within `training_step`, `fit()`
|
||||
terminiates as expected.
|
||||
"""
|
||||
EXPECT_NUM_BATCHES_PROCESSED = 2
|
||||
|
||||
class TestModel(AsyncBoringModel):
|
||||
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
|
||||
output = super().training_step(dataloader_iter)
|
||||
if self.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED:
|
||||
raise StopIteration()
|
||||
return output
|
||||
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = TestModel()
|
||||
trainer.fit(m)
|
||||
assert (
|
||||
m.num_batches_processed == EXPECT_NUM_BATCHES_PROCESSED
|
||||
), "Expect {EXPECT_NUM_BATCHES_PROCESSED} batches to be processed."
|
||||
|
||||
|
||||
def test_on_train_batch_start_overridden(tmpdir) -> None:
|
||||
"""
|
||||
Verify that a `MisconfigurationException` is raised when
|
||||
`on_train_batch_start` is overridden on the `LightningModule`.
|
||||
"""
|
||||
|
||||
class InvalidModel(AsyncBoringModel):
|
||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = InvalidModel()
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer.fit(m)
|
||||
|
||||
|
||||
def test_on_train_batch_end_overridden(tmpdir) -> None:
|
||||
"""
|
||||
Verify that a `MisconfigurationException` is raised when
|
||||
`on_train_batch_end` is overridden on the `LightningModule`.
|
||||
"""
|
||||
|
||||
class InvalidModel(AsyncBoringModel):
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = InvalidModel()
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer.fit(m)
|
||||
|
||||
|
||||
def test_tbptt_split_batch_overridden(tmpdir) -> None:
|
||||
"""
|
||||
Verify that a `MisconfigurationException` is raised when
|
||||
`tbptt_split_batch` is overridden on the `LightningModule`.
|
||||
"""
|
||||
|
||||
class InvalidModel(AsyncBoringModel):
|
||||
def tbptt_split_batch(self, batch, split_size):
|
||||
pass
|
||||
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = InvalidModel()
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer.fit(m)
|
||||
|
||||
|
||||
def test_accumulate_grad_batches(tmpdir) -> None:
|
||||
"""
|
||||
Verify that a `MisconfigurationException` is raised when
|
||||
`accumulate_grad_batches` is not set to 1.
|
||||
"""
|
||||
trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir)
|
||||
m = AsyncBoringModel()
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer.fit(m)
|
||||
|
||||
|
||||
def test_is_last_not_set(tmpdir) -> None:
|
||||
"""
|
||||
Verify that a `MisconfigurationException` is raised when `training_step`
|
||||
doesn't include "is_last" in the result dict.
|
||||
"""
|
||||
|
||||
class InvalidModel(AsyncBoringModel):
|
||||
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
|
||||
output = super().training_step(dataloader_iter)
|
||||
del output["is_last"]
|
||||
return output
|
||||
|
||||
trainer = Trainer(max_epochs=1, accumulate_grad_batches=2, default_root_dir=tmpdir)
|
||||
m = InvalidModel()
|
||||
with pytest.raises(MisconfigurationException):
|
||||
trainer.fit(m)
|
|
@ -419,7 +419,7 @@ def test_multiple_optimizers_step(tmpdir):
|
|||
|
||||
called = False
|
||||
|
||||
def on_after_backward(self):
|
||||
def on_before_optimizer_step(self, *args):
|
||||
self.called = True
|
||||
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
|
||||
if not (torch.isinf(norm) or torch.isnan(norm)):
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from time import time
|
||||
from typing import Any
|
||||
from typing import Any, Iterator
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
@ -25,7 +25,8 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from tests.helpers import BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
|
@ -125,7 +126,8 @@ def get_cycles_per_ms() -> float:
|
|||
return sum(stats) / len(stats)
|
||||
|
||||
|
||||
BATCH_SIZE = 128
|
||||
BATCH_SIZE = 32
|
||||
DATASET_LEN = 64
|
||||
EMB_SZ = 100
|
||||
EMB_DIM = 64
|
||||
|
||||
|
@ -176,6 +178,7 @@ class RecommenderModel(BoringModel):
|
|||
def test_trainer_num_prefetch_batches(tmpdir):
|
||||
|
||||
model = RecommenderModel()
|
||||
|
||||
trainer_kwargs = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
|
@ -190,8 +193,8 @@ def test_trainer_num_prefetch_batches(tmpdir):
|
|||
trainer = Trainer(**trainer_kwargs)
|
||||
trainer.fit(model)
|
||||
t1 = time()
|
||||
global_step = trainer.global_step
|
||||
assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
|
||||
global_step = trainer.global_step
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
@ -199,9 +202,9 @@ def test_trainer_num_prefetch_batches(tmpdir):
|
|||
trainer = Trainer(**trainer_kwargs)
|
||||
trainer.fit(model)
|
||||
t3 = time()
|
||||
assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
|
||||
|
||||
assert global_step == trainer.global_step == 4
|
||||
assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
|
||||
ratio = (t3 - t2) / (t1 - t0)
|
||||
assert ratio > 1.1, ratio
|
||||
|
||||
|
@ -218,7 +221,7 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir):
|
|||
|
||||
def training_step(self, dataloader_iter, batch_idx):
|
||||
assert self.count == batch_idx
|
||||
assert isinstance(self.trainer.data_connector.data_fetcher, DataLoaderIterDataFetcher)
|
||||
assert isinstance(self.trainer.data_connector.train_data_fetcher, DataLoaderIterDataFetcher)
|
||||
# fetch 2 batches
|
||||
self.batches.append(next(dataloader_iter))
|
||||
self.batches.append(next(dataloader_iter))
|
||||
|
@ -227,7 +230,10 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir):
|
|||
assert isinstance(batch, torch.Tensor) or batch is None
|
||||
self.count += 2
|
||||
if self.automatic_optimization:
|
||||
return super().training_step(batch, 0)
|
||||
loss = super().training_step(batch, 0)
|
||||
with pytest.raises(MisconfigurationException, match="dataloader_iter"):
|
||||
self.log("train_loss", loss["loss"])
|
||||
self.log("train_loss", loss["loss"], batch_size=1)
|
||||
else:
|
||||
opt = self.optimizers()
|
||||
output = self(batch)
|
||||
|
@ -236,10 +242,152 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir):
|
|||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
training_epoch_end = None
|
||||
def training_epoch_end(self, *_):
|
||||
assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33
|
||||
assert self.trainer.data_connector.train_data_fetcher.fetched == 64
|
||||
assert self.count == 64
|
||||
|
||||
model = TestModel(automatic_optimization=automatic_optimization)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
|
||||
trainer.data_connector.data_fetcher = DataLoaderIterDataFetcher()
|
||||
trainer.fit(model)
|
||||
assert model.count == 64
|
||||
|
||||
|
||||
class DummyWaitable:
|
||||
def __init__(self, val: Any) -> None:
|
||||
self.val = val
|
||||
|
||||
def wait(self) -> Any:
|
||||
return self.val
|
||||
|
||||
|
||||
class AsyncBoringModel(BoringModel):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.automatic_optimization = False
|
||||
self.batch_i_handle = None
|
||||
self.num_batches_processed = 0
|
||||
|
||||
def _async_op(self, batch: Any) -> DummyWaitable:
|
||||
return DummyWaitable(val=batch)
|
||||
|
||||
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
|
||||
if self.batch_i_handle is None:
|
||||
batch_i_raw = next(dataloader_iter)
|
||||
self.batch_i_handle = self._async_op(batch_i_raw)
|
||||
|
||||
# Invariant: _async_op for batch[i] has been initiated
|
||||
batch_ip1_handle = None
|
||||
is_last = False
|
||||
try:
|
||||
batch_ip1_raw = next(dataloader_iter)
|
||||
batch_ip1_handle = self._async_op(batch_ip1_raw)
|
||||
except StopIteration:
|
||||
is_last = True
|
||||
|
||||
batch_i = self.batch_i_handle.wait()
|
||||
|
||||
pred = self.layer(batch_i)
|
||||
loss = self.loss(batch_i, pred)
|
||||
loss.backward()
|
||||
self.optimizers().step()
|
||||
self.optimizers().zero_grad()
|
||||
|
||||
self.batch_i_handle = batch_ip1_handle
|
||||
self.num_batches_processed += 1
|
||||
|
||||
return {"loss": loss, "is_last": is_last}
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(RandomDataset(BATCH_SIZE, DATASET_LEN))
|
||||
|
||||
|
||||
def test_training_step_with_dataloader_access(tmpdir) -> None:
|
||||
"""
|
||||
A baseline functional test for `training_step` with dataloader access.
|
||||
"""
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = AsyncBoringModel()
|
||||
trainer.fit(m)
|
||||
assert m.num_batches_processed == DATASET_LEN, f"Expect all {DATASET_LEN} batches to be processed."
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trigger_stop_iteration", [False, True])
|
||||
def test_stop_iteration(trigger_stop_iteration, tmpdir):
|
||||
"""
|
||||
Verify that StopIteration properly terminates the training when this is trigged
|
||||
from the current `dataloader_iter`
|
||||
"""
|
||||
EXPECT_NUM_BATCHES_PROCESSED = 2
|
||||
|
||||
class TestModel(AsyncBoringModel):
|
||||
def __init__(self, trigger_stop_iteration) -> None:
|
||||
super().__init__()
|
||||
self.trigger_stop_iteration = trigger_stop_iteration
|
||||
|
||||
def training_step(self, dataloader_iter: Iterator, *args) -> STEP_OUTPUT:
|
||||
output = super().training_step(dataloader_iter)
|
||||
if self.trigger_stop_iteration and args[0] == EXPECT_NUM_BATCHES_PROCESSED:
|
||||
raise StopIteration
|
||||
return output
|
||||
|
||||
def train_dataloader(self):
|
||||
if self.trigger_stop_iteration:
|
||||
return DataLoader(RandomDataset(BATCH_SIZE, 2 * EXPECT_NUM_BATCHES_PROCESSED))
|
||||
return DataLoader(RandomDataset(BATCH_SIZE, EXPECT_NUM_BATCHES_PROCESSED))
|
||||
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = TestModel(trigger_stop_iteration)
|
||||
trainer.fit(m)
|
||||
expected = EXPECT_NUM_BATCHES_PROCESSED
|
||||
if trigger_stop_iteration:
|
||||
expected *= 2
|
||||
assert m.num_batches_processed == expected
|
||||
|
||||
|
||||
def test_on_train_batch_start_overridden(tmpdir) -> None:
|
||||
"""
|
||||
Verify that a `MisconfigurationException` is raised when
|
||||
`on_train_batch_start` is overridden on the `LightningModule`.
|
||||
"""
|
||||
|
||||
class InvalidModel(AsyncBoringModel):
|
||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = InvalidModel()
|
||||
with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_start` is not compatible with"):
|
||||
trainer.fit(m)
|
||||
|
||||
|
||||
def test_on_train_batch_end_overridden(tmpdir) -> None:
|
||||
"""
|
||||
Verify that a `MisconfigurationException` is raised when
|
||||
`on_train_batch_end` is overridden on the `LightningModule`.
|
||||
"""
|
||||
|
||||
class InvalidModel(AsyncBoringModel):
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
pass
|
||||
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = InvalidModel()
|
||||
with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_end` is not compatible with"):
|
||||
trainer.fit(m)
|
||||
|
||||
|
||||
def test_tbptt_split_batch_overridden(tmpdir) -> None:
|
||||
"""
|
||||
Verify that a `MisconfigurationException` is raised when
|
||||
`tbptt_split_batch` is overridden on the `LightningModule`.
|
||||
"""
|
||||
|
||||
class InvalidModel(AsyncBoringModel):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.truncated_bptt_steps = 2
|
||||
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = InvalidModel()
|
||||
with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."):
|
||||
trainer.fit(m)
|
||||
|
|
Loading…
Reference in New Issue