Construct the hook kwargs inside each loop (#12100)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
cd01856ffc
commit
f4505ce6b2
|
@ -22,7 +22,6 @@ from pl_examples.domain_templates.generative_adversarial_net import MNISTDataMod
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loops import OptimizerLoop
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult
|
||||
from pytorch_lightning.loops.utilities import _build_training_step_kwargs
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
#############################################################################################
|
||||
|
@ -56,28 +55,25 @@ class YieldLoop(OptimizerLoop):
|
|||
def connect(self, **kwargs):
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
|
||||
|
||||
def on_run_start(self, batch, optimizers, batch_idx):
|
||||
super().on_run_start(batch, optimizers, batch_idx)
|
||||
def on_run_start(self, optimizers, kwargs):
|
||||
super().on_run_start(optimizers, kwargs)
|
||||
if not inspect.isgeneratorfunction(self.trainer.lightning_module.training_step):
|
||||
raise MisconfigurationException("The `LightningModule` does not yield anything in the `training_step`.")
|
||||
assert self.trainer.lightning_module.automatic_optimization
|
||||
|
||||
# We request the generator once and save it for later
|
||||
# so we can call next() on it.
|
||||
self._generator = self._get_generator(batch, batch_idx, opt_idx=0)
|
||||
# We request the generator once and save it for later so we can call next() on it.
|
||||
self._generator = self._get_generator(kwargs)
|
||||
|
||||
def _make_step_fn(self, split_batch, batch_idx, opt_idx):
|
||||
def _make_step_fn(self, *_):
|
||||
return partial(self._training_step, self._generator)
|
||||
|
||||
def _get_generator(self, split_batch, batch_idx, opt_idx):
|
||||
step_kwargs = _build_training_step_kwargs(
|
||||
self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens=None
|
||||
)
|
||||
def _get_generator(self, kwargs, opt_idx=0):
|
||||
kwargs = self._build_kwargs(kwargs, opt_idx, hiddens=None)
|
||||
|
||||
# Here we are basically calling `lightning_module.training_step()`
|
||||
# and this returns a generator! The `training_step` is handled by the
|
||||
# accelerator to enable distributed training.
|
||||
return self.trainer.strategy.training_step(*step_kwargs.values())
|
||||
# and this returns a generator! The `training_step` is handled by
|
||||
# the accelerator to enable distributed training.
|
||||
return self.trainer.strategy.training_step(*kwargs.values())
|
||||
|
||||
def _training_step(self, generator):
|
||||
# required for logging
|
||||
|
|
|
@ -279,7 +279,7 @@ class BaseFinetuning(Callback):
|
|||
# import is here to avoid circular imports
|
||||
from pytorch_lightning.loops.utilities import _get_active_optimizers
|
||||
|
||||
for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies):
|
||||
for opt_idx, optimizer in _get_active_optimizers(trainer.optimizers, trainer.optimizer_frequencies, 0):
|
||||
num_param_groups = len(optimizer.param_groups)
|
||||
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
|
||||
current_param_groups = optimizer.param_groups
|
||||
|
|
|
@ -11,9 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
from deprecate import void
|
||||
from torch import Tensor
|
||||
|
||||
from pytorch_lightning.loops.base import Loop
|
||||
|
@ -59,35 +58,35 @@ class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]):
|
|||
"""Resets the loop state."""
|
||||
self._outputs = []
|
||||
|
||||
def on_run_start(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
|
||||
def on_run_start(self, kwargs: OrderedDict) -> None: # type: ignore[override]
|
||||
"""Splits the data into tbptt splits.
|
||||
|
||||
Args:
|
||||
batch: the current batch to run the trainstep on
|
||||
batch_idx: the index of the current batch
|
||||
kwargs: the kwargs passed down to the hooks.
|
||||
"""
|
||||
void(batch_idx)
|
||||
batch = kwargs["batch"]
|
||||
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))
|
||||
|
||||
def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
|
||||
def advance(self, kwargs: OrderedDict) -> None: # type: ignore[override]
|
||||
"""Runs the train step together with optimization (if necessary) on the current batch split.
|
||||
|
||||
Args:
|
||||
batch: the current batch to run the training on (this is not the split!)
|
||||
batch_idx: the index of the current batch
|
||||
kwargs: the kwargs passed down to the hooks.
|
||||
"""
|
||||
void(batch)
|
||||
self.split_idx, split_batch = self._remaining_splits.pop(0)
|
||||
# replace the batch with the split batch
|
||||
self.split_idx, kwargs["batch"] = self._remaining_splits.pop(0)
|
||||
|
||||
self.trainer._logger_connector.on_train_split_start(self.split_idx)
|
||||
|
||||
outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy
|
||||
# choose which loop will run the optimization
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)
|
||||
outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
|
||||
optimizers = _get_active_optimizers(
|
||||
self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
|
||||
)
|
||||
outputs = self.optimizer_loop.run(optimizers, kwargs)
|
||||
else:
|
||||
outputs = self.manual_loop.run(split_batch, batch_idx)
|
||||
outputs = self.manual_loop.run(kwargs)
|
||||
if outputs:
|
||||
# automatic: can be empty if all optimizers skip their batches
|
||||
# manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
|
||||
|
|
|
@ -264,7 +264,7 @@ class EvaluationEpochLoop(Loop):
|
|||
self.trainer._logger_connector.on_batch_end()
|
||||
|
||||
def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict:
|
||||
"""Helper function to build the arguments for the current step.
|
||||
"""Helper method to build the arguments for the current step.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs passed down to the hooks.
|
||||
|
@ -273,7 +273,8 @@ class EvaluationEpochLoop(Loop):
|
|||
Returns:
|
||||
The kwargs passed down to the hooks.
|
||||
"""
|
||||
kwargs.update({"batch": batch, "batch_idx": batch_idx})
|
||||
kwargs.update(batch=batch, batch_idx=batch_idx)
|
||||
# `dataloader_idx` should be last so we need to push these to the front
|
||||
kwargs.move_to_end("batch_idx", last=False)
|
||||
kwargs.move_to_end("batch", last=False)
|
||||
return kwargs
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, OrderedDict
|
||||
from typing import Any, Dict, Generator, List, Optional, overload, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -173,6 +173,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
batch_idx, batch = next(data_fetcher)
|
||||
self.batch_progress.is_last_batch = data_fetcher.done
|
||||
|
||||
kwargs = self._build_kwargs(OrderedDict(), batch, batch_idx)
|
||||
|
||||
self.batch_progress.increment_ready()
|
||||
|
||||
self.trainer._logger_connector.on_batch_start(batch, batch_idx)
|
||||
|
@ -205,7 +207,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
self.batch_progress.increment_started()
|
||||
|
||||
with self.trainer.profiler.profile("run_training_batch"):
|
||||
batch_output = self.batch_loop.run(batch, batch_idx)
|
||||
batch_output = self.batch_loop.run(kwargs)
|
||||
|
||||
self.batch_progress.increment_processed()
|
||||
|
||||
|
@ -356,6 +358,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
if (
|
||||
num_optimizers > 1
|
||||
and lightning_module.truncated_bptt_steps > 0
|
||||
and is_overridden("on_train_batch_end", lightning_module)
|
||||
and not _v1_8_output_format(lightning_module.on_train_batch_end)
|
||||
):
|
||||
rank_zero_deprecation(
|
||||
|
@ -546,6 +549,25 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
data_fetcher.dataloader.load_state_dict(self._dataloader_state_dict)
|
||||
self._dataloader_state_dict = None
|
||||
|
||||
def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict:
|
||||
"""Helper method to build the arguments for the current step.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs passed down to the hooks.
|
||||
batch: The current batch to run through the step.
|
||||
batch_idx: The current batch idx.
|
||||
|
||||
Returns:
|
||||
The kwargs passed down to the hooks.
|
||||
"""
|
||||
kwargs["batch"] = batch
|
||||
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
|
||||
# the `batch_idx` is optional, however, when there's more than 1 argument we cannot differentiate whether the
|
||||
# user wants the `batch_idx` or another key like `optimizer_idx` as we are not strict about the argument names
|
||||
if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
|
||||
kwargs["batch_idx"] = batch_idx
|
||||
return kwargs
|
||||
|
||||
|
||||
def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]:
|
||||
"""Converts an optimizer dict to a list in which the key of the dict determines the position of the element.
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
@ -97,30 +98,25 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
|||
lightning_optimizer._on_before_step = self._on_before_step
|
||||
lightning_optimizer._on_after_step = self._on_after_step
|
||||
|
||||
def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
|
||||
def advance(self, kwargs: OrderedDict) -> None: # type: ignore[override]
|
||||
"""Performs the training step for manual optimization.
|
||||
|
||||
Args:
|
||||
batch: the current tbptt split of the current batch
|
||||
batch_idx: the index of the current batch
|
||||
kwargs: The kwargs passed down to the hooks.
|
||||
"""
|
||||
assert self.trainer is not None
|
||||
lightning_module = self.trainer.lightning_module
|
||||
|
||||
step_kwargs = _build_training_step_kwargs(
|
||||
lightning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens
|
||||
)
|
||||
kwargs = self._build_kwargs(kwargs, self._hiddens)
|
||||
|
||||
# manually capture logged metrics
|
||||
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
|
||||
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
|
||||
del kwargs # release the batch from memory
|
||||
self.trainer.strategy.post_training_step()
|
||||
|
||||
del step_kwargs
|
||||
|
||||
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
|
||||
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
|
||||
training_step_output = strategy_output if model_output is None else model_output
|
||||
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
|
||||
self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps)
|
||||
|
||||
result = self.output_result_cls.from_training_step_output(training_step_output)
|
||||
|
||||
|
@ -149,3 +145,17 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
|||
def _on_after_step(self) -> None:
|
||||
self.trainer.profiler.stop("optimizer_step")
|
||||
self.optim_step_progress.increment_completed()
|
||||
|
||||
def _build_kwargs(self, kwargs: OrderedDict, hiddens: Optional[Any]) -> OrderedDict:
|
||||
"""Helper method to build the arguments for the current step.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs passed down to the hooks.
|
||||
hiddens: the hidden state of the previous RNN iteration.
|
||||
|
||||
Returns:
|
||||
The kwargs passed down to the hooks.
|
||||
"""
|
||||
return _build_training_step_kwargs(
|
||||
kwargs, self.trainer.lightning_module, self.trainer.optimizers, None, hiddens
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -164,7 +164,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
self._outputs: _OUTPUTS_TYPE = {}
|
||||
self._skip_backward: bool = False
|
||||
self._batch_idx: int = 0
|
||||
self._optimizers: Tuple[Optimizer, ...] = tuple()
|
||||
self._indices: Tuple[int, ...] = tuple()
|
||||
self._hiddens: Optional[Any] = None
|
||||
|
@ -190,20 +189,16 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
self._outputs = {}
|
||||
|
||||
def on_run_start( # type: ignore[override]
|
||||
self, batch: Any, optimizers: List[Tuple[int, Optimizer]], batch_idx: int
|
||||
self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict
|
||||
) -> None:
|
||||
self._batch_idx = batch_idx
|
||||
self._indices, self._optimizers = zip(*optimizers)
|
||||
if self.done:
|
||||
self.optim_progress.optimizer_position = 0
|
||||
|
||||
def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
|
||||
result = self._run_optimization(
|
||||
batch,
|
||||
self._batch_idx,
|
||||
self._optimizers[self.optim_progress.optimizer_position],
|
||||
self.optimizer_idx,
|
||||
)
|
||||
def advance(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None: # type: ignore[override]
|
||||
kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens)
|
||||
|
||||
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
|
||||
if result.loss is not None:
|
||||
# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch
|
||||
# would be skipped otherwise
|
||||
|
@ -216,21 +211,19 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
self._optimizers = tuple()
|
||||
return outputs
|
||||
|
||||
def _run_optimization(
|
||||
self, split_batch: Any, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int
|
||||
) -> ClosureResult:
|
||||
def _run_optimization(self, kwargs: OrderedDict, optimizer: torch.optim.Optimizer) -> ClosureResult:
|
||||
"""Runs closure (train step + backward) together with optimization if necessary.
|
||||
|
||||
Args:
|
||||
split_batch: the current tbptt split of the whole batch
|
||||
batch_idx: the index of the current batch
|
||||
kwargs: the kwargs passed down to the hooks.
|
||||
optimizer: the current optimizer
|
||||
opt_idx: the index of the current optimizer
|
||||
"""
|
||||
opt_idx = kwargs.get("optimizer_idx", 0)
|
||||
|
||||
# toggle model params
|
||||
self._run_optimization_start(opt_idx, optimizer)
|
||||
|
||||
closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer)
|
||||
closure = self._make_closure(kwargs, optimizer)
|
||||
|
||||
if (
|
||||
# when the strategy handles accumulation, we want to always call the optimizer step
|
||||
|
@ -251,7 +244,8 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
# ------------------------------
|
||||
# gradient update with accumulated gradients
|
||||
else:
|
||||
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
|
||||
# the `batch_idx` is optional with inter-batch parallelism
|
||||
self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
|
||||
|
||||
result = closure.consume_result()
|
||||
|
||||
|
@ -265,17 +259,18 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
self._run_optimization_end(opt_idx)
|
||||
return result
|
||||
|
||||
def _make_closure(self, split_batch: Any, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Closure:
|
||||
def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer) -> Closure:
|
||||
"""Build a closure object that captures the given arguments and runs the `training_step` function and
|
||||
optionally other functions such as `backward` and `zero_grad`."""
|
||||
step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx)
|
||||
opt_idx = kwargs.get("optimizer_idx", 0)
|
||||
step_fn = self._make_step_fn(kwargs)
|
||||
backward_fn = self._make_backward_fn(optimizer, opt_idx)
|
||||
zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer)
|
||||
zero_grad_fn = self._make_zero_grad_fn(kwargs.get("batch_idx", 0), opt_idx, optimizer)
|
||||
return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn)
|
||||
|
||||
def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Callable[[], ClosureResult]:
|
||||
def _make_step_fn(self, kwargs: OrderedDict) -> Callable[[], ClosureResult]:
|
||||
"""Build the step function that runs the `training_step` and processes its output."""
|
||||
return partial(self._training_step, split_batch, batch_idx, opt_idx)
|
||||
return partial(self._training_step, kwargs)
|
||||
|
||||
def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]:
|
||||
"""Build a `zero_grad` function that zeroes the gradients before back-propagation.
|
||||
|
@ -399,33 +394,24 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
)
|
||||
self.optim_progress.optimizer.zero_grad.increment_completed()
|
||||
|
||||
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
|
||||
def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
|
||||
"""Performs the actual train step with the tied hooks.
|
||||
|
||||
Args:
|
||||
split_batch: the current tbptt split of the current batch
|
||||
batch_idx: the index of the current batch
|
||||
opt_idx: the index of the current optimizer
|
||||
kwargs: the kwargs passed down to the hooks.
|
||||
|
||||
Returns:
|
||||
A ``ClosureResult`` containing the training step output.
|
||||
"""
|
||||
# give the PL module a result for logging
|
||||
lightning_module = self.trainer.lightning_module
|
||||
|
||||
step_kwargs = _build_training_step_kwargs(
|
||||
lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
|
||||
)
|
||||
|
||||
# manually capture logged metrics
|
||||
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
|
||||
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
|
||||
self.trainer.strategy.post_training_step()
|
||||
|
||||
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
|
||||
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
|
||||
training_step_output = strategy_output if model_output is None else model_output
|
||||
|
||||
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
|
||||
self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps)
|
||||
|
||||
result = self.output_result_cls.from_training_step_output(
|
||||
training_step_output, self.trainer.accumulate_grad_batches
|
||||
|
@ -437,3 +423,18 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
self.trainer._results.cpu()
|
||||
|
||||
return result
|
||||
|
||||
def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int, hiddens: Optional[Any]) -> OrderedDict:
|
||||
"""Helper method to build the arguments for the current step.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs passed down to the hooks.
|
||||
opt_idx: the index of the current optimizer.
|
||||
hiddens: the hidden state of the previous RNN iteration.
|
||||
|
||||
Returns:
|
||||
The kwargs passed down to the hooks.
|
||||
"""
|
||||
return _build_training_step_kwargs(
|
||||
kwargs, self.trainer.lightning_module, self.trainer.optimizers, opt_idx, hiddens
|
||||
)
|
||||
|
|
|
@ -106,34 +106,25 @@ def _parse_loop_limits(
|
|||
|
||||
|
||||
def _build_training_step_kwargs(
|
||||
kwargs: OrderedDict,
|
||||
lightning_module: "pl.LightningModule",
|
||||
optimizers: Sequence[Optimizer],
|
||||
batch: Any,
|
||||
batch_idx: int,
|
||||
opt_idx: Optional[int],
|
||||
hiddens: Optional[Any],
|
||||
) -> Dict[str, Any]:
|
||||
) -> OrderedDict:
|
||||
"""Builds the keyword arguments for training_step.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs passed down to the hooks.
|
||||
lightning_module: the LightningModule with a `training_step` hook implementation
|
||||
optimizers: the list of optimizers from the Trainer
|
||||
batch: the batch to train on
|
||||
batch_idx: the index of the current batch
|
||||
opt_idx: the index of the current optimizer
|
||||
hiddens: the hidden state of the previous RNN iteration
|
||||
|
||||
Returns:
|
||||
the keyword arguments for the training step
|
||||
"""
|
||||
# enable not needing to add opt_idx to training_step
|
||||
step_kwargs = OrderedDict([("batch", batch)])
|
||||
|
||||
training_step_fx = getattr(lightning_module, "training_step")
|
||||
|
||||
if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
|
||||
step_kwargs["batch_idx"] = batch_idx
|
||||
|
||||
if len(optimizers) > 1:
|
||||
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
|
||||
if has_opt_idx_in_train_step:
|
||||
|
@ -143,7 +134,7 @@ def _build_training_step_kwargs(
|
|||
" in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
|
||||
" argument or set `self.automatic_optimization = True`."
|
||||
)
|
||||
step_kwargs["optimizer_idx"] = opt_idx
|
||||
kwargs["optimizer_idx"] = opt_idx
|
||||
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
|
||||
raise ValueError(
|
||||
f"Your LightningModule defines {len(optimizers)} optimizers but"
|
||||
|
@ -152,9 +143,9 @@ def _build_training_step_kwargs(
|
|||
|
||||
# pass hiddens if using tbptt
|
||||
if lightning_module.truncated_bptt_steps > 0:
|
||||
step_kwargs["hiddens"] = hiddens
|
||||
kwargs["hiddens"] = hiddens
|
||||
|
||||
return step_kwargs
|
||||
return kwargs
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -182,7 +173,7 @@ def _cumulative_optimizer_frequencies(frequencies: Tuple[int]) -> np.ndarray:
|
|||
|
||||
|
||||
def _get_active_optimizers(
|
||||
optimizers: List[Optimizer], frequencies: List[int], batch_idx: Optional[int] = None
|
||||
optimizers: List[Optimizer], frequencies: List[int], batch_idx: int
|
||||
) -> List[Tuple[int, Optimizer]]:
|
||||
"""Returns the currently active optimizers. When multiple optimizers are used with different frequencies, only
|
||||
one of the optimizers is active at a time.
|
||||
|
|
|
@ -170,3 +170,36 @@ def test_tbptt_logging(tmpdir, model_class):
|
|||
)
|
||||
trainer.fit(model)
|
||||
assert set(trainer.logged_metrics) == {"loss_step", "loss_epoch"}
|
||||
|
||||
|
||||
def test_hiddens_multiple_optimizers(tmpdir):
|
||||
class TBPTTModel(LSTMModel):
|
||||
# TODO: `optimizer_idx=n` gets the hiddens from `optimizer_idx=n-1` instead of the hidden from
|
||||
# `optimizer_idx=n`, `split_idx=m-1`. This is unexpected and should be changed
|
||||
test_hiddens = None
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx, hiddens):
|
||||
if hiddens is None:
|
||||
assert self.test_hiddens is None
|
||||
else:
|
||||
assert all(torch.equal(h, th) for h, th in zip(hiddens, self.test_hiddens))
|
||||
out = super().training_step(batch, batch_idx, hiddens)
|
||||
self.test_hiddens = out["hiddens"]
|
||||
return out
|
||||
|
||||
def configure_optimizers(self):
|
||||
return [super().configure_optimizers(), super().configure_optimizers()]
|
||||
|
||||
model = TBPTTModel(truncated_bptt_steps=2, input_size=1, hidden_size=1)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=0,
|
||||
enable_model_summary=False,
|
||||
logger=False,
|
||||
enable_checkpointing=False,
|
||||
enable_progress_bar=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.global_step == 8 / 2 * 2 # time_dim_length / tbptt_steps * num_optimizers
|
||||
|
|
|
@ -12,10 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
from pytorch_lightning.loops import TrainingEpochLoop
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
from tests.deprecated_api import no_deprecated_call
|
||||
|
@ -33,7 +34,8 @@ _out13 = {"loss": 1.3}
|
|||
|
||||
class TestPrepareOutputs:
|
||||
def prepare_outputs(self, fn, tbptt_splits, new_format, batch_outputs, num_optimizers, automatic_optimization):
|
||||
lightning_module = Mock()
|
||||
lightning_module = LightningModule()
|
||||
lightning_module.on_train_batch_end = lambda *_: None # override to trigger the deprecation message
|
||||
lightning_module.automatic_optimization = automatic_optimization
|
||||
lightning_module.truncated_bptt_steps = tbptt_splits
|
||||
match = "will change in version v1.8.*new_format=True"
|
||||
|
|
|
@ -63,8 +63,8 @@ def test__eval_step__flow(tmpdir):
|
|||
|
||||
# simulate training manually
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
batch_idx, batch = 0, next(iter(model.train_dataloader()))
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
|
||||
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
|
||||
|
||||
assert len(train_step_out) == 1
|
||||
train_step_out = train_step_out[0][0]
|
||||
|
@ -72,9 +72,7 @@ def test__eval_step__flow(tmpdir):
|
|||
assert train_step_out["loss"].item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
|
||||
batch, batch_idx, 0, trainer.optimizers[0]
|
||||
)
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure_result = opt_closure()
|
||||
assert opt_closure_result.item() == 171
|
||||
|
||||
|
@ -126,8 +124,8 @@ def test__eval_step__eval_step_end__flow(tmpdir):
|
|||
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
# make sure training outputs what is expected
|
||||
batch_idx, batch = 0, next(iter(model.train_dataloader()))
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
|
||||
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
|
||||
|
||||
assert len(train_step_out) == 1
|
||||
train_step_out = train_step_out[0][0]
|
||||
|
@ -135,9 +133,7 @@ def test__eval_step__eval_step_end__flow(tmpdir):
|
|||
assert train_step_out["loss"].item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
|
||||
batch, batch_idx, 0, trainer.optimizers[0]
|
||||
)
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure_result = opt_closure()
|
||||
assert opt_closure_result.item() == 171
|
||||
|
||||
|
|
|
@ -146,8 +146,8 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
|
|||
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
# make sure training outputs what is expected
|
||||
batch_idx, batch = 0, next(iter(model.train_dataloader()))
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
|
||||
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
|
||||
|
||||
assert len(train_step_out) == 1
|
||||
train_step_out = train_step_out[0][0]
|
||||
|
@ -155,9 +155,7 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
|
|||
assert train_step_out["loss"].item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
|
||||
batch, batch_idx, 0, trainer.optimizers[0]
|
||||
)
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure_result = opt_closure()
|
||||
assert opt_closure_result.item() == 171
|
||||
|
||||
|
@ -218,8 +216,8 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
|
|||
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
# make sure training outputs what is expected
|
||||
batch_idx, batch = 0, next(iter(model.train_dataloader()))
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
|
||||
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
|
||||
|
||||
assert len(train_step_out) == 1
|
||||
train_step_out = train_step_out[0][0]
|
||||
|
@ -227,9 +225,7 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
|
|||
assert train_step_out["loss"].item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
|
||||
batch, batch_idx, 0, trainer.optimizers[0]
|
||||
)
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure_result = opt_closure()
|
||||
assert opt_closure_result.item() == 171
|
||||
|
||||
|
@ -239,7 +235,7 @@ def test_train_step_no_return(tmpdir):
|
|||
automatic_optimization."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
def training_step(self, batch):
|
||||
self.training_step_called = True
|
||||
loss = self.step(batch[0])
|
||||
self.log("a", loss, on_step=True, on_epoch=True)
|
||||
|
@ -305,7 +301,7 @@ def test_training_step_no_return_when_even(tmpdir):
|
|||
|
||||
# manually check a few batches
|
||||
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
|
||||
out = trainer.fit_loop.epoch_loop.batch_loop.run({"batch": batch, "batch_idx": batch_idx})
|
||||
if not batch_idx % 2:
|
||||
assert out == []
|
||||
|
||||
|
|
|
@ -141,7 +141,7 @@ def test__training_step__step_end__epoch_end__log(tmpdir, batches, log_interval,
|
|||
"""Tests that training_step_end and training_epoch_end can log."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
def training_step(self, batch):
|
||||
loss = self.step(batch[0])
|
||||
self.log("a", loss, on_step=True, on_epoch=True)
|
||||
return loss
|
||||
|
|
|
@ -375,9 +375,10 @@ def test_stop_iteration(trigger_stop_iteration, tmpdir):
|
|||
super().__init__()
|
||||
self.trigger_stop_iteration = trigger_stop_iteration
|
||||
|
||||
def training_step(self, dataloader_iter: Iterator, *args) -> STEP_OUTPUT:
|
||||
def training_step(self, dataloader_iter: Iterator) -> STEP_OUTPUT:
|
||||
output = super().training_step(dataloader_iter)
|
||||
if self.trigger_stop_iteration and args[0] == EXPECT_NUM_BATCHES_PROCESSED:
|
||||
batch_idx = self.trainer.fit_loop.epoch_loop.batch_idx
|
||||
if self.trigger_stop_iteration and batch_idx == EXPECT_NUM_BATCHES_PROCESSED:
|
||||
raise StopIteration
|
||||
return output
|
||||
|
||||
|
|
Loading…
Reference in New Issue