Loop flattening: remove the default `.run()` implementation (#16427)

This commit is contained in:
Carlos Mocholí 2023-01-19 13:49:25 +01:00 committed by Luca Antiga
parent f031f1e453
commit fd9a3803b8
13 changed files with 174 additions and 114 deletions

View File

@ -90,6 +90,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed `Loop.replace()` ([#16361](https://github.com/Lightning-AI/lightning/pull/16361))
* Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
* Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
* Removed the default `Loop.run()` implementation ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
* Removed the `LightningModule.truncated_bptt_steps` attribute

View File

@ -13,7 +13,7 @@
# limitations under the License.
from abc import abstractmethod
from typing import Any, Sequence
from typing import Sequence
from torch.utils.data import DataLoader
@ -60,7 +60,7 @@ class DataLoaderLoop(Loop):
else:
self.dataloader_progress.reset_on_restart()
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
def on_advance_start(self) -> None:
self.dataloader_progress.increment_ready()
def on_advance_end(self) -> None:

View File

@ -89,6 +89,22 @@ class EvaluationLoop(DataLoaderLoop):
max_batches = self._get_max_batches()
return sum(max_batches) == 0
def run(self) -> List[_OUT_DICT]:
if self.skip:
return []
self.reset()
self.on_run_start()
while not self.done:
try:
self.on_advance_start()
self.advance()
self.on_advance_end()
self._restarting = False
except StopIteration:
break
self._restarting = False
return self.on_run_end()
def reset(self) -> None:
"""Resets the internal state of the loop."""
self._max_batches = self._get_max_batches()
@ -105,10 +121,7 @@ class EvaluationLoop(DataLoaderLoop):
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
self.dataloader_progress.reset_on_run()
def on_skip(self) -> List:
return []
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
def on_run_start(self) -> None:
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
data_fetcher_cls = _select_data_fetcher_type(self.trainer)
@ -120,7 +133,7 @@ class EvaluationLoop(DataLoaderLoop):
self._on_evaluation_start()
self._on_evaluation_epoch_start()
def advance(self, *args: Any, **kwargs: Any) -> None:
def advance(self) -> None:
"""Performs evaluation on one single dataloader."""
dataloader_idx = self.current_dataloader_idx
dataloader = self.current_dataloader

View File

@ -16,7 +16,7 @@ class PredictionLoop(DataLoaderLoop):
def __init__(self) -> None:
super().__init__()
self.predictions: List[List[Any]] = []
self.epoch_batch_indices: List[List[int]] = []
self.epoch_batch_indices: List[List[List[int]]] = [] # used by PredictionWriter
self.epoch_loop = PredictionEpochLoop()
self._results = None # for `trainer._results` access
@ -66,6 +66,22 @@ class PredictionLoop(DataLoaderLoop):
def skip(self) -> bool:
return sum(self.max_batches) == 0
def run(self) -> Optional[_PREDICT_OUTPUT]:
if self.skip:
return None
self.reset()
self.on_run_start()
while not self.done:
try:
self.on_advance_start()
self.advance()
self.on_advance_end()
self._restarting = False
except StopIteration:
break
self._restarting = False
return self.on_run_end()
def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
self.predictions = []
@ -84,7 +100,7 @@ class PredictionLoop(DataLoaderLoop):
self._on_predict_start()
self._on_predict_epoch_start()
def advance(self, *args: Any, **kwargs: Any) -> None:
def advance(self) -> None:
"""Predicts one entire dataloader."""
dataloader = self.current_dataloader
if dataloader is not None:

View File

@ -14,7 +14,7 @@
from collections import OrderedDict
from functools import lru_cache
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from torch.utils.data import DataLoader
@ -45,7 +45,7 @@ class EvaluationEpochLoop(Loop):
self.batch_progress = BatchProgress()
self._outputs: EPOCH_OUTPUT = []
self._dl_max_batches = 0
self._dl_max_batches: Union[int, float] = 0
self._data_fetcher: Optional[AbstractDataFetcher] = None
self._dataloader_state_dict: Dict[str, Any] = {}
self._dl_batch_idx = [0]
@ -55,6 +55,20 @@ class EvaluationEpochLoop(Loop):
"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
return self.batch_progress.current.completed >= self._dl_max_batches
def run(
self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict
) -> EPOCH_OUTPUT:
self.reset()
self.on_run_start(data_fetcher, dl_max_batches, kwargs)
while not self.done:
try:
self.advance(data_fetcher, kwargs)
self._restarting = False
except StopIteration:
break
self._restarting = False
return self.on_run_end()
def reset(self) -> None:
"""Resets the loop's internal state."""
self._dl_max_batches = 0
@ -70,7 +84,9 @@ class EvaluationEpochLoop(Loop):
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
self.batch_progress.reset_on_run()
def on_run_start(self, data_fetcher: AbstractDataFetcher, dl_max_batches: int, kwargs: OrderedDict) -> None:
def on_run_start(
self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict
) -> None:
"""Adds the passed arguments to the loop's state if necessary.
Args:
@ -103,14 +119,12 @@ class EvaluationEpochLoop(Loop):
def advance(
self,
data_fetcher: AbstractDataFetcher,
dl_max_batches: int,
kwargs: OrderedDict,
) -> None:
"""Calls the evaluation step with the corresponding hooks and updates the logger connector.
Args:
data_fetcher: iterator over the dataloader
dl_max_batches: maximum number of batches the dataloader can produce
kwargs: the kwargs passed down to the hooks.
Raises:

View File

@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List, Tuple, Union
import torch
@ -22,7 +22,7 @@ class PredictionEpochLoop(Loop):
self.current_batch_indices: List[int] = []
self.batch_progress = Progress()
self._dl_max_batches = 0
self._dl_max_batches: Union[int, float] = 0
self._num_dataloaders = 0
self._warning_cache = WarningCache()
self._seen_batch_indices: List[List[int]] = []
@ -38,6 +38,24 @@ class PredictionEpochLoop(Loop):
any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks)
return self.return_predictions or any_pred
def run(
self,
dataloader_iter: Iterator,
dataloader_idx: int,
dl_max_batches: Union[int, float],
num_dataloaders: int,
) -> Tuple[List[Any], List[List[int]]]:
self.reset()
self.on_run_start(dataloader_idx, dl_max_batches, num_dataloaders)
while not self.done:
try:
self.advance(dataloader_iter, dataloader_idx)
self._restarting = False
except StopIteration:
break
self._restarting = False
return self.on_run_end()
def reset(self) -> None:
"""Resets the loops internal state."""
self._seen_batch_indices = []
@ -46,15 +64,13 @@ class PredictionEpochLoop(Loop):
def on_run_start(
self,
dataloader_iter: Iterator,
dataloader_idx: int,
dl_max_batches: int,
dl_max_batches: Union[int, float],
num_dataloaders: int,
) -> None:
"""Prepares the loops internal state.
Args:
dataloader_iter: the iterator over the current dataloader
dataloader_idx: the index of the current dataloader
dl_max_batches: the maximum number of batches the current loader can produce
num_dataloaders: the total number of dataloaders
@ -68,16 +84,12 @@ class PredictionEpochLoop(Loop):
self,
dataloader_iter: Iterator,
dataloader_idx: int,
dl_max_batches: int,
num_dataloaders: int,
) -> None:
"""Runs one prediction step.
Args:
dataloader_iter: the iterator over the current dataloader
dataloader_idx: the index of the current dataloader
dl_max_batches: the maximum number of batches the current loader can produce
num_dataloaders: the total number of dataloaders
"""
action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next"
with self.trainer.profiler.profile(action_name):

View File

@ -39,7 +39,7 @@ _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
class TrainingEpochLoop(loops.Loop):
"""Runs over all batches in a dataloader (one epoch).
Args:
@ -121,6 +121,19 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
return False
def run(self, data_fetcher: AbstractDataFetcher) -> _OUTPUTS_TYPE:
self.reset()
self.on_run_start(data_fetcher)
while not self.done:
try:
self.advance(data_fetcher)
self.on_advance_end()
self._restarting = False
except StopIteration:
break
self._restarting = False
return self.on_run_end()
def reset(self) -> None:
"""Resets the internal state of the loop for a new run."""
if self.restarting:

View File

@ -31,7 +31,7 @@ from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signatu
log = logging.getLogger(__name__)
class FitLoop(Loop[None]):
class FitLoop(Loop):
"""This Loop iterates over the epochs to run the training.
Args:
@ -169,6 +169,22 @@ class FitLoop(Loop[None]):
# until `on_run_start`, we use `limit_train_batches` instead
return self.done or self.trainer.limit_train_batches == 0
def run(self) -> None:
if self.skip:
return
self.reset()
self.on_run_start()
while not self.done:
try:
self.on_advance_start()
self.advance()
self.on_advance_end()
self._restarting = False
except StopIteration:
break
self._restarting = False
self.on_run_end()
def reset(self) -> None:
"""Resets the internal state of this loop."""
if self.restarting:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional, TypeVar
from typing import Any, Dict, Optional
from torchmetrics import Metric
@ -21,28 +21,13 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import _Result
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.imports import _fault_tolerant_training
T = TypeVar("T") # the output type of `run`
class Loop(ABC, Generic[T]):
class Loop(ABC):
"""Basic Loops interface. All classes derived from this must implement the following properties and methods:
* :attr:`done` (property): Condition to break the loop
* :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run`
* :attr:`advance` (method): Implements one step of the loop
This class implements the following loop structure:
.. code-block:: python
on_run_start()
while not done:
on_advance_start()
advance()
on_advance_end()
on_run_end()
* :attr:`done` (property): Condition to break the loop
* :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run`
* :attr:`advance` (method): Implements one step of the loop
"""
def __init__(self) -> None:
@ -100,58 +85,9 @@ class Loop(ABC, Generic[T]):
"""
return False
def on_skip(self) -> T:
"""The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.
Returns:
the default output value of :meth:`on_run_end`
"""
def run(self, *args: Any, **kwargs: Any) -> T:
"""The main entry point to the loop.
Will frequently check the :attr:`done` condition and calls :attr:`advance`
until :attr:`done` evaluates to ``True``.
Override this if you wish to change the default behavior. The default implementation is:
Example::
def run(self, *args, **kwargs):
if self.skip:
return self.on_skip()
self.reset()
self.on_run_start(*args, **kwargs)
while not self.done:
self.advance(*args, **kwargs)
output = self.on_run_end()
return output
Returns:
The output of :attr:`on_run_end` (often outputs collected from each step of the loop)
"""
if self.skip:
return self.on_skip()
self.reset()
self.on_run_start(*args, **kwargs)
while not self.done:
try:
self.on_advance_start(*args, **kwargs)
self.advance(*args, **kwargs)
self.on_advance_end()
self._restarting = False
except StopIteration:
break
self._restarting = False
output = self.on_run_end()
return output
@abstractmethod
def run(self, *args: Any, **kwargs: Any) -> Any:
"""The main entry point to the loop."""
@abstractmethod
def reset(self) -> None:
@ -195,7 +131,7 @@ class Loop(ABC, Generic[T]):
def on_advance_end(self) -> None:
"""Hook to be called each time after :attr:`advance` is called."""
def on_run_end(self) -> T:
def on_run_end(self) -> Any:
"""Hook to be called at the end of the run.
Its return argument is returned from :attr:`run`.

View File

@ -64,7 +64,7 @@ class ManualResult(OutputResult):
_OUTPUTS_TYPE = Dict[str, Any]
class ManualOptimization(Loop[_OUTPUTS_TYPE]):
class ManualOptimization(Loop):
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
entirely in the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` and therefore the user is
responsible for back-propagating gradients and making calls to the optimizers.
@ -88,10 +88,22 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
def done(self) -> bool:
return self._done
def run(self, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
self.reset()
self.on_run_start()
while not self.done:
try:
self.advance(kwargs)
self._restarting = False
except StopIteration:
break
self._restarting = False
return self.on_run_end()
def reset(self) -> None:
self._done = False
def on_run_start(self, *_: Any, **__: Any) -> None:
def on_run_start(self) -> None:
# inject logic around the optimizer step
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
lightning_optimizer._on_before_step = self._on_before_step

View File

@ -146,7 +146,7 @@ class Closure(AbstractClosure[ClosureResult]):
_OUTPUTS_TYPE = Dict[int, Dict[str, Any]]
class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
class OptimizerLoop(Loop):
"""Runs over a sequence of optimizers.
This loop implements what is known in Lightning as Automatic Optimization.
@ -172,6 +172,18 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
"""Returns ``True`` when the last optimizer in the sequence has run."""
return self.optim_progress.optimizer_position >= len(self._indices)
def run(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> _OUTPUTS_TYPE:
self.reset()
self.on_run_start(optimizers)
while not self.done:
try:
self.advance(kwargs)
self._restarting = False
except StopIteration:
break
self._restarting = False
return self.on_run_end()
def reset(self) -> None:
if not self.restarting:
# when reset() is called from outside (manually), we reset the loop progress
@ -180,12 +192,12 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
self.optim_progress.reset_on_restart()
self._outputs = {}
def on_run_start(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None:
def on_run_start(self, optimizers: List[Tuple[int, Optimizer]]) -> None:
self._indices, self._optimizers = zip(*optimizers)
if self.done:
self.optim_progress.optimizer_position = 0
def advance(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None:
def advance(self, kwargs: OrderedDict) -> None:
kwargs = self._build_kwargs(kwargs, self.optimizer_idx)
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])

View File

@ -43,6 +43,7 @@ from torch.utils.data import DataLoader
from typing_extensions import Literal
import pytorch_lightning as pl
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning_fabric.utilities.cloud_io import get_filesystem
from lightning_fabric.utilities.data import _auto_add_worker_init_fn
from lightning_fabric.utilities.types import _PATH
@ -1113,13 +1114,7 @@ class Trainer:
eval_loop_results = self._evaluation_loop.run()
# remove the tensors from the eval results
for result in eval_loop_results:
if isinstance(result, dict):
for k, v in result.items():
if isinstance(v, Tensor):
result[k] = v.cpu().item()
return eval_loop_results
return convert_tensors_to_scalars(eval_loop_results)
def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
self.reset_predict_dataloader(self.lightning_module)

View File

@ -46,6 +46,9 @@ def test_restarting_loops_recursive():
def advance(self, *args, **kwargs):
pass
def run(self):
pass
loop = MyLoop(MyLoop(MyLoop()))
assert not loop.restarting
@ -68,14 +71,21 @@ def test_loop_restore():
self.iteration_count = 0
self.dataset = dataset
@property
def skip(self) -> bool:
return False
@property
def done(self) -> bool:
return self.iteration_count > len(self.dataset)
def run(self):
self.reset()
while not self.done:
try:
self.advance()
self.on_advance_end()
self._restarting = False
except StopIteration:
break
self._restarting = False
def reset(self) -> None:
self.iter_dataset = iter(self.dataset)
if self.restarting:
@ -135,6 +145,16 @@ def test_loop_hierarchy():
self.a = a
self.progress = SimpleProgress()
def run(self):
while not self.done:
try:
self.advance()
self.on_advance_end()
self._restarting = False
except StopIteration:
break
self._restarting = False
def advance(self, *args: Any, **kwargs: Any) -> None:
loop = getattr(self, "loop_child", None)
if not loop: