Mark the loop classes as protected (#16445)

This commit is contained in:
Carlos Mocholí 2023-01-23 16:30:13 +00:00 committed by GitHub
parent 48e1c9c99c
commit 5891cdc940
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 94 additions and 172 deletions

View File

@ -313,9 +313,9 @@ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:
</details>
### Pro-level control of training loops (advanced users)
### Pro-level control of optimization (advanced users)
For complex/professional level work, you have optional full control of the training loop and optimizers.
For complex/professional level work, you have optional full control of the optimizers.
```python
class LitAutoEncoder(pl.LightningModule):

View File

@ -100,66 +100,6 @@ loggers
tensorboard
wandb
loops
^^^^^
Base Classes
""""""""""""
.. currentmodule:: pytorch_lightning.loops
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
~dataloader.dataloader_loop.DataLoaderLoop
~loop.Loop
Training
""""""""
.. currentmodule:: pytorch_lightning.loops
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
~epoch.TrainingEpochLoop
FitLoop
~optimization.ManualOptimization
~optimization.OptimizerLoop
Validation and Testing
""""""""""""""""""""""
.. currentmodule:: pytorch_lightning.loops
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
~epoch.EvaluationEpochLoop
~dataloader.EvaluationLoop
Prediction
""""""""""
.. currentmodule:: pytorch_lightning.loops
.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst
~epoch.PredictionEpochLoop
~dataloader.PredictionLoop
plugins
^^^^^^^

View File

@ -306,7 +306,7 @@ If you have multiple lines of code with similar functionalities, you can use cal
Use a raw PyTorch loop
======================
For certain types of work at the bleeding-edge of research, Lightning offers experts full control of their training loops in various ways.
For certain types of work at the bleeding-edge of research, Lightning offers experts full control of optimization or the training loop in various ways.
.. raw:: html
@ -333,15 +333,6 @@ For certain types of work at the bleeding-edge of research, Lightning offers exp
:image_height: 220px
:height: 320
.. displayitem::
:header: Loops
:description: Enable meta-learning, reinforcement learning, GANs with full control.
:col_css: col-md-4
:image_center: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/loops.png
:button_link: ../extensions/loops.html
:image_height: 220px
:height: 320
.. raw:: html
</div>

View File

@ -247,10 +247,6 @@ class LightningAcceleratorVisitor(LightningVisitor):
class_name = "Accelerator"
class LightningLoopVisitor(LightningVisitor):
class_name = "Loop"
class TorchMetricVisitor(LightningVisitor):
class_name = "Metric"
@ -290,7 +286,6 @@ class Scanner:
LightningPrecisionPluginVisitor,
LightningAcceleratorVisitor,
LightningLoggerVisitor,
LightningLoopVisitor,
TorchMetricVisitor,
FabricVisitor,
LightningProfilerVisitor,

View File

@ -113,6 +113,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
* Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
* Removed the default `Loop.run()` implementation ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
* The loop classes are now marked as protected ([#16445](https://github.com/Lightning-AI/lightning/pull/16445))
- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
* Removed the `LightningModule.truncated_bptt_steps` attribute

View File

@ -300,9 +300,9 @@ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:
</details>
### Pro-level control of training loops (advanced users)
### Pro-level control of optimization (advanced users)
For complex/professional level work, you have optional full control of the training loop and optimizers.
For complex/professional level work, you have optional full control of the optimizers.
```python
class LitAutoEncoder(pl.LightningModule):

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.loop import Loop # noqa: F401 isort: skip (avoids circular imports)
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.optimization import ManualOptimization, OptimizerLoop # noqa: F401
from pytorch_lightning.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports)
from pytorch_lightning.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import _FitLoop # noqa: F401
from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401

View File

@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop # noqa: F401
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop # noqa: F401
from pytorch_lightning.loops.dataloader.dataloader_loop import _DataLoaderLoop # noqa: F401
from pytorch_lightning.loops.dataloader.evaluation_loop import _EvaluationLoop # noqa: F401
from pytorch_lightning.loops.dataloader.prediction_loop import _PredictionLoop # noqa: F401

