diff --git a/README.md b/README.md index c50a169dbe..2f00173fcd 100644 --- a/README.md +++ b/README.md @@ -313,9 +313,9 @@ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile: -### Pro-level control of training loops (advanced users) +### Pro-level control of optimization (advanced users) -For complex/professional level work, you have optional full control of the training loop and optimizers. +For complex/professional level work, you have optional full control of the optimizers. ```python class LitAutoEncoder(pl.LightningModule): diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 3ef38e141b..c44582a967 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -100,66 +100,6 @@ loggers tensorboard wandb -loops -^^^^^ - -Base Classes -"""""""""""" - -.. currentmodule:: pytorch_lightning.loops - -.. autosummary:: - :toctree: api - :nosignatures: - :template: classtemplate.rst - - ~dataloader.dataloader_loop.DataLoaderLoop - ~loop.Loop - -Training -"""""""" - -.. currentmodule:: pytorch_lightning.loops - -.. autosummary:: - :toctree: api - :nosignatures: - :template: classtemplate.rst - - ~epoch.TrainingEpochLoop - FitLoop - ~optimization.ManualOptimization - ~optimization.OptimizerLoop - - -Validation and Testing -"""""""""""""""""""""" - -.. currentmodule:: pytorch_lightning.loops - -.. autosummary:: - :toctree: api - :nosignatures: - :template: classtemplate.rst - - ~epoch.EvaluationEpochLoop - ~dataloader.EvaluationLoop - - -Prediction -"""""""""" - -.. currentmodule:: pytorch_lightning.loops - -.. autosummary:: - :toctree: api - :nosignatures: - :template: classtemplate.rst - - ~epoch.PredictionEpochLoop - ~dataloader.PredictionLoop - - plugins ^^^^^^^ diff --git a/docs/source-pytorch/starter/introduction.rst b/docs/source-pytorch/starter/introduction.rst index 47098062e8..45ede778ca 100644 --- a/docs/source-pytorch/starter/introduction.rst +++ b/docs/source-pytorch/starter/introduction.rst @@ -306,7 +306,7 @@ If you have multiple lines of code with similar functionalities, you can use cal Use a raw PyTorch loop ====================== -For certain types of work at the bleeding-edge of research, Lightning offers experts full control of their training loops in various ways. +For certain types of work at the bleeding-edge of research, Lightning offers experts full control of optimization or the training loop in various ways. .. raw:: html @@ -333,15 +333,6 @@ For certain types of work at the bleeding-edge of research, Lightning offers exp :image_height: 220px :height: 320 -.. displayitem:: - :header: Loops - :description: Enable meta-learning, reinforcement learning, GANs with full control. - :col_css: col-md-4 - :image_center: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/loops.png - :button_link: ../extensions/loops.html - :image_height: 220px - :height: 320 - .. raw:: html diff --git a/src/lightning_app/utilities/introspection.py b/src/lightning_app/utilities/introspection.py index 3da59f488c..a0a222e9e5 100644 --- a/src/lightning_app/utilities/introspection.py +++ b/src/lightning_app/utilities/introspection.py @@ -247,10 +247,6 @@ class LightningAcceleratorVisitor(LightningVisitor): class_name = "Accelerator" -class LightningLoopVisitor(LightningVisitor): - class_name = "Loop" - - class TorchMetricVisitor(LightningVisitor): class_name = "Metric" @@ -290,7 +286,6 @@ class Scanner: LightningPrecisionPluginVisitor, LightningAcceleratorVisitor, LightningLoggerVisitor, - LightningLoopVisitor, TorchMetricVisitor, FabricVisitor, LightningProfilerVisitor, diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index f8945a6d1b..951a244587 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -113,6 +113,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) * Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) * Removed the default `Loop.run()` implementation ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) + * The loop classes are now marked as protected ([#16445](https://github.com/Lightning-AI/lightning/pull/16445)) - Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172)) * Removed the `LightningModule.truncated_bptt_steps` attribute diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index f4558594cd..9a76d6a847 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -300,9 +300,9 @@ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile: -### Pro-level control of training loops (advanced users) +### Pro-level control of optimization (advanced users) -For complex/professional level work, you have optional full control of the training loop and optimizers. +For complex/professional level work, you have optional full control of the optimizers. ```python class LitAutoEncoder(pl.LightningModule): diff --git a/src/pytorch_lightning/loops/__init__.py b/src/pytorch_lightning/loops/__init__.py index 5fde69c150..ef3b19d57d 100644 --- a/src/pytorch_lightning/loops/__init__.py +++ b/src/pytorch_lightning/loops/__init__.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.loops.loop import Loop # noqa: F401 isort: skip (avoids circular imports) -from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401 -from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401 -from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 -from pytorch_lightning.loops.optimization import ManualOptimization, OptimizerLoop # noqa: F401 +from pytorch_lightning.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports) +from pytorch_lightning.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401 +from pytorch_lightning.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401 +from pytorch_lightning.loops.fit_loop import _FitLoop # noqa: F401 +from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401 diff --git a/src/pytorch_lightning/loops/dataloader/__init__.py b/src/pytorch_lightning/loops/dataloader/__init__.py index db2b2f7926..1d0189ef17 100644 --- a/src/pytorch_lightning/loops/dataloader/__init__.py +++ b/src/pytorch_lightning/loops/dataloader/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401 -from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop # noqa: F401 -from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop # noqa: F401 +from pytorch_lightning.loops.dataloader.dataloader_loop import _DataLoaderLoop # noqa: F401 +from pytorch_lightning.loops.dataloader.evaluation_loop import _EvaluationLoop # noqa: F401 +from pytorch_lightning.loops.dataloader.prediction_loop import _PredictionLoop # noqa: F401 diff --git a/src/pytorch_lightning/loops/dataloader/dataloader_loop.py b/src/pytorch_lightning/loops/dataloader/dataloader_loop.py index a313ff28df..03fc2aecf0 100644 --- a/src/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/src/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -17,11 +17,11 @@ from typing import Sequence from torch.utils.data import DataLoader -from pytorch_lightning.loops.loop import Loop +from pytorch_lightning.loops.loop import _Loop from pytorch_lightning.trainer.progress import DataLoaderProgress -class DataLoaderLoop(Loop): +class _DataLoaderLoop(_Loop): """Base class to loop over all dataloaders.""" def __init__(self) -> None: diff --git a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py index 9c82050c24..ac1a8c3209 100644 --- a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -23,8 +23,8 @@ from torch.utils.data.dataloader import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE -from pytorch_lightning.loops.dataloader import DataLoaderLoop -from pytorch_lightning.loops.epoch import EvaluationEpochLoop +from pytorch_lightning.loops.dataloader import _DataLoaderLoop +from pytorch_lightning.loops.epoch import _EvaluationEpochLoop from pytorch_lightning.loops.utilities import _set_sampler_epoch from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection from pytorch_lightning.trainer.states import TrainerFn @@ -38,7 +38,7 @@ if _RICH_AVAILABLE: from rich.table import Column, Table -class EvaluationLoop(DataLoaderLoop): +class _EvaluationLoop(_DataLoaderLoop): """Top-level loop where validation/testing starts. It simply iterates over each evaluation dataloader from one to the next by calling ``EvaluationEpochLoop.run()`` in @@ -47,7 +47,7 @@ class EvaluationLoop(DataLoaderLoop): def __init__(self, verbose: bool = True) -> None: super().__init__() - self.epoch_loop = EvaluationEpochLoop() + self.epoch_loop = _EvaluationEpochLoop() self.verbose = verbose self._results = _ResultCollection(training=False) @@ -304,7 +304,7 @@ class EvaluationLoop(DataLoaderLoop): def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]: for k, v in data.items(): if isinstance(v, dict): - for new_key in apply_to_collection(v, dict, EvaluationLoop._get_keys): + for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys): yield (k, *new_key) # this need to be in parenthesis for older python versions else: yield k, @@ -317,13 +317,13 @@ class EvaluationLoop(DataLoaderLoop): result = data[target_start] if not rest: return result - return EvaluationLoop._find_value(result, rest) + return _EvaluationLoop._find_value(result, rest) @staticmethod def _print_results(results: List[_OUT_DICT], stage: str) -> None: # remove the dl idx suffix results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results] - metrics_paths = {k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys} + metrics_paths = {k for keys in apply_to_collection(results, dict, _EvaluationLoop._get_keys) for k in keys} if not metrics_paths: return @@ -341,7 +341,7 @@ class EvaluationLoop(DataLoaderLoop): for result in results: for metric, row in zip(metrics_paths, rows): - val = EvaluationLoop._find_value(result, metric) + val = _EvaluationLoop._find_value(result, metric) if val is not None: if isinstance(val, Tensor): val = val.item() if val.numel() == 1 else val.tolist() diff --git a/src/pytorch_lightning/loops/dataloader/prediction_loop.py b/src/pytorch_lightning/loops/dataloader/prediction_loop.py index 86de2b45c7..e779b13f2e 100644 --- a/src/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/src/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -2,26 +2,26 @@ from typing import Any, List, Optional, Sequence, Union from torch.utils.data import DataLoader -from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop -from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop +from pytorch_lightning.loops.dataloader.dataloader_loop import _DataLoaderLoop +from pytorch_lightning.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop from pytorch_lightning.loops.utilities import _set_sampler_epoch from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PREDICT_OUTPUT -class PredictionLoop(DataLoaderLoop): +class _PredictionLoop(_DataLoaderLoop): """Top-level loop where prediction starts. - It simply iterates over each predict dataloader from one to the next by calling ``PredictionEpochLoop.run()`` in its - ``advance()`` method. + It simply iterates over each predict dataloader from one to the next by calling ``_PredictionEpochLoop.run()`` in + its ``advance()`` method. """ def __init__(self) -> None: super().__init__() self.predictions: List[List[Any]] = [] self.epoch_batch_indices: List[List[List[int]]] = [] # used by PredictionWriter - self.epoch_loop = PredictionEpochLoop() + self.epoch_loop = _PredictionEpochLoop() self._results = None # for `trainer._results` access self._return_predictions: bool = False diff --git a/src/pytorch_lightning/loops/epoch/__init__.py b/src/pytorch_lightning/loops/epoch/__init__.py index 789953937a..a150858b0a 100644 --- a/src/pytorch_lightning/loops/epoch/__init__.py +++ b/src/pytorch_lightning/loops/epoch/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop # noqa: F401 -from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop # noqa: F401 -from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop # noqa: F401 +from pytorch_lightning.loops.epoch.evaluation_epoch_loop import _EvaluationEpochLoop # noqa: F401 +from pytorch_lightning.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop # noqa: F401 +from pytorch_lightning.loops.epoch.training_epoch_loop import _TrainingEpochLoop # noqa: F401 diff --git a/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index ce0fcf17a3..3a9027f7b5 100644 --- a/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -18,7 +18,7 @@ from typing import Any, Dict, Optional, Union from torch.utils.data import DataLoader -from pytorch_lightning.loops.loop import Loop +from pytorch_lightning.loops.loop import _Loop from pytorch_lightning.trainer.progress import BatchProgress from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader @@ -33,7 +33,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT -class EvaluationEpochLoop(Loop): +class _EvaluationEpochLoop(_Loop): """This is the loop performing the evaluation. It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current diff --git a/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 1818c46094..3e2c2f8684 100644 --- a/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Iterator, List, Tuple, Union import torch from lightning_fabric.utilities import move_data_to_device -from pytorch_lightning.loops.loop import Loop +from pytorch_lightning.loops.loop import _Loop from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.rank_zero import WarningCache @@ -12,7 +12,7 @@ from pytorch_lightning.utilities.rank_zero import WarningCache warning_cache = WarningCache() -class PredictionEpochLoop(Loop): +class _PredictionEpochLoop(_Loop): """Loop performing prediction on arbitrary sequentially used dataloaders.""" def __init__(self) -> None: diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 300796c8f3..4c5eb2c868 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -21,7 +21,7 @@ from lightning_utilities.core.apply_func import apply_to_collection import pytorch_lightning as pl from pytorch_lightning import loops # import as loops to avoid circular imports -from pytorch_lightning.loops.optimization import ManualOptimization, OptimizerLoop +from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached @@ -39,7 +39,7 @@ _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_ _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] -class TrainingEpochLoop(loops.Loop): +class _TrainingEpochLoop(loops._Loop): """ Iterates over all batches in the dataloader (one epoch) that the user returns in their :meth:`~pytorch_lightning.core.module.LightningModule.train_dataloader` method. @@ -73,10 +73,10 @@ class TrainingEpochLoop(loops.Loop): self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() - self.optimizer_loop = OptimizerLoop() - self.manual_loop = ManualOptimization() + self.optimizer_loop = _OptimizerLoop() + self.manual_loop = _ManualOptimization() - self.val_loop = loops.EvaluationLoop(verbose=False) + self.val_loop = loops._EvaluationLoop(verbose=False) self._results = _ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index eecb3ce244..03ada08f75 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -15,8 +15,8 @@ import logging from typing import Any, Optional, Type import pytorch_lightning as pl -from pytorch_lightning.loops import Loop -from pytorch_lightning.loops.epoch import TrainingEpochLoop +from pytorch_lightning.loops import _Loop +from pytorch_lightning.loops.epoch import _TrainingEpochLoop from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signatu log = logging.getLogger(__name__) -class FitLoop(Loop): +class _FitLoop(_Loop): """This loop is the top-level loop where training starts. It simply counts the epochs and iterates from one to the next by calling ``TrainingEpochLoop.run()`` in its @@ -73,7 +73,7 @@ class FitLoop(Loop): self.max_epochs = max_epochs self.min_epochs = min_epochs - self.epoch_loop = TrainingEpochLoop() + self.epoch_loop = _TrainingEpochLoop() self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True @@ -116,13 +116,13 @@ class FitLoop(Loop): ) self.epoch_loop.max_steps = value - @Loop.restarting.setter + @_Loop.restarting.setter def restarting(self, restarting: bool) -> None: # if the last epoch completely finished, we are not actually restarting values = self.epoch_progress.current.ready, self.epoch_progress.current.started epoch_unfinished = any(v != self.epoch_progress.current.processed for v in values) restarting = restarting and epoch_unfinished or self._iteration_based_training() - Loop.restarting.fset(self, restarting) # call the parent setter + _Loop.restarting.fset(self, restarting) # call the parent setter @property def prefetch_batches(self) -> int: diff --git a/src/pytorch_lightning/loops/loop.py b/src/pytorch_lightning/loops/loop.py index 6bc2ee4698..e31378a7f5 100644 --- a/src/pytorch_lightning/loops/loop.py +++ b/src/pytorch_lightning/loops/loop.py @@ -21,7 +21,7 @@ from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.imports import _fault_tolerant_training -class Loop: +class _Loop: """Basic Loops interface.""" def __init__(self) -> None: @@ -39,7 +39,7 @@ class Loop: """Connects this loop's trainer and its children.""" self._trainer = trainer for v in self.__dict__.values(): - if isinstance(v, Loop): + if isinstance(v, _Loop): v.trainer = trainer @property @@ -52,7 +52,7 @@ class Loop: """Connects this loop's restarting value and its children.""" self._restarting = restarting for loop in vars(self).values(): - if isinstance(loop, Loop): + if isinstance(loop, _Loop): loop.restarting = restarting def on_save_checkpoint(self) -> Dict: @@ -85,7 +85,7 @@ class Loop: key = prefix + k if isinstance(v, BaseProgress): destination[key] = v.state_dict() - elif isinstance(v, Loop): + elif isinstance(v, _Loop): v.state_dict(destination, key + ".") elif ft_enabled and isinstance(v, _ResultCollection): # sync / unsync metrics @@ -104,7 +104,7 @@ class Loop: """Loads the state of this loop and all its children.""" self._load_from_state_dict(state_dict.copy(), prefix, metrics) for k, v in self.__dict__.items(): - if isinstance(v, Loop): + if isinstance(v, _Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".") self.restarting = True diff --git a/src/pytorch_lightning/loops/optimization/__init__.py b/src/pytorch_lightning/loops/optimization/__init__.py index 07249b6a13..94c63f7b7f 100644 --- a/src/pytorch_lightning/loops/optimization/__init__.py +++ b/src/pytorch_lightning/loops/optimization/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401 -from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401 +from pytorch_lightning.loops.optimization.manual_loop import _ManualOptimization # noqa: F401 +from pytorch_lightning.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401 diff --git a/src/pytorch_lightning/loops/optimization/manual_loop.py b/src/pytorch_lightning/loops/optimization/manual_loop.py index 172b19f230..b13275e066 100644 --- a/src/pytorch_lightning/loops/optimization/manual_loop.py +++ b/src/pytorch_lightning/loops/optimization/manual_loop.py @@ -19,7 +19,7 @@ from typing import Any, Dict, Optional from torch import Tensor from pytorch_lightning.core.optimizer import do_nothing_closure -from pytorch_lightning.loops import Loop +from pytorch_lightning.loops import _Loop from pytorch_lightning.loops.optimization.closure import OutputResult from pytorch_lightning.loops.utilities import _build_training_step_kwargs from pytorch_lightning.trainer.progress import Progress, ReadyCompletedTracker @@ -65,7 +65,7 @@ class ManualResult(OutputResult): _OUTPUTS_TYPE = Dict[str, Any] -class ManualOptimization(Loop): +class _ManualOptimization(_Loop): """A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens entirely in the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` and therefore the user is responsible for back-propagating gradients and making calls to the optimizers. diff --git a/src/pytorch_lightning/loops/optimization/optimizer_loop.py b/src/pytorch_lightning/loops/optimization/optimizer_loop.py index cce34374fd..6b6a7b51e5 100644 --- a/src/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/src/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -22,7 +22,7 @@ from typing_extensions import OrderedDict from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.loops import Loop +from pytorch_lightning.loops import _Loop from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior, _build_training_step_kwargs from pytorch_lightning.trainer.progress import OptimizationProgress @@ -146,7 +146,7 @@ class Closure(AbstractClosure[ClosureResult]): _OUTPUTS_TYPE = Dict[int, Dict[str, Any]] -class OptimizerLoop(Loop): +class _OptimizerLoop(_Loop): """Iterates over one or multiple optimizers and for each one it calls the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` method with the batch, the current batch index and the optimizer index if multiple optimizers are requested. diff --git a/src/pytorch_lightning/loops/utilities.py b/src/pytorch_lightning/loops/utilities.py index 342cded638..23591990db 100644 --- a/src/pytorch_lightning/loops/utilities.py +++ b/src/pytorch_lightning/loops/utilities.py @@ -25,7 +25,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl from lightning_fabric.utilities.warnings import PossibleUserWarning from pytorch_lightning.callbacks.timer import Timer -from pytorch_lightning.loops import Loop +from pytorch_lightning.loops import _Loop from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.progress import BaseProgress @@ -182,11 +182,11 @@ def _is_max_limit_reached(current: int, maximum: int = -1) -> bool: return maximum != -1 and current >= maximum -def _reset_progress(loop: Loop) -> None: +def _reset_progress(loop: _Loop) -> None: for v in vars(loop).values(): if isinstance(v, BaseProgress): v.reset() - elif isinstance(v, Loop): + elif isinstance(v, _Loop): _reset_progress(v) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 5044bd1196..d71e38debd 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -54,9 +54,9 @@ from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.loggers import Logger from pytorch_lightning.loggers.tensorboard import TensorBoardLogger -from pytorch_lightning.loops import PredictionLoop, TrainingEpochLoop -from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop -from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.loops import _PredictionLoop, _TrainingEpochLoop +from pytorch_lightning.loops.dataloader.evaluation_loop import _EvaluationLoop +from pytorch_lightning.loops.fit_loop import _FitLoop from pytorch_lightning.loops.utilities import _parse_loop_limits, _reset_progress from pytorch_lightning.plugins import PLUGIN_INPUT, PrecisionPlugin from pytorch_lightning.profilers import Profiler @@ -373,11 +373,11 @@ class Trainer: self.tuner = Tuner(self) # init loops - self.fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs) - self.fit_loop.epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps) - self.validate_loop = EvaluationLoop() - self.test_loop = EvaluationLoop() - self.predict_loop = PredictionLoop() + self.fit_loop = _FitLoop(min_epochs=min_epochs, max_epochs=max_epochs) + self.fit_loop.epoch_loop = _TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps) + self.validate_loop = _EvaluationLoop() + self.test_loop = _EvaluationLoop() + self.predict_loop = _PredictionLoop() self.fit_loop.trainer = self self.validate_loop.trainer = self self.test_loop.trainer = self @@ -1939,7 +1939,7 @@ class Trainer: return self.fit_loop.epoch_loop.batch_progress.is_last_batch @property - def _evaluation_loop(self) -> EvaluationLoop: + def _evaluation_loop(self) -> _EvaluationLoop: if self.state.fn == TrainerFn.FITTING: return self.fit_loop.epoch_loop.val_loop if self.state.fn == TrainerFn.VALIDATING: @@ -1949,7 +1949,7 @@ class Trainer: raise RuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope") @property - def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop]]: + def _active_loop(self) -> Optional[Union[_FitLoop, _EvaluationLoop, _PredictionLoop]]: if self.training: return self.fit_loop if self.sanity_checking or self.evaluating: diff --git a/tests/tests_app/core/scripts/lightning_overrides.py b/tests/tests_app/core/scripts/lightning_overrides.py index 162cb7ee5a..287f257619 100644 --- a/tests/tests_app/core/scripts/lightning_overrides.py +++ b/tests/tests_app/core/scripts/lightning_overrides.py @@ -11,7 +11,6 @@ if _is_pytorch_lightning_available(): from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import Logger - from pytorch_lightning.loops import Loop from pytorch_lightning.plugins import PrecisionPlugin from pytorch_lightning.profilers import Profiler @@ -42,9 +41,6 @@ if __name__ == "__main__": class BoringLogger(Logger): pass - class BoringLoop(Loop): - pass - class BoringMetric(Metric): pass diff --git a/tests/tests_app/utilities/test_introspection.py b/tests/tests_app/utilities/test_introspection.py index f08d157559..e701ffcf51 100644 --- a/tests/tests_app/utilities/test_introspection.py +++ b/tests/tests_app/utilities/test_introspection.py @@ -55,7 +55,6 @@ def test_introspection_lightning_overrides(): "Fabric", "Logger", "LightningModule", - "Loop", "Metric", "PrecisionPlugin", "Trainer", diff --git a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py index b6e9757c83..6fefe498e7 100644 --- a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py @@ -18,7 +18,7 @@ import pytest from pytorch_lightning import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.loops import TrainingEpochLoop +from pytorch_lightning.loops import _TrainingEpochLoop from pytorch_lightning.trainer.trainer import Trainer _out00 = {"loss": 0.0} @@ -43,7 +43,7 @@ class TestPrepareOutputs: def prepare_outputs_training_epoch_end(self, batch_outputs, num_optimizers, automatic_optimization=True): return self.prepare_outputs( - TrainingEpochLoop._prepare_outputs_training_epoch_end, + _TrainingEpochLoop._prepare_outputs_training_epoch_end, batch_outputs, num_optimizers, automatic_optimization=automatic_optimization, @@ -51,7 +51,7 @@ class TestPrepareOutputs: def prepare_outputs_training_batch_end(self, batch_outputs, num_optimizers, automatic_optimization=True): return self.prepare_outputs( - TrainingEpochLoop._prepare_outputs_training_batch_end, + _TrainingEpochLoop._prepare_outputs_training_batch_end, batch_outputs, num_optimizers, automatic_optimization=automatic_optimization, diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index 143e014009..8db9f2173f 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -24,7 +24,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden from tests_pytorch.helpers.runif import RunIf -@mock.patch("pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop._on_evaluation_epoch_end") +@mock.patch("pytorch_lightning.loops.dataloader.evaluation_loop._EvaluationLoop._on_evaluation_epoch_end") def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): """Tests that `on_evaluation_epoch_end` is called for `on_validation_epoch_end` and `on_test_epoch_end` hooks.""" diff --git a/tests/tests_pytorch/loops/test_loop_state_dict.py b/tests/tests_pytorch/loops/test_loop_state_dict.py index 72d846e535..afea14bf8b 100644 --- a/tests/tests_pytorch/loops/test_loop_state_dict.py +++ b/tests/tests_pytorch/loops/test_loop_state_dict.py @@ -15,19 +15,19 @@ import os from unittest import mock from unittest.mock import Mock -from pytorch_lightning.loops import FitLoop +from pytorch_lightning.loops import _FitLoop from pytorch_lightning.trainer.trainer import Trainer def test_loops_state_dict(): trainer = Trainer() - fit_loop = FitLoop() + fit_loop = _FitLoop() fit_loop.trainer = trainer state_dict = fit_loop.state_dict() - new_fit_loop = FitLoop() + new_fit_loop = _FitLoop() new_fit_loop.trainer = trainer new_fit_loop.load_state_dict(state_dict) diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 30abb7bff7..1e68b879c8 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -25,13 +25,13 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoad from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset -from pytorch_lightning.loops import Loop +from pytorch_lightning.loops import _Loop from pytorch_lightning.trainer.progress import BaseProgress from tests_pytorch.helpers.runif import RunIf def test_restarting_loops_recursive(): - class MyLoop(Loop): + class MyLoop(_Loop): def __init__(self, loop=None): super().__init__() self.child = loop @@ -52,7 +52,7 @@ class CustomException(Exception): def test_loop_restore(): - class Simple(Loop): + class Simple(_Loop): def __init__(self, dataset: Iterator): super().__init__() self.iteration_count = 0 @@ -119,7 +119,7 @@ def test_loop_hierarchy(): class SimpleProgress(BaseProgress): increment: int = 0 - class Simple(Loop): + class Simple(_Loop): def __init__(self, a): super().__init__() self.a = a diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index ffa3e40995..e989d6bec3 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -19,7 +19,7 @@ import torch from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.loops import FitLoop +from pytorch_lightning.loops import _FitLoop def test_outputs_format(tmpdir): @@ -141,7 +141,7 @@ def test_should_stop_mid_epoch(tmpdir): def test_fit_loop_done_log_messages(caplog): - fit_loop = FitLoop(max_epochs=1) + fit_loop = _FitLoop(max_epochs=1) trainer = Mock(spec=Trainer) fit_loop.trainer = trainer diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 28ed3ba80f..32b0358c79 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -29,7 +29,7 @@ from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.loops.dataloader import EvaluationLoop +from pytorch_lightning.loops.dataloader import _EvaluationLoop from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 @@ -845,7 +845,7 @@ def test_native_print_results(monkeypatch, inputs, expected): monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) with redirect_stdout(StringIO()) as out: - EvaluationLoop._print_results(*inputs) + _EvaluationLoop._print_results(*inputs) expected = expected[1:] # remove the initial line break from the """ string assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip() @@ -859,7 +859,7 @@ def test_native_print_results_encodings(monkeypatch, encoding): out = mock.Mock() out.encoding = encoding with redirect_stdout(out) as out: - EvaluationLoop._print_results(*inputs0) + _EvaluationLoop._print_results(*inputs0) # Attempt to encode everything the file is told to write with the given encoding for call_ in out.method_calls: @@ -937,7 +937,7 @@ expected3 = """ def test_rich_print_results(inputs, expected): console = get_console() with console.capture() as capture: - EvaluationLoop._print_results(*inputs) + _EvaluationLoop._print_results(*inputs) expected = expected[1:] # remove the initial line break from the """ string assert capture.get() == expected.lstrip() diff --git a/tests/tests_pytorch/utilities/test_fetching.py b/tests/tests_pytorch/utilities/test_fetching.py index f04e3e17dd..1c68fae71f 100644 --- a/tests/tests_pytorch/utilities/test_fetching.py +++ b/tests/tests_pytorch/utilities/test_fetching.py @@ -478,25 +478,25 @@ def test_fetching_is_profiled(): # validation for i in range(2): - key = f"[EvaluationEpochLoop].val_dataloader_idx_{i}_next" + key = f"[_EvaluationEpochLoop].val_dataloader_idx_{i}_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] assert len(durations) == fast_dev_run assert all(d > 0 for d in durations) # training - key = "[TrainingEpochLoop].train_dataloader_next" + key = "[_TrainingEpochLoop].train_dataloader_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] assert len(durations) == fast_dev_run assert all(d > 0 for d in durations) # test - key = "[EvaluationEpochLoop].val_dataloader_idx_0_next" + key = "[_EvaluationEpochLoop].val_dataloader_idx_0_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] assert len(durations) == fast_dev_run assert all(d > 0 for d in durations) # predict - key = "[PredictionEpochLoop].predict_dataloader_idx_0_next" + key = "[_PredictionEpochLoop].predict_dataloader_idx_0_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] assert len(durations) == fast_dev_run @@ -524,7 +524,7 @@ def test_fetching_is_profiled(): profiler = trainer.profiler assert isinstance(profiler, SimpleProfiler) - key = "[TrainingEpochLoop].train_dataloader_next" + key = "[_TrainingEpochLoop].train_dataloader_next" assert key in profiler.recorded_durations durations = profiler.recorded_durations[key] assert len(durations) == 2 # 2 polls in training_step