diff --git a/CHANGELOG.md b/CHANGELOG.md index 3913552145..35ef6fb59b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,7 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Loop customization: * Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642)) - + * Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191)) - Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187)) diff --git a/pytorch_lightning/loops/__init__.py b/pytorch_lightning/loops/__init__.py index b7eb47167d..3886a21c65 100644 --- a/pytorch_lightning/loops/__init__.py +++ b/pytorch_lightning/loops/__init__.py @@ -17,3 +17,4 @@ from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401 from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401 from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401 from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 +from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401 diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 1b2f26383d..46426a94f6 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -13,7 +13,7 @@ # limitations under the License. from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import numpy as np import torch @@ -21,21 +21,16 @@ from deprecate import void from torch import Tensor from torch.optim import Optimizer -from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.closure import Closure, ClosureResult +from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop from pytorch_lightning.loops.utilities import ( - _block_parallel_sync_behavior, _build_training_step_kwargs, _check_training_step_output, _process_training_step_output, - check_finite_loss, ) -from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TPU_AVAILABLE +from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -50,13 +45,12 @@ class TrainingBatchLoop(Loop): self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx: Optional[int] = None - self.optim_progress = OptimizationProgress() + self.optimizer_loop = OptimizerLoop() self._warning_cache: WarningCache = WarningCache() self._hiddens: Optional[Tensor] = None self._optimizer_freq_cumsum: Optional[int] = None self._remaining_splits: Optional[List[Any]] = None - self._skip_backward: bool = False @property def done(self) -> bool: @@ -70,8 +64,8 @@ class TrainingBatchLoop(Loop): self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) return self._optimizer_freq_cumsum - def connect(self, **kwargs: "Loop") -> None: - raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") + def connect(self, optimizer_loop: "Loop") -> None: + self.optimizer_loop = optimizer_loop def run(self, batch: Any, batch_idx: int) -> AttributeDict: """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks @@ -132,17 +126,12 @@ class TrainingBatchLoop(Loop): self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) if self.trainer.lightning_module.automatic_optimization: - for opt_idx, optimizer in self.get_active_optimizers(batch_idx): - # handle optimization restart - if self.restarting: - if opt_idx < self.optim_progress.optimizer_idx: - continue - - self.optim_progress.optimizer_idx = opt_idx - - result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) - if result: - self.batch_outputs[opt_idx].append(deepcopy(result.result_collection)) + # in automatic optimization, hand over execution to the OptimizerLoop + optimizers = [optimizer for _, optimizer in self.get_active_optimizers(batch_idx)] + batch_outputs, self._hiddens = self.optimizer_loop.run(split_batch, self._hiddens, optimizers, batch_idx) + # combine outputs from each optimizer + for k in range(len(batch_outputs)): + self.batch_outputs[k].extend(batch_outputs[k]) else: # in manual optimization, there is no looping over optimizers result = self._run_optimization(batch_idx, split_batch) @@ -161,42 +150,16 @@ class TrainingBatchLoop(Loop): self, batch_idx: int, split_batch: Any, - opt_idx: Optional[int] = None, - optimizer: Optional[torch.optim.Optimizer] = None, ) -> Optional[ClosureResult]: """Runs closure (train step + backward) together with optimization if necessary. Args: batch_idx: the index of the current batch split_batch: the current tbptt split of the whole batch - opt_idx: the index of the current optimizer or `None` in case of manual optimization - optimizer: the current optimizer or `None` in case of manual optimization """ - # toggle model params - self._run_optimization_start(opt_idx, optimizer) - - closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens) - - if self.trainer.fit_loop.should_accumulate(): - # For gradient accumulation - - # ------------------- - # calculate loss (train step + train step end) - # ------------------- - # automatic_optimization: perform ddp sync only when performing optimizer_step - with _block_parallel_sync_behavior(self._trainer): - closure() - - # ------------------------------ - # BACKWARD PASS - # ------------------------------ - # gradient update with accumulated gradients - else: - if self.trainer.lightning_module.automatic_optimization: - self._optimizer_step(optimizer, opt_idx, batch_idx, closure) - else: - closure() - + # TODO: replace call through closure by direct call (manual optimization) + closure = self._make_closure(split_batch, batch_idx, self._hiddens) + closure() result = closure.get_result() if result: @@ -204,25 +167,21 @@ class TrainingBatchLoop(Loop): # otherwise update running loss + reset accumulated loss self._update_running_loss(result.loss) - # untoggle model params - self._run_optimization_end(opt_idx) return result def _make_closure( self, split_batch: Any, batch_idx: int, - opt_idx: int, - optimizer: Optimizer, hiddens: Any, ) -> Closure: """ Build a closure object that captures the given arguments and runs the `training_step` function and optionally other functions such as `backward` and `zero_grad`. """ - step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens) - backward_fn = self._make_backward_fn(optimizer, opt_idx) - zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer) + step_fn = self._make_step_fn(split_batch, batch_idx, hiddens) + backward_fn = None + zero_grad_fn = None return Closure( step_fn=step_fn, @@ -231,66 +190,27 @@ class TrainingBatchLoop(Loop): profiler=self.trainer.profiler, ) - def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any) -> Callable[[], dict]: + def _make_step_fn(self, split_batch: Any, batch_idx: int, hiddens: Any) -> Callable[[], dict]: """Build the step function that runs the `training_step` and processes its output.""" - return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens) + return partial(self._training_step, split_batch, batch_idx, hiddens) - def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]: - """ - Build a `zero_grad` function that zeroes the gradients before back-propagation. - Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on. - """ - - def zero_grad_fn(): - self._on_before_zero_grad(optimizer) - self._optimizer_zero_grad(batch_idx, optimizer, opt_idx) - - is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 - if ( - not self._skip_backward - and self.trainer.lightning_module.automatic_optimization - and is_first_batch_to_accumulate - ): - return zero_grad_fn - - def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Callable[[Tensor], Tensor]]: - """ - Build a `backward` function that handles back-propagation through the output produced by the `training_step` - function. Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on. - """ - - def backward_fn(loss: Tensor): - self.backward(loss, optimizer, opt_idx) - - # check if loss or model weights are nan - if self.trainer.terminate_on_nan: - check_finite_loss(self.trainer.lightning_module, loss) - - return loss - - if not self._skip_backward and self.trainer.lightning_module.automatic_optimization: - return backward_fn - - def _training_step( - self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor - ) -> Optional[AttributeDict]: - """Performs the actual train step with the tied hooks. + def _training_step(self, split_batch: Any, batch_idx: int, hiddens: Tensor) -> Optional[AttributeDict]: + """Performs the training step for manual optimization. Args: split_batch: the current tbptt split of the current batch batch_idx: the index of the current batch - opt_idx: the index of the current optimizer hiddens: the model's hidden state of the previous iteration Returns: - an AttributeDict containing the loss value and the training step output. + an AttributeDict containing the training step output. """ # give the PL module a result for logging model_ref = self.trainer.lightning_module with self.trainer.profiler.profile("model_forward"): step_kwargs = _build_training_step_kwargs( - model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens + model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx=None, hiddens=hiddens ) # manually capture logged metrics @@ -309,97 +229,7 @@ class TrainingBatchLoop(Loop): if result_collection is None: return - closure_loss = None - loss = None - if self.trainer.lightning_module.automatic_optimization: - # accumulate loss. if accumulate_grad_batches==1, no effect - closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches - # the loss will get scaled for amp. avoid any modifications to it - loss = closure_loss.detach().clone() - return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection) - - def _optimizer_step( - self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable - ) -> None: - """Performs the optimizer step and some sanity checking. - - Args: - optimizer: the optimizer to perform the step with - opt_idx: the index of the current :param:`optimizer` - batch_idx: the index of the current batch - train_step_and_backward_closure: the closure function performing the train step and computing the - gradients. By default called by the optimizer (if possible) - """ - model_ref = self.trainer.lightning_module - - is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) - using_native_amp = self.trainer.amp_backend == AMPType.NATIVE - - # native amp + lbfgs is a no go right now - if using_native_amp and is_lbfgs: - raise MisconfigurationException( - "native PyTorch amp and lbfgs are not compatible." - " To request, please file a Github issue in PyTorch and tag @mcarilli" - ) - - # wraps into LightningOptimizer only for running step - optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) - - self.optim_progress.optimizer.step.increment_ready() - - # model hook - model_ref.optimizer_step( - self.trainer.current_epoch, - batch_idx, - optimizer, - opt_idx, - train_step_and_backward_closure, - on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE), - using_native_amp=using_native_amp, - using_lbfgs=is_lbfgs, - ) - - self.optim_progress.optimizer.step.increment_completed() - - def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: - """Calls the ``on_before_zero_grad`` hook. - - Args: - optimizer: the current optimizer - """ - self.optim_progress.optimizer.zero_grad.increment_ready() - self.trainer.call_hook("on_before_zero_grad", optimizer) - self.optim_progress.optimizer.zero_grad.increment_started() - - def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None: - """Zeroes out all gradients of parameters optimized by the current optimizer. - - Args: - batch_idx: the index of the current batch - optimizer: the current optimizer - opt_idx: the index of the current optimizer - """ - self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) - self.optim_progress.optimizer.zero_grad.increment_completed() - - def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]: - """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. - - Args: - optimizer: the current optimizer - """ - # track gradient norms - grad_norm_dict = {} - can_log = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 - should_track = float(self.trainer.track_grad_norm) > 0 - if should_track and can_log: - grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm) - - # clip gradients - self.trainer.accelerator.clip_gradients( - optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm - ) - return grad_norm_dict + return AttributeDict(closure_loss=None, loss=None, result_collection=result_collection) def _tbptt_split_batch(self, batch: Any) -> List[Any]: """Splits a single batch into a list of sequence steps for tbptt. @@ -416,25 +246,7 @@ class TrainingBatchLoop(Loop): splits = model_ref.tbptt_split_batch(batch, tbptt_steps) return splits - def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None: - """Toggles the optimizer to ensure the correct one is used and prevend dangling grads. - - Args: - opt_idx: the index of the optimizer to use - optimizer: the optimizer to use - - """ - # make sure only the gradients of the current optimizer's parameters are calculated - # in the training step to prevent dangling gradients in multiple-optimizer setup. - if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: - model = self.trainer.lightning_module - model.toggle_optimizer(optimizer, opt_idx) - - def _run_optimization_end(self, opt_idx: int) -> None: - if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: - model = self.trainer.lightning_module - model.untoggle_optimizer(opt_idx) - + # TODO: remove this method and update tests def backward( self, loss: Tensor, @@ -451,13 +263,6 @@ class TrainingBatchLoop(Loop): opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization. """ self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs) - - if not self.trainer.fit_loop.should_accumulate(): - # track gradients - grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer) - if grad_norm_dict: - self.trainer.lightning_module._current_fx_name = "on_after_backward" - self.trainer.lightning_module.log_grad_norm(grad_norm_dict) return loss def _update_running_loss(self, current_loss: Tensor) -> None: diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 43d51fe002..2020ac6cc6 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -95,7 +95,7 @@ class TrainingEpochLoop(loops.Loop): if not self.restarting: self.batch_progress.current.reset() self.scheduler_progress.current.reset() - self.batch_loop.optim_progress.reset_on_epoch() + self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch() def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None: # hook diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8e86d1b722..47eb50a2ab 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -108,12 +108,12 @@ class FitLoop(Loop): @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" - return self.epoch_loop.batch_loop._skip_backward + return self.epoch_loop.batch_loop.optimizer_loop._skip_backward @_skip_backward.setter def _skip_backward(self, value: bool) -> None: """Determines whether the loop will skip backward during automatic optimization.""" - self.epoch_loop.batch_loop._skip_backward = value + self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value @property def _results(self) -> ResultCollection: diff --git a/pytorch_lightning/loops/optimizer/__init__.py b/pytorch_lightning/loops/optimizer/__init__.py new file mode 100644 index 0000000000..2f74442741 --- /dev/null +++ b/pytorch_lightning/loops/optimizer/__init__.py @@ -0,0 +1,15 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401 diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py new file mode 100644 index 0000000000..fb2440bd58 --- /dev/null +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -0,0 +1,376 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import Tensor +from torch.optim import Optimizer + +from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.loops import Loop +from pytorch_lightning.loops.closure import Closure, ClosureResult +from pytorch_lightning.loops.utilities import ( + _block_parallel_sync_behavior, + _build_training_step_kwargs, + _check_training_step_output, + _process_training_step_output, + check_finite_loss, +) +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.progress import OptimizationProgress +from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TPU_AVAILABLE + +_OUTPUTS_TYPE = List[List[Optional[ResultCollection]]] + + +class OptimizerLoop(Loop): + """Runs over a sequence of optimizers. This loop implements what is known in Lightning as Automatic Optimization.""" + + def __init__(self): + super().__init__() + # TODO: use default dict here to simplify logic in loop + self.outputs: _OUTPUTS_TYPE = [] + self.optim_progress: OptimizationProgress = OptimizationProgress() + + self._skip_backward: bool = False + self._batch_idx: Optional[int] = None + self._optimizers: Optional[List[Optimizer]] = None + self._hiddens: Optional[Any] = None + + @property + def done(self) -> bool: + """Returns ``True`` when the last optimizer in the sequence has run.""" + return self.optim_progress.optimizer_idx >= len(self._optimizers) + + def connect(self, **kwargs: "Loop") -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") + + def reset(self) -> None: + if not self.restarting: + self.optim_progress.optimizer_idx = 0 + self.outputs = [[] for _ in range(len(self.trainer.optimizers))] + + def on_run_start(self, batch: Any, hiddens: Any, optimizers: List[Optimizer], batch_idx: int) -> None: + self._batch_idx = batch_idx + self._optimizers = optimizers + + def advance(self, batch: Any, hiddens: Any, *args, **kwargs) -> None: + self._hiddens = hiddens + result = self._run_optimization( + self._batch_idx, + batch, + self.optim_progress.optimizer_idx, + self._optimizers[self.optim_progress.optimizer_idx], + ) + if result: + self.outputs[self.optim_progress.optimizer_idx].append(deepcopy(result.result_collection)) + + self.optim_progress.optimizer_idx += 1 + + def on_run_end(self) -> Tuple[_OUTPUTS_TYPE, Optional[Any]]: + outputs = self.outputs + hiddens = self._hiddens + # free memory + self.outputs = [] + self._hiddens = None + return outputs, hiddens + + def backward( + self, + loss: Tensor, + optimizer: Optional[torch.optim.Optimizer], + opt_idx: Optional[int] = None, + *args: Any, + **kwargs: Any, + ) -> Tensor: + """Performs the backward step. + + Args: + loss: The loss value to back-propagate on + optimizer: Current optimizer being used. ``None`` if using manual optimization. + opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization. + """ + self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs) + + if not self.trainer.fit_loop.should_accumulate(): + # track gradients + grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer) + if grad_norm_dict: + self.trainer.lightning_module._current_fx_name = "on_after_backward" + self.trainer.lightning_module.log_grad_norm(grad_norm_dict) + return loss + + def _run_optimization( + self, + batch_idx: int, + split_batch: Any, + opt_idx: Optional[int] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + ) -> Optional[ClosureResult]: + """Runs closure (train step + backward) together with optimization if necessary. + + Args: + batch_idx: the index of the current batch + split_batch: the current tbptt split of the whole batch + opt_idx: the index of the current optimizer or `None` in case of manual optimization + optimizer: the current optimizer or `None` in case of manual optimization + """ + # toggle model params + self._run_optimization_start(opt_idx, optimizer) + + closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens) + + if self.trainer.fit_loop.should_accumulate(): + # For gradient accumulation + + # ------------------- + # calculate loss (train step + train step end) + # ------------------- + # automatic_optimization=True: perform ddp sync only when performing optimizer_step + with _block_parallel_sync_behavior(self.trainer, block=True): + closure() + + # ------------------------------ + # BACKWARD PASS + # ------------------------------ + # gradient update with accumulated gradients + else: + self._optimizer_step(optimizer, opt_idx, batch_idx, closure) + + result = closure.get_result() + + if result: + # if no result, user decided to skip optimization + # otherwise update running loss + reset accumulated loss + # TODO: find proper way to handle updating running loss + self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss) + + # untoggle model params + self._run_optimization_end(opt_idx) + return result + + def _make_closure( + self, + split_batch: Any, + batch_idx: int, + opt_idx: int, + optimizer: Optimizer, + hiddens: Any, + ) -> Closure: + """ + Build a closure object that captures the given arguments and runs the `training_step` function and optionally + other functions such as `backward` and `zero_grad`. + """ + step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens) + backward_fn = self._make_backward_fn(optimizer, opt_idx) + zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer) + + return Closure( + step_fn=step_fn, + backward_fn=backward_fn, + zero_grad_fn=zero_grad_fn, + profiler=self.trainer.profiler, + ) + + def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any) -> Callable[[], dict]: + """Build the step function that runs the `training_step` and processes its output.""" + return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens) + + def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]: + """ + Build a `zero_grad` function that zeroes the gradients before back-propagation. + Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on. + """ + + def zero_grad_fn(): + self._on_before_zero_grad(optimizer) + self._optimizer_zero_grad(batch_idx, optimizer, opt_idx) + + is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 + if not self._skip_backward and is_first_batch_to_accumulate: + return zero_grad_fn + + def _make_backward_fn( + self, + optimizer: Optimizer, + opt_idx: int, + ) -> Optional[Callable[[Tensor], Tensor]]: + """ + Build a `backward` function that handles back-propagation through the output produced by the `training_step` + function. Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on. + """ + + def backward_fn(loss: Tensor): + self.backward(loss, optimizer, opt_idx) + + # check if loss or model weights are nan + if self.trainer.terminate_on_nan: + check_finite_loss(self.trainer.lightning_module, loss) + + return loss + + if not self._skip_backward: + return backward_fn + + def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None: + """Toggles the optimizer to ensure the correct one is used and prevend dangling grads. + + Args: + opt_idx: the index of the optimizer to use + optimizer: the optimizer to use + + """ + # make sure only the gradients of the current optimizer's parameters are calculated + # in the training step to prevent dangling gradients in multiple-optimizer setup. + if len(self.trainer.optimizers) > 1: + model = self.trainer.lightning_module + model.toggle_optimizer(optimizer, opt_idx) + + def _run_optimization_end(self, opt_idx: int) -> None: + if len(self.trainer.optimizers) > 1: + model = self.trainer.lightning_module + model.untoggle_optimizer(opt_idx) + + def _optimizer_step( + self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable + ) -> None: + """Performs the optimizer step and some sanity checking. + + Args: + optimizer: the optimizer to perform the step with + opt_idx: the index of the current :param:`optimizer` + batch_idx: the index of the current batch + train_step_and_backward_closure: the closure function performing the train step and computing the + gradients. By default called by the optimizer (if possible) + """ + model_ref = self.trainer.lightning_module + + is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) + using_native_amp = self.trainer.amp_backend == AMPType.NATIVE + + # native amp + lbfgs is a no go right now + if using_native_amp and is_lbfgs: + raise MisconfigurationException( + "native PyTorch amp and lbfgs are not compatible." + " To request, please file a Github issue in PyTorch and tag @mcarilli" + ) + + # wraps into LightningOptimizer only for running step + optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) + + self.optim_progress.optimizer.step.increment_ready() + + # model hook + model_ref.optimizer_step( + self.trainer.current_epoch, + batch_idx, + optimizer, + opt_idx, + train_step_and_backward_closure, + on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE), + using_native_amp=using_native_amp, + using_lbfgs=is_lbfgs, + ) + + self.optim_progress.optimizer.step.increment_completed() + + def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: + """Calls the ``on_before_zero_grad`` hook. + + Args: + optimizer: the current optimizer + """ + self.optim_progress.optimizer.zero_grad.increment_ready() + self.trainer.call_hook("on_before_zero_grad", optimizer) + self.optim_progress.optimizer.zero_grad.increment_started() + + def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None: + """Zeroes out all gradients of parameters optimized by the current optimizer. + + Args: + batch_idx: the index of the current batch + optimizer: the current optimizer + opt_idx: the index of the current optimizer + """ + self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + self.optim_progress.optimizer.zero_grad.increment_completed() + + def _training_step( + self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor + ) -> Optional[AttributeDict]: + """Performs the actual train step with the tied hooks. + + Args: + split_batch: the current tbptt split of the current batch + batch_idx: the index of the current batch + opt_idx: the index of the current optimizer + hiddens: the model's hidden state of the previous iteration + + Returns: + an AttributeDict containing the loss value and the training step output. + """ + # give the PL module a result for logging + model_ref = self.trainer.lightning_module + + with self.trainer.profiler.profile("model_forward"): + + step_kwargs = _build_training_step_kwargs( + self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens + ) + + # manually capture logged metrics + model_ref._current_fx_name = "training_step" + with self.trainer.profiler.profile("training_step"): + training_step_output = self.trainer.accelerator.training_step(step_kwargs) + self.trainer.accelerator.post_training_step() + + del step_kwargs + + training_step_output = self.trainer.call_hook("training_step_end", training_step_output) + + _check_training_step_output(self.trainer.lightning_module, training_step_output) + + result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output) + if result_collection is None: + return + + # accumulate loss. if accumulate_grad_batches==1, no effect + closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches + # the loss will get scaled for amp. avoid any modifications to it + loss = closure_loss.detach().clone() + return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection) + + def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]: + """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. + + Args: + optimizer: the current optimizer + """ + # track gradient norms + grad_norm_dict = {} + can_log = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 + should_track = float(self.trainer.track_grad_norm) > 0 + if should_track and can_log: + grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm) + + # clip gradients + self.trainer.accelerator.clip_gradients( + optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm + ) + return grad_norm_dict diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 5fdd09d5fd..4cff47f2f5 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -56,8 +56,9 @@ def test_loops_state_dict_structure(): "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, + "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, "epoch_loop.batch_loop.state_dict": {}, - "epoch_loop.batch_loop.optim_progress": { + "epoch_loop.batch_loop.optimizer_loop.optim_progress": { "optimizer": { "step": { "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 200b2daae9..c20a0b6261 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -382,7 +382,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch assert os.path.exists(ckpt_path) checkpoint = torch.load(ckpt_path) - optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress + optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch @@ -461,7 +461,8 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch "current": {"ready": be_sch_steps, "started": None, "processed": None, "completed": be_sch_steps}, }, "epoch_loop.batch_loop.state_dict": ANY, - "epoch_loop.batch_loop.optim_progress": { + "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.batch_loop.optimizer_loop.optim_progress": { "optimizer_idx": stop_optimizer, "optimizer": { "step": { diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 78dd4bce72..6af5df081c 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -78,7 +78,7 @@ def test__eval_step__flow(tmpdir): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure( + opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) opt_closure_result = opt_closure() @@ -145,7 +145,7 @@ def test__eval_step__eval_step_end__flow(tmpdir): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure( + opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) opt_closure_result = opt_closure() diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 692d2420bf..43a8e561f8 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -157,7 +157,7 @@ def test__training_step__epoch_end__flow_scalar(tmpdir): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure( + opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) opt_closure_result = opt_closure() @@ -231,7 +231,7 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir): assert train_step_out.minimize.item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop._make_closure( + opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure( batch, batch_idx, 0, trainer.optimizers[0], hiddens=None ) opt_closure_result = opt_closure()