View File

@ -17,11 +17,11 @@ from typing import Sequence
from torch.utils.data import DataLoader
from pytorch_lightning.loops.loop import Loop
from pytorch_lightning.loops.loop import _Loop
from pytorch_lightning.trainer.progress import DataLoaderProgress
class DataLoaderLoop(Loop):
class _DataLoaderLoop(_Loop):
"""Base class to loop over all dataloaders."""
def __init__(self) -> None:

View File

@ -23,8 +23,8 @@ from torch.utils.data.dataloader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.loops.dataloader import _DataLoaderLoop
from pytorch_lightning.loops.epoch import _EvaluationEpochLoop
from pytorch_lightning.loops.utilities import _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from pytorch_lightning.trainer.states import TrainerFn
@ -38,7 +38,7 @@ if _RICH_AVAILABLE:
from rich.table import Column, Table
class EvaluationLoop(DataLoaderLoop):
class _EvaluationLoop(_DataLoaderLoop):
"""Top-level loop where validation/testing starts.
It simply iterates over each evaluation dataloader from one to the next by calling ``EvaluationEpochLoop.run()`` in
@ -47,7 +47,7 @@ class EvaluationLoop(DataLoaderLoop):
def __init__(self, verbose: bool = True) -> None:
super().__init__()
self.epoch_loop = EvaluationEpochLoop()
self.epoch_loop = _EvaluationEpochLoop()
self.verbose = verbose
self._results = _ResultCollection(training=False)
@ -304,7 +304,7 @@ class EvaluationLoop(DataLoaderLoop):
def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]:
for k, v in data.items():
if isinstance(v, dict):
for new_key in apply_to_collection(v, dict, EvaluationLoop._get_keys):
for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys):
yield (k, *new_key) # this need to be in parenthesis for older python versions
else:
yield k,
@ -317,13 +317,13 @@ class EvaluationLoop(DataLoaderLoop):
result = data[target_start]
if not rest:
return result
return EvaluationLoop._find_value(result, rest)
return _EvaluationLoop._find_value(result, rest)
@staticmethod
def _print_results(results: List[_OUT_DICT], stage: str) -> None:
# remove the dl idx suffix
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
metrics_paths = {k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys}
metrics_paths = {k for keys in apply_to_collection(results, dict, _EvaluationLoop._get_keys) for k in keys}
if not metrics_paths:
return
@ -341,7 +341,7 @@ class EvaluationLoop(DataLoaderLoop):
for result in results:
for metric, row in zip(metrics_paths, rows):
val = EvaluationLoop._find_value(result, metric)
val = _EvaluationLoop._find_value(result, metric)
if val is not None:
if isinstance(val, Tensor):
val = val.item() if val.numel() == 1 else val.tolist()

View File

@ -2,26 +2,26 @@ from typing import Any, List, Optional, Sequence, Union
from torch.utils.data import DataLoader
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.loops.dataloader.dataloader_loop import _DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop
from pytorch_lightning.loops.utilities import _set_sampler_epoch
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
class PredictionLoop(DataLoaderLoop):
class _PredictionLoop(_DataLoaderLoop):
"""Top-level loop where prediction starts.
It simply iterates over each predict dataloader from one to the next by calling ``PredictionEpochLoop.run()`` in its
``advance()`` method.
It simply iterates over each predict dataloader from one to the next by calling ``_PredictionEpochLoop.run()`` in
its ``advance()`` method.
"""
def __init__(self) -> None:
super().__init__()
self.predictions: List[List[Any]] = []
self.epoch_batch_indices: List[List[List[int]]] = [] # used by PredictionWriter
self.epoch_loop = PredictionEpochLoop()
self.epoch_loop = _PredictionEpochLoop()
self._results = None # for `trainer._results` access
self._return_predictions: bool = False

View File

@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop # noqa: F401
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop # noqa: F401
from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.epoch.evaluation_epoch_loop import _EvaluationEpochLoop # noqa: F401
from pytorch_lightning.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop # noqa: F401
from pytorch_lightning.loops.epoch.training_epoch_loop import _TrainingEpochLoop # noqa: F401

View File

@ -18,7 +18,7 @@ from typing import Any, Dict, Optional, Union
from torch.utils.data import DataLoader
from pytorch_lightning.loops.loop import Loop
from pytorch_lightning.loops.loop import _Loop
from pytorch_lightning.trainer.progress import BatchProgress
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.supporters import CombinedLoader
@ -33,7 +33,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
class EvaluationEpochLoop(Loop):
class _EvaluationEpochLoop(_Loop):
"""This is the loop performing the evaluation.
It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, Iterator, List, Tuple, Union
import torch
from lightning_fabric.utilities import move_data_to_device
from pytorch_lightning.loops.loop import Loop
from pytorch_lightning.loops.loop import _Loop
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.utilities.rank_zero import WarningCache
@ -12,7 +12,7 @@ from pytorch_lightning.utilities.rank_zero import WarningCache
warning_cache = WarningCache()
class PredictionEpochLoop(Loop):
class _PredictionEpochLoop(_Loop):
"""Loop performing prediction on arbitrary sequentially used dataloaders."""
def __init__(self) -> None:

