extract optimizer loop (#9191)
This commit is contained in:
parent
c86e6cf1c4
commit
75350938ca
|
@ -65,7 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Loop customization:
|
||||
* Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642))
|
||||
|
||||
* Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191))
|
||||
|
||||
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
|
||||
|
||||
|
|
|
@ -17,3 +17,4 @@ from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
|
|||
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
|
||||
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -21,21 +21,16 @@ from deprecate import void
|
|||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loops.base import Loop
|
||||
from pytorch_lightning.loops.closure import Closure, ClosureResult
|
||||
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
|
||||
from pytorch_lightning.loops.utilities import (
|
||||
_block_parallel_sync_behavior,
|
||||
_build_training_step_kwargs,
|
||||
_check_training_step_output,
|
||||
_process_training_step_output,
|
||||
check_finite_loss,
|
||||
)
|
||||
from pytorch_lightning.trainer.progress import OptimizationProgress
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities import AttributeDict
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
|
@ -50,13 +45,12 @@ class TrainingBatchLoop(Loop):
|
|||
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20)
|
||||
# the current split index when the batch gets split into chunks in truncated backprop through time
|
||||
self.split_idx: Optional[int] = None
|
||||
self.optim_progress = OptimizationProgress()
|
||||
self.optimizer_loop = OptimizerLoop()
|
||||
|
||||
self._warning_cache: WarningCache = WarningCache()
|
||||
self._hiddens: Optional[Tensor] = None
|
||||
self._optimizer_freq_cumsum: Optional[int] = None
|
||||
self._remaining_splits: Optional[List[Any]] = None
|
||||
self._skip_backward: bool = False
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
|
@ -70,8 +64,8 @@ class TrainingBatchLoop(Loop):
|
|||
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
|
||||
return self._optimizer_freq_cumsum
|
||||
|
||||
def connect(self, **kwargs: "Loop") -> None:
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
|
||||
def connect(self, optimizer_loop: "Loop") -> None:
|
||||
self.optimizer_loop = optimizer_loop
|
||||
|
||||
def run(self, batch: Any, batch_idx: int) -> AttributeDict:
|
||||
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks
|
||||
|
@ -132,17 +126,12 @@ class TrainingBatchLoop(Loop):
|
|||
self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch)
|
||||
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
for opt_idx, optimizer in self.get_active_optimizers(batch_idx):
|
||||
# handle optimization restart
|
||||
if self.restarting:
|
||||
if opt_idx < self.optim_progress.optimizer_idx:
|
||||
continue
|
||||
|
||||
self.optim_progress.optimizer_idx = opt_idx
|
||||
|
||||
result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
|
||||
if result:
|
||||
self.batch_outputs[opt_idx].append(deepcopy(result.result_collection))
|
||||
# in automatic optimization, hand over execution to the OptimizerLoop
|
||||
optimizers = [optimizer for _, optimizer in self.get_active_optimizers(batch_idx)]
|
||||
batch_outputs, self._hiddens = self.optimizer_loop.run(split_batch, self._hiddens, optimizers, batch_idx)
|
||||
# combine outputs from each optimizer
|
||||
for k in range(len(batch_outputs)):
|
||||
self.batch_outputs[k].extend(batch_outputs[k])
|
||||
else:
|
||||
# in manual optimization, there is no looping over optimizers
|
||||
result = self._run_optimization(batch_idx, split_batch)
|
||||
|
@ -161,42 +150,16 @@ class TrainingBatchLoop(Loop):
|
|||
self,
|
||||
batch_idx: int,
|
||||
split_batch: Any,
|
||||
opt_idx: Optional[int] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Optional[ClosureResult]:
|
||||
"""Runs closure (train step + backward) together with optimization if necessary.
|
||||
|
||||
Args:
|
||||
batch_idx: the index of the current batch
|
||||
split_batch: the current tbptt split of the whole batch
|
||||
opt_idx: the index of the current optimizer or `None` in case of manual optimization
|
||||
optimizer: the current optimizer or `None` in case of manual optimization
|
||||
"""
|
||||
# toggle model params
|
||||
self._run_optimization_start(opt_idx, optimizer)
|
||||
|
||||
closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)
|
||||
|
||||
if self.trainer.fit_loop.should_accumulate():
|
||||
# For gradient accumulation
|
||||
|
||||
# -------------------
|
||||
# calculate loss (train step + train step end)
|
||||
# -------------------
|
||||
# automatic_optimization: perform ddp sync only when performing optimizer_step
|
||||
with _block_parallel_sync_behavior(self._trainer):
|
||||
closure()
|
||||
|
||||
# ------------------------------
|
||||
# BACKWARD PASS
|
||||
# ------------------------------
|
||||
# gradient update with accumulated gradients
|
||||
else:
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
|
||||
else:
|
||||
closure()
|
||||
|
||||
# TODO: replace call through closure by direct call (manual optimization)
|
||||
closure = self._make_closure(split_batch, batch_idx, self._hiddens)
|
||||
closure()
|
||||
result = closure.get_result()
|
||||
|
||||
if result:
|
||||
|
@ -204,25 +167,21 @@ class TrainingBatchLoop(Loop):
|
|||
# otherwise update running loss + reset accumulated loss
|
||||
self._update_running_loss(result.loss)
|
||||
|
||||
# untoggle model params
|
||||
self._run_optimization_end(opt_idx)
|
||||
return result
|
||||
|
||||
def _make_closure(
|
||||
self,
|
||||
split_batch: Any,
|
||||
batch_idx: int,
|
||||
opt_idx: int,
|
||||
optimizer: Optimizer,
|
||||
hiddens: Any,
|
||||
) -> Closure:
|
||||
"""
|
||||
Build a closure object that captures the given arguments and runs the `training_step` function and optionally
|
||||
other functions such as `backward` and `zero_grad`.
|
||||
"""
|
||||
step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens)
|
||||
backward_fn = self._make_backward_fn(optimizer, opt_idx)
|
||||
zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer)
|
||||
step_fn = self._make_step_fn(split_batch, batch_idx, hiddens)
|
||||
backward_fn = None
|
||||
zero_grad_fn = None
|
||||
|
||||
return Closure(
|
||||
step_fn=step_fn,
|
||||
|
@ -231,66 +190,27 @@ class TrainingBatchLoop(Loop):
|
|||
profiler=self.trainer.profiler,
|
||||
)
|
||||
|
||||
def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any) -> Callable[[], dict]:
|
||||
def _make_step_fn(self, split_batch: Any, batch_idx: int, hiddens: Any) -> Callable[[], dict]:
|
||||
"""Build the step function that runs the `training_step` and processes its output."""
|
||||
return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens)
|
||||
return partial(self._training_step, split_batch, batch_idx, hiddens)
|
||||
|
||||
def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]:
|
||||
"""
|
||||
Build a `zero_grad` function that zeroes the gradients before back-propagation.
|
||||
Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
|
||||
"""
|
||||
|
||||
def zero_grad_fn():
|
||||
self._on_before_zero_grad(optimizer)
|
||||
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
|
||||
|
||||
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
|
||||
if (
|
||||
not self._skip_backward
|
||||
and self.trainer.lightning_module.automatic_optimization
|
||||
and is_first_batch_to_accumulate
|
||||
):
|
||||
return zero_grad_fn
|
||||
|
||||
def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Callable[[Tensor], Tensor]]:
|
||||
"""
|
||||
Build a `backward` function that handles back-propagation through the output produced by the `training_step`
|
||||
function. Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
|
||||
"""
|
||||
|
||||
def backward_fn(loss: Tensor):
|
||||
self.backward(loss, optimizer, opt_idx)
|
||||
|
||||
# check if loss or model weights are nan
|
||||
if self.trainer.terminate_on_nan:
|
||||
check_finite_loss(self.trainer.lightning_module, loss)
|
||||
|
||||
return loss
|
||||
|
||||
if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
|
||||
return backward_fn
|
||||
|
||||
def _training_step(
|
||||
self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor
|
||||
) -> Optional[AttributeDict]:
|
||||
"""Performs the actual train step with the tied hooks.
|
||||
def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> Optional[AttributeDict]:
|
||||
"""Performs the training step for manual optimization.
|
||||
|
||||
Args:
|
||||
split_batch: the current tbptt split of the current batch
|
||||
batch_idx: the index of the current batch
|
||||
opt_idx: the index of the current optimizer
|
||||
hiddens: the model's hidden state of the previous iteration
|
||||
|
||||
Returns:
|
||||
an AttributeDict containing the loss value and the training step output.
|
||||
an AttributeDict containing the training step output.
|
||||
"""
|
||||
# give the PL module a result for logging
|
||||
model_ref = self.trainer.lightning_module
|
||||
|
||||
with self.trainer.profiler.profile("model_forward"):
|
||||
step_kwargs = _build_training_step_kwargs(
|
||||
model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens
|
||||
model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx=None, hiddens=hiddens
|
||||
)
|
||||
|
||||
# manually capture logged metrics
|
||||
|
@ -309,97 +229,7 @@ class TrainingBatchLoop(Loop):
|
|||
if result_collection is None:
|
||||
return
|
||||
|
||||
closure_loss = None
|
||||
loss = None
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
# accumulate loss. if accumulate_grad_batches==1, no effect
|
||||
closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches
|
||||
# the loss will get scaled for amp. avoid any modifications to it
|
||||
loss = closure_loss.detach().clone()
|
||||
return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)
|
||||
|
||||
def _optimizer_step(
|
||||
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
|
||||
) -> None:
|
||||
"""Performs the optimizer step and some sanity checking.
|
||||
|
||||
Args:
|
||||
optimizer: the optimizer to perform the step with
|
||||
opt_idx: the index of the current :param:`optimizer`
|
||||
batch_idx: the index of the current batch
|
||||
train_step_and_backward_closure: the closure function performing the train step and computing the
|
||||
gradients. By default called by the optimizer (if possible)
|
||||
"""
|
||||
model_ref = self.trainer.lightning_module
|
||||
|
||||
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
|
||||
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
|
||||
|
||||
# native amp + lbfgs is a no go right now
|
||||
if using_native_amp and is_lbfgs:
|
||||
raise MisconfigurationException(
|
||||
"native PyTorch amp and lbfgs are not compatible."
|
||||
" To request, please file a Github issue in PyTorch and tag @mcarilli"
|
||||
)
|
||||
|
||||
# wraps into LightningOptimizer only for running step
|
||||
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
|
||||
|
||||
self.optim_progress.optimizer.step.increment_ready()
|
||||
|
||||
# model hook
|
||||
model_ref.optimizer_step(
|
||||
self.trainer.current_epoch,
|
||||
batch_idx,
|
||||
optimizer,
|
||||
opt_idx,
|
||||
train_step_and_backward_closure,
|
||||
on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE),
|
||||
using_native_amp=using_native_amp,
|
||||
using_lbfgs=is_lbfgs,
|
||||
)
|
||||
|
||||
self.optim_progress.optimizer.step.increment_completed()
|
||||
|
||||
def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
|
||||
"""Calls the ``on_before_zero_grad`` hook.
|
||||
|
||||
Args:
|
||||
optimizer: the current optimizer
|
||||
"""
|
||||
self.optim_progress.optimizer.zero_grad.increment_ready()
|
||||
self.trainer.call_hook("on_before_zero_grad", optimizer)
|
||||
self.optim_progress.optimizer.zero_grad.increment_started()
|
||||
|
||||
def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
|
||||
"""Zeroes out all gradients of parameters optimized by the current optimizer.
|
||||
|
||||
Args:
|
||||
batch_idx: the index of the current batch
|
||||
optimizer: the current optimizer
|
||||
opt_idx: the index of the current optimizer
|
||||
"""
|
||||
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||
self.optim_progress.optimizer.zero_grad.increment_completed()
|
||||
|
||||
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
|
||||
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
|
||||
|
||||
Args:
|
||||
optimizer: the current optimizer
|
||||
"""
|
||||
# track gradient norms
|
||||
grad_norm_dict = {}
|
||||
can_log = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
|
||||
should_track = float(self.trainer.track_grad_norm) > 0
|
||||
if should_track and can_log:
|
||||
grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm)
|
||||
|
||||
# clip gradients
|
||||
self.trainer.accelerator.clip_gradients(
|
||||
optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
|
||||
)
|
||||
return grad_norm_dict
|
||||
return AttributeDict(closure_loss=None, loss=None, result_collection=result_collection)
|
||||
|
||||
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
|
||||
"""Splits a single batch into a list of sequence steps for tbptt.
|
||||
|
@ -416,25 +246,7 @@ class TrainingBatchLoop(Loop):
|
|||
splits = model_ref.tbptt_split_batch(batch, tbptt_steps)
|
||||
return splits
|
||||
|
||||
def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
|
||||
"""Toggles the optimizer to ensure the correct one is used and prevend dangling grads.
|
||||
|
||||
Args:
|
||||
opt_idx: the index of the optimizer to use
|
||||
optimizer: the optimizer to use
|
||||
|
||||
"""
|
||||
# make sure only the gradients of the current optimizer's parameters are calculated
|
||||
# in the training step to prevent dangling gradients in multiple-optimizer setup.
|
||||
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
|
||||
model = self.trainer.lightning_module
|
||||
model.toggle_optimizer(optimizer, opt_idx)
|
||||
|
||||
def _run_optimization_end(self, opt_idx: int) -> None:
|
||||
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
|
||||
model = self.trainer.lightning_module
|
||||
model.untoggle_optimizer(opt_idx)
|
||||
|
||||
# TODO: remove this method and update tests
|
||||
def backward(
|
||||
self,
|
||||
loss: Tensor,
|
||||
|
@ -451,13 +263,6 @@ class TrainingBatchLoop(Loop):
|
|||
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
|
||||
"""
|
||||
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
|
||||
|
||||
if not self.trainer.fit_loop.should_accumulate():
|
||||
# track gradients
|
||||
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
|
||||
if grad_norm_dict:
|
||||
self.trainer.lightning_module._current_fx_name = "on_after_backward"
|
||||
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
|
||||
return loss
|
||||
|
||||
def _update_running_loss(self, current_loss: Tensor) -> None:
|
||||
|
|
|
@ -95,7 +95,7 @@ class TrainingEpochLoop(loops.Loop):
|
|||
if not self.restarting:
|
||||
self.batch_progress.current.reset()
|
||||
self.scheduler_progress.current.reset()
|
||||
self.batch_loop.optim_progress.reset_on_epoch()
|
||||
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()
|
||||
|
||||
def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
|
||||
# hook
|
||||
|
|
|
@ -108,12 +108,12 @@ class FitLoop(Loop):
|
|||
@property
|
||||
def _skip_backward(self) -> bool:
|
||||
"""Determines whether the loop will skip backward during automatic optimization."""
|
||||
return self.epoch_loop.batch_loop._skip_backward
|
||||
return self.epoch_loop.batch_loop.optimizer_loop._skip_backward
|
||||
|
||||
@_skip_backward.setter
|
||||
def _skip_backward(self, value: bool) -> None:
|
||||
"""Determines whether the loop will skip backward during automatic optimization."""
|
||||
self.epoch_loop.batch_loop._skip_backward = value
|
||||
self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value
|
||||
|
||||
@property
|
||||
def _results(self) -> ResultCollection:
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
|
|
@ -0,0 +1,376 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops.closure import Closure, ClosureResult
|
||||
from pytorch_lightning.loops.utilities import (
|
||||
_block_parallel_sync_behavior,
|
||||
_build_training_step_kwargs,
|
||||
_check_training_step_output,
|
||||
_process_training_step_output,
|
||||
check_finite_loss,
|
||||
)
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.progress import OptimizationProgress
|
||||
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
|
||||
|
||||
_OUTPUTS_TYPE = List[List[Optional[ResultCollection]]]
|
||||
|
||||
|
||||
class OptimizerLoop(Loop):
|
||||
"""Runs over a sequence of optimizers. This loop implements what is known in Lightning as Automatic Optimization."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# TODO: use default dict here to simplify logic in loop
|
||||
self.outputs: _OUTPUTS_TYPE = []
|
||||
self.optim_progress: OptimizationProgress = OptimizationProgress()
|
||||
|
||||
self._skip_backward: bool = False
|
||||
self._batch_idx: Optional[int] = None
|
||||
self._optimizers: Optional[List[Optimizer]] = None
|
||||
self._hiddens: Optional[Any] = None
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Returns ``True`` when the last optimizer in the sequence has run."""
|
||||
return self.optim_progress.optimizer_idx >= len(self._optimizers)
|
||||
|
||||
def connect(self, **kwargs: "Loop") -> None:
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
|
||||
|
||||
def reset(self) -> None:
|
||||
if not self.restarting:
|
||||
self.optim_progress.optimizer_idx = 0
|
||||
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]
|
||||
|
||||
def on_run_start(self, batch: Any, hiddens: Any, optimizers: List[Optimizer], batch_idx: int) -> None:
|
||||
self._batch_idx = batch_idx
|
||||
self._optimizers = optimizers
|
||||
|
||||
def advance(self, batch: Any, hiddens: Any, *args, **kwargs) -> None:
|
||||
self._hiddens = hiddens
|
||||
result = self._run_optimization(
|
||||
self._batch_idx,
|
||||
batch,
|
||||
self.optim_progress.optimizer_idx,
|
||||
self._optimizers[self.optim_progress.optimizer_idx],
|
||||
)
|
||||
if result:
|
||||
self.outputs[self.optim_progress.optimizer_idx].append(deepcopy(result.result_collection))
|
||||
|
||||
self.optim_progress.optimizer_idx += 1
|
||||
|
||||
def on_run_end(self) -> Tuple[_OUTPUTS_TYPE, Optional[Any]]:
|
||||
outputs = self.outputs
|
||||
hiddens = self._hiddens
|
||||
# free memory
|
||||
self.outputs = []
|
||||
self._hiddens = None
|
||||
return outputs, hiddens
|
||||
|
||||
def backward(
|
||||
self,
|
||||
loss: Tensor,
|
||||
optimizer: Optional[torch.optim.Optimizer],
|
||||
opt_idx: Optional[int] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Tensor:
|
||||
"""Performs the backward step.
|
||||
|
||||
Args:
|
||||
loss: The loss value to back-propagate on
|
||||
optimizer: Current optimizer being used. ``None`` if using manual optimization.
|
||||
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
|
||||
"""
|
||||
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
|
||||
|
||||
if not self.trainer.fit_loop.should_accumulate():
|
||||
# track gradients
|
||||
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
|
||||
if grad_norm_dict:
|
||||
self.trainer.lightning_module._current_fx_name = "on_after_backward"
|
||||
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
|
||||
return loss
|
||||
|
||||
def _run_optimization(
|
||||
self,
|
||||
batch_idx: int,
|
||||
split_batch: Any,
|
||||
opt_idx: Optional[int] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Optional[ClosureResult]:
|
||||
"""Runs closure (train step + backward) together with optimization if necessary.
|
||||
|
||||
Args:
|
||||
batch_idx: the index of the current batch
|
||||
split_batch: the current tbptt split of the whole batch
|
||||
opt_idx: the index of the current optimizer or `None` in case of manual optimization
|
||||
optimizer: the current optimizer or `None` in case of manual optimization
|
||||
"""
|
||||
# toggle model params
|
||||
self._run_optimization_start(opt_idx, optimizer)
|
||||
|
||||
closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)
|
||||
|
||||
if self.trainer.fit_loop.should_accumulate():
|
||||
# For gradient accumulation
|
||||
|
||||
# -------------------
|
||||
# calculate loss (train step + train step end)
|
||||
# -------------------
|
||||
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
|
||||
with _block_parallel_sync_behavior(self.trainer, block=True):
|
||||
closure()
|
||||
|
||||
# ------------------------------
|
||||
# BACKWARD PASS
|
||||
# ------------------------------
|
||||
# gradient update with accumulated gradients
|
||||
else:
|
||||
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
|
||||
|
||||
result = closure.get_result()
|
||||
|
||||
if result:
|
||||
# if no result, user decided to skip optimization
|
||||
# otherwise update running loss + reset accumulated loss
|
||||
# TODO: find proper way to handle updating running loss
|
||||
self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss)
|
||||
|
||||
# untoggle model params
|
||||
self._run_optimization_end(opt_idx)
|
||||
return result
|
||||
|
||||
def _make_closure(
|
||||
self,
|
||||
split_batch: Any,
|
||||
batch_idx: int,
|
||||
opt_idx: int,
|
||||
optimizer: Optimizer,
|
||||
hiddens: Any,
|
||||
) -> Closure:
|
||||
"""
|
||||
Build a closure object that captures the given arguments and runs the `training_step` function and optionally
|
||||
other functions such as `backward` and `zero_grad`.
|
||||
"""
|
||||
step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens)
|
||||
backward_fn = self._make_backward_fn(optimizer, opt_idx)
|
||||
zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer)
|
||||
|
||||
return Closure(
|
||||
step_fn=step_fn,
|
||||
backward_fn=backward_fn,
|
||||
zero_grad_fn=zero_grad_fn,
|
||||
profiler=self.trainer.profiler,
|
||||
)
|
||||
|
||||
def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any) -> Callable[[], dict]:
|
||||
"""Build the step function that runs the `training_step` and processes its output."""
|
||||
return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens)
|
||||
|
||||
def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]:
|
||||
"""
|
||||
Build a `zero_grad` function that zeroes the gradients before back-propagation.
|
||||
Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
|
||||
"""
|
||||
|
||||
def zero_grad_fn():
|
||||
self._on_before_zero_grad(optimizer)
|
||||
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
|
||||
|
||||
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
|
||||
if not self._skip_backward and is_first_batch_to_accumulate:
|
||||
return zero_grad_fn
|
||||
|
||||
def _make_backward_fn(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
opt_idx: int,
|
||||
) -> Optional[Callable[[Tensor], Tensor]]:
|
||||
"""
|
||||
Build a `backward` function that handles back-propagation through the output produced by the `training_step`
|
||||
function. Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
|
||||
"""
|
||||
|
||||
def backward_fn(loss: Tensor):
|
||||
self.backward(loss, optimizer, opt_idx)
|
||||
|
||||
# check if loss or model weights are nan
|
||||
if self.trainer.terminate_on_nan:
|
||||
check_finite_loss(self.trainer.lightning_module, loss)
|
||||
|
||||
return loss
|
||||
|
||||
if not self._skip_backward:
|
||||
return backward_fn
|
||||
|
||||
def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
|
||||
"""Toggles the optimizer to ensure the correct one is used and prevend dangling grads.
|
||||
|
||||
Args:
|
||||
opt_idx: the index of the optimizer to use
|
||||
optimizer: the optimizer to use
|
||||
|
||||
"""
|
||||
# make sure only the gradients of the current optimizer's parameters are calculated
|
||||
# in the training step to prevent dangling gradients in multiple-optimizer setup.
|
||||
if len(self.trainer.optimizers) > 1:
|
||||
model = self.trainer.lightning_module
|
||||
model.toggle_optimizer(optimizer, opt_idx)
|
||||
|
||||
def _run_optimization_end(self, opt_idx: int) -> None:
|
||||
if len(self.trainer.optimizers) > 1:
|
||||
model = self.trainer.lightning_module
|
||||
model.untoggle_optimizer(opt_idx)
|
||||
|
||||
def _optimizer_step(
|
||||
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
|
||||
) -> None:
|
||||
"""Performs the optimizer step and some sanity checking.
|
||||
|
||||
Args:
|
||||
optimizer: the optimizer to perform the step with
|
||||
opt_idx: the index of the current :param:`optimizer`
|
||||
batch_idx: the index of the current batch
|
||||
train_step_and_backward_closure: the closure function performing the train step and computing the
|
||||
gradients. By default called by the optimizer (if possible)
|
||||
"""
|
||||
model_ref = self.trainer.lightning_module
|
||||
|
||||
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
|
||||
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
|
||||
|
||||
# native amp + lbfgs is a no go right now
|
||||
if using_native_amp and is_lbfgs:
|
||||
raise MisconfigurationException(
|
||||
"native PyTorch amp and lbfgs are not compatible."
|
||||
" To request, please file a Github issue in PyTorch and tag @mcarilli"
|
||||
)
|
||||
|
||||
# wraps into LightningOptimizer only for running step
|
||||
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
|
||||
|
||||
self.optim_progress.optimizer.step.increment_ready()
|
||||
|
||||
# model hook
|
||||
model_ref.optimizer_step(
|
||||
self.trainer.current_epoch,
|
||||
batch_idx,
|
||||
optimizer,
|
||||
opt_idx,
|
||||
train_step_and_backward_closure,
|
||||
on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE),
|
||||
using_native_amp=using_native_amp,
|
||||
using_lbfgs=is_lbfgs,
|
||||
)
|
||||
|
||||
self.optim_progress.optimizer.step.increment_completed()
|
||||
|
||||
def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
|
||||
"""Calls the ``on_before_zero_grad`` hook.
|
||||
|
||||
Args:
|
||||
optimizer: the current optimizer
|
||||
"""
|
||||
self.optim_progress.optimizer.zero_grad.increment_ready()
|
||||
self.trainer.call_hook("on_before_zero_grad", optimizer)
|
||||
self.optim_progress.optimizer.zero_grad.increment_started()
|
||||
|
||||
def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
|
||||
"""Zeroes out all gradients of parameters optimized by the current optimizer.
|
||||
|
||||
Args:
|
||||
batch_idx: the index of the current batch
|
||||
optimizer: the current optimizer
|
||||
opt_idx: the index of the current optimizer
|
||||
"""
|
||||
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||
self.optim_progress.optimizer.zero_grad.increment_completed()
|
||||
|
||||
def _training_step(
|
||||
self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor
|
||||
) -> Optional[AttributeDict]:
|
||||
"""Performs the actual train step with the tied hooks.
|
||||
|
||||
Args:
|
||||
split_batch: the current tbptt split of the current batch
|
||||
batch_idx: the index of the current batch
|
||||
opt_idx: the index of the current optimizer
|
||||
hiddens: the model's hidden state of the previous iteration
|
||||
|
||||
Returns:
|
||||
an AttributeDict containing the loss value and the training step output.
|
||||
"""
|
||||
# give the PL module a result for logging
|
||||
model_ref = self.trainer.lightning_module
|
||||
|
||||
with self.trainer.profiler.profile("model_forward"):
|
||||
|
||||
step_kwargs = _build_training_step_kwargs(
|
||||
self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens
|
||||
)
|
||||
|
||||
# manually capture logged metrics
|
||||
model_ref._current_fx_name = "training_step"
|
||||
with self.trainer.profiler.profile("training_step"):
|
||||
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
|
||||
self.trainer.accelerator.post_training_step()
|
||||
|
||||
del step_kwargs
|
||||
|
||||
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
|
||||
|
||||
_check_training_step_output(self.trainer.lightning_module, training_step_output)
|
||||
|
||||
result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
|
||||
if result_collection is None:
|
||||
return
|
||||
|
||||
# accumulate loss. if accumulate_grad_batches==1, no effect
|
||||
closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches
|
||||
# the loss will get scaled for amp. avoid any modifications to it
|
||||
loss = closure_loss.detach().clone()
|
||||
return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)
|
||||
|
||||
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]:
|
||||
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
|
||||
|
||||
Args:
|
||||
optimizer: the current optimizer
|
||||
"""
|
||||
# track gradient norms
|
||||
grad_norm_dict = {}
|
||||
can_log = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
|
||||
should_track = float(self.trainer.track_grad_norm) > 0
|
||||
if should_track and can_log:
|
||||
grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm)
|
||||
|
||||
# clip gradients
|
||||
self.trainer.accelerator.clip_gradients(
|
||||
optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
|
||||
)
|
||||
return grad_norm_dict
|
|
@ -56,8 +56,9 @@ def test_loops_state_dict_structure():
|
|||
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
|
||||
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
|
||||
},
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optim_progress": {
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
|
||||
"optimizer": {
|
||||
"step": {
|
||||
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
|
||||
|
|
|
@ -382,7 +382,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
assert os.path.exists(ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path)
|
||||
|
||||
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress
|
||||
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress
|
||||
sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress
|
||||
|
||||
# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
|
||||
|
@ -461,7 +461,8 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
"current": {"ready": be_sch_steps, "started": None, "processed": None, "completed": be_sch_steps},
|
||||
},
|
||||
"epoch_loop.batch_loop.state_dict": ANY,
|
||||
"epoch_loop.batch_loop.optim_progress": {
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
|
||||
"optimizer_idx": stop_optimizer,
|
||||
"optimizer": {
|
||||
"step": {
|
||||
|
|
|
@ -78,7 +78,7 @@ def test__eval_step__flow(tmpdir):
|
|||
assert train_step_out.minimize.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
|
||||
)
|
||||
opt_closure_result = opt_closure()
|
||||
|
@ -145,7 +145,7 @@ def test__eval_step__eval_step_end__flow(tmpdir):
|
|||
assert train_step_out.minimize.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
|
||||
)
|
||||
opt_closure_result = opt_closure()
|
||||
|
|
|
@ -157,7 +157,7 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
|
|||
assert train_step_out.minimize.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
|
||||
)
|
||||
opt_closure_result = opt_closure()
|
||||
|
@ -231,7 +231,7 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
|
|||
assert train_step_out.minimize.item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure(
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
|
||||
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
|
||||
)
|
||||
opt_closure_result = opt_closure()
|
||||
|
|
Loading…
Reference in New Issue