Mark the loop classes as protected (#16445)
This commit is contained in:
parent
48e1c9c99c
commit
5891cdc940
|
@ -313,9 +313,9 @@ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:
|
|||
|
||||
</details>
|
||||
|
||||
### Pro-level control of training loops (advanced users)
|
||||
### Pro-level control of optimization (advanced users)
|
||||
|
||||
For complex/professional level work, you have optional full control of the training loop and optimizers.
|
||||
For complex/professional level work, you have optional full control of the optimizers.
|
||||
|
||||
```python
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
|
|
|
@ -100,66 +100,6 @@ loggers
|
|||
tensorboard
|
||||
wandb
|
||||
|
||||
loops
|
||||
^^^^^
|
||||
|
||||
Base Classes
|
||||
""""""""""""
|
||||
|
||||
.. currentmodule:: pytorch_lightning.loops
|
||||
|
||||
.. autosummary::
|
||||
:toctree: api
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
~dataloader.dataloader_loop.DataLoaderLoop
|
||||
~loop.Loop
|
||||
|
||||
Training
|
||||
""""""""
|
||||
|
||||
.. currentmodule:: pytorch_lightning.loops
|
||||
|
||||
.. autosummary::
|
||||
:toctree: api
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
~epoch.TrainingEpochLoop
|
||||
FitLoop
|
||||
~optimization.ManualOptimization
|
||||
~optimization.OptimizerLoop
|
||||
|
||||
|
||||
Validation and Testing
|
||||
""""""""""""""""""""""
|
||||
|
||||
.. currentmodule:: pytorch_lightning.loops
|
||||
|
||||
.. autosummary::
|
||||
:toctree: api
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
~epoch.EvaluationEpochLoop
|
||||
~dataloader.EvaluationLoop
|
||||
|
||||
|
||||
Prediction
|
||||
""""""""""
|
||||
|
||||
.. currentmodule:: pytorch_lightning.loops
|
||||
|
||||
.. autosummary::
|
||||
:toctree: api
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
~epoch.PredictionEpochLoop
|
||||
~dataloader.PredictionLoop
|
||||
|
||||
|
||||
plugins
|
||||
^^^^^^^
|
||||
|
||||
|
|
|
@ -306,7 +306,7 @@ If you have multiple lines of code with similar functionalities, you can use cal
|
|||
Use a raw PyTorch loop
|
||||
======================
|
||||
|
||||
For certain types of work at the bleeding-edge of research, Lightning offers experts full control of their training loops in various ways.
|
||||
For certain types of work at the bleeding-edge of research, Lightning offers experts full control of optimization or the training loop in various ways.
|
||||
|
||||
.. raw:: html
|
||||
|
||||
|
@ -333,15 +333,6 @@ For certain types of work at the bleeding-edge of research, Lightning offers exp
|
|||
:image_height: 220px
|
||||
:height: 320
|
||||
|
||||
.. displayitem::
|
||||
:header: Loops
|
||||
:description: Enable meta-learning, reinforcement learning, GANs with full control.
|
||||
:col_css: col-md-4
|
||||
:image_center: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/loops.png
|
||||
:button_link: ../extensions/loops.html
|
||||
:image_height: 220px
|
||||
:height: 320
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
|
|
|
@ -247,10 +247,6 @@ class LightningAcceleratorVisitor(LightningVisitor):
|
|||
class_name = "Accelerator"
|
||||
|
||||
|
||||
class LightningLoopVisitor(LightningVisitor):
|
||||
class_name = "Loop"
|
||||
|
||||
|
||||
class TorchMetricVisitor(LightningVisitor):
|
||||
class_name = "Metric"
|
||||
|
||||
|
@ -290,7 +286,6 @@ class Scanner:
|
|||
LightningPrecisionPluginVisitor,
|
||||
LightningAcceleratorVisitor,
|
||||
LightningLoggerVisitor,
|
||||
LightningLoopVisitor,
|
||||
TorchMetricVisitor,
|
||||
FabricVisitor,
|
||||
LightningProfilerVisitor,
|
||||
|
|
|
@ -113,6 +113,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
|
||||
* Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
|
||||
* Removed the default `Loop.run()` implementation ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
|
||||
* The loop classes are now marked as protected ([#16445](https://github.com/Lightning-AI/lightning/pull/16445))
|
||||
|
||||
- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
|
||||
* Removed the `LightningModule.truncated_bptt_steps` attribute
|
||||
|
|
|
@ -300,9 +300,9 @@ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:
|
|||
|
||||
</details>
|
||||
|
||||
### Pro-level control of training loops (advanced users)
|
||||
### Pro-level control of optimization (advanced users)
|
||||
|
||||
For complex/professional level work, you have optional full control of the training loop and optimizers.
|
||||
For complex/professional level work, you have optional full control of the optimizers.
|
||||
|
||||
```python
|
||||
class LitAutoEncoder(pl.LightningModule):
|
||||
|
|
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.loops.loop import Loop # noqa: F401 isort: skip (avoids circular imports)
|
||||
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
|
||||
from pytorch_lightning.loops.optimization import ManualOptimization, OptimizerLoop # noqa: F401
|
||||
from pytorch_lightning.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports)
|
||||
from pytorch_lightning.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.fit_loop import _FitLoop # noqa: F401
|
||||
from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop # noqa: F401
|
||||
|
|
|
@ -12,6 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401
|
||||
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop # noqa: F401
|
||||
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop # noqa: F401
|
||||
from pytorch_lightning.loops.dataloader.dataloader_loop import _DataLoaderLoop # noqa: F401
|
||||
from pytorch_lightning.loops.dataloader.evaluation_loop import _EvaluationLoop # noqa: F401
|
||||
from pytorch_lightning.loops.dataloader.prediction_loop import _PredictionLoop # noqa: F401
|
||||
|
|
|
@ -17,11 +17,11 @@ from typing import Sequence
|
|||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.loops.loop import Loop
|
||||
from pytorch_lightning.loops.loop import _Loop
|
||||
from pytorch_lightning.trainer.progress import DataLoaderProgress
|
||||
|
||||
|
||||
class DataLoaderLoop(Loop):
|
||||
class _DataLoaderLoop(_Loop):
|
||||
"""Base class to loop over all dataloaders."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
|
|
@ -23,8 +23,8 @@ from torch.utils.data.dataloader import DataLoader
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
|
||||
from pytorch_lightning.loops.dataloader import DataLoaderLoop
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
|
||||
from pytorch_lightning.loops.dataloader import _DataLoaderLoop
|
||||
from pytorch_lightning.loops.epoch import _EvaluationEpochLoop
|
||||
from pytorch_lightning.loops.utilities import _set_sampler_epoch
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
|
@ -38,7 +38,7 @@ if _RICH_AVAILABLE:
|
|||
from rich.table import Column, Table
|
||||
|
||||
|
||||
class EvaluationLoop(DataLoaderLoop):
|
||||
class _EvaluationLoop(_DataLoaderLoop):
|
||||
"""Top-level loop where validation/testing starts.
|
||||
|
||||
It simply iterates over each evaluation dataloader from one to the next by calling ``EvaluationEpochLoop.run()`` in
|
||||
|
@ -47,7 +47,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
|
||||
def __init__(self, verbose: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.epoch_loop = EvaluationEpochLoop()
|
||||
self.epoch_loop = _EvaluationEpochLoop()
|
||||
self.verbose = verbose
|
||||
|
||||
self._results = _ResultCollection(training=False)
|
||||
|
@ -304,7 +304,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]:
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
for new_key in apply_to_collection(v, dict, EvaluationLoop._get_keys):
|
||||
for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys):
|
||||
yield (k, *new_key) # this need to be in parenthesis for older python versions
|
||||
else:
|
||||
yield k,
|
||||
|
@ -317,13 +317,13 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
result = data[target_start]
|
||||
if not rest:
|
||||
return result
|
||||
return EvaluationLoop._find_value(result, rest)
|
||||
return _EvaluationLoop._find_value(result, rest)
|
||||
|
||||
@staticmethod
|
||||
def _print_results(results: List[_OUT_DICT], stage: str) -> None:
|
||||
# remove the dl idx suffix
|
||||
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
|
||||
metrics_paths = {k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys}
|
||||
metrics_paths = {k for keys in apply_to_collection(results, dict, _EvaluationLoop._get_keys) for k in keys}
|
||||
if not metrics_paths:
|
||||
return
|
||||
|
||||
|
@ -341,7 +341,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
|
||||
for result in results:
|
||||
for metric, row in zip(metrics_paths, rows):
|
||||
val = EvaluationLoop._find_value(result, metric)
|
||||
val = _EvaluationLoop._find_value(result, metric)
|
||||
if val is not None:
|
||||
if isinstance(val, Tensor):
|
||||
val = val.item() if val.numel() == 1 else val.tolist()
|
||||
|
|
|
@ -2,26 +2,26 @@ from typing import Any, List, Optional, Sequence, Union
|
|||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
|
||||
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
|
||||
from pytorch_lightning.loops.dataloader.dataloader_loop import _DataLoaderLoop
|
||||
from pytorch_lightning.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop
|
||||
from pytorch_lightning.loops.utilities import _set_sampler_epoch
|
||||
from pytorch_lightning.strategies import DDPSpawnStrategy
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
|
||||
|
||||
|
||||
class PredictionLoop(DataLoaderLoop):
|
||||
class _PredictionLoop(_DataLoaderLoop):
|
||||
"""Top-level loop where prediction starts.
|
||||
|
||||
It simply iterates over each predict dataloader from one to the next by calling ``PredictionEpochLoop.run()`` in its
|
||||
``advance()`` method.
|
||||
It simply iterates over each predict dataloader from one to the next by calling ``_PredictionEpochLoop.run()`` in
|
||||
its ``advance()`` method.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.predictions: List[List[Any]] = []
|
||||
self.epoch_batch_indices: List[List[List[int]]] = [] # used by PredictionWriter
|
||||
self.epoch_loop = PredictionEpochLoop()
|
||||
self.epoch_loop = _PredictionEpochLoop()
|
||||
|
||||
self._results = None # for `trainer._results` access
|
||||
self._return_predictions: bool = False
|
||||
|
|
|
@ -12,6 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch.evaluation_epoch_loop import _EvaluationEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch.training_epoch_loop import _TrainingEpochLoop # noqa: F401
|
||||
|
|
|
@ -18,7 +18,7 @@ from typing import Any, Dict, Optional, Union
|
|||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.loops.loop import Loop
|
||||
from pytorch_lightning.loops.loop import _Loop
|
||||
from pytorch_lightning.trainer.progress import BatchProgress
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
|
@ -33,7 +33,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
|
|||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
||||
|
||||
|
||||
class EvaluationEpochLoop(Loop):
|
||||
class _EvaluationEpochLoop(_Loop):
|
||||
"""This is the loop performing the evaluation.
|
||||
|
||||
It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, Iterator, List, Tuple, Union
|
|||
import torch
|
||||
|
||||
from lightning_fabric.utilities import move_data_to_device
|
||||
from pytorch_lightning.loops.loop import Loop
|
||||
from pytorch_lightning.loops.loop import _Loop
|
||||
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
|
||||
from pytorch_lightning.trainer.progress import Progress
|
||||
from pytorch_lightning.utilities.rank_zero import WarningCache
|
||||
|
@ -12,7 +12,7 @@ from pytorch_lightning.utilities.rank_zero import WarningCache
|
|||
warning_cache = WarningCache()
|
||||
|
||||
|
||||
class PredictionEpochLoop(Loop):
|
||||
class _PredictionEpochLoop(_Loop):
|
||||
"""Loop performing prediction on arbitrary sequentially used dataloaders."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
|
|
@ -21,7 +21,7 @@ from lightning_utilities.core.apply_func import apply_to_collection
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import loops # import as loops to avoid circular imports
|
||||
from pytorch_lightning.loops.optimization import ManualOptimization, OptimizerLoop
|
||||
from pytorch_lightning.loops.optimization import _ManualOptimization, _OptimizerLoop
|
||||
from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
|
||||
|
@ -39,7 +39,7 @@ _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_
|
|||
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
|
||||
|
||||
|
||||
class TrainingEpochLoop(loops.Loop):
|
||||
class _TrainingEpochLoop(loops._Loop):
|
||||
"""
|
||||
Iterates over all batches in the dataloader (one epoch) that the user returns in their
|
||||
:meth:`~pytorch_lightning.core.module.LightningModule.train_dataloader` method.
|
||||
|
@ -73,10 +73,10 @@ class TrainingEpochLoop(loops.Loop):
|
|||
self.batch_progress = BatchProgress()
|
||||
self.scheduler_progress = SchedulerProgress()
|
||||
|
||||
self.optimizer_loop = OptimizerLoop()
|
||||
self.manual_loop = ManualOptimization()
|
||||
self.optimizer_loop = _OptimizerLoop()
|
||||
self.manual_loop = _ManualOptimization()
|
||||
|
||||
self.val_loop = loops.EvaluationLoop(verbose=False)
|
||||
self.val_loop = loops._EvaluationLoop(verbose=False)
|
||||
|
||||
self._results = _ResultCollection(training=True)
|
||||
self._outputs: _OUTPUTS_TYPE = []
|
||||
|
|
|
@ -15,8 +15,8 @@ import logging
|
|||
from typing import Any, Optional, Type
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops.epoch import TrainingEpochLoop
|
||||
from pytorch_lightning.loops import _Loop
|
||||
from pytorch_lightning.loops.epoch import _TrainingEpochLoop
|
||||
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
|
@ -31,7 +31,7 @@ from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signatu
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FitLoop(Loop):
|
||||
class _FitLoop(_Loop):
|
||||
"""This loop is the top-level loop where training starts.
|
||||
|
||||
It simply counts the epochs and iterates from one to the next by calling ``TrainingEpochLoop.run()`` in its
|
||||
|
@ -73,7 +73,7 @@ class FitLoop(Loop):
|
|||
|
||||
self.max_epochs = max_epochs
|
||||
self.min_epochs = min_epochs
|
||||
self.epoch_loop = TrainingEpochLoop()
|
||||
self.epoch_loop = _TrainingEpochLoop()
|
||||
self.epoch_progress = Progress()
|
||||
|
||||
self._is_fresh_start_epoch: bool = True
|
||||
|
@ -116,13 +116,13 @@ class FitLoop(Loop):
|
|||
)
|
||||
self.epoch_loop.max_steps = value
|
||||
|
||||
@Loop.restarting.setter
|
||||
@_Loop.restarting.setter
|
||||
def restarting(self, restarting: bool) -> None:
|
||||
# if the last epoch completely finished, we are not actually restarting
|
||||
values = self.epoch_progress.current.ready, self.epoch_progress.current.started
|
||||
epoch_unfinished = any(v != self.epoch_progress.current.processed for v in values)
|
||||
restarting = restarting and epoch_unfinished or self._iteration_based_training()
|
||||
Loop.restarting.fset(self, restarting) # call the parent setter
|
||||
_Loop.restarting.fset(self, restarting) # call the parent setter
|
||||
|
||||
@property
|
||||
def prefetch_batches(self) -> int:
|
||||
|
|
|
@ -21,7 +21,7 @@ from pytorch_lightning.trainer.progress import BaseProgress
|
|||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
|
||||
class Loop:
|
||||
class _Loop:
|
||||
"""Basic Loops interface."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -39,7 +39,7 @@ class Loop:
|
|||
"""Connects this loop's trainer and its children."""
|
||||
self._trainer = trainer
|
||||
for v in self.__dict__.values():
|
||||
if isinstance(v, Loop):
|
||||
if isinstance(v, _Loop):
|
||||
v.trainer = trainer
|
||||
|
||||
@property
|
||||
|
@ -52,7 +52,7 @@ class Loop:
|
|||
"""Connects this loop's restarting value and its children."""
|
||||
self._restarting = restarting
|
||||
for loop in vars(self).values():
|
||||
if isinstance(loop, Loop):
|
||||
if isinstance(loop, _Loop):
|
||||
loop.restarting = restarting
|
||||
|
||||
def on_save_checkpoint(self) -> Dict:
|
||||
|
@ -85,7 +85,7 @@ class Loop:
|
|||
key = prefix + k
|
||||
if isinstance(v, BaseProgress):
|
||||
destination[key] = v.state_dict()
|
||||
elif isinstance(v, Loop):
|
||||
elif isinstance(v, _Loop):
|
||||
v.state_dict(destination, key + ".")
|
||||
elif ft_enabled and isinstance(v, _ResultCollection):
|
||||
# sync / unsync metrics
|
||||
|
@ -104,7 +104,7 @@ class Loop:
|
|||
"""Loads the state of this loop and all its children."""
|
||||
self._load_from_state_dict(state_dict.copy(), prefix, metrics)
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, Loop):
|
||||
if isinstance(v, _Loop):
|
||||
v.load_state_dict(state_dict.copy(), prefix + k + ".")
|
||||
self.restarting = True
|
||||
|
||||
|
|
|
@ -12,5 +12,5 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401
|
||||
from pytorch_lightning.loops.optimization.manual_loop import _ManualOptimization # noqa: F401
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import _OptimizerLoop # noqa: F401
|
||||
|
|
|
@ -19,7 +19,7 @@ from typing import Any, Dict, Optional
|
|||
from torch import Tensor
|
||||
|
||||
from pytorch_lightning.core.optimizer import do_nothing_closure
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops import _Loop
|
||||
from pytorch_lightning.loops.optimization.closure import OutputResult
|
||||
from pytorch_lightning.loops.utilities import _build_training_step_kwargs
|
||||
from pytorch_lightning.trainer.progress import Progress, ReadyCompletedTracker
|
||||
|
@ -65,7 +65,7 @@ class ManualResult(OutputResult):
|
|||
_OUTPUTS_TYPE = Dict[str, Any]
|
||||
|
||||
|
||||
class ManualOptimization(Loop):
|
||||
class _ManualOptimization(_Loop):
|
||||
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
|
||||
entirely in the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` and therefore the user is
|
||||
responsible for back-propagating gradients and making calls to the optimizers.
|
||||
|
|
|
@ -22,7 +22,7 @@ from typing_extensions import OrderedDict
|
|||
|
||||
from pytorch_lightning.accelerators import TPUAccelerator
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops import _Loop
|
||||
from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult
|
||||
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior, _build_training_step_kwargs
|
||||
from pytorch_lightning.trainer.progress import OptimizationProgress
|
||||
|
@ -146,7 +146,7 @@ class Closure(AbstractClosure[ClosureResult]):
|
|||
_OUTPUTS_TYPE = Dict[int, Dict[str, Any]]
|
||||
|
||||
|
||||
class OptimizerLoop(Loop):
|
||||
class _OptimizerLoop(_Loop):
|
||||
"""Iterates over one or multiple optimizers and for each one it calls the
|
||||
:meth:`~pytorch_lightning.core.module.LightningModule.training_step` method with the batch, the current batch index
|
||||
and the optimizer index if multiple optimizers are requested.
|
||||
|
|
|
@ -25,7 +25,7 @@ from torch.utils.data import DataLoader
|
|||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.warnings import PossibleUserWarning
|
||||
from pytorch_lightning.callbacks.timer import Timer
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops import _Loop
|
||||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
|
@ -182,11 +182,11 @@ def _is_max_limit_reached(current: int, maximum: int = -1) -> bool:
|
|||
return maximum != -1 and current >= maximum
|
||||
|
||||
|
||||
def _reset_progress(loop: Loop) -> None:
|
||||
def _reset_progress(loop: _Loop) -> None:
|
||||
for v in vars(loop).values():
|
||||
if isinstance(v, BaseProgress):
|
||||
v.reset()
|
||||
elif isinstance(v, Loop):
|
||||
elif isinstance(v, _Loop):
|
||||
_reset_progress(v)
|
||||
|
||||
|
||||
|
|
|
@ -54,9 +54,9 @@ from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
|||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.loggers import Logger
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.loops import PredictionLoop, TrainingEpochLoop
|
||||
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop
|
||||
from pytorch_lightning.loops import _PredictionLoop, _TrainingEpochLoop
|
||||
from pytorch_lightning.loops.dataloader.evaluation_loop import _EvaluationLoop
|
||||
from pytorch_lightning.loops.fit_loop import _FitLoop
|
||||
from pytorch_lightning.loops.utilities import _parse_loop_limits, _reset_progress
|
||||
from pytorch_lightning.plugins import PLUGIN_INPUT, PrecisionPlugin
|
||||
from pytorch_lightning.profilers import Profiler
|
||||
|
@ -373,11 +373,11 @@ class Trainer:
|
|||
self.tuner = Tuner(self)
|
||||
|
||||
# init loops
|
||||
self.fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
|
||||
self.fit_loop.epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
|
||||
self.validate_loop = EvaluationLoop()
|
||||
self.test_loop = EvaluationLoop()
|
||||
self.predict_loop = PredictionLoop()
|
||||
self.fit_loop = _FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
|
||||
self.fit_loop.epoch_loop = _TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
|
||||
self.validate_loop = _EvaluationLoop()
|
||||
self.test_loop = _EvaluationLoop()
|
||||
self.predict_loop = _PredictionLoop()
|
||||
self.fit_loop.trainer = self
|
||||
self.validate_loop.trainer = self
|
||||
self.test_loop.trainer = self
|
||||
|
@ -1939,7 +1939,7 @@ class Trainer:
|
|||
return self.fit_loop.epoch_loop.batch_progress.is_last_batch
|
||||
|
||||
@property
|
||||
def _evaluation_loop(self) -> EvaluationLoop:
|
||||
def _evaluation_loop(self) -> _EvaluationLoop:
|
||||
if self.state.fn == TrainerFn.FITTING:
|
||||
return self.fit_loop.epoch_loop.val_loop
|
||||
if self.state.fn == TrainerFn.VALIDATING:
|
||||
|
@ -1949,7 +1949,7 @@ class Trainer:
|
|||
raise RuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope")
|
||||
|
||||
@property
|
||||
def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop]]:
|
||||
def _active_loop(self) -> Optional[Union[_FitLoop, _EvaluationLoop, _PredictionLoop]]:
|
||||
if self.training:
|
||||
return self.fit_loop
|
||||
if self.sanity_checking or self.evaluating:
|
||||
|
|
|
@ -11,7 +11,6 @@ if _is_pytorch_lightning_available():
|
|||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.loggers import Logger
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.plugins import PrecisionPlugin
|
||||
from pytorch_lightning.profilers import Profiler
|
||||
|
||||
|
@ -42,9 +41,6 @@ if __name__ == "__main__":
|
|||
class BoringLogger(Logger):
|
||||
pass
|
||||
|
||||
class BoringLoop(Loop):
|
||||
pass
|
||||
|
||||
class BoringMetric(Metric):
|
||||
pass
|
||||
|
||||
|
|
|
@ -55,7 +55,6 @@ def test_introspection_lightning_overrides():
|
|||
"Fabric",
|
||||
"Logger",
|
||||
"LightningModule",
|
||||
"Loop",
|
||||
"Metric",
|
||||
"PrecisionPlugin",
|
||||
"Trainer",
|
||||
|
|
|
@ -18,7 +18,7 @@ import pytest
|
|||
|
||||
from pytorch_lightning import LightningModule
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.loops import TrainingEpochLoop
|
||||
from pytorch_lightning.loops import _TrainingEpochLoop
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
_out00 = {"loss": 0.0}
|
||||
|
@ -43,7 +43,7 @@ class TestPrepareOutputs:
|
|||
|
||||
def prepare_outputs_training_epoch_end(self, batch_outputs, num_optimizers, automatic_optimization=True):
|
||||
return self.prepare_outputs(
|
||||
TrainingEpochLoop._prepare_outputs_training_epoch_end,
|
||||
_TrainingEpochLoop._prepare_outputs_training_epoch_end,
|
||||
batch_outputs,
|
||||
num_optimizers,
|
||||
automatic_optimization=automatic_optimization,
|
||||
|
@ -51,7 +51,7 @@ class TestPrepareOutputs:
|
|||
|
||||
def prepare_outputs_training_batch_end(self, batch_outputs, num_optimizers, automatic_optimization=True):
|
||||
return self.prepare_outputs(
|
||||
TrainingEpochLoop._prepare_outputs_training_batch_end,
|
||||
_TrainingEpochLoop._prepare_outputs_training_batch_end,
|
||||
batch_outputs,
|
||||
num_optimizers,
|
||||
automatic_optimization=automatic_optimization,
|
||||
|
|
|
@ -24,7 +24,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
|
|||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop._on_evaluation_epoch_end")
|
||||
@mock.patch("pytorch_lightning.loops.dataloader.evaluation_loop._EvaluationLoop._on_evaluation_epoch_end")
|
||||
def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):
|
||||
"""Tests that `on_evaluation_epoch_end` is called for `on_validation_epoch_end` and `on_test_epoch_end`
|
||||
hooks."""
|
||||
|
|
|
@ -15,19 +15,19 @@ import os
|
|||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
from pytorch_lightning.loops import FitLoop
|
||||
from pytorch_lightning.loops import _FitLoop
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
|
||||
def test_loops_state_dict():
|
||||
trainer = Trainer()
|
||||
|
||||
fit_loop = FitLoop()
|
||||
fit_loop = _FitLoop()
|
||||
|
||||
fit_loop.trainer = trainer
|
||||
state_dict = fit_loop.state_dict()
|
||||
|
||||
new_fit_loop = FitLoop()
|
||||
new_fit_loop = _FitLoop()
|
||||
new_fit_loop.trainer = trainer
|
||||
|
||||
new_fit_loop.load_state_dict(state_dict)
|
||||
|
|
|
@ -25,13 +25,13 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoad
|
|||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops import _Loop
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
def test_restarting_loops_recursive():
|
||||
class MyLoop(Loop):
|
||||
class MyLoop(_Loop):
|
||||
def __init__(self, loop=None):
|
||||
super().__init__()
|
||||
self.child = loop
|
||||
|
@ -52,7 +52,7 @@ class CustomException(Exception):
|
|||
|
||||
|
||||
def test_loop_restore():
|
||||
class Simple(Loop):
|
||||
class Simple(_Loop):
|
||||
def __init__(self, dataset: Iterator):
|
||||
super().__init__()
|
||||
self.iteration_count = 0
|
||||
|
@ -119,7 +119,7 @@ def test_loop_hierarchy():
|
|||
class SimpleProgress(BaseProgress):
|
||||
increment: int = 0
|
||||
|
||||
class Simple(Loop):
|
||||
class Simple(_Loop):
|
||||
def __init__(self, a):
|
||||
super().__init__()
|
||||
self.a = a
|
||||
|
|
|
@ -19,7 +19,7 @@ import torch
|
|||
|
||||
from pytorch_lightning import seed_everything, Trainer
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.loops import FitLoop
|
||||
from pytorch_lightning.loops import _FitLoop
|
||||
|
||||
|
||||
def test_outputs_format(tmpdir):
|
||||
|
@ -141,7 +141,7 @@ def test_should_stop_mid_epoch(tmpdir):
|
|||
|
||||
|
||||
def test_fit_loop_done_log_messages(caplog):
|
||||
fit_loop = FitLoop(max_epochs=1)
|
||||
fit_loop = _FitLoop(max_epochs=1)
|
||||
trainer = Mock(spec=Trainer)
|
||||
fit_loop.trainer = trainer
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ from pytorch_lightning import callbacks, Trainer
|
|||
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.loops.dataloader import EvaluationLoop
|
||||
from pytorch_lightning.loops.dataloader import _EvaluationLoop
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
|
||||
|
@ -845,7 +845,7 @@ def test_native_print_results(monkeypatch, inputs, expected):
|
|||
monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
|
||||
|
||||
with redirect_stdout(StringIO()) as out:
|
||||
EvaluationLoop._print_results(*inputs)
|
||||
_EvaluationLoop._print_results(*inputs)
|
||||
expected = expected[1:] # remove the initial line break from the """ string
|
||||
assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip()
|
||||
|
||||
|
@ -859,7 +859,7 @@ def test_native_print_results_encodings(monkeypatch, encoding):
|
|||
out = mock.Mock()
|
||||
out.encoding = encoding
|
||||
with redirect_stdout(out) as out:
|
||||
EvaluationLoop._print_results(*inputs0)
|
||||
_EvaluationLoop._print_results(*inputs0)
|
||||
|
||||
# Attempt to encode everything the file is told to write with the given encoding
|
||||
for call_ in out.method_calls:
|
||||
|
@ -937,7 +937,7 @@ expected3 = """
|
|||
def test_rich_print_results(inputs, expected):
|
||||
console = get_console()
|
||||
with console.capture() as capture:
|
||||
EvaluationLoop._print_results(*inputs)
|
||||
_EvaluationLoop._print_results(*inputs)
|
||||
expected = expected[1:] # remove the initial line break from the """ string
|
||||
assert capture.get() == expected.lstrip()
|
||||
|
||||
|
|
|
@ -478,25 +478,25 @@ def test_fetching_is_profiled():
|
|||
|
||||
# validation
|
||||
for i in range(2):
|
||||
key = f"[EvaluationEpochLoop].val_dataloader_idx_{i}_next"
|
||||
key = f"[_EvaluationEpochLoop].val_dataloader_idx_{i}_next"
|
||||
assert key in profiler.recorded_durations
|
||||
durations = profiler.recorded_durations[key]
|
||||
assert len(durations) == fast_dev_run
|
||||
assert all(d > 0 for d in durations)
|
||||
# training
|
||||
key = "[TrainingEpochLoop].train_dataloader_next"
|
||||
key = "[_TrainingEpochLoop].train_dataloader_next"
|
||||
assert key in profiler.recorded_durations
|
||||
durations = profiler.recorded_durations[key]
|
||||
assert len(durations) == fast_dev_run
|
||||
assert all(d > 0 for d in durations)
|
||||
# test
|
||||
key = "[EvaluationEpochLoop].val_dataloader_idx_0_next"
|
||||
key = "[_EvaluationEpochLoop].val_dataloader_idx_0_next"
|
||||
assert key in profiler.recorded_durations
|
||||
durations = profiler.recorded_durations[key]
|
||||
assert len(durations) == fast_dev_run
|
||||
assert all(d > 0 for d in durations)
|
||||
# predict
|
||||
key = "[PredictionEpochLoop].predict_dataloader_idx_0_next"
|
||||
key = "[_PredictionEpochLoop].predict_dataloader_idx_0_next"
|
||||
assert key in profiler.recorded_durations
|
||||
durations = profiler.recorded_durations[key]
|
||||
assert len(durations) == fast_dev_run
|
||||
|
@ -524,7 +524,7 @@ def test_fetching_is_profiled():
|
|||
profiler = trainer.profiler
|
||||
assert isinstance(profiler, SimpleProfiler)
|
||||
|
||||
key = "[TrainingEpochLoop].train_dataloader_next"
|
||||
key = "[_TrainingEpochLoop].train_dataloader_next"
|
||||
assert key in profiler.recorded_durations
|
||||
durations = profiler.recorded_durations[key]
|
||||
assert len(durations) == 2 # 2 polls in training_step
|
||||
|
|
Loading…
Reference in New Issue