From fd9a3803b87762b38cccb38253e547bd290f9f58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 19 Jan 2023 13:49:25 +0100 Subject: [PATCH] Loop flattening: remove the default `.run()` implementation (#16427) --- src/pytorch_lightning/CHANGELOG.md | 1 + .../loops/dataloader/dataloader_loop.py | 4 +- .../loops/dataloader/evaluation_loop.py | 23 ++++-- .../loops/dataloader/prediction_loop.py | 20 ++++- .../loops/epoch/evaluation_epoch_loop.py | 24 ++++-- .../loops/epoch/prediction_epoch_loop.py | 30 +++++-- .../loops/epoch/training_epoch_loop.py | 15 +++- src/pytorch_lightning/loops/fit_loop.py | 18 +++- src/pytorch_lightning/loops/loop.py | 82 ++----------------- .../loops/optimization/manual_loop.py | 16 +++- .../loops/optimization/optimizer_loop.py | 18 +++- src/pytorch_lightning/trainer/trainer.py | 9 +- tests/tests_pytorch/loops/test_loops.py | 28 ++++++- 13 files changed, 174 insertions(+), 114 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 664b2be831..a4c23a95c6 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -90,6 +90,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Removed `Loop.replace()` ([#16361](https://github.com/Lightning-AI/lightning/pull/16361)) * Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) * Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) + * Removed the default `Loop.run()` implementation ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) - Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172)) * Removed the `LightningModule.truncated_bptt_steps` attribute diff --git a/src/pytorch_lightning/loops/dataloader/dataloader_loop.py b/src/pytorch_lightning/loops/dataloader/dataloader_loop.py index 6c83b7fbcb..a313ff28df 100644 --- a/src/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/src/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import abstractmethod -from typing import Any, Sequence +from typing import Sequence from torch.utils.data import DataLoader @@ -60,7 +60,7 @@ class DataLoaderLoop(Loop): else: self.dataloader_progress.reset_on_restart() - def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + def on_advance_start(self) -> None: self.dataloader_progress.increment_ready() def on_advance_end(self) -> None: diff --git a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py index f2d840590e..15b4de4cf4 100644 --- a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -89,6 +89,22 @@ class EvaluationLoop(DataLoaderLoop): max_batches = self._get_max_batches() return sum(max_batches) == 0 + def run(self) -> List[_OUT_DICT]: + if self.skip: + return [] + self.reset() + self.on_run_start() + while not self.done: + try: + self.on_advance_start() + self.advance() + self.on_advance_end() + self._restarting = False + except StopIteration: + break + self._restarting = False + return self.on_run_end() + def reset(self) -> None: """Resets the internal state of the loop.""" self._max_batches = self._get_max_batches() @@ -105,10 +121,7 @@ class EvaluationLoop(DataLoaderLoop): if self.done and self.trainer.state.fn != TrainerFn.FITTING: self.dataloader_progress.reset_on_run() - def on_skip(self) -> List: - return [] - - def on_run_start(self, *args: Any, **kwargs: Any) -> None: + def on_run_start(self) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" data_fetcher_cls = _select_data_fetcher_type(self.trainer) @@ -120,7 +133,7 @@ class EvaluationLoop(DataLoaderLoop): self._on_evaluation_start() self._on_evaluation_epoch_start() - def advance(self, *args: Any, **kwargs: Any) -> None: + def advance(self) -> None: """Performs evaluation on one single dataloader.""" dataloader_idx = self.current_dataloader_idx dataloader = self.current_dataloader diff --git a/src/pytorch_lightning/loops/dataloader/prediction_loop.py b/src/pytorch_lightning/loops/dataloader/prediction_loop.py index 1f9df89a00..25163bddcc 100644 --- a/src/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/src/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -16,7 +16,7 @@ class PredictionLoop(DataLoaderLoop): def __init__(self) -> None: super().__init__() self.predictions: List[List[Any]] = [] - self.epoch_batch_indices: List[List[int]] = [] + self.epoch_batch_indices: List[List[List[int]]] = [] # used by PredictionWriter self.epoch_loop = PredictionEpochLoop() self._results = None # for `trainer._results` access @@ -66,6 +66,22 @@ class PredictionLoop(DataLoaderLoop): def skip(self) -> bool: return sum(self.max_batches) == 0 + def run(self) -> Optional[_PREDICT_OUTPUT]: + if self.skip: + return None + self.reset() + self.on_run_start() + while not self.done: + try: + self.on_advance_start() + self.advance() + self.on_advance_end() + self._restarting = False + except StopIteration: + break + self._restarting = False + return self.on_run_end() + def reset(self) -> None: """Resets the internal state of the loop for a new run.""" self.predictions = [] @@ -84,7 +100,7 @@ class PredictionLoop(DataLoaderLoop): self._on_predict_start() self._on_predict_epoch_start() - def advance(self, *args: Any, **kwargs: Any) -> None: + def advance(self) -> None: """Predicts one entire dataloader.""" dataloader = self.current_dataloader if dataloader is not None: diff --git a/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 5540382d45..ce0fcf17a3 100644 --- a/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -14,7 +14,7 @@ from collections import OrderedDict from functools import lru_cache -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from torch.utils.data import DataLoader @@ -45,7 +45,7 @@ class EvaluationEpochLoop(Loop): self.batch_progress = BatchProgress() self._outputs: EPOCH_OUTPUT = [] - self._dl_max_batches = 0 + self._dl_max_batches: Union[int, float] = 0 self._data_fetcher: Optional[AbstractDataFetcher] = None self._dataloader_state_dict: Dict[str, Any] = {} self._dl_batch_idx = [0] @@ -55,6 +55,20 @@ class EvaluationEpochLoop(Loop): """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.batch_progress.current.completed >= self._dl_max_batches + def run( + self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict + ) -> EPOCH_OUTPUT: + self.reset() + self.on_run_start(data_fetcher, dl_max_batches, kwargs) + while not self.done: + try: + self.advance(data_fetcher, kwargs) + self._restarting = False + except StopIteration: + break + self._restarting = False + return self.on_run_end() + def reset(self) -> None: """Resets the loop's internal state.""" self._dl_max_batches = 0 @@ -70,7 +84,9 @@ class EvaluationEpochLoop(Loop): if self.done and self.trainer.state.fn != TrainerFn.FITTING: self.batch_progress.reset_on_run() - def on_run_start(self, data_fetcher: AbstractDataFetcher, dl_max_batches: int, kwargs: OrderedDict) -> None: + def on_run_start( + self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict + ) -> None: """Adds the passed arguments to the loop's state if necessary. Args: @@ -103,14 +119,12 @@ class EvaluationEpochLoop(Loop): def advance( self, data_fetcher: AbstractDataFetcher, - dl_max_batches: int, kwargs: OrderedDict, ) -> None: """Calls the evaluation step with the corresponding hooks and updates the logger connector. Args: data_fetcher: iterator over the dataloader - dl_max_batches: maximum number of batches the dataloader can produce kwargs: the kwargs passed down to the hooks. Raises: diff --git a/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 66794a8caf..1818c46094 100644 --- a/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Tuple, Union import torch @@ -22,7 +22,7 @@ class PredictionEpochLoop(Loop): self.current_batch_indices: List[int] = [] self.batch_progress = Progress() - self._dl_max_batches = 0 + self._dl_max_batches: Union[int, float] = 0 self._num_dataloaders = 0 self._warning_cache = WarningCache() self._seen_batch_indices: List[List[int]] = [] @@ -38,6 +38,24 @@ class PredictionEpochLoop(Loop): any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks) return self.return_predictions or any_pred + def run( + self, + dataloader_iter: Iterator, + dataloader_idx: int, + dl_max_batches: Union[int, float], + num_dataloaders: int, + ) -> Tuple[List[Any], List[List[int]]]: + self.reset() + self.on_run_start(dataloader_idx, dl_max_batches, num_dataloaders) + while not self.done: + try: + self.advance(dataloader_iter, dataloader_idx) + self._restarting = False + except StopIteration: + break + self._restarting = False + return self.on_run_end() + def reset(self) -> None: """Resets the loops internal state.""" self._seen_batch_indices = [] @@ -46,15 +64,13 @@ class PredictionEpochLoop(Loop): def on_run_start( self, - dataloader_iter: Iterator, dataloader_idx: int, - dl_max_batches: int, + dl_max_batches: Union[int, float], num_dataloaders: int, ) -> None: """Prepares the loops internal state. Args: - dataloader_iter: the iterator over the current dataloader dataloader_idx: the index of the current dataloader dl_max_batches: the maximum number of batches the current loader can produce num_dataloaders: the total number of dataloaders @@ -68,16 +84,12 @@ class PredictionEpochLoop(Loop): self, dataloader_iter: Iterator, dataloader_idx: int, - dl_max_batches: int, - num_dataloaders: int, ) -> None: """Runs one prediction step. Args: dataloader_iter: the iterator over the current dataloader dataloader_idx: the index of the current dataloader - dl_max_batches: the maximum number of batches the current loader can produce - num_dataloaders: the total number of dataloaders """ action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next" with self.trainer.profiler.profile(action_name): diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 3abcfb9520..5474cbf60e 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -39,7 +39,7 @@ _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_ _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] -class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): +class TrainingEpochLoop(loops.Loop): """Runs over all batches in a dataloader (one epoch). Args: @@ -121,6 +121,19 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): return False + def run(self, data_fetcher: AbstractDataFetcher) -> _OUTPUTS_TYPE: + self.reset() + self.on_run_start(data_fetcher) + while not self.done: + try: + self.advance(data_fetcher) + self.on_advance_end() + self._restarting = False + except StopIteration: + break + self._restarting = False + return self.on_run_end() + def reset(self) -> None: """Resets the internal state of the loop for a new run.""" if self.restarting: diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index d47a5cec86..39e103d7ba 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signatu log = logging.getLogger(__name__) -class FitLoop(Loop[None]): +class FitLoop(Loop): """This Loop iterates over the epochs to run the training. Args: @@ -169,6 +169,22 @@ class FitLoop(Loop[None]): # until `on_run_start`, we use `limit_train_batches` instead return self.done or self.trainer.limit_train_batches == 0 + def run(self) -> None: + if self.skip: + return + self.reset() + self.on_run_start() + while not self.done: + try: + self.on_advance_start() + self.advance() + self.on_advance_end() + self._restarting = False + except StopIteration: + break + self._restarting = False + self.on_run_end() + def reset(self) -> None: """Resets the internal state of this loop.""" if self.restarting: diff --git a/src/pytorch_lightning/loops/loop.py b/src/pytorch_lightning/loops/loop.py index 461b0a6a2f..00b5757a45 100644 --- a/src/pytorch_lightning/loops/loop.py +++ b/src/pytorch_lightning/loops/loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Dict, Optional from torchmetrics import Metric @@ -21,28 +21,13 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import _Result from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.imports import _fault_tolerant_training -T = TypeVar("T") # the output type of `run` - -class Loop(ABC, Generic[T]): +class Loop(ABC): """Basic Loops interface. All classes derived from this must implement the following properties and methods: - * :attr:`done` (property): Condition to break the loop - * :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run` - * :attr:`advance` (method): Implements one step of the loop - - This class implements the following loop structure: - - .. code-block:: python - - on_run_start() - - while not done: - on_advance_start() - advance() - on_advance_end() - - on_run_end() + * :attr:`done` (property): Condition to break the loop + * :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run` + * :attr:`advance` (method): Implements one step of the loop """ def __init__(self) -> None: @@ -100,58 +85,9 @@ class Loop(ABC, Generic[T]): """ return False - def on_skip(self) -> T: - """The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`. - - Returns: - the default output value of :meth:`on_run_end` - """ - - def run(self, *args: Any, **kwargs: Any) -> T: - """The main entry point to the loop. - - Will frequently check the :attr:`done` condition and calls :attr:`advance` - until :attr:`done` evaluates to ``True``. - - Override this if you wish to change the default behavior. The default implementation is: - - Example:: - - def run(self, *args, **kwargs): - if self.skip: - return self.on_skip() - - self.reset() - self.on_run_start(*args, **kwargs) - - while not self.done: - self.advance(*args, **kwargs) - - output = self.on_run_end() - return output - - Returns: - The output of :attr:`on_run_end` (often outputs collected from each step of the loop) - """ - if self.skip: - return self.on_skip() - - self.reset() - - self.on_run_start(*args, **kwargs) - - while not self.done: - try: - self.on_advance_start(*args, **kwargs) - self.advance(*args, **kwargs) - self.on_advance_end() - self._restarting = False - except StopIteration: - break - self._restarting = False - - output = self.on_run_end() - return output + @abstractmethod + def run(self, *args: Any, **kwargs: Any) -> Any: + """The main entry point to the loop.""" @abstractmethod def reset(self) -> None: @@ -195,7 +131,7 @@ class Loop(ABC, Generic[T]): def on_advance_end(self) -> None: """Hook to be called each time after :attr:`advance` is called.""" - def on_run_end(self) -> T: + def on_run_end(self) -> Any: """Hook to be called at the end of the run. Its return argument is returned from :attr:`run`. diff --git a/src/pytorch_lightning/loops/optimization/manual_loop.py b/src/pytorch_lightning/loops/optimization/manual_loop.py index 5fe61c3055..f452bc8bff 100644 --- a/src/pytorch_lightning/loops/optimization/manual_loop.py +++ b/src/pytorch_lightning/loops/optimization/manual_loop.py @@ -64,7 +64,7 @@ class ManualResult(OutputResult): _OUTPUTS_TYPE = Dict[str, Any] -class ManualOptimization(Loop[_OUTPUTS_TYPE]): +class ManualOptimization(Loop): """A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens entirely in the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` and therefore the user is responsible for back-propagating gradients and making calls to the optimizers. @@ -88,10 +88,22 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): def done(self) -> bool: return self._done + def run(self, kwargs: OrderedDict) -> _OUTPUTS_TYPE: + self.reset() + self.on_run_start() + while not self.done: + try: + self.advance(kwargs) + self._restarting = False + except StopIteration: + break + self._restarting = False + return self.on_run_end() + def reset(self) -> None: self._done = False - def on_run_start(self, *_: Any, **__: Any) -> None: + def on_run_start(self) -> None: # inject logic around the optimizer step for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items(): lightning_optimizer._on_before_step = self._on_before_step diff --git a/src/pytorch_lightning/loops/optimization/optimizer_loop.py b/src/pytorch_lightning/loops/optimization/optimizer_loop.py index 07284198aa..d7878f21ad 100644 --- a/src/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/src/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -146,7 +146,7 @@ class Closure(AbstractClosure[ClosureResult]): _OUTPUTS_TYPE = Dict[int, Dict[str, Any]] -class OptimizerLoop(Loop[_OUTPUTS_TYPE]): +class OptimizerLoop(Loop): """Runs over a sequence of optimizers. This loop implements what is known in Lightning as Automatic Optimization. @@ -172,6 +172,18 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): """Returns ``True`` when the last optimizer in the sequence has run.""" return self.optim_progress.optimizer_position >= len(self._indices) + def run(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> _OUTPUTS_TYPE: + self.reset() + self.on_run_start(optimizers) + while not self.done: + try: + self.advance(kwargs) + self._restarting = False + except StopIteration: + break + self._restarting = False + return self.on_run_end() + def reset(self) -> None: if not self.restarting: # when reset() is called from outside (manually), we reset the loop progress @@ -180,12 +192,12 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): self.optim_progress.reset_on_restart() self._outputs = {} - def on_run_start(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None: + def on_run_start(self, optimizers: List[Tuple[int, Optimizer]]) -> None: self._indices, self._optimizers = zip(*optimizers) if self.done: self.optim_progress.optimizer_position = 0 - def advance(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None: + def advance(self, kwargs: OrderedDict) -> None: kwargs = self._build_kwargs(kwargs, self.optimizer_idx) result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position]) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 0dab4ee15f..5044bd1196 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -43,6 +43,7 @@ from torch.utils.data import DataLoader from typing_extensions import Literal import pytorch_lightning as pl +from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars from lightning_fabric.utilities.cloud_io import get_filesystem from lightning_fabric.utilities.data import _auto_add_worker_init_fn from lightning_fabric.utilities.types import _PATH @@ -1113,13 +1114,7 @@ class Trainer: eval_loop_results = self._evaluation_loop.run() # remove the tensors from the eval results - for result in eval_loop_results: - if isinstance(result, dict): - for k, v in result.items(): - if isinstance(v, Tensor): - result[k] = v.cpu().item() - - return eval_loop_results + return convert_tensors_to_scalars(eval_loop_results) def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: self.reset_predict_dataloader(self.lightning_module) diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index dbb944ae33..7dc64a85f2 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -46,6 +46,9 @@ def test_restarting_loops_recursive(): def advance(self, *args, **kwargs): pass + def run(self): + pass + loop = MyLoop(MyLoop(MyLoop())) assert not loop.restarting @@ -68,14 +71,21 @@ def test_loop_restore(): self.iteration_count = 0 self.dataset = dataset - @property - def skip(self) -> bool: - return False - @property def done(self) -> bool: return self.iteration_count > len(self.dataset) + def run(self): + self.reset() + while not self.done: + try: + self.advance() + self.on_advance_end() + self._restarting = False + except StopIteration: + break + self._restarting = False + def reset(self) -> None: self.iter_dataset = iter(self.dataset) if self.restarting: @@ -135,6 +145,16 @@ def test_loop_hierarchy(): self.a = a self.progress = SimpleProgress() + def run(self): + while not self.done: + try: + self.advance() + self.on_advance_end() + self._restarting = False + except StopIteration: + break + self._restarting = False + def advance(self, *args: Any, **kwargs: Any) -> None: loop = getattr(self, "loop_child", None) if not loop: