diff --git a/pl_examples/loop_examples/yielding_training_step.py b/pl_examples/loop_examples/yielding_training_step.py index e787c8bd98..52abf768fe 100644 --- a/pl_examples/loop_examples/yielding_training_step.py +++ b/pl_examples/loop_examples/yielding_training_step.py @@ -22,7 +22,6 @@ from pl_examples.domain_templates.generative_adversarial_net import MNISTDataMod from pytorch_lightning import Trainer from pytorch_lightning.loops import OptimizerLoop from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult -from pytorch_lightning.loops.utilities import _build_training_step_kwargs from pytorch_lightning.utilities.exceptions import MisconfigurationException ############################################################################################# @@ -56,28 +55,25 @@ class YieldLoop(OptimizerLoop): def connect(self, **kwargs): raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") - def on_run_start(self, batch, optimizers, batch_idx): - super().on_run_start(batch, optimizers, batch_idx) + def on_run_start(self, optimizers, kwargs): + super().on_run_start(optimizers, kwargs) if not inspect.isgeneratorfunction(self.trainer.lightning_module.training_step): raise MisconfigurationException("The `LightningModule` does not yield anything in the `training_step`.") assert self.trainer.lightning_module.automatic_optimization - # We request the generator once and save it for later - # so we can call next() on it. - self._generator = self._get_generator(batch, batch_idx, opt_idx=0) + # We request the generator once and save it for later so we can call next() on it. + self._generator = self._get_generator(kwargs) - def _make_step_fn(self, split_batch, batch_idx, opt_idx): + def _make_step_fn(self, *_): return partial(self._training_step, self._generator) - def _get_generator(self, split_batch, batch_idx, opt_idx): - step_kwargs = _build_training_step_kwargs( - self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens=None - ) + def _get_generator(self, kwargs, opt_idx=0): + kwargs = self._build_kwargs(kwargs, opt_idx, hiddens=None) # Here we are basically calling `lightning_module.training_step()` - # and this returns a generator! The `training_step` is handled by the - # accelerator to enable distributed training. - return self.trainer.strategy.training_step(*step_kwargs.values()) + # and this returns a generator! The `training_step` is handled by + # the accelerator to enable distributed training. + return self.trainer.strategy.training_step(*kwargs.values()) def _training_step(self, generator): # required for logging diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index c01df94378..26ef742ee1 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -279,7 +279,7 @@ class BaseFinetuning(Callback): # import is here to avoid circular imports from pytorch_lightning.loops.utilities import _get_active_optimizers - for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies): + for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies, 0): num_param_groups = len(optimizer.param_groups) self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) current_param_groups = optimizer.param_groups diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index d88387dfeb..0198e57a21 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, OrderedDict, Tuple, Union -from deprecate import void from torch import Tensor from pytorch_lightning.loops.base import Loop @@ -59,35 +58,35 @@ class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]): """Resets the loop state.""" self._outputs = [] - def on_run_start(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] + def on_run_start(self, kwargs: OrderedDict) -> None: # type: ignore[override] """Splits the data into tbptt splits. Args: - batch: the current batch to run the trainstep on - batch_idx: the index of the current batch + kwargs: the kwargs passed down to the hooks. """ - void(batch_idx) + batch = kwargs["batch"] self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch))) - def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] + def advance(self, kwargs: OrderedDict) -> None: # type: ignore[override] """Runs the train step together with optimization (if necessary) on the current batch split. Args: - batch: the current batch to run the training on (this is not the split!) - batch_idx: the index of the current batch + kwargs: the kwargs passed down to the hooks. """ - void(batch) - self.split_idx, split_batch = self._remaining_splits.pop(0) + # replace the batch with the split batch + self.split_idx, kwargs["batch"] = self._remaining_splits.pop(0) self.trainer._logger_connector.on_train_split_start(self.split_idx) outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy # choose which loop will run the optimization if self.trainer.lightning_module.automatic_optimization: - optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx) - outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx) + optimizers = _get_active_optimizers( + self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0) + ) + outputs = self.optimizer_loop.run(optimizers, kwargs) else: - outputs = self.manual_loop.run(split_batch, batch_idx) + outputs = self.manual_loop.run(kwargs) if outputs: # automatic: can be empty if all optimizers skip their batches # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens, diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index a45766b1eb..8c631bf23f 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -264,7 +264,7 @@ class EvaluationEpochLoop(Loop): self.trainer._logger_connector.on_batch_end() def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict: - """Helper function to build the arguments for the current step. + """Helper method to build the arguments for the current step. Args: kwargs: The kwargs passed down to the hooks. @@ -273,7 +273,8 @@ class EvaluationEpochLoop(Loop): Returns: The kwargs passed down to the hooks. """ - kwargs.update({"batch": batch, "batch_idx": batch_idx}) + kwargs.update(batch=batch, batch_idx=batch_idx) + # `dataloader_idx` should be last so we need to push these to the front kwargs.move_to_end("batch_idx", last=False) kwargs.move_to_end("batch", last=False) return kwargs diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index ec02a099f6..e059446bd4 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from collections import defaultdict +from collections import defaultdict, OrderedDict from typing import Any, Dict, Generator, List, Optional, overload, Tuple, Union import numpy as np @@ -173,6 +173,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): batch_idx, batch = next(data_fetcher) self.batch_progress.is_last_batch = data_fetcher.done + kwargs = self._build_kwargs(OrderedDict(), batch, batch_idx) + self.batch_progress.increment_ready() self.trainer._logger_connector.on_batch_start(batch, batch_idx) @@ -205,7 +207,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): self.batch_progress.increment_started() with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(batch, batch_idx) + batch_output = self.batch_loop.run(kwargs) self.batch_progress.increment_processed() @@ -356,6 +358,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): if ( num_optimizers > 1 and lightning_module.truncated_bptt_steps > 0 + and is_overridden("on_train_batch_end", lightning_module) and not _v1_8_output_format(lightning_module.on_train_batch_end) ): rank_zero_deprecation( @@ -546,6 +549,25 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): data_fetcher.dataloader.load_state_dict(self._dataloader_state_dict) self._dataloader_state_dict = None + def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict: + """Helper method to build the arguments for the current step. + + Args: + kwargs: The kwargs passed down to the hooks. + batch: The current batch to run through the step. + batch_idx: The current batch idx. + + Returns: + The kwargs passed down to the hooks. + """ + kwargs["batch"] = batch + training_step_fx = getattr(self.trainer.lightning_module, "training_step") + # the `batch_idx` is optional, however, when there's more than 1 argument we cannot differentiate whether the + # user wants the `batch_idx` or another key like `optimizer_idx` as we are not strict about the argument names + if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2): + kwargs["batch_idx"] = batch_idx + return kwargs + def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]: """Converts an optimizer dict to a list in which the key of the dict determines the position of the element. diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 70ada6a239..7d17ed33e2 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict from dataclasses import dataclass, field from typing import Any, Dict, Optional @@ -97,30 +98,25 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): lightning_optimizer._on_before_step = self._on_before_step lightning_optimizer._on_after_step = self._on_after_step - def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] + def advance(self, kwargs: OrderedDict) -> None: # type: ignore[override] """Performs the training step for manual optimization. Args: - batch: the current tbptt split of the current batch - batch_idx: the index of the current batch + kwargs: The kwargs passed down to the hooks. """ assert self.trainer is not None - lightning_module = self.trainer.lightning_module - step_kwargs = _build_training_step_kwargs( - lightning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens - ) + kwargs = self._build_kwargs(kwargs, self._hiddens) # manually capture logged metrics - training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values()) + training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values()) + del kwargs # release the batch from memory self.trainer.strategy.post_training_step() - del step_kwargs - model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output) strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output) training_step_output = strategy_output if model_output is None else model_output - self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps) + self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps) result = self.output_result_cls.from_training_step_output(training_step_output) @@ -149,3 +145,17 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): def _on_after_step(self) -> None: self.trainer.profiler.stop("optimizer_step") self.optim_step_progress.increment_completed() + + def _build_kwargs(self, kwargs: OrderedDict, hiddens: Optional[Any]) -> OrderedDict: + """Helper method to build the arguments for the current step. + + Args: + kwargs: The kwargs passed down to the hooks. + hiddens: the hidden state of the previous RNN iteration. + + Returns: + The kwargs passed down to the hooks. + """ + return _build_training_step_kwargs( + kwargs, self.trainer.lightning_module, self.trainer.optimizers, None, hiddens + ) diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 5cd81aa30f..95c072bf28 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Union import torch from torch import Tensor @@ -164,7 +164,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): self._outputs: _OUTPUTS_TYPE = {} self._skip_backward: bool = False - self._batch_idx: int = 0 self._optimizers: Tuple[Optimizer, ...] = tuple() self._indices: Tuple[int, ...] = tuple() self._hiddens: Optional[Any] = None @@ -190,20 +189,16 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): self._outputs = {} def on_run_start( # type: ignore[override] - self, batch: Any, optimizers: List[Tuple[int, Optimizer]], batch_idx: int + self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict ) -> None: - self._batch_idx = batch_idx self._indices, self._optimizers = zip(*optimizers) if self.done: self.optim_progress.optimizer_position = 0 - def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override] - result = self._run_optimization( - batch, - self._batch_idx, - self._optimizers[self.optim_progress.optimizer_position], - self.optimizer_idx, - ) + def advance(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None: # type: ignore[override] + kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens) + + result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position]) if result.loss is not None: # automatic optimization assumes a loss needs to be returned for extras to be considered as the batch # would be skipped otherwise @@ -216,21 +211,19 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): self._optimizers = tuple() return outputs - def _run_optimization( - self, split_batch: Any, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int - ) -> ClosureResult: + def _run_optimization(self, kwargs: OrderedDict, optimizer: torch.optim.Optimizer) -> ClosureResult: """Runs closure (train step + backward) together with optimization if necessary. Args: - split_batch: the current tbptt split of the whole batch - batch_idx: the index of the current batch + kwargs: the kwargs passed down to the hooks. optimizer: the current optimizer - opt_idx: the index of the current optimizer """ + opt_idx = kwargs.get("optimizer_idx", 0) + # toggle model params self._run_optimization_start(opt_idx, optimizer) - closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer) + closure = self._make_closure(kwargs, optimizer) if ( # when the strategy handles accumulation, we want to always call the optimizer step @@ -251,7 +244,8 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): # ------------------------------ # gradient update with accumulated gradients else: - self._optimizer_step(optimizer, opt_idx, batch_idx, closure) + # the `batch_idx` is optional with inter-batch parallelism + self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure) result = closure.consume_result() @@ -265,17 +259,18 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): self._run_optimization_end(opt_idx) return result - def _make_closure(self, split_batch: Any, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Closure: + def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer) -> Closure: """Build a closure object that captures the given arguments and runs the `training_step` function and optionally other functions such as `backward` and `zero_grad`.""" - step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx) + opt_idx = kwargs.get("optimizer_idx", 0) + step_fn = self._make_step_fn(kwargs) backward_fn = self._make_backward_fn(optimizer, opt_idx) - zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer) + zero_grad_fn = self._make_zero_grad_fn(kwargs.get("batch_idx", 0), opt_idx, optimizer) return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn) - def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Callable[[], ClosureResult]: + def _make_step_fn(self, kwargs: OrderedDict) -> Callable[[], ClosureResult]: """Build the step function that runs the `training_step` and processes its output.""" - return partial(self._training_step, split_batch, batch_idx, opt_idx) + return partial(self._training_step, kwargs) def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]: """Build a `zero_grad` function that zeroes the gradients before back-propagation. @@ -399,33 +394,24 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): ) self.optim_progress.optimizer.zero_grad.increment_completed() - def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult: + def _training_step(self, kwargs: OrderedDict) -> ClosureResult: """Performs the actual train step with the tied hooks. Args: - split_batch: the current tbptt split of the current batch - batch_idx: the index of the current batch - opt_idx: the index of the current optimizer + kwargs: the kwargs passed down to the hooks. Returns: A ``ClosureResult`` containing the training step output. """ - # give the PL module a result for logging - lightning_module = self.trainer.lightning_module - - step_kwargs = _build_training_step_kwargs( - lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens - ) - # manually capture logged metrics - training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values()) + training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values()) self.trainer.strategy.post_training_step() model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output) strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output) training_step_output = strategy_output if model_output is None else model_output - self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps) + self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps) result = self.output_result_cls.from_training_step_output( training_step_output, self.trainer.accumulate_grad_batches @@ -437,3 +423,18 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): self.trainer._results.cpu() return result + + def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int, hiddens: Optional[Any]) -> OrderedDict: + """Helper method to build the arguments for the current step. + + Args: + kwargs: The kwargs passed down to the hooks. + opt_idx: the index of the current optimizer. + hiddens: the hidden state of the previous RNN iteration. + + Returns: + The kwargs passed down to the hooks. + """ + return _build_training_step_kwargs( + kwargs, self.trainer.lightning_module, self.trainer.optimizers, opt_idx, hiddens + ) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index d84c195d75..8b0f98efe6 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -106,34 +106,25 @@ def _parse_loop_limits( def _build_training_step_kwargs( + kwargs: OrderedDict, lightning_module: "pl.LightningModule", optimizers: Sequence[Optimizer], - batch: Any, - batch_idx: int, opt_idx: Optional[int], hiddens: Optional[Any], -) -> Dict[str, Any]: +) -> OrderedDict: """Builds the keyword arguments for training_step. Args: + kwargs: The kwargs passed down to the hooks. lightning_module: the LightningModule with a `training_step` hook implementation optimizers: the list of optimizers from the Trainer - batch: the batch to train on - batch_idx: the index of the current batch opt_idx: the index of the current optimizer hiddens: the hidden state of the previous RNN iteration Returns: the keyword arguments for the training step """ - # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict([("batch", batch)]) - training_step_fx = getattr(lightning_module, "training_step") - - if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2): - step_kwargs["batch_idx"] = batch_idx - if len(optimizers) > 1: has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: @@ -143,7 +134,7 @@ def _build_training_step_kwargs( " in manual optimization optimizers must be handled by the user. Remove the optimizer_idx" " argument or set `self.automatic_optimization = True`." ) - step_kwargs["optimizer_idx"] = opt_idx + kwargs["optimizer_idx"] = opt_idx elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(optimizers)} optimizers but" @@ -152,9 +143,9 @@ def _build_training_step_kwargs( # pass hiddens if using tbptt if lightning_module.truncated_bptt_steps > 0: - step_kwargs["hiddens"] = hiddens + kwargs["hiddens"] = hiddens - return step_kwargs + return kwargs @contextmanager @@ -182,7 +173,7 @@ def _cumulative_optimizer_frequencies(frequencies: Tuple[int]) -> np.ndarray: def _get_active_optimizers( - optimizers: List[Optimizer], frequencies: List[int], batch_idx: Optional[int] = None + optimizers: List[Optimizer], frequencies: List[int], batch_idx: int ) -> List[Tuple[int, Optimizer]]: """Returns the currently active optimizers. When multiple optimizers are used with different frequencies, only one of the optimizers is active at a time. diff --git a/tests/loops/batch/test_truncated_bptt.py b/tests/loops/batch/test_truncated_bptt.py index 55adbc618b..a43d15909f 100644 --- a/tests/loops/batch/test_truncated_bptt.py +++ b/tests/loops/batch/test_truncated_bptt.py @@ -170,3 +170,36 @@ def test_tbptt_logging(tmpdir, model_class): ) trainer.fit(model) assert set(trainer.logged_metrics) == {"loss_step", "loss_epoch"} + + +def test_hiddens_multiple_optimizers(tmpdir): + class TBPTTModel(LSTMModel): + # TODO: `optimizer_idx=n` gets the hiddens from `optimizer_idx=n-1` instead of the hidden from + # `optimizer_idx=n`, `split_idx=m-1`. This is unexpected and should be changed + test_hiddens = None + + def training_step(self, batch, batch_idx, optimizer_idx, hiddens): + if hiddens is None: + assert self.test_hiddens is None + else: + assert all(torch.equal(h, th) for h, th in zip(hiddens, self.test_hiddens)) + out = super().training_step(batch, batch_idx, hiddens) + self.test_hiddens = out["hiddens"] + return out + + def configure_optimizers(self): + return [super().configure_optimizers(), super().configure_optimizers()] + + model = TBPTTModel(truncated_bptt_steps=2, input_size=1, hidden_size=1) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=0, + enable_model_summary=False, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + ) + trainer.fit(model) + assert trainer.global_step == 8 / 2 * 2 # time_dim_length / tbptt_steps * num_optimizers diff --git a/tests/loops/epoch/test_training_epoch_loop.py b/tests/loops/epoch/test_training_epoch_loop.py index ed3a853644..d6f6a906cc 100644 --- a/tests/loops/epoch/test_training_epoch_loop.py +++ b/tests/loops/epoch/test_training_epoch_loop.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest +from pytorch_lightning import LightningModule from pytorch_lightning.loops import TrainingEpochLoop from pytorch_lightning.trainer.trainer import Trainer from tests.deprecated_api import no_deprecated_call @@ -33,7 +34,8 @@ _out13 = {"loss": 1.3} class TestPrepareOutputs: def prepare_outputs(self, fn, tbptt_splits, new_format, batch_outputs, num_optimizers, automatic_optimization): - lightning_module = Mock() + lightning_module = LightningModule() + lightning_module.on_train_batch_end = lambda *_: None # override to trigger the deprecation message lightning_module.automatic_optimization = automatic_optimization lightning_module.truncated_bptt_steps = tbptt_splits match = "will change in version v1.8.*new_format=True" diff --git a/tests/loops/test_evaluation_loop_flow.py b/tests/loops/test_evaluation_loop_flow.py index 0fe90557b3..20f966e6c3 100644 --- a/tests/loops/test_evaluation_loop_flow.py +++ b/tests/loops/test_evaluation_loop_flow.py @@ -63,8 +63,8 @@ def test__eval_step__flow(tmpdir): # simulate training manually trainer.state.stage = RunningStage.TRAINING - batch_idx, batch = 0, next(iter(model.train_dataloader())) - train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) + kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs) assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] @@ -72,9 +72,7 @@ def test__eval_step__flow(tmpdir): assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( - batch, batch_idx, 0, trainer.optimizers[0] - ) + opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 @@ -126,8 +124,8 @@ def test__eval_step__eval_step_end__flow(tmpdir): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected - batch_idx, batch = 0, next(iter(model.train_dataloader())) - train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) + kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs) assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] @@ -135,9 +133,7 @@ def test__eval_step__eval_step_end__flow(tmpdir): assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( - batch, batch_idx, 0, trainer.optimizers[0] - ) + opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 diff --git a/tests/loops/test_training_loop_flow_scalar.py b/tests/loops/test_training_loop_flow_scalar.py index 8493de4db0..29e3d3b3a0 100644 --- a/tests/loops/test_training_loop_flow_scalar.py +++ b/tests/loops/test_training_loop_flow_scalar.py @@ -146,8 +146,8 @@ def test__training_step__epoch_end__flow_scalar(tmpdir): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected - batch_idx, batch = 0, next(iter(model.train_dataloader())) - train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) + kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs) assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] @@ -155,9 +155,7 @@ def test__training_step__epoch_end__flow_scalar(tmpdir): assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( - batch, batch_idx, 0, trainer.optimizers[0] - ) + opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 @@ -218,8 +216,8 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected - batch_idx, batch = 0, next(iter(model.train_dataloader())) - train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) + kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} + train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs) assert len(train_step_out) == 1 train_step_out = train_step_out[0][0] @@ -227,9 +225,7 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir): assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( - batch, batch_idx, 0, trainer.optimizers[0] - ) + opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 @@ -239,7 +235,7 @@ def test_train_step_no_return(tmpdir): automatic_optimization.""" class TestModel(BoringModel): - def training_step(self, batch, batch_idx): + def training_step(self, batch): self.training_step_called = True loss = self.step(batch[0]) self.log("a", loss, on_step=True, on_epoch=True) @@ -305,7 +301,7 @@ def test_training_step_no_return_when_even(tmpdir): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx) + out = trainer.fit_loop.epoch_loop.batch_loop.run({"batch": batch, "batch_idx": batch_idx}) if not batch_idx % 2: assert out == [] diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 0cf05e47a0..e985dff214 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -141,7 +141,7 @@ def test__training_step__step_end__epoch_end__log(tmpdir, batches, log_interval, """Tests that training_step_end and training_epoch_end can log.""" class TestModel(BoringModel): - def training_step(self, batch, batch_idx): + def training_step(self, batch): loss = self.step(batch[0]) self.log("a", loss, on_step=True, on_epoch=True) return loss diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 010e2e5714..843462ad22 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -375,9 +375,10 @@ def test_stop_iteration(trigger_stop_iteration, tmpdir): super().__init__() self.trigger_stop_iteration = trigger_stop_iteration - def training_step(self, dataloader_iter: Iterator, *args) -> STEP_OUTPUT: + def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT: output = super().training_step(dataloader_iter) - if self.trigger_stop_iteration and args[0] == EXPECT_NUM_BATCHES_PROCESSED: + batch_idx = self.trainer.fit_loop.epoch_loop.batch_idx + if self.trigger_stop_iteration and batch_idx == EXPECT_NUM_BATCHES_PROCESSED: raise StopIteration return output