View File

@ -21,7 +21,7 @@ from lightning_utilities.core.apply_func import apply_to_collection
import pytorch_lightning as pl
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.optimization import ManualOptimization, OptimizerLoop
from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop
from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
@ -39,7 +39,7 @@ _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
class TrainingEpochLoop(loops.Loop):
class _TrainingEpochLoop(loops._Loop):
"""
Iterates over all batches in the dataloader (one epoch) that the user returns in their
:meth:`~pytorch_lightning.core.module.LightningModule.train_dataloader` method.
@ -73,10 +73,10 @@ class TrainingEpochLoop(loops.Loop):
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()
self.optimizer_loop = OptimizerLoop()
self.manual_loop = ManualOptimization()
self.optimizer_loop = _OptimizerLoop()
self.manual_loop = _ManualOptimization()
self.val_loop = loops.EvaluationLoop(verbose=False)
self.val_loop = loops._EvaluationLoop(verbose=False)
self._results = _ResultCollection(training=True)
self._outputs: _OUTPUTS_TYPE = []

View File

@ -15,8 +15,8 @@ import logging
from typing import Any, Optional, Type
import pytorch_lightning as pl
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.loops import _Loop
from pytorch_lightning.loops.epoch import _TrainingEpochLoop
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
@ -31,7 +31,7 @@ from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signatu
log = logging.getLogger(__name__)
class FitLoop(Loop):
class _FitLoop(_Loop):
"""This loop is the top-level loop where training starts.
It simply counts the epochs and iterates from one to the next by calling ``TrainingEpochLoop.run()`` in its
@ -73,7 +73,7 @@ class FitLoop(Loop):
self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.epoch_loop = TrainingEpochLoop()
self.epoch_loop = _TrainingEpochLoop()
self.epoch_progress = Progress()
self._is_fresh_start_epoch: bool = True
@ -116,13 +116,13 @@ class FitLoop(Loop):
)
self.epoch_loop.max_steps = value
@Loop.restarting.setter
@_Loop.restarting.setter
def restarting(self, restarting: bool) -> None:
# if the last epoch completely finished, we are not actually restarting
values = self.epoch_progress.current.ready, self.epoch_progress.current.started
epoch_unfinished = any(v != self.epoch_progress.current.processed for v in values)
restarting = restarting and epoch_unfinished or self._iteration_based_training()
Loop.restarting.fset(self, restarting) # call the parent setter
_Loop.restarting.fset(self, restarting) # call the parent setter
@property
def prefetch_batches(self) -> int:

View File

@ -21,7 +21,7 @@ from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.imports import _fault_tolerant_training
class Loop:
class _Loop:
"""Basic Loops interface."""
def __init__(self) -> None:
@ -39,7 +39,7 @@ class Loop:
"""Connects this loop's trainer and its children."""
self._trainer = trainer
for v in self.__dict__.values():
if isinstance(v, Loop):
if isinstance(v, _Loop):
v.trainer = trainer
@property
@ -52,7 +52,7 @@ class Loop:
"""Connects this loop's restarting value and its children."""
self._restarting = restarting
for loop in vars(self).values():
if isinstance(loop, Loop):
if isinstance(loop, _Loop):
loop.restarting = restarting
def on_save_checkpoint(self) -> Dict:
@ -85,7 +85,7 @@ class Loop:
key = prefix + k
if isinstance(v, BaseProgress):
destination[key] = v.state_dict()
elif isinstance(v, Loop):
elif isinstance(v, _Loop):
v.state_dict(destination, key + ".")
elif ft_enabled and isinstance(v, _ResultCollection):
# sync / unsync metrics
@ -104,7 +104,7 @@ class Loop:
"""Loads the state of this loop and all its children."""
self._load_from_state_dict(state_dict.copy(), prefix, metrics)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
if isinstance(v, _Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".")
self.restarting = True

View File

@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401
from pytorch_lightning.loops.optimization.manual_loop import _ManualOptimization # noqa: F401
from pytorch_lightning.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401

View File

