Loop flattening: remove the default `.run()` implementation (#16427)
This commit is contained in:
parent
f031f1e453
commit
fd9a3803b8
|
@ -90,6 +90,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Removed `Loop.replace()` ([#16361](https://github.com/Lightning-AI/lightning/pull/16361))
|
||||
* Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
|
||||
* Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
|
||||
* Removed the default `Loop.run()` implementation ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
|
||||
|
||||
- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
|
||||
* Removed the `LightningModule.truncated_bptt_steps` attribute
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Sequence
|
||||
from typing import Sequence
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -60,7 +60,7 @@ class DataLoaderLoop(Loop):
|
|||
else:
|
||||
self.dataloader_progress.reset_on_restart()
|
||||
|
||||
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
def on_advance_start(self) -> None:
|
||||
self.dataloader_progress.increment_ready()
|
||||
|
||||
def on_advance_end(self) -> None:
|
||||
|
|
|
@ -89,6 +89,22 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
max_batches = self._get_max_batches()
|
||||
return sum(max_batches) == 0
|
||||
|
||||
def run(self) -> List[_OUT_DICT]:
|
||||
if self.skip:
|
||||
return []
|
||||
self.reset()
|
||||
self.on_run_start()
|
||||
while not self.done:
|
||||
try:
|
||||
self.on_advance_start()
|
||||
self.advance()
|
||||
self.on_advance_end()
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
return self.on_run_end()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the internal state of the loop."""
|
||||
self._max_batches = self._get_max_batches()
|
||||
|
@ -105,10 +121,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
|
||||
self.dataloader_progress.reset_on_run()
|
||||
|
||||
def on_skip(self) -> List:
|
||||
return []
|
||||
|
||||
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
def on_run_start(self) -> None:
|
||||
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
|
||||
hooks."""
|
||||
data_fetcher_cls = _select_data_fetcher_type(self.trainer)
|
||||
|
@ -120,7 +133,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
self._on_evaluation_start()
|
||||
self._on_evaluation_epoch_start()
|
||||
|
||||
def advance(self, *args: Any, **kwargs: Any) -> None:
|
||||
def advance(self) -> None:
|
||||
"""Performs evaluation on one single dataloader."""
|
||||
dataloader_idx = self.current_dataloader_idx
|
||||
dataloader = self.current_dataloader
|
||||
|
|
|
@ -16,7 +16,7 @@ class PredictionLoop(DataLoaderLoop):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.predictions: List[List[Any]] = []
|
||||
self.epoch_batch_indices: List[List[int]] = []
|
||||
self.epoch_batch_indices: List[List[List[int]]] = [] # used by PredictionWriter
|
||||
self.epoch_loop = PredictionEpochLoop()
|
||||
|
||||
self._results = None # for `trainer._results` access
|
||||
|
@ -66,6 +66,22 @@ class PredictionLoop(DataLoaderLoop):
|
|||
def skip(self) -> bool:
|
||||
return sum(self.max_batches) == 0
|
||||
|
||||
def run(self) -> Optional[_PREDICT_OUTPUT]:
|
||||
if self.skip:
|
||||
return None
|
||||
self.reset()
|
||||
self.on_run_start()
|
||||
while not self.done:
|
||||
try:
|
||||
self.on_advance_start()
|
||||
self.advance()
|
||||
self.on_advance_end()
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
return self.on_run_end()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the internal state of the loop for a new run."""
|
||||
self.predictions = []
|
||||
|
@ -84,7 +100,7 @@ class PredictionLoop(DataLoaderLoop):
|
|||
self._on_predict_start()
|
||||
self._on_predict_epoch_start()
|
||||
|
||||
def advance(self, *args: Any, **kwargs: Any) -> None:
|
||||
def advance(self) -> None:
|
||||
"""Predicts one entire dataloader."""
|
||||
dataloader = self.current_dataloader
|
||||
if dataloader is not None:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
from collections import OrderedDict
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -45,7 +45,7 @@ class EvaluationEpochLoop(Loop):
|
|||
self.batch_progress = BatchProgress()
|
||||
|
||||
self._outputs: EPOCH_OUTPUT = []
|
||||
self._dl_max_batches = 0
|
||||
self._dl_max_batches: Union[int, float] = 0
|
||||
self._data_fetcher: Optional[AbstractDataFetcher] = None
|
||||
self._dataloader_state_dict: Dict[str, Any] = {}
|
||||
self._dl_batch_idx = [0]
|
||||
|
@ -55,6 +55,20 @@ class EvaluationEpochLoop(Loop):
|
|||
"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
|
||||
return self.batch_progress.current.completed >= self._dl_max_batches
|
||||
|
||||
def run(
|
||||
self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict
|
||||
) -> EPOCH_OUTPUT:
|
||||
self.reset()
|
||||
self.on_run_start(data_fetcher, dl_max_batches, kwargs)
|
||||
while not self.done:
|
||||
try:
|
||||
self.advance(data_fetcher, kwargs)
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
return self.on_run_end()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the loop's internal state."""
|
||||
self._dl_max_batches = 0
|
||||
|
@ -70,7 +84,9 @@ class EvaluationEpochLoop(Loop):
|
|||
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
|
||||
self.batch_progress.reset_on_run()
|
||||
|
||||
def on_run_start(self, data_fetcher: AbstractDataFetcher, dl_max_batches: int, kwargs: OrderedDict) -> None:
|
||||
def on_run_start(
|
||||
self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict
|
||||
) -> None:
|
||||
"""Adds the passed arguments to the loop's state if necessary.
|
||||
|
||||
Args:
|
||||
|
@ -103,14 +119,12 @@ class EvaluationEpochLoop(Loop):
|
|||
def advance(
|
||||
self,
|
||||
data_fetcher: AbstractDataFetcher,
|
||||
dl_max_batches: int,
|
||||
kwargs: OrderedDict,
|
||||
) -> None:
|
||||
"""Calls the evaluation step with the corresponding hooks and updates the logger connector.
|
||||
|
||||
Args:
|
||||
data_fetcher: iterator over the dataloader
|
||||
dl_max_batches: maximum number of batches the dataloader can produce
|
||||
kwargs: the kwargs passed down to the hooks.
|
||||
|
||||
Raises:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Iterator, List, Tuple
|
||||
from typing import Any, Dict, Iterator, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -22,7 +22,7 @@ class PredictionEpochLoop(Loop):
|
|||
self.current_batch_indices: List[int] = []
|
||||
self.batch_progress = Progress()
|
||||
|
||||
self._dl_max_batches = 0
|
||||
self._dl_max_batches: Union[int, float] = 0
|
||||
self._num_dataloaders = 0
|
||||
self._warning_cache = WarningCache()
|
||||
self._seen_batch_indices: List[List[int]] = []
|
||||
|
@ -38,6 +38,24 @@ class PredictionEpochLoop(Loop):
|
|||
any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks)
|
||||
return self.return_predictions or any_pred
|
||||
|
||||
def run(
|
||||
self,
|
||||
dataloader_iter: Iterator,
|
||||
dataloader_idx: int,
|
||||
dl_max_batches: Union[int, float],
|
||||
num_dataloaders: int,
|
||||
) -> Tuple[List[Any], List[List[int]]]:
|
||||
self.reset()
|
||||
self.on_run_start(dataloader_idx, dl_max_batches, num_dataloaders)
|
||||
while not self.done:
|
||||
try:
|
||||
self.advance(dataloader_iter, dataloader_idx)
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
return self.on_run_end()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the loops internal state."""
|
||||
self._seen_batch_indices = []
|
||||
|
@ -46,15 +64,13 @@ class PredictionEpochLoop(Loop):
|
|||
|
||||
def on_run_start(
|
||||
self,
|
||||
dataloader_iter: Iterator,
|
||||
dataloader_idx: int,
|
||||
dl_max_batches: int,
|
||||
dl_max_batches: Union[int, float],
|
||||
num_dataloaders: int,
|
||||
) -> None:
|
||||
"""Prepares the loops internal state.
|
||||
|
||||
Args:
|
||||
dataloader_iter: the iterator over the current dataloader
|
||||
dataloader_idx: the index of the current dataloader
|
||||
dl_max_batches: the maximum number of batches the current loader can produce
|
||||
num_dataloaders: the total number of dataloaders
|
||||
|
@ -68,16 +84,12 @@ class PredictionEpochLoop(Loop):
|
|||
self,
|
||||
dataloader_iter: Iterator,
|
||||
dataloader_idx: int,
|
||||
dl_max_batches: int,
|
||||
num_dataloaders: int,
|
||||
) -> None:
|
||||
"""Runs one prediction step.
|
||||
|
||||
Args:
|
||||
dataloader_iter: the iterator over the current dataloader
|
||||
dataloader_idx: the index of the current dataloader
|
||||
dl_max_batches: the maximum number of batches the current loader can produce
|
||||
num_dataloaders: the total number of dataloaders
|
||||
"""
|
||||
action_name = f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next"
|
||||
with self.trainer.profiler.profile(action_name):
|
||||
|
|
|
@ -39,7 +39,7 @@ _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_
|
|||
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
|
||||
|
||||
|
||||
class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
||||
class TrainingEpochLoop(loops.Loop):
|
||||
"""Runs over all batches in a dataloader (one epoch).
|
||||
|
||||
Args:
|
||||
|
@ -121,6 +121,19 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
return False
|
||||
|
||||
def run(self, data_fetcher: AbstractDataFetcher) -> _OUTPUTS_TYPE:
|
||||
self.reset()
|
||||
self.on_run_start(data_fetcher)
|
||||
while not self.done:
|
||||
try:
|
||||
self.advance(data_fetcher)
|
||||
self.on_advance_end()
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
return self.on_run_end()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the internal state of the loop for a new run."""
|
||||
if self.restarting:
|
||||
|
|
|
@ -31,7 +31,7 @@ from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signatu
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FitLoop(Loop[None]):
|
||||
class FitLoop(Loop):
|
||||
"""This Loop iterates over the epochs to run the training.
|
||||
|
||||
Args:
|
||||
|
@ -169,6 +169,22 @@ class FitLoop(Loop[None]):
|
|||
# until `on_run_start`, we use `limit_train_batches` instead
|
||||
return self.done or self.trainer.limit_train_batches == 0
|
||||
|
||||
def run(self) -> None:
|
||||
if self.skip:
|
||||
return
|
||||
self.reset()
|
||||
self.on_run_start()
|
||||
while not self.done:
|
||||
try:
|
||||
self.on_advance_start()
|
||||
self.advance()
|
||||
self.on_advance_end()
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
self.on_run_end()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the internal state of this loop."""
|
||||
if self.restarting:
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, Optional, TypeVar
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from torchmetrics import Metric
|
||||
|
||||
|
@ -21,28 +21,13 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import _Result
|
|||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
T = TypeVar("T") # the output type of `run`
|
||||
|
||||
|
||||
class Loop(ABC, Generic[T]):
|
||||
class Loop(ABC):
|
||||
"""Basic Loops interface. All classes derived from this must implement the following properties and methods:
|
||||
|
||||
* :attr:`done` (property): Condition to break the loop
|
||||
* :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run`
|
||||
* :attr:`advance` (method): Implements one step of the loop
|
||||
|
||||
This class implements the following loop structure:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
on_run_start()
|
||||
|
||||
while not done:
|
||||
on_advance_start()
|
||||
advance()
|
||||
on_advance_end()
|
||||
|
||||
on_run_end()
|
||||
* :attr:`done` (property): Condition to break the loop
|
||||
* :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run`
|
||||
* :attr:`advance` (method): Implements one step of the loop
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -100,58 +85,9 @@ class Loop(ABC, Generic[T]):
|
|||
"""
|
||||
return False
|
||||
|
||||
def on_skip(self) -> T:
|
||||
"""The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.
|
||||
|
||||
Returns:
|
||||
the default output value of :meth:`on_run_end`
|
||||
"""
|
||||
|
||||
def run(self, *args: Any, **kwargs: Any) -> T:
|
||||
"""The main entry point to the loop.
|
||||
|
||||
Will frequently check the :attr:`done` condition and calls :attr:`advance`
|
||||
until :attr:`done` evaluates to ``True``.
|
||||
|
||||
Override this if you wish to change the default behavior. The default implementation is:
|
||||
|
||||
Example::
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
if self.skip:
|
||||
return self.on_skip()
|
||||
|
||||
self.reset()
|
||||
self.on_run_start(*args, **kwargs)
|
||||
|
||||
while not self.done:
|
||||
self.advance(*args, **kwargs)
|
||||
|
||||
output = self.on_run_end()
|
||||
return output
|
||||
|
||||
Returns:
|
||||
The output of :attr:`on_run_end` (often outputs collected from each step of the loop)
|
||||
"""
|
||||
if self.skip:
|
||||
return self.on_skip()
|
||||
|
||||
self.reset()
|
||||
|
||||
self.on_run_start(*args, **kwargs)
|
||||
|
||||
while not self.done:
|
||||
try:
|
||||
self.on_advance_start(*args, **kwargs)
|
||||
self.advance(*args, **kwargs)
|
||||
self.on_advance_end()
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
|
||||
output = self.on_run_end()
|
||||
return output
|
||||
@abstractmethod
|
||||
def run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""The main entry point to the loop."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
|
@ -195,7 +131,7 @@ class Loop(ABC, Generic[T]):
|
|||
def on_advance_end(self) -> None:
|
||||
"""Hook to be called each time after :attr:`advance` is called."""
|
||||
|
||||
def on_run_end(self) -> T:
|
||||
def on_run_end(self) -> Any:
|
||||
"""Hook to be called at the end of the run.
|
||||
|
||||
Its return argument is returned from :attr:`run`.
|
||||
|
|
|
@ -64,7 +64,7 @@ class ManualResult(OutputResult):
|
|||
_OUTPUTS_TYPE = Dict[str, Any]
|
||||
|
||||
|
||||
class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
||||
class ManualOptimization(Loop):
|
||||
"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
|
||||
entirely in the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` and therefore the user is
|
||||
responsible for back-propagating gradients and making calls to the optimizers.
|
||||
|
@ -88,10 +88,22 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
|||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
def run(self, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
|
||||
self.reset()
|
||||
self.on_run_start()
|
||||
while not self.done:
|
||||
try:
|
||||
self.advance(kwargs)
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
return self.on_run_end()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._done = False
|
||||
|
||||
def on_run_start(self, *_: Any, **__: Any) -> None:
|
||||
def on_run_start(self) -> None:
|
||||
# inject logic around the optimizer step
|
||||
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
|
||||
lightning_optimizer._on_before_step = self._on_before_step
|
||||
|
|
|
@ -146,7 +146,7 @@ class Closure(AbstractClosure[ClosureResult]):
|
|||
_OUTPUTS_TYPE = Dict[int, Dict[str, Any]]
|
||||
|
||||
|
||||
class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
||||
class OptimizerLoop(Loop):
|
||||
"""Runs over a sequence of optimizers.
|
||||
|
||||
This loop implements what is known in Lightning as Automatic Optimization.
|
||||
|
@ -172,6 +172,18 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
"""Returns ``True`` when the last optimizer in the sequence has run."""
|
||||
return self.optim_progress.optimizer_position >= len(self._indices)
|
||||
|
||||
def run(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> _OUTPUTS_TYPE:
|
||||
self.reset()
|
||||
self.on_run_start(optimizers)
|
||||
while not self.done:
|
||||
try:
|
||||
self.advance(kwargs)
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
return self.on_run_end()
|
||||
|
||||
def reset(self) -> None:
|
||||
if not self.restarting:
|
||||
# when reset() is called from outside (manually), we reset the loop progress
|
||||
|
@ -180,12 +192,12 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
self.optim_progress.reset_on_restart()
|
||||
self._outputs = {}
|
||||
|
||||
def on_run_start(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None:
|
||||
def on_run_start(self, optimizers: List[Tuple[int, Optimizer]]) -> None:
|
||||
self._indices, self._optimizers = zip(*optimizers)
|
||||
if self.done:
|
||||
self.optim_progress.optimizer_position = 0
|
||||
|
||||
def advance(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None:
|
||||
def advance(self, kwargs: OrderedDict) -> None:
|
||||
kwargs = self._build_kwargs(kwargs, self.optimizer_idx)
|
||||
|
||||
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
|
||||
|
|
|
@ -43,6 +43,7 @@ from torch.utils.data import DataLoader
|
|||
from typing_extensions import Literal
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars
|
||||
from lightning_fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning_fabric.utilities.data import _auto_add_worker_init_fn
|
||||
from lightning_fabric.utilities.types import _PATH
|
||||
|
@ -1113,13 +1114,7 @@ class Trainer:
|
|||
eval_loop_results = self._evaluation_loop.run()
|
||||
|
||||
# remove the tensors from the eval results
|
||||
for result in eval_loop_results:
|
||||
if isinstance(result, dict):
|
||||
for k, v in result.items():
|
||||
if isinstance(v, Tensor):
|
||||
result[k] = v.cpu().item()
|
||||
|
||||
return eval_loop_results
|
||||
return convert_tensors_to_scalars(eval_loop_results)
|
||||
|
||||
def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
|
||||
self.reset_predict_dataloader(self.lightning_module)
|
||||
|
|
|
@ -46,6 +46,9 @@ def test_restarting_loops_recursive():
|
|||
def advance(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
loop = MyLoop(MyLoop(MyLoop()))
|
||||
|
||||
assert not loop.restarting
|
||||
|
@ -68,14 +71,21 @@ def test_loop_restore():
|
|||
self.iteration_count = 0
|
||||
self.dataset = dataset
|
||||
|
||||
@property
|
||||
def skip(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
return self.iteration_count > len(self.dataset)
|
||||
|
||||
def run(self):
|
||||
self.reset()
|
||||
while not self.done:
|
||||
try:
|
||||
self.advance()
|
||||
self.on_advance_end()
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
|
||||
def reset(self) -> None:
|
||||
self.iter_dataset = iter(self.dataset)
|
||||
if self.restarting:
|
||||
|
@ -135,6 +145,16 @@ def test_loop_hierarchy():
|
|||
self.a = a
|
||||
self.progress = SimpleProgress()
|
||||
|
||||
def run(self):
|
||||
while not self.done:
|
||||
try:
|
||||
self.advance()
|
||||
self.on_advance_end()
|
||||
self._restarting = False
|
||||
except StopIteration:
|
||||
break
|
||||
self._restarting = False
|
||||
|
||||
def advance(self, *args: Any, **kwargs: Any) -> None:
|
||||
loop = getattr(self, "loop_child", None)
|
||||
if not loop:
|
||||
|
|
Loading…
Reference in New Issue