restrict public interface of training loop (#8024)

* active optimizers

* check checkpoint callback

* epoch loop properties

* epoch loop methods

* training_batch_loop

* changelog

* update chlog

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unused imports

* yapf

* backward

* fix missing string reference

* is_last_batch remains public

* remove dead code

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-06-23 12:25:29 +02:00 committed by GitHub
parent a45ab00b30
commit fe48203111
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 59 additions and 95 deletions

View File

@ -147,6 +147,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed `pytorch_lightning/trainer/training_loop.py` ([#7985](https://github.com/PyTorchLightning/pytorch-lightning/pull/7985))
* Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationDataLoaderLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990))
* Removed `pytorch_lightning/trainer/evaluation_loop.py` ([#8056](https://github.com/PyTorchLightning/pytorch-lightning/pull/8056))
* Restricted public access to several internal functions ([#8024](https://github.com/PyTorchLightning/pytorch-lightning/pull/8024))
* Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065))
* Refactored prediction loop interface; added new classes `PredictionDataLoaderLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700))

View File

@ -285,7 +285,7 @@ class BaseFinetuning(Callback):
def on_train_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
for opt_idx, optimizer in trainer.fit_loop.get_active_optimizers():
for opt_idx, optimizer in trainer.fit_loop.training_loop.batch_loop.get_active_optimizers():
num_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
current_param_groups = optimizer.param_groups

View File

@ -14,9 +14,7 @@
import logging
from contextlib import suppress
from typing import Any, List, Optional, Tuple
from torch.optim import Optimizer
from typing import Any, Optional
import pytorch_lightning as pl
from pytorch_lightning.loops.base import Loop
@ -230,7 +228,7 @@ class FitLoop(Loop):
did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.skip
if did_train_only:
self.global_step -= 1
self.check_checkpoint_callback(True)
self._check_checkpoint_callback(True)
self.global_step += 1
def on_run_end(self) -> None:
@ -245,7 +243,7 @@ class FitLoop(Loop):
# when a checkpoint was saved at the last step
self.training_loop.global_step -= 1
# TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406
self.check_checkpoint_callback(should_update=True, is_last=True)
self._check_checkpoint_callback(should_update=True, is_last=True)
self.training_loop.global_step += 1
# hook
@ -270,11 +268,7 @@ class FitLoop(Loop):
"""Whether the gradients should be accumulated"""
return self.training_loop.batch_loop.should_accumulate()
def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]:
"""Generates a list of active optimizers"""
return self.training_loop.batch_loop.get_active_optimizers(batch_idx)
def check_checkpoint_callback(self, should_update: bool, is_last: bool = False):
def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False):
"""Checks if checkpointing needs to be done"""
# TODO: bake this logic into the ModelCheckpoint callback
if should_update and self.trainer.checkpoint_connector.has_trained:

View File

@ -109,7 +109,7 @@ class TrainingBatchLoop(Loop):
dataloader_idx: the index of the dataloader producing the current batch
"""
void(batch_idx, dataloader_idx)
self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch)))
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))
def advance(self, batch, batch_idx, dataloader_idx):
"""Runs the train step together with optimization (if necessary) on the current batch split
@ -157,10 +157,10 @@ class TrainingBatchLoop(Loop):
# opt_idx=0 to opt_idx=None in the signature here
# toggle model params
self.run_optimization_start(opt_idx, optimizer)
self._run_optimization_start(opt_idx, optimizer)
result = AttributeDict()
closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result)
closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result)
if self.should_accumulate():
# For gradient accumulation
@ -179,24 +179,24 @@ class TrainingBatchLoop(Loop):
# gradient update with accumulated gradients
else:
if self.trainer.lightning_module.automatic_optimization:
self.optimizer_step(optimizer, opt_idx, batch_idx, closure)
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
if len(self.trainer.optimizers) > 1:
# revert back to previous state
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
else:
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens)
if not result:
# user decided to skip optimization
return result
# update running loss + reset accumulated loss
self.update_running_loss(result.loss)
self._update_running_loss(result.loss)
self._process_closure_result(result)
return result
def training_step_and_backward_closure(
def _training_step_and_backward_closure(
self,
split_batch: Any,
batch_idx: int,
@ -221,10 +221,10 @@ class TrainingBatchLoop(Loop):
return_result.update(result)
return return_result.loss
def make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable:
def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable:
""" Wraps the training step closure into a partial object which will be called within ``optimizer.step``. """
partial_func = partial(self.training_step_and_backward_closure, *closure_args, **closure_kwargs)
return update_wrapper(partial_func, self.training_step_and_backward_closure)
partial_func = partial(self._training_step_and_backward_closure, *closure_args, **closure_kwargs)
return update_wrapper(partial_func, self._training_step_and_backward_closure)
def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None:
"""Checks if the closure results is finite and optionally breaks if it is not
@ -239,7 +239,7 @@ class TrainingBatchLoop(Loop):
if self.trainer.terminate_on_nan:
self._check_finite(opt_closure_result.loss)
def on_after_backward(self, batch_idx: int, untouched_loss: Tensor) -> None:
def _on_after_backward(self, batch_idx: int, untouched_loss: Tensor) -> None:
"""Calls ``on_after_backward`` hook and tracks loss history
Args:
@ -276,7 +276,13 @@ class TrainingBatchLoop(Loop):
"a dict with key 'loss' or None (where the step will be skipped)."
)
def training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> Optional[AttributeDict]:
def _training_step(
self,
split_batch: Any,
batch_idx: int,
opt_idx: int,
hiddens: Tensor,
) -> Optional[AttributeDict]:
"""Performs the actual train step with the tied hooks.
Args:
@ -355,7 +361,7 @@ class TrainingBatchLoop(Loop):
results.cpu()
return results
def optimizer_step(
def _optimizer_step(
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
) -> None:
"""Performs the optimizer step and some sanity checking.
@ -394,7 +400,7 @@ class TrainingBatchLoop(Loop):
using_lbfgs=is_lbfgs,
)
def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""Calls the ``on_before_zero_grad`` hook.
Args:
@ -402,7 +408,7 @@ class TrainingBatchLoop(Loop):
"""
self.trainer.call_hook('on_before_zero_grad', optimizer)
def optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
"""Zeroes out all gradients of parameters optimized by the current optimizer.
Args:
@ -412,7 +418,7 @@ class TrainingBatchLoop(Loop):
"""
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
def track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
Args:
@ -452,7 +458,7 @@ class TrainingBatchLoop(Loop):
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)
def tbptt_split_batch(self, batch: Any) -> List[Any]:
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
"""Splits a single batch into a list of sequence steps for tbptt.
Args:
@ -465,45 +471,7 @@ class TrainingBatchLoop(Loop):
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
return splits
def build_train_args(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> List[Any]:
"""Builds arguments for train step
Args:
batch: the current batch to train on
batch_idx: the index of the current batch
opt_idx: the index of the current optimizer
hiddens: the hidden state of the previous RNN iteration
Returns:
the positional arguments for training
"""
# enable not needing to add opt_idx to training_step
args = [batch, batch_idx]
if len(self.trainer.optimizers) > 1:
if self.trainer.has_arg("training_step", "optimizer_idx"):
if not self.trainer.lightning_module.automatic_optimization:
self.warning_cache.deprecation(
"`training_step` hook signature has changed in v1.3."
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
" the old signature will be removed in v1.5"
)
args.append(opt_idx)
elif not self.trainer.has_arg(
"training_step", "optimizer_idx"
) and self.trainer.lightning_module.automatic_optimization:
raise ValueError(
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
' `training_step` is missing the `optimizer_idx` argument.'
)
# pass hiddens if using tbptt
if self.trainer.truncated_bptt_steps is not None:
args.append(hiddens)
return args
def run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
"""Toggles the optimizer to ensure the correct one is used and prevend dangling grads.
Args:
@ -551,14 +519,14 @@ class TrainingBatchLoop(Loop):
"""Wrap forward, zero_grad and backward in a closure so second order methods work"""
with self.trainer.profiler.profile("training_step_and_backward"):
# lightning module hook
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
result = self._training_step(split_batch, batch_idx, opt_idx, hiddens)
if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
if is_first_batch_to_accumulate:
self.on_before_zero_grad(optimizer)
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
self._on_before_zero_grad(optimizer)
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
# backward pass
if result is not None:
@ -568,7 +536,7 @@ class TrainingBatchLoop(Loop):
# hook - call this hook only
# when gradients have finished to accumulate
if not self.should_accumulate():
self.on_after_backward(batch_idx, result.loss)
self._on_after_backward(batch_idx, result.loss)
# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
@ -616,12 +584,12 @@ class TrainingBatchLoop(Loop):
if not self.should_accumulate():
# track gradients
grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer)
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
if grad_norm_dict:
self.trainer.lightning_module._current_fx_name = "on_after_backward"
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
def update_running_loss(self, current_loss: Tensor) -> None:
def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value"""
if self.trainer.lightning_module.automatic_optimization:
# track total loss for logging (avoid mem leaks)

View File

@ -43,17 +43,15 @@ class TrainingEpochLoop(Loop):
self.iteration_count: int = 0
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx: Optional[int] = None
self._dataloader_idx: Optional[int] = None
self._should_stop: bool = False
self.is_last_batch: Optional[bool] = None
# the number of batches seen this run, updates immediately after batch_loop.run()
self.batches_seen: int = 0
self.warning_cache: WarningCache = WarningCache()
self.epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
self.is_last_batch: Optional[bool] = None
self.batch_loop: Optional[TrainingBatchLoop] = None
self._dataloader_idx: Optional[int] = None
self._warning_cache: WarningCache = WarningCache()
self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
self._results = ResultCollection(training=True)
@property
@ -86,10 +84,9 @@ class TrainingEpochLoop(Loop):
self.batches_seen = 0
self.is_last_batch = False
self._dataloader_idx = 0
self._should_stop = False
# track epoch output
self.epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
# hook
@ -137,7 +134,7 @@ class TrainingEpochLoop(Loop):
self.trainer.logger_connector.on_batch_end()
# figure out what to track for epoch end
self.track_epoch_end_reduce_metrics(self.epoch_output, batch_end_outputs)
self._track_epoch_end_reduce_metrics(self._epoch_output, batch_end_outputs)
# -----------------------------------------
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR
@ -153,7 +150,7 @@ class TrainingEpochLoop(Loop):
# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
should_check_val = self.should_check_val_fx(self.iteration_count, self.is_last_batch)
should_check_val = self._should_check_val_fx(self.iteration_count, self.is_last_batch)
if should_check_val:
self.trainer.validating = True
self._run_validation()
@ -162,7 +159,7 @@ class TrainingEpochLoop(Loop):
# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
# -----------------------------------------
self.save_loggers_on_train_batch_end()
self._save_loggers_on_train_batch_end()
# update plateau LR scheduler after metrics are logged
self.update_lr_schedulers('step', update_plateau_schedulers=True)
@ -171,7 +168,7 @@ class TrainingEpochLoop(Loop):
self.total_batch_idx += 1
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
self._increment_accumulated_grad_global_step()
if self.done:
raise StopIteration
@ -200,7 +197,7 @@ class TrainingEpochLoop(Loop):
self.trainer.logger_connector.epoch_end_reached()
# prepare epoch output
processed_outputs = self._prepare_outputs(self.epoch_output, batch_mode=False)
processed_outputs = self._prepare_outputs(self._epoch_output, batch_mode=False)
# get the model and call model.training_epoch_end
model = self.trainer.lightning_module
@ -223,7 +220,7 @@ class TrainingEpochLoop(Loop):
self._on_train_epoch_end_hook(processed_outputs)
self.trainer.call_hook('on_epoch_end')
self.trainer.logger_connector.on_epoch_end()
return self.epoch_output
return self._epoch_output
def teardown(self) -> None:
"""Frees memory of tracked epoch outputs."""
@ -253,7 +250,7 @@ class TrainingEpochLoop(Loop):
if is_overridden(hook_name, model_ref):
hook_fx = getattr(model_ref, hook_name)
if is_param_in_hook_signature(hook_fx, "outputs"):
self.warning_cache.deprecation(
self._warning_cache.deprecation(
"The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
" `outputs` parameter has been deprecated."
" Support for the old signature will be removed in v1.5",
@ -276,7 +273,7 @@ class TrainingEpochLoop(Loop):
# TODO: Can we combine this with training_batch_loop's arg that does a similar check?
return self.batches_seen == self.trainer.num_training_batches or is_last_batch
def track_epoch_end_reduce_metrics(
def _track_epoch_end_reduce_metrics(
self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT
) -> None:
"""Adds the batch outputs to the epoch outputs and prepares reduction"""
@ -379,7 +376,7 @@ class TrainingEpochLoop(Loop):
opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)],
)
def increment_accumulated_grad_global_step(self) -> None:
def _increment_accumulated_grad_global_step(self) -> None:
"""increments global step"""
num_accumulated_batches_reached = self.batch_loop._accumulated_batches_reached()
num_training_batches_reached = self._num_training_batches_reached()
@ -390,7 +387,7 @@ class TrainingEpochLoop(Loop):
self.total_batch_idx, self.trainer.global_step
)
def should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
""" Decide if we should run validation. """
if not self.trainer.enable_validation:
return False
@ -415,7 +412,7 @@ class TrainingEpochLoop(Loop):
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0
return is_val_check_batch
def save_loggers_on_train_batch_end(self) -> None:
def _save_loggers_on_train_batch_end(self) -> None:
"""Flushes loggers to disk"""
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs

View File

@ -522,6 +522,10 @@ class TrainerProperties(ABC):
def min_steps(self) -> Optional[int]:
return self.fit_loop.min_steps
@property
def is_last_batch(self) -> bool:
return self.fit_loop.training_loop.is_last_batch
@property
def _active_loop(self) -> Optional[Union[FitLoop, EvaluationDataLoaderLoop]]:
if self.training:

View File

@ -242,7 +242,7 @@ def test_lightning_optimizer_automatic_optimization_optimizer_step(tmpdir):
...
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **_):
assert optimizer_closure.__name__ == "training_step_and_backward_closure"
assert optimizer_closure.__name__ == "_training_step_and_backward_closure"
# not passing the closure to the optimizer because step is mocked
# zero_grad is called inside the closure
if isinstance(optimizer, SGD) and batch_idx % 2 == 0:

View File

@ -244,7 +244,7 @@ def test_v1_5_0_old_on_train_epoch_end(tmpdir):
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)
trainer.fit_loop.training_loop.warning_cache.clear()
trainer.fit_loop.training_loop._warning_cache.clear()
class NewSignature(Callback):