@ -19,7 +19,7 @@ from typing import Any, Dict, Optional
from torch import Tensor
from pytorch_lightning.core.optimizer import do_nothing_closure
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops import _Loop
from pytorch_lightning.loops.optimization.closure import OutputResult
from pytorch_lightning.loops.utilities import _build_training_step_kwargs
from pytorch_lightning.trainer.progress import Progress, ReadyCompletedTracker
@ -65,7 +65,7 @@ class ManualResult(OutputResult):
_OUTPUTS_TYPE = Dict[str, Any]
class ManualOptimization(Loop):
class _ManualOptimization(_Loop):
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
entirely in the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` and therefore the user is
responsible for back-propagating gradients and making calls to the optimizers.

View File

@ -22,7 +22,7 @@ from typing_extensions import OrderedDict
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops import _Loop
from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior, _build_training_step_kwargs
from pytorch_lightning.trainer.progress import OptimizationProgress
@ -146,7 +146,7 @@ class Closure(AbstractClosure[ClosureResult]):
_OUTPUTS_TYPE = Dict[int, Dict[str, Any]]
class OptimizerLoop(Loop):
class _OptimizerLoop(_Loop):
"""Iterates over one or multiple optimizers and for each one it calls the
:meth:`~pytorch_lightning.core.module.LightningModule.training_step` method with the batch, the current batch index
and the optimizer index if multiple optimizers are requested.

View File

@ -25,7 +25,7 @@ from torch.utils.data import DataLoader
import pytorch_lightning as pl
from lightning_fabric.utilities.warnings import PossibleUserWarning
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops import _Loop
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.trainer.progress import BaseProgress
@ -182,11 +182,11 @@ def _is_max_limit_reached(current: int, maximum: int = -1) -> bool:
return maximum != -1 and current >= maximum
def _reset_progress(loop: Loop) -> None:
def _reset_progress(loop: _Loop) -> None:
for v in vars(loop).values():
if isinstance(v, BaseProgress):
v.reset()
elif isinstance(v, Loop):
elif isinstance(v, _Loop):
_reset_progress(v)

View File

@ -54,9 +54,9 @@ from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.loggers import Logger
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops import PredictionLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.loops import _PredictionLoop, _TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import _EvaluationLoop
from pytorch_lightning.loops.fit_loop import _FitLoop
from pytorch_lightning.loops.utilities import _parse_loop_limits, _reset_progress
from pytorch_lightning.plugins import PLUGIN_INPUT, PrecisionPlugin
from pytorch_lightning.profilers import Profiler
@ -373,11 +373,11 @@ class Trainer:
self.tuner = Tuner(self)
# init loops
self.fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
self.fit_loop.epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
self.validate_loop = EvaluationLoop()
self.test_loop = EvaluationLoop()
self.predict_loop = PredictionLoop()
self.fit_loop = _FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
self.fit_loop.epoch_loop = _TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
self.validate_loop = _EvaluationLoop()
self.test_loop = _EvaluationLoop()
self.predict_loop = _PredictionLoop()
self.fit_loop.trainer = self
self.validate_loop.trainer = self
self.test_loop.trainer = self
@ -1939,7 +1939,7 @@ class Trainer:
return self.fit_loop.epoch_loop.batch_progress.is_last_batch
@property
def _evaluation_loop(self) -> EvaluationLoop:
def _evaluation_loop(self) -> _EvaluationLoop:
if self.state.fn == TrainerFn.FITTING:
return self.fit_loop.epoch_loop.val_loop
if self.state.fn == TrainerFn.VALIDATING:
@ -1949,7 +1949,7 @@ class Trainer:
raise RuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope")
@property
def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop]]:
def _active_loop(self) -> Optional[Union[_FitLoop, _EvaluationLoop, _PredictionLoop]]:
if self.training:
return self.fit_loop
if self.sanity_checking or self.evaluating:

View File

@ -11,7 +11,6 @@ if _is_pytorch_lightning_available():
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import Logger
from pytorch_lightning.loops import Loop
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.profilers import Profiler
@ -42,9 +41,6 @@ if __name__ == "__main__":
class BoringLogger(Logger):
pass
class BoringLoop(Loop):
pass
class BoringMetric(Metric):
pass

View File

@ -55,7 +55,6 @@ def test_introspection_lightning_overrides():
"Fabric",
"Logger",
"LightningModule",
"Loop",
"Metric",
"PrecisionPlugin",
"Trainer",

View File

@ -18,7 +18,7 @@ import pytest
from pytorch_lightning import LightningModule
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loops import TrainingEpochLoop
from pytorch_lightning.loops import _TrainingEpochLoop
from pytorch_lightning.trainer.trainer import Trainer
_out00 = {"loss": 0.0}
@ -43,7 +43,7 @@ class TestPrepareOutputs:
def prepare_outputs_training_epoch_end(self, batch_outputs, num_optimizers, automatic_optimization=True):
return self.prepare_outputs(
TrainingEpochLoop._prepare_outputs_training_epoch_end,
_TrainingEpochLoop._prepare_outputs_training_epoch_end,
batch_outputs,
num_optimizers,
automatic_optimization=automatic_optimization,
@ -51,7 +51,7 @@ class TestPrepareOutputs:
def prepare_outputs_training_batch_end(self, batch_outputs, num_optimizers, automatic_optimization=True):
return self.prepare_outputs(
TrainingEpochLoop._prepare_outputs_training_batch_end,
_TrainingEpochLoop._prepare_outputs_training_batch_end,
batch_outputs,
num_optimizers,
automatic_optimization=automatic_optimization,

View File

@ -24,7 +24,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
from tests_pytorch.helpers.runif import RunIf
@mock.patch("pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop._on_evaluation_epoch_end")
@mock.patch("pytorch_lightning.loops.dataloader.evaluation_loop._EvaluationLoop._on_evaluation_epoch_end")
def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):
"""Tests that `on_evaluation_epoch_end` is called for `on_validation_epoch_end` and `on_test_epoch_end`
hooks."""

View File

@ -15,19 +15,19 @@ import os
from unittest import mock
from unittest.mock import Mock
from pytorch_lightning.loops import FitLoop
from pytorch_lightning.loops import _FitLoop
from pytorch_lightning.trainer.trainer import Trainer
def test_loops_state_dict():
trainer = Trainer()
fit_loop = FitLoop()
fit_loop = _FitLoop()
fit_loop.trainer = trainer
state_dict = fit_loop.state_dict()
new_fit_loop = FitLoop()
new_fit_loop = _FitLoop()
new_fit_loop.trainer = trainer
new_fit_loop.load_state_dict(state_dict)

View File

@ -25,13 +25,13 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoad
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops import _Loop
from pytorch_lightning.trainer.progress import BaseProgress
from tests_pytorch.helpers.runif import RunIf
def test_restarting_loops_recursive():
class MyLoop(Loop):
class MyLoop(_Loop):
def __init__(self, loop=None):
super().__init__()
self.child = loop
@ -52,7 +52,7 @@ class CustomException(Exception):
def test_loop_restore():
class Simple(Loop):
class Simple(_Loop):
def __init__(self, dataset: Iterator):
super().__init__()
self.iteration_count = 0
@ -119,7 +119,7 @@ def test_loop_hierarchy():
class SimpleProgress(BaseProgress):
increment: int = 0
class Simple(Loop):
class Simple(_Loop):
def __init__(self, a):
super().__init__()
self.a = a

View File

@ -19,7 +19,7 @@ import torch
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loops import FitLoop
from pytorch_lightning.loops import _FitLoop
def test_outputs_format(tmpdir):
@ -141,7 +141,7 @@ def test_should_stop_mid_epoch(tmpdir):
def test_fit_loop_done_log_messages(caplog):
fit_loop = FitLoop(max_epochs=1)
fit_loop = _FitLoop(max_epochs=1)
trainer = Mock(spec=Trainer)
fit_loop.trainer = trainer

View File

@ -29,7 +29,7 @@ from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loops.dataloader import EvaluationLoop
from pytorch_lightning.loops.dataloader import _EvaluationLoop
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
@ -845,7 +845,7 @@ def test_native_print_results(monkeypatch, inputs, expected):
monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
with redirect_stdout(StringIO()) as out:
EvaluationLoop._print_results(*inputs)
_EvaluationLoop._print_results(*inputs)
expected = expected[1:] # remove the initial line break from the """ string
assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip()
@ -859,7 +859,7 @@ def test_native_print_results_encodings(monkeypatch, encoding):
out = mock.Mock()
out.encoding = encoding
with redirect_stdout(out) as out:
EvaluationLoop._print_results(*inputs0)
_EvaluationLoop._print_results(*inputs0)
# Attempt to encode everything the file is told to write with the given encoding
for call_ in out.method_calls:
@ -937,7 +937,7 @@ expected3 = """
def test_rich_print_results(inputs, expected):
console = get_console()
with console.capture() as capture:
EvaluationLoop._print_results(*inputs)
_EvaluationLoop._print_results(*inputs)
expected = expected[1:] # remove the initial line break from the """ string
assert capture.get() == expected.lstrip()

View File

@ -478,25 +478,25 @@ def test_fetching_is_profiled():
# validation
for i in range(2):
key = f"[EvaluationEpochLoop].val_dataloader_idx_{i}_next"
key = f"[_EvaluationEpochLoop].val_dataloader_idx_{i}_next"
assert key in profiler.recorded_durations
durations = profiler.recorded_durations[key]
assert len(durations) == fast_dev_run
assert all(d > 0 for d in durations)
# training
key = "[TrainingEpochLoop].train_dataloader_next"
key = "[_TrainingEpochLoop].train_dataloader_next"
assert key in profiler.recorded_durations
durations = profiler.recorded_durations[key]
assert len(durations) == fast_dev_run
assert all(d > 0 for d in durations)
# test
key = "[EvaluationEpochLoop].val_dataloader_idx_0_next"
key = "[_EvaluationEpochLoop].val_dataloader_idx_0_next"
assert key in profiler.recorded_durations
durations = profiler.recorded_durations[key]
assert len(durations) == fast_dev_run
assert all(d > 0 for d in durations)
# predict
key = "[PredictionEpochLoop].predict_dataloader_idx_0_next"
key = "[_PredictionEpochLoop].predict_dataloader_idx_0_next"
assert key in profiler.recorded_durations
durations = profiler.recorded_durations[key]
assert len(durations) == fast_dev_run
@ -524,7 +524,7 @@ def test_fetching_is_profiled():
profiler = trainer.profiler
assert isinstance(profiler, SimpleProfiler)
key = "[TrainingEpochLoop].train_dataloader_next"
key = "[_TrainingEpochLoop].train_dataloader_next"
assert key in profiler.recorded_durations
durations = profiler.recorded_durations[key]
assert len(durations) == 2 # 2 polls in training_step