extract optimizer loop (#9191)

This commit is contained in:
Adrian Wälchli 2021-09-02 13:40:05 +02:00 committed by GitHub
parent c86e6cf1c4
commit 75350938ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 431 additions and 232 deletions

View File

@ -65,7 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Loop customization:
* Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642))
* Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191))
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))

View File

@ -17,3 +17,4 @@ from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401

View File

@ -13,7 +13,7 @@
# limitations under the License.
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple
import numpy as np
import torch
@ -21,21 +21,16 @@ from deprecate import void
from torch import Tensor
from torch.optim import Optimizer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.closure import Closure, ClosureResult
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.utilities import (
_block_parallel_sync_behavior,
_build_training_step_kwargs,
_check_training_step_output,
_process_training_step_output,
check_finite_loss,
)
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
@ -50,13 +45,12 @@ class TrainingBatchLoop(Loop):
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20)
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx: Optional[int] = None
self.optim_progress = OptimizationProgress()
self.optimizer_loop = OptimizerLoop()
self._warning_cache: WarningCache = WarningCache()
self._hiddens: Optional[Tensor] = None
self._optimizer_freq_cumsum: Optional[int] = None
self._remaining_splits: Optional[List[Any]] = None
self._skip_backward: bool = False
@property
def done(self) -> bool:
@ -70,8 +64,8 @@ class TrainingBatchLoop(Loop):
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum
def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
def connect(self, optimizer_loop: "Loop") -> None:
self.optimizer_loop = optimizer_loop
def run(self, batch: Any, batch_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks
@ -132,17 +126,12 @@ class TrainingBatchLoop(Loop):
self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch)
if self.trainer.lightning_module.automatic_optimization:
for opt_idx, optimizer in self.get_active_optimizers(batch_idx):
# handle optimization restart
if self.restarting:
if opt_idx < self.optim_progress.optimizer_idx:
continue
self.optim_progress.optimizer_idx = opt_idx
result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
if result:
self.batch_outputs[opt_idx].append(deepcopy(result.result_collection))
# in automatic optimization, hand over execution to the OptimizerLoop
optimizers = [optimizer for _, optimizer in self.get_active_optimizers(batch_idx)]
batch_outputs, self._hiddens = self.optimizer_loop.run(split_batch, self._hiddens, optimizers, batch_idx)
# combine outputs from each optimizer
for k in range(len(batch_outputs)):
self.batch_outputs[k].extend(batch_outputs[k])
else:
# in manual optimization, there is no looping over optimizers
result = self._run_optimization(batch_idx, split_batch)
@ -161,42 +150,16 @@ class TrainingBatchLoop(Loop):
self,
batch_idx: int,
split_batch: Any,
opt_idx: Optional[int] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
) -> Optional[ClosureResult]:
"""Runs closure (train step + backward) together with optimization if necessary.
Args:
batch_idx: the index of the current batch
split_batch: the current tbptt split of the whole batch
opt_idx: the index of the current optimizer or `None` in case of manual optimization
optimizer: the current optimizer or `None` in case of manual optimization
"""
# toggle model params
self._run_optimization_start(opt_idx, optimizer)
closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)
if self.trainer.fit_loop.should_accumulate():
# For gradient accumulation
# -------------------
# calculate loss (train step + train step end)
# -------------------
# automatic_optimization: perform ddp sync only when performing optimizer_step
with _block_parallel_sync_behavior(self._trainer):
closure()
# ------------------------------
# BACKWARD PASS
# ------------------------------
# gradient update with accumulated gradients
else:
if self.trainer.lightning_module.automatic_optimization:
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
else:
closure()
# TODO: replace call through closure by direct call (manual optimization)
closure = self._make_closure(split_batch, batch_idx, self._hiddens)
closure()
result = closure.get_result()
if result:
@ -204,25 +167,21 @@ class TrainingBatchLoop(Loop):
# otherwise update running loss + reset accumulated loss
self._update_running_loss(result.loss)
# untoggle model params
self._run_optimization_end(opt_idx)
return result
def _make_closure(
self,
split_batch: Any,
batch_idx: int,
opt_idx: int,
optimizer: Optimizer,
hiddens: Any,
) -> Closure:
"""
Build a closure object that captures the given arguments and runs the `training_step` function and optionally
other functions such as `backward` and `zero_grad`.
"""
step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens)
backward_fn = self._make_backward_fn(optimizer, opt_idx)
zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer)
step_fn = self._make_step_fn(split_batch, batch_idx, hiddens)
backward_fn = None
zero_grad_fn = None
return Closure(
step_fn=step_fn,
@ -231,66 +190,27 @@ class TrainingBatchLoop(Loop):
profiler=self.trainer.profiler,
)
def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any) -> Callable[[], dict]:
def _make_step_fn(self, split_batch: Any, batch_idx: int, hiddens: Any) -> Callable[[], dict]:
"""Build the step function that runs the `training_step` and processes its output."""
return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens)
return partial(self._training_step, split_batch, batch_idx, hiddens)
def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]:
"""
Build a `zero_grad` function that zeroes the gradients before back-propagation.
Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
"""
def zero_grad_fn():
self._on_before_zero_grad(optimizer)
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
if (
not self._skip_backward
and self.trainer.lightning_module.automatic_optimization
and is_first_batch_to_accumulate
):
return zero_grad_fn
def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Callable[[Tensor], Tensor]]:
"""
Build a `backward` function that handles back-propagation through the output produced by the `training_step`
function. Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
"""
def backward_fn(loss: Tensor):
self.backward(loss, optimizer, opt_idx)
# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
check_finite_loss(self.trainer.lightning_module, loss)
return loss
if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
return backward_fn
def _training_step(
self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor
) -> Optional[AttributeDict]:
"""Performs the actual train step with the tied hooks.
def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> Optional[AttributeDict]:
"""Performs the training step for manual optimization.
Args:
split_batch: the current tbptt split of the current batch
batch_idx: the index of the current batch
opt_idx: the index of the current optimizer
hiddens: the model's hidden state of the previous iteration
Returns:
an AttributeDict containing the loss value and the training step output.
an AttributeDict containing the training step output.
"""
# give the PL module a result for logging
model_ref = self.trainer.lightning_module
with self.trainer.profiler.profile("model_forward"):
step_kwargs = _build_training_step_kwargs(
model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens
model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx=None, hiddens=hiddens
)
# manually capture logged metrics
@ -309,97 +229,7 @@ class TrainingBatchLoop(Loop):
if result_collection is None:
return
closure_loss = None
loss = None
if self.trainer.lightning_module.automatic_optimization:
# accumulate loss. if accumulate_grad_batches==1, no effect
closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches
# the loss will get scaled for amp. avoid any modifications to it
loss = closure_loss.detach().clone()
return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)
def _optimizer_step(
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
) -> None:
"""Performs the optimizer step and some sanity checking.
Args:
optimizer: the optimizer to perform the step with
opt_idx: the index of the current :param:`optimizer`
batch_idx: the index of the current batch
train_step_and_backward_closure: the closure function performing the train step and computing the
gradients. By default called by the optimizer (if possible)
"""
model_ref = self.trainer.lightning_module
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
# native amp + lbfgs is a no go right now
if using_native_amp and is_lbfgs:
raise MisconfigurationException(
"native PyTorch amp and lbfgs are not compatible."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
# wraps into LightningOptimizer only for running step
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
self.optim_progress.optimizer.step.increment_ready()
# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
batch_idx,
optimizer,
opt_idx,
train_step_and_backward_closure,
on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE),
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs,
)
self.optim_progress.optimizer.step.increment_completed()
def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""Calls the ``on_before_zero_grad`` hook.
Args:
optimizer: the current optimizer
"""
self.optim_progress.optimizer.zero_grad.increment_ready()
self.trainer.call_hook("on_before_zero_grad", optimizer)
self.optim_progress.optimizer.zero_grad.increment_started()
def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
"""Zeroes out all gradients of parameters optimized by the current optimizer.
Args:
batch_idx: the index of the current batch
optimizer: the current optimizer
opt_idx: the index of the current optimizer
"""
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.optim_progress.optimizer.zero_grad.increment_completed()
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
Args:
optimizer: the current optimizer
"""
# track gradient norms
grad_norm_dict = {}
can_log = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
should_track = float(self.trainer.track_grad_norm) > 0
if should_track and can_log:
grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm)
# clip gradients
self.trainer.accelerator.clip_gradients(
optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
)
return grad_norm_dict
return AttributeDict(closure_loss=None, loss=None, result_collection=result_collection)
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
"""Splits a single batch into a list of sequence steps for tbptt.
@ -416,25 +246,7 @@ class TrainingBatchLoop(Loop):
splits = model_ref.tbptt_split_batch(batch, tbptt_steps)
return splits
def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
"""Toggles the optimizer to ensure the correct one is used and prevend dangling grads.
Args:
opt_idx: the index of the optimizer to use
optimizer: the optimizer to use
"""
# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
model = self.trainer.lightning_module
model.toggle_optimizer(optimizer, opt_idx)
def _run_optimization_end(self, opt_idx: int) -> None:
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
model = self.trainer.lightning_module
model.untoggle_optimizer(opt_idx)
# TODO: remove this method and update tests
def backward(
self,
loss: Tensor,
@ -451,13 +263,6 @@ class TrainingBatchLoop(Loop):
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
"""
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
if not self.trainer.fit_loop.should_accumulate():
# track gradients
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
if grad_norm_dict:
self.trainer.lightning_module._current_fx_name = "on_after_backward"
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
return loss
def _update_running_loss(self, current_loss: Tensor) -> None:

View File

@ -95,7 +95,7 @@ class TrainingEpochLoop(loops.Loop):
if not self.restarting:
self.batch_progress.current.reset()
self.scheduler_progress.current.reset()
self.batch_loop.optim_progress.reset_on_epoch()
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()
def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
# hook

View File

@ -108,12 +108,12 @@ class FitLoop(Loop):
@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
return self.epoch_loop.batch_loop._skip_backward
return self.epoch_loop.batch_loop.optimizer_loop._skip_backward
@_skip_backward.setter
def _skip_backward(self, value: bool) -> None:
"""Determines whether the loop will skip backward during automatic optimization."""
self.epoch_loop.batch_loop._skip_backward = value
self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value
@property
def _results(self) -> ResultCollection:

View File

@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401

View File

@ -0,0 +1,376 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
from torch.optim import Optimizer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.closure import Closure, ClosureResult
from pytorch_lightning.loops.utilities import (
_block_parallel_sync_behavior,
_build_training_step_kwargs,
_check_training_step_output,
_process_training_step_output,
check_finite_loss,
)
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
_OUTPUTS_TYPE = List[List[Optional[ResultCollection]]]
class OptimizerLoop(Loop):
"""Runs over a sequence of optimizers. This loop implements what is known in Lightning as Automatic Optimization."""
def __init__(self):
super().__init__()
# TODO: use default dict here to simplify logic in loop
self.outputs: _OUTPUTS_TYPE = []
self.optim_progress: OptimizationProgress = OptimizationProgress()
self._skip_backward: bool = False
self._batch_idx: Optional[int] = None
self._optimizers: Optional[List[Optimizer]] = None
self._hiddens: Optional[Any] = None
@property
def done(self) -> bool:
"""Returns ``True`` when the last optimizer in the sequence has run."""
return self.optim_progress.optimizer_idx >= len(self._optimizers)
def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
def reset(self) -> None:
if not self.restarting:
self.optim_progress.optimizer_idx = 0
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]
def on_run_start(self, batch: Any, hiddens: Any, optimizers: List[Optimizer], batch_idx: int) -> None:
self._batch_idx = batch_idx
self._optimizers = optimizers
def advance(self, batch: Any, hiddens: Any, *args, **kwargs) -> None:
self._hiddens = hiddens
result = self._run_optimization(
self._batch_idx,
batch,
self.optim_progress.optimizer_idx,
self._optimizers[self.optim_progress.optimizer_idx],
)
if result:
self.outputs[self.optim_progress.optimizer_idx].append(deepcopy(result.result_collection))
self.optim_progress.optimizer_idx += 1
def on_run_end(self) -> Tuple[_OUTPUTS_TYPE, Optional[Any]]:
outputs = self.outputs
hiddens = self._hiddens
# free memory
self.outputs = []
self._hiddens = None
return outputs, hiddens
def backward(
self,
loss: Tensor,
optimizer: Optional[torch.optim.Optimizer],
opt_idx: Optional[int] = None,
*args: Any,
**kwargs: Any,
) -> Tensor:
"""Performs the backward step.
Args:
loss: The loss value to back-propagate on
optimizer: Current optimizer being used. ``None`` if using manual optimization.
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
"""
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
if not self.trainer.fit_loop.should_accumulate():
# track gradients
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
if grad_norm_dict:
self.trainer.lightning_module._current_fx_name = "on_after_backward"
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
return loss
def _run_optimization(
self,
batch_idx: int,
split_batch: Any,
opt_idx: Optional[int] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
) -> Optional[ClosureResult]:
"""Runs closure (train step + backward) together with optimization if necessary.
Args:
batch_idx: the index of the current batch
split_batch: the current tbptt split of the whole batch
opt_idx: the index of the current optimizer or `None` in case of manual optimization
optimizer: the current optimizer or `None` in case of manual optimization
"""
# toggle model params
self._run_optimization_start(opt_idx, optimizer)
closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)
if self.trainer.fit_loop.should_accumulate():
# For gradient accumulation
# -------------------
# calculate loss (train step + train step end)
# -------------------
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
with _block_parallel_sync_behavior(self.trainer, block=True):
closure()
# ------------------------------
# BACKWARD PASS
# ------------------------------
# gradient update with accumulated gradients
else:
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
result = closure.get_result()
if result:
# if no result, user decided to skip optimization
# otherwise update running loss + reset accumulated loss
# TODO: find proper way to handle updating running loss
self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss)
# untoggle model params
self._run_optimization_end(opt_idx)
return result
def _make_closure(
self,
split_batch: Any,
batch_idx: int,
opt_idx: int,
optimizer: Optimizer,
hiddens: Any,
) -> Closure:
"""
Build a closure object that captures the given arguments and runs the `training_step` function and optionally
other functions such as `backward` and `zero_grad`.
"""
step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens)
backward_fn = self._make_backward_fn(optimizer, opt_idx)
zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer)
return Closure(
step_fn=step_fn,
backward_fn=backward_fn,
zero_grad_fn=zero_grad_fn,
profiler=self.trainer.profiler,
)
def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any) -> Callable[[], dict]:
"""Build the step function that runs the `training_step` and processes its output."""
return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens)
def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]:
"""
Build a `zero_grad` function that zeroes the gradients before back-propagation.
Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
"""
def zero_grad_fn():
self._on_before_zero_grad(optimizer)
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
if not self._skip_backward and is_first_batch_to_accumulate:
return zero_grad_fn
def _make_backward_fn(
self,
optimizer: Optimizer,
opt_idx: int,
) -> Optional[Callable[[Tensor], Tensor]]:
"""
Build a `backward` function that handles back-propagation through the output produced by the `training_step`
function. Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
"""
def backward_fn(loss: Tensor):
self.backward(loss, optimizer, opt_idx)
# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
check_finite_loss(self.trainer.lightning_module, loss)
return loss
if not self._skip_backward:
return backward_fn
def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
"""Toggles the optimizer to ensure the correct one is used and prevend dangling grads.
Args:
opt_idx: the index of the optimizer to use
optimizer: the optimizer to use
"""
# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if len(self.trainer.optimizers) > 1:
model = self.trainer.lightning_module
model.toggle_optimizer(optimizer, opt_idx)
def _run_optimization_end(self, opt_idx: int) -> None:
if len(self.trainer.optimizers) > 1:
model = self.trainer.lightning_module
model.untoggle_optimizer(opt_idx)
def _optimizer_step(
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
) -> None:
"""Performs the optimizer step and some sanity checking.
Args:
optimizer: the optimizer to perform the step with
opt_idx: the index of the current :param:`optimizer`
batch_idx: the index of the current batch
train_step_and_backward_closure: the closure function performing the train step and computing the
gradients. By default called by the optimizer (if possible)
"""
model_ref = self.trainer.lightning_module
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
# native amp + lbfgs is a no go right now
if using_native_amp and is_lbfgs:
raise MisconfigurationException(
"native PyTorch amp and lbfgs are not compatible."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
# wraps into LightningOptimizer only for running step
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
self.optim_progress.optimizer.step.increment_ready()
# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
batch_idx,
optimizer,
opt_idx,
train_step_and_backward_closure,
on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE),
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs,
)
self.optim_progress.optimizer.step.increment_completed()
def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""Calls the ``on_before_zero_grad`` hook.
Args:
optimizer: the current optimizer
"""
self.optim_progress.optimizer.zero_grad.increment_ready()
self.trainer.call_hook("on_before_zero_grad", optimizer)
self.optim_progress.optimizer.zero_grad.increment_started()
def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
"""Zeroes out all gradients of parameters optimized by the current optimizer.
Args:
batch_idx: the index of the current batch
optimizer: the current optimizer
opt_idx: the index of the current optimizer
"""
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.optim_progress.optimizer.zero_grad.increment_completed()
def _training_step(
self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor
) -> Optional[AttributeDict]:
"""Performs the actual train step with the tied hooks.
Args:
split_batch: the current tbptt split of the current batch
batch_idx: the index of the current batch
opt_idx: the index of the current optimizer
hiddens: the model's hidden state of the previous iteration
Returns:
an AttributeDict containing the loss value and the training step output.
"""
# give the PL module a result for logging
model_ref = self.trainer.lightning_module
with self.trainer.profiler.profile("model_forward"):
step_kwargs = _build_training_step_kwargs(
self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens
)
# manually capture logged metrics
model_ref._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()
del step_kwargs
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
_check_training_step_output(self.trainer.lightning_module, training_step_output)
result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
if result_collection is None:
return
# accumulate loss. if accumulate_grad_batches==1, no effect
closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches
# the loss will get scaled for amp. avoid any modifications to it
loss = closure_loss.detach().clone()
return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
Args:
optimizer: the current optimizer
"""
# track gradient norms
grad_norm_dict = {}
can_log = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
should_track = float(self.trainer.track_grad_norm) > 0
if should_track and can_log:
grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm)
# clip gradients
self.trainer.accelerator.clip_gradients(
optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
)
return grad_norm_dict

View File

@ -56,8 +56,9 @@ def test_loops_state_dict_structure():
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_loop.optim_progress": {
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},

View File

@ -382,7 +382,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress
sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress
# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
@ -461,7 +461,8 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
"current": {"ready": be_sch_steps, "started": None, "processed": None, "completed": be_sch_steps},
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.optim_progress": {
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer_idx": stop_optimizer,
"optimizer": {
"step": {

View File

@ -78,7 +78,7 @@ def test__eval_step__flow(tmpdir):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
opt_closure_result = opt_closure()
@ -145,7 +145,7 @@ def test__eval_step__eval_step_end__flow(tmpdir):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
opt_closure_result = opt_closure()

View File

@ -157,7 +157,7 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
opt_closure_result = opt_closure()
@ -231,7 +231,7 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
assert train_step_out.minimize.item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
opt_closure_result = opt_closure()