Construct the hook kwargs inside each loop (#12100)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Carlos Mocholí 2022-05-03 17:08:02 +02:00 committed by GitHub
parent cd01856ffc
commit f4505ce6b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 172 additions and 124 deletions

View File

@ -22,7 +22,6 @@ from pl_examples.domain_templates.generative_adversarial_net import MNISTDataMod
from pytorch_lightning import Trainer
from pytorch_lightning.loops import OptimizerLoop
from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult
from pytorch_lightning.loops.utilities import _build_training_step_kwargs
from pytorch_lightning.utilities.exceptions import MisconfigurationException
#############################################################################################
@ -56,28 +55,25 @@ class YieldLoop(OptimizerLoop):
def connect(self, **kwargs):
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
def on_run_start(self, batch, optimizers, batch_idx):
super().on_run_start(batch, optimizers, batch_idx)
def on_run_start(self, optimizers, kwargs):
super().on_run_start(optimizers, kwargs)
if not inspect.isgeneratorfunction(self.trainer.lightning_module.training_step):
raise MisconfigurationException("The `LightningModule` does not yield anything in the `training_step`.")
assert self.trainer.lightning_module.automatic_optimization
# We request the generator once and save it for later
# so we can call next() on it.
self._generator = self._get_generator(batch, batch_idx, opt_idx=0)
# We request the generator once and save it for later so we can call next() on it.
self._generator = self._get_generator(kwargs)
def _make_step_fn(self, split_batch, batch_idx, opt_idx):
def _make_step_fn(self, *_):
return partial(self._training_step, self._generator)
def _get_generator(self, split_batch, batch_idx, opt_idx):
step_kwargs = _build_training_step_kwargs(
self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens=None
)
def _get_generator(self, kwargs, opt_idx=0):
kwargs = self._build_kwargs(kwargs, opt_idx, hiddens=None)
# Here we are basically calling `lightning_module.training_step()`
# and this returns a generator! The `training_step` is handled by the
# accelerator to enable distributed training.
return self.trainer.strategy.training_step(*step_kwargs.values())
# and this returns a generator! The `training_step` is handled by
# the accelerator to enable distributed training.
return self.trainer.strategy.training_step(*kwargs.values())
def _training_step(self, generator):
# required for logging

View File

@ -279,7 +279,7 @@ class BaseFinetuning(Callback):
# import is here to avoid circular imports
from pytorch_lightning.loops.utilities import _get_active_optimizers
for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies):
for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies, 0):
num_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
current_param_groups = optimizer.param_groups

View File

@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, OrderedDict, Tuple, Union
from deprecate import void
from torch import Tensor
from pytorch_lightning.loops.base import Loop
@ -59,35 +58,35 @@ class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]):
"""Resets the loop state."""
self._outputs = []
def on_run_start(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
def on_run_start(self, kwargs: OrderedDict) -> None: # type: ignore[override]
"""Splits the data into tbptt splits.
Args:
batch: the current batch to run the trainstep on
batch_idx: the index of the current batch
kwargs: the kwargs passed down to the hooks.
"""
void(batch_idx)
batch = kwargs["batch"]
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))
def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
def advance(self, kwargs: OrderedDict) -> None: # type: ignore[override]
"""Runs the train step together with optimization (if necessary) on the current batch split.
Args:
batch: the current batch to run the training on (this is not the split!)
batch_idx: the index of the current batch
kwargs: the kwargs passed down to the hooks.
"""
void(batch)
self.split_idx, split_batch = self._remaining_splits.pop(0)
# replace the batch with the split batch
self.split_idx, kwargs["batch"] = self._remaining_splits.pop(0)
self.trainer._logger_connector.on_train_split_start(self.split_idx)
outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy
# choose which loop will run the optimization
if self.trainer.lightning_module.automatic_optimization:
optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)
outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
optimizers = _get_active_optimizers(
self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
)
outputs = self.optimizer_loop.run(optimizers, kwargs)
else:
outputs = self.manual_loop.run(split_batch, batch_idx)
outputs = self.manual_loop.run(kwargs)
if outputs:
# automatic: can be empty if all optimizers skip their batches
# manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,

View File

@ -264,7 +264,7 @@ class EvaluationEpochLoop(Loop):
self.trainer._logger_connector.on_batch_end()
def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict:
"""Helper function to build the arguments for the current step.
"""Helper method to build the arguments for the current step.
Args:
kwargs: The kwargs passed down to the hooks.
@ -273,7 +273,8 @@ class EvaluationEpochLoop(Loop):
Returns:
The kwargs passed down to the hooks.
"""
kwargs.update({"batch": batch, "batch_idx": batch_idx})
kwargs.update(batch=batch, batch_idx=batch_idx)
# `dataloader_idx` should be last so we need to push these to the front
kwargs.move_to_end("batch_idx", last=False)
kwargs.move_to_end("batch", last=False)
return kwargs

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import defaultdict
from collections import defaultdict, OrderedDict
from typing import Any, Dict, Generator, List, Optional, overload, Tuple, Union
import numpy as np
@ -173,6 +173,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
batch_idx, batch = next(data_fetcher)
self.batch_progress.is_last_batch = data_fetcher.done
kwargs = self._build_kwargs(OrderedDict(), batch, batch_idx)
self.batch_progress.increment_ready()
self.trainer._logger_connector.on_batch_start(batch, batch_idx)
@ -205,7 +207,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
self.batch_progress.increment_started()
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)
batch_output = self.batch_loop.run(kwargs)
self.batch_progress.increment_processed()
@ -356,6 +358,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
if (
num_optimizers > 1
and lightning_module.truncated_bptt_steps > 0
and is_overridden("on_train_batch_end", lightning_module)
and not _v1_8_output_format(lightning_module.on_train_batch_end)
):
rank_zero_deprecation(
@ -546,6 +549,25 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
data_fetcher.dataloader.load_state_dict(self._dataloader_state_dict)
self._dataloader_state_dict = None
def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict:
"""Helper method to build the arguments for the current step.
Args:
kwargs: The kwargs passed down to the hooks.
batch: The current batch to run through the step.
batch_idx: The current batch idx.
Returns:
The kwargs passed down to the hooks.
"""
kwargs["batch"] = batch
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
# the `batch_idx` is optional, however, when there's more than 1 argument we cannot differentiate whether the
# user wants the `batch_idx` or another key like `optimizer_idx` as we are not strict about the argument names
if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
kwargs["batch_idx"] = batch_idx
return kwargs
def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]:
"""Converts an optimizer dict to a list in which the key of the dict determines the position of the element.

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@ -97,30 +98,25 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
lightning_optimizer._on_before_step = self._on_before_step
lightning_optimizer._on_after_step = self._on_after_step
def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
def advance(self, kwargs: OrderedDict) -> None: # type: ignore[override]
"""Performs the training step for manual optimization.
Args:
batch: the current tbptt split of the current batch
batch_idx: the index of the current batch
kwargs: The kwargs passed down to the hooks.
"""
assert self.trainer is not None
lightning_module = self.trainer.lightning_module
step_kwargs = _build_training_step_kwargs(
lightning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens
)
kwargs = self._build_kwargs(kwargs, self._hiddens)
# manually capture logged metrics
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
del kwargs # release the batch from memory
self.trainer.strategy.post_training_step()
del step_kwargs
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
training_step_output = strategy_output if model_output is None else model_output
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps)
result = self.output_result_cls.from_training_step_output(training_step_output)
@ -149,3 +145,17 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
def _on_after_step(self) -> None:
self.trainer.profiler.stop("optimizer_step")
self.optim_step_progress.increment_completed()
def _build_kwargs(self, kwargs: OrderedDict, hiddens: Optional[Any]) -> OrderedDict:
"""Helper method to build the arguments for the current step.
Args:
kwargs: The kwargs passed down to the hooks.
hiddens: the hidden state of the previous RNN iteration.
Returns:
The kwargs passed down to the hooks.
"""
return _build_training_step_kwargs(
kwargs, self.trainer.lightning_module, self.trainer.optimizers, None, hiddens
)

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Union
import torch
from torch import Tensor
@ -164,7 +164,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
self._outputs: _OUTPUTS_TYPE = {}
self._skip_backward: bool = False
self._batch_idx: int = 0
self._optimizers: Tuple[Optimizer, ...] = tuple()
self._indices: Tuple[int, ...] = tuple()
self._hiddens: Optional[Any] = None
@ -190,20 +189,16 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
self._outputs = {}
def on_run_start( # type: ignore[override]
self, batch: Any, optimizers: List[Tuple[int, Optimizer]], batch_idx: int
self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict
) -> None:
self._batch_idx = batch_idx
self._indices, self._optimizers = zip(*optimizers)
if self.done:
self.optim_progress.optimizer_position = 0
def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
result = self._run_optimization(
batch,
self._batch_idx,
self._optimizers[self.optim_progress.optimizer_position],
self.optimizer_idx,
)
def advance(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None: # type: ignore[override]
kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens)
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
if result.loss is not None:
# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch
# would be skipped otherwise
@ -216,21 +211,19 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
self._optimizers = tuple()
return outputs
def _run_optimization(
self, split_batch: Any, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int
) -> ClosureResult:
def _run_optimization(self, kwargs: OrderedDict, optimizer: torch.optim.Optimizer) -> ClosureResult:
"""Runs closure (train step + backward) together with optimization if necessary.
Args:
split_batch: the current tbptt split of the whole batch
batch_idx: the index of the current batch
kwargs: the kwargs passed down to the hooks.
optimizer: the current optimizer
opt_idx: the index of the current optimizer
"""
opt_idx = kwargs.get("optimizer_idx", 0)
# toggle model params
self._run_optimization_start(opt_idx, optimizer)
closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer)
closure = self._make_closure(kwargs, optimizer)
if (
# when the strategy handles accumulation, we want to always call the optimizer step
@ -251,7 +244,8 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
# ------------------------------
# gradient update with accumulated gradients
else:
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
# the `batch_idx` is optional with inter-batch parallelism
self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
result = closure.consume_result()
@ -265,17 +259,18 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
self._run_optimization_end(opt_idx)
return result
def _make_closure(self, split_batch: Any, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Closure:
def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer) -> Closure:
"""Build a closure object that captures the given arguments and runs the `training_step` function and
optionally other functions such as `backward` and `zero_grad`."""
step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx)
opt_idx = kwargs.get("optimizer_idx", 0)
step_fn = self._make_step_fn(kwargs)
backward_fn = self._make_backward_fn(optimizer, opt_idx)
zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer)
zero_grad_fn = self._make_zero_grad_fn(kwargs.get("batch_idx", 0), opt_idx, optimizer)
return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn)
def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Callable[[], ClosureResult]:
def _make_step_fn(self, kwargs: OrderedDict) -> Callable[[], ClosureResult]:
"""Build the step function that runs the `training_step` and processes its output."""
return partial(self._training_step, split_batch, batch_idx, opt_idx)
return partial(self._training_step, kwargs)
def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]:
"""Build a `zero_grad` function that zeroes the gradients before back-propagation.
@ -399,33 +394,24 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
)
self.optim_progress.optimizer.zero_grad.increment_completed()
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
"""Performs the actual train step with the tied hooks.
Args:
split_batch: the current tbptt split of the current batch
batch_idx: the index of the current batch
opt_idx: the index of the current optimizer
kwargs: the kwargs passed down to the hooks.
Returns:
A ``ClosureResult`` containing the training step output.
"""
# give the PL module a result for logging
lightning_module = self.trainer.lightning_module
step_kwargs = _build_training_step_kwargs(
lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
)
# manually capture logged metrics
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
self.trainer.strategy.post_training_step()
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
training_step_output = strategy_output if model_output is None else model_output
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps)
result = self.output_result_cls.from_training_step_output(
training_step_output, self.trainer.accumulate_grad_batches
@ -437,3 +423,18 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
self.trainer._results.cpu()
return result
def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int, hiddens: Optional[Any]) -> OrderedDict:
"""Helper method to build the arguments for the current step.
Args:
kwargs: The kwargs passed down to the hooks.
opt_idx: the index of the current optimizer.
hiddens: the hidden state of the previous RNN iteration.
Returns:
The kwargs passed down to the hooks.
"""
return _build_training_step_kwargs(
kwargs, self.trainer.lightning_module, self.trainer.optimizers, opt_idx, hiddens
)

View File

@ -106,34 +106,25 @@ def _parse_loop_limits(
def _build_training_step_kwargs(
kwargs: OrderedDict,
lightning_module: "pl.LightningModule",
optimizers: Sequence[Optimizer],
batch: Any,
batch_idx: int,
opt_idx: Optional[int],
hiddens: Optional[Any],
) -> Dict[str, Any]:
) -> OrderedDict:
"""Builds the keyword arguments for training_step.
Args:
kwargs: The kwargs passed down to the hooks.
lightning_module: the LightningModule with a `training_step` hook implementation
optimizers: the list of optimizers from the Trainer
batch: the batch to train on
batch_idx: the index of the current batch
opt_idx: the index of the current optimizer
hiddens: the hidden state of the previous RNN iteration
Returns:
the keyword arguments for the training step
"""
# enable not needing to add opt_idx to training_step
step_kwargs = OrderedDict([("batch", batch)])
training_step_fx = getattr(lightning_module, "training_step")
if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
step_kwargs["batch_idx"] = batch_idx
if len(optimizers) > 1:
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
@ -143,7 +134,7 @@ def _build_training_step_kwargs(
" in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
" argument or set `self.automatic_optimization = True`."
)
step_kwargs["optimizer_idx"] = opt_idx
kwargs["optimizer_idx"] = opt_idx
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
raise ValueError(
f"Your LightningModule defines {len(optimizers)} optimizers but"
@ -152,9 +143,9 @@ def _build_training_step_kwargs(
# pass hiddens if using tbptt
if lightning_module.truncated_bptt_steps > 0:
step_kwargs["hiddens"] = hiddens
kwargs["hiddens"] = hiddens
return step_kwargs
return kwargs
@contextmanager
@ -182,7 +173,7 @@ def _cumulative_optimizer_frequencies(frequencies: Tuple[int]) -> np.ndarray:
def _get_active_optimizers(
optimizers: List[Optimizer], frequencies: List[int], batch_idx: Optional[int] = None
optimizers: List[Optimizer], frequencies: List[int], batch_idx: int
) -> List[Tuple[int, Optimizer]]:
"""Returns the currently active optimizers. When multiple optimizers are used with different frequencies, only
one of the optimizers is active at a time.

View File

@ -170,3 +170,36 @@ def test_tbptt_logging(tmpdir, model_class):
)
trainer.fit(model)
assert set(trainer.logged_metrics) == {"loss_step", "loss_epoch"}
def test_hiddens_multiple_optimizers(tmpdir):
class TBPTTModel(LSTMModel):
# TODO: `optimizer_idx=n` gets the hiddens from `optimizer_idx=n-1` instead of the hidden from
# `optimizer_idx=n`, `split_idx=m-1`. This is unexpected and should be changed
test_hiddens = None
def training_step(self, batch, batch_idx, optimizer_idx, hiddens):
if hiddens is None:
assert self.test_hiddens is None
else:
assert all(torch.equal(h, th) for h, th in zip(hiddens, self.test_hiddens))
out = super().training_step(batch, batch_idx, hiddens)
self.test_hiddens = out["hiddens"]
return out
def configure_optimizers(self):
return [super().configure_optimizers(), super().configure_optimizers()]
model = TBPTTModel(truncated_bptt_steps=2, input_size=1, hidden_size=1)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=0,
enable_model_summary=False,
logger=False,
enable_checkpointing=False,
enable_progress_bar=False,
)
trainer.fit(model)
assert trainer.global_step == 8 / 2 * 2 # time_dim_length / tbptt_steps * num_optimizers

View File

@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import Mock, patch
from unittest.mock import patch
import pytest
from pytorch_lightning import LightningModule
from pytorch_lightning.loops import TrainingEpochLoop
from pytorch_lightning.trainer.trainer import Trainer
from tests.deprecated_api import no_deprecated_call
@ -33,7 +34,8 @@ _out13 = {"loss": 1.3}
class TestPrepareOutputs:
def prepare_outputs(self, fn, tbptt_splits, new_format, batch_outputs, num_optimizers, automatic_optimization):
lightning_module = Mock()
lightning_module = LightningModule()
lightning_module.on_train_batch_end = lambda *_: None # override to trigger the deprecation message
lightning_module.automatic_optimization = automatic_optimization
lightning_module.truncated_bptt_steps = tbptt_splits
match = "will change in version v1.8.*new_format=True"

View File

@ -63,8 +63,8 @@ def test__eval_step__flow(tmpdir):
# simulate training manually
trainer.state.stage = RunningStage.TRAINING
batch_idx, batch = 0, next(iter(model.train_dataloader()))
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
@ -72,9 +72,7 @@ def test__eval_step__flow(tmpdir):
assert train_step_out["loss"].item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0]
)
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171
@ -126,8 +124,8 @@ def test__eval_step__eval_step_end__flow(tmpdir):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
@ -135,9 +133,7 @@ def test__eval_step__eval_step_end__flow(tmpdir):
assert train_step_out["loss"].item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0]
)
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

View File

@ -146,8 +146,8 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
@ -155,9 +155,7 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
assert train_step_out["loss"].item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0]
)
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171
@ -218,8 +216,8 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
@ -227,9 +225,7 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
assert train_step_out["loss"].item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0]
)
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171
@ -239,7 +235,7 @@ def test_train_step_no_return(tmpdir):
automatic_optimization."""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
def training_step(self, batch):
self.training_step_called = True
loss = self.step(batch[0])
self.log("a", loss, on_step=True, on_epoch=True)
@ -305,7 +301,7 @@ def test_training_step_no_return_when_even(tmpdir):
# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
out = trainer.fit_loop.epoch_loop.batch_loop.run({"batch": batch, "batch_idx": batch_idx})
if not batch_idx % 2:
assert out == []

View File

@ -141,7 +141,7 @@ def test__training_step__step_end__epoch_end__log(tmpdir, batches, log_interval,
"""Tests that training_step_end and training_epoch_end can log."""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
def training_step(self, batch):
loss = self.step(batch[0])
self.log("a", loss, on_step=True, on_epoch=True)
return loss

View File

@ -375,9 +375,10 @@ def test_stop_iteration(trigger_stop_iteration, tmpdir):
super().__init__()
self.trigger_stop_iteration = trigger_stop_iteration
def training_step(self, dataloader_iter: Iterator, *args) -> STEP_OUTPUT:
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
output = super().training_step(dataloader_iter)
if self.trigger_stop_iteration and args[0] == EXPECT_NUM_BATCHES_PROCESSED:
batch_idx = self.trainer.fit_loop.epoch_loop.batch_idx
if self.trigger_stop_iteration and batch_idx == EXPECT_NUM_BATCHES_PROCESSED:
raise StopIteration
return output