restrict public interface of training loop (#8024)
* active optimizers * check checkpoint callback * epoch loop properties * epoch loop methods * training_batch_loop * changelog * update chlog * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unused imports * yapf * backward * fix missing string reference * is_last_batch remains public * remove dead code Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
parent
a45ab00b30
commit
fe48203111
|
@ -147,6 +147,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Removed `pytorch_lightning/trainer/training_loop.py` ([#7985](https://github.com/PyTorchLightning/pytorch-lightning/pull/7985))
|
||||
* Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationDataLoaderLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990))
|
||||
* Removed `pytorch_lightning/trainer/evaluation_loop.py` ([#8056](https://github.com/PyTorchLightning/pytorch-lightning/pull/8056))
|
||||
* Restricted public access to several internal functions ([#8024](https://github.com/PyTorchLightning/pytorch-lightning/pull/8024))
|
||||
* Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065))
|
||||
* Refactored prediction loop interface; added new classes `PredictionDataLoaderLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700))
|
||||
|
||||
|
|
|
@ -285,7 +285,7 @@ class BaseFinetuning(Callback):
|
|||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
"""Called when the epoch begins."""
|
||||
for opt_idx, optimizer in trainer.fit_loop.get_active_optimizers():
|
||||
for opt_idx, optimizer in trainer.fit_loop.training_loop.batch_loop.get_active_optimizers():
|
||||
num_param_groups = len(optimizer.param_groups)
|
||||
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
|
||||
current_param_groups = optimizer.param_groups
|
||||
|
|
|
@ -14,9 +14,7 @@
|
|||
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loops.base import Loop
|
||||
|
@ -230,7 +228,7 @@ class FitLoop(Loop):
|
|||
did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.skip
|
||||
if did_train_only:
|
||||
self.global_step -= 1
|
||||
self.check_checkpoint_callback(True)
|
||||
self._check_checkpoint_callback(True)
|
||||
self.global_step += 1
|
||||
|
||||
def on_run_end(self) -> None:
|
||||
|
@ -245,7 +243,7 @@ class FitLoop(Loop):
|
|||
# when a checkpoint was saved at the last step
|
||||
self.training_loop.global_step -= 1
|
||||
# TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406
|
||||
self.check_checkpoint_callback(should_update=True, is_last=True)
|
||||
self._check_checkpoint_callback(should_update=True, is_last=True)
|
||||
self.training_loop.global_step += 1
|
||||
|
||||
# hook
|
||||
|
@ -270,11 +268,7 @@ class FitLoop(Loop):
|
|||
"""Whether the gradients should be accumulated"""
|
||||
return self.training_loop.batch_loop.should_accumulate()
|
||||
|
||||
def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]:
|
||||
"""Generates a list of active optimizers"""
|
||||
return self.training_loop.batch_loop.get_active_optimizers(batch_idx)
|
||||
|
||||
def check_checkpoint_callback(self, should_update: bool, is_last: bool = False):
|
||||
def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False):
|
||||
"""Checks if checkpointing needs to be done"""
|
||||
# TODO: bake this logic into the ModelCheckpoint callback
|
||||
if should_update and self.trainer.checkpoint_connector.has_trained:
|
||||
|
|
|
@ -109,7 +109,7 @@ class TrainingBatchLoop(Loop):
|
|||
dataloader_idx: the index of the dataloader producing the current batch
|
||||
"""
|
||||
void(batch_idx, dataloader_idx)
|
||||
self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch)))
|
||||
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))
|
||||
|
||||
def advance(self, batch, batch_idx, dataloader_idx):
|
||||
"""Runs the train step together with optimization (if necessary) on the current batch split
|
||||
|
@ -157,10 +157,10 @@ class TrainingBatchLoop(Loop):
|
|||
# opt_idx=0 to opt_idx=None in the signature here
|
||||
|
||||
# toggle model params
|
||||
self.run_optimization_start(opt_idx, optimizer)
|
||||
self._run_optimization_start(opt_idx, optimizer)
|
||||
|
||||
result = AttributeDict()
|
||||
closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result)
|
||||
closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result)
|
||||
|
||||
if self.should_accumulate():
|
||||
# For gradient accumulation
|
||||
|
@ -179,24 +179,24 @@ class TrainingBatchLoop(Loop):
|
|||
# gradient update with accumulated gradients
|
||||
else:
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
self.optimizer_step(optimizer, opt_idx, batch_idx, closure)
|
||||
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
|
||||
if len(self.trainer.optimizers) > 1:
|
||||
# revert back to previous state
|
||||
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
|
||||
else:
|
||||
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
|
||||
result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens)
|
||||
|
||||
if not result:
|
||||
# user decided to skip optimization
|
||||
return result
|
||||
|
||||
# update running loss + reset accumulated loss
|
||||
self.update_running_loss(result.loss)
|
||||
self._update_running_loss(result.loss)
|
||||
|
||||
self._process_closure_result(result)
|
||||
return result
|
||||
|
||||
def training_step_and_backward_closure(
|
||||
def _training_step_and_backward_closure(
|
||||
self,
|
||||
split_batch: Any,
|
||||
batch_idx: int,
|
||||
|
@ -221,10 +221,10 @@ class TrainingBatchLoop(Loop):
|
|||
return_result.update(result)
|
||||
return return_result.loss
|
||||
|
||||
def make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable:
|
||||
def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable:
|
||||
""" Wraps the training step closure into a partial object which will be called within ``optimizer.step``. """
|
||||
partial_func = partial(self.training_step_and_backward_closure, *closure_args, **closure_kwargs)
|
||||
return update_wrapper(partial_func, self.training_step_and_backward_closure)
|
||||
partial_func = partial(self._training_step_and_backward_closure, *closure_args, **closure_kwargs)
|
||||
return update_wrapper(partial_func, self._training_step_and_backward_closure)
|
||||
|
||||
def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None:
|
||||
"""Checks if the closure results is finite and optionally breaks if it is not
|
||||
|
@ -239,7 +239,7 @@ class TrainingBatchLoop(Loop):
|
|||
if self.trainer.terminate_on_nan:
|
||||
self._check_finite(opt_closure_result.loss)
|
||||
|
||||
def on_after_backward(self, batch_idx: int, untouched_loss: Tensor) -> None:
|
||||
def _on_after_backward(self, batch_idx: int, untouched_loss: Tensor) -> None:
|
||||
"""Calls ``on_after_backward`` hook and tracks loss history
|
||||
|
||||
Args:
|
||||
|
@ -276,7 +276,13 @@ class TrainingBatchLoop(Loop):
|
|||
"a dict with key 'loss' or None (where the step will be skipped)."
|
||||
)
|
||||
|
||||
def training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> Optional[AttributeDict]:
|
||||
def _training_step(
|
||||
self,
|
||||
split_batch: Any,
|
||||
batch_idx: int,
|
||||
opt_idx: int,
|
||||
hiddens: Tensor,
|
||||
) -> Optional[AttributeDict]:
|
||||
"""Performs the actual train step with the tied hooks.
|
||||
|
||||
Args:
|
||||
|
@ -355,7 +361,7 @@ class TrainingBatchLoop(Loop):
|
|||
results.cpu()
|
||||
return results
|
||||
|
||||
def optimizer_step(
|
||||
def _optimizer_step(
|
||||
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
|
||||
) -> None:
|
||||
"""Performs the optimizer step and some sanity checking.
|
||||
|
@ -394,7 +400,7 @@ class TrainingBatchLoop(Loop):
|
|||
using_lbfgs=is_lbfgs,
|
||||
)
|
||||
|
||||
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
|
||||
def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
|
||||
"""Calls the ``on_before_zero_grad`` hook.
|
||||
|
||||
Args:
|
||||
|
@ -402,7 +408,7 @@ class TrainingBatchLoop(Loop):
|
|||
"""
|
||||
self.trainer.call_hook('on_before_zero_grad', optimizer)
|
||||
|
||||
def optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
|
||||
def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
|
||||
"""Zeroes out all gradients of parameters optimized by the current optimizer.
|
||||
|
||||
Args:
|
||||
|
@ -412,7 +418,7 @@ class TrainingBatchLoop(Loop):
|
|||
"""
|
||||
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||
|
||||
def track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
|
||||
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
|
||||
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
|
||||
|
||||
Args:
|
||||
|
@ -452,7 +458,7 @@ class TrainingBatchLoop(Loop):
|
|||
is_final_batch = self._num_training_batches_reached()
|
||||
return not (accumulation_done or is_final_batch)
|
||||
|
||||
def tbptt_split_batch(self, batch: Any) -> List[Any]:
|
||||
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
|
||||
"""Splits a single batch into a list of sequence steps for tbptt.
|
||||
|
||||
Args:
|
||||
|
@ -465,45 +471,7 @@ class TrainingBatchLoop(Loop):
|
|||
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
|
||||
return splits
|
||||
|
||||
def build_train_args(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> List[Any]:
|
||||
"""Builds arguments for train step
|
||||
|
||||
Args:
|
||||
batch: the current batch to train on
|
||||
batch_idx: the index of the current batch
|
||||
opt_idx: the index of the current optimizer
|
||||
hiddens: the hidden state of the previous RNN iteration
|
||||
|
||||
Returns:
|
||||
the positional arguments for training
|
||||
"""
|
||||
# enable not needing to add opt_idx to training_step
|
||||
args = [batch, batch_idx]
|
||||
|
||||
if len(self.trainer.optimizers) > 1:
|
||||
if self.trainer.has_arg("training_step", "optimizer_idx"):
|
||||
if not self.trainer.lightning_module.automatic_optimization:
|
||||
self.warning_cache.deprecation(
|
||||
"`training_step` hook signature has changed in v1.3."
|
||||
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
|
||||
" the old signature will be removed in v1.5"
|
||||
)
|
||||
args.append(opt_idx)
|
||||
elif not self.trainer.has_arg(
|
||||
"training_step", "optimizer_idx"
|
||||
) and self.trainer.lightning_module.automatic_optimization:
|
||||
raise ValueError(
|
||||
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
|
||||
' `training_step` is missing the `optimizer_idx` argument.'
|
||||
)
|
||||
|
||||
# pass hiddens if using tbptt
|
||||
if self.trainer.truncated_bptt_steps is not None:
|
||||
args.append(hiddens)
|
||||
|
||||
return args
|
||||
|
||||
def run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
|
||||
def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
|
||||
"""Toggles the optimizer to ensure the correct one is used and prevend dangling grads.
|
||||
|
||||
Args:
|
||||
|
@ -551,14 +519,14 @@ class TrainingBatchLoop(Loop):
|
|||
"""Wrap forward, zero_grad and backward in a closure so second order methods work"""
|
||||
with self.trainer.profiler.profile("training_step_and_backward"):
|
||||
# lightning module hook
|
||||
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
|
||||
result = self._training_step(split_batch, batch_idx, opt_idx, hiddens)
|
||||
|
||||
if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
|
||||
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
|
||||
|
||||
if is_first_batch_to_accumulate:
|
||||
self.on_before_zero_grad(optimizer)
|
||||
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
|
||||
self._on_before_zero_grad(optimizer)
|
||||
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
|
||||
|
||||
# backward pass
|
||||
if result is not None:
|
||||
|
@ -568,7 +536,7 @@ class TrainingBatchLoop(Loop):
|
|||
# hook - call this hook only
|
||||
# when gradients have finished to accumulate
|
||||
if not self.should_accumulate():
|
||||
self.on_after_backward(batch_idx, result.loss)
|
||||
self._on_after_backward(batch_idx, result.loss)
|
||||
|
||||
# check if loss or model weights are nan
|
||||
if self.trainer.terminate_on_nan:
|
||||
|
@ -616,12 +584,12 @@ class TrainingBatchLoop(Loop):
|
|||
|
||||
if not self.should_accumulate():
|
||||
# track gradients
|
||||
grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer)
|
||||
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
|
||||
if grad_norm_dict:
|
||||
self.trainer.lightning_module._current_fx_name = "on_after_backward"
|
||||
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
|
||||
|
||||
def update_running_loss(self, current_loss: Tensor) -> None:
|
||||
def _update_running_loss(self, current_loss: Tensor) -> None:
|
||||
"""Updates the running loss value with the current value"""
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
# track total loss for logging (avoid mem leaks)
|
||||
|
|
|
@ -43,17 +43,15 @@ class TrainingEpochLoop(Loop):
|
|||
self.iteration_count: int = 0
|
||||
# the current split index when the batch gets split into chunks in truncated backprop through time
|
||||
self.split_idx: Optional[int] = None
|
||||
|
||||
self._dataloader_idx: Optional[int] = None
|
||||
self._should_stop: bool = False
|
||||
|
||||
self.is_last_batch: Optional[bool] = None
|
||||
# the number of batches seen this run, updates immediately after batch_loop.run()
|
||||
self.batches_seen: int = 0
|
||||
self.warning_cache: WarningCache = WarningCache()
|
||||
self.epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
|
||||
self.is_last_batch: Optional[bool] = None
|
||||
|
||||
self.batch_loop: Optional[TrainingBatchLoop] = None
|
||||
|
||||
self._dataloader_idx: Optional[int] = None
|
||||
self._warning_cache: WarningCache = WarningCache()
|
||||
self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
|
||||
self._results = ResultCollection(training=True)
|
||||
|
||||
@property
|
||||
|
@ -86,10 +84,9 @@ class TrainingEpochLoop(Loop):
|
|||
self.batches_seen = 0
|
||||
self.is_last_batch = False
|
||||
self._dataloader_idx = 0
|
||||
self._should_stop = False
|
||||
|
||||
# track epoch output
|
||||
self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
|
||||
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
|
||||
|
||||
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
# hook
|
||||
|
@ -137,7 +134,7 @@ class TrainingEpochLoop(Loop):
|
|||
self.trainer.logger_connector.on_batch_end()
|
||||
|
||||
# figure out what to track for epoch end
|
||||
self.track_epoch_end_reduce_metrics(self.epoch_output, batch_end_outputs)
|
||||
self._track_epoch_end_reduce_metrics(self._epoch_output, batch_end_outputs)
|
||||
|
||||
# -----------------------------------------
|
||||
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR
|
||||
|
@ -153,7 +150,7 @@ class TrainingEpochLoop(Loop):
|
|||
# -----------------------------------------
|
||||
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
|
||||
# -----------------------------------------
|
||||
should_check_val = self.should_check_val_fx(self.iteration_count, self.is_last_batch)
|
||||
should_check_val = self._should_check_val_fx(self.iteration_count, self.is_last_batch)
|
||||
if should_check_val:
|
||||
self.trainer.validating = True
|
||||
self._run_validation()
|
||||
|
@ -162,7 +159,7 @@ class TrainingEpochLoop(Loop):
|
|||
# -----------------------------------------
|
||||
# SAVE LOGGERS (ie: Tensorboard, etc...)
|
||||
# -----------------------------------------
|
||||
self.save_loggers_on_train_batch_end()
|
||||
self._save_loggers_on_train_batch_end()
|
||||
|
||||
# update plateau LR scheduler after metrics are logged
|
||||
self.update_lr_schedulers('step', update_plateau_schedulers=True)
|
||||
|
@ -171,7 +168,7 @@ class TrainingEpochLoop(Loop):
|
|||
self.total_batch_idx += 1
|
||||
|
||||
# progress global step according to grads progress
|
||||
self.increment_accumulated_grad_global_step()
|
||||
self._increment_accumulated_grad_global_step()
|
||||
|
||||
if self.done:
|
||||
raise StopIteration
|
||||
|
@ -200,7 +197,7 @@ class TrainingEpochLoop(Loop):
|
|||
self.trainer.logger_connector.epoch_end_reached()
|
||||
|
||||
# prepare epoch output
|
||||
processed_outputs = self._prepare_outputs(self.epoch_output, batch_mode=False)
|
||||
processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False)
|
||||
|
||||
# get the model and call model.training_epoch_end
|
||||
model = self.trainer.lightning_module
|
||||
|
@ -223,7 +220,7 @@ class TrainingEpochLoop(Loop):
|
|||
self._on_train_epoch_end_hook(processed_outputs)
|
||||
self.trainer.call_hook('on_epoch_end')
|
||||
self.trainer.logger_connector.on_epoch_end()
|
||||
return self.epoch_output
|
||||
return self._epoch_output
|
||||
|
||||
def teardown(self) -> None:
|
||||
"""Frees memory of tracked epoch outputs."""
|
||||
|
@ -253,7 +250,7 @@ class TrainingEpochLoop(Loop):
|
|||
if is_overridden(hook_name, model_ref):
|
||||
hook_fx = getattr(model_ref, hook_name)
|
||||
if is_param_in_hook_signature(hook_fx, "outputs"):
|
||||
self.warning_cache.deprecation(
|
||||
self._warning_cache.deprecation(
|
||||
"The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
|
||||
" `outputs` parameter has been deprecated."
|
||||
" Support for the old signature will be removed in v1.5",
|
||||
|
@ -276,7 +273,7 @@ class TrainingEpochLoop(Loop):
|
|||
# TODO: Can we combine this with training_batch_loop's arg that does a similar check?
|
||||
return self.batches_seen == self.trainer.num_training_batches or is_last_batch
|
||||
|
||||
def track_epoch_end_reduce_metrics(
|
||||
def _track_epoch_end_reduce_metrics(
|
||||
self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT
|
||||
) -> None:
|
||||
"""Adds the batch outputs to the epoch outputs and prepares reduction"""
|
||||
|
@ -379,7 +376,7 @@ class TrainingEpochLoop(Loop):
|
|||
opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)],
|
||||
)
|
||||
|
||||
def increment_accumulated_grad_global_step(self) -> None:
|
||||
def _increment_accumulated_grad_global_step(self) -> None:
|
||||
"""increments global step"""
|
||||
num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached()
|
||||
num_training_batches_reached = self._num_training_batches_reached()
|
||||
|
@ -390,7 +387,7 @@ class TrainingEpochLoop(Loop):
|
|||
self.total_batch_idx, self.trainer.global_step
|
||||
)
|
||||
|
||||
def should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
|
||||
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
|
||||
""" Decide if we should run validation. """
|
||||
if not self.trainer.enable_validation:
|
||||
return False
|
||||
|
@ -415,7 +412,7 @@ class TrainingEpochLoop(Loop):
|
|||
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
|
||||
return is_val_check_batch
|
||||
|
||||
def save_loggers_on_train_batch_end(self) -> None:
|
||||
def _save_loggers_on_train_batch_end(self) -> None:
|
||||
"""Flushes loggers to disk"""
|
||||
# when loggers should save to disk
|
||||
should_flush_logs = self.trainer.logger_connector.should_flush_logs
|
||||
|
|
|
@ -522,6 +522,10 @@ class TrainerProperties(ABC):
|
|||
def min_steps(self) -> Optional[int]:
|
||||
return self.fit_loop.min_steps
|
||||
|
||||
@property
|
||||
def is_last_batch(self) -> bool:
|
||||
return self.fit_loop.training_loop.is_last_batch
|
||||
|
||||
@property
|
||||
def _active_loop(self) -> Optional[Union[FitLoop, EvaluationDataLoaderLoop]]:
|
||||
if self.training:
|
||||
|
|
|
@ -242,7 +242,7 @@ def test_lightning_optimizer_automatic_optimization_optimizer_step(tmpdir):
|
|||
...
|
||||
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_):
|
||||
assert optimizer_closure.__name__ == "training_step_and_backward_closure"
|
||||
assert optimizer_closure.__name__ == "_training_step_and_backward_closure"
|
||||
# not passing the closure to the optimizer because step is mocked
|
||||
# zero_grad is called inside the closure
|
||||
if isinstance(optimizer, SGD) and batch_idx % 2 == 0:
|
||||
|
|
|
@ -244,7 +244,7 @@ def test_v1_5_0_old_on_train_epoch_end(tmpdir):
|
|||
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
|
||||
trainer.fit(model)
|
||||
|
||||
trainer.fit_loop.training_loop.warning_cache.clear()
|
||||
trainer.fit_loop.training_loop._warning_cache.clear()
|
||||
|
||||
class NewSignature(Callback):
|
||||
|
||||
|
|
Loading…
Reference in New Issue