496 lines
20 KiB
Python
496 lines
20 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from copy import copy
|
|
from functools import partial
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from deprecate import void
|
|
from torch import Tensor
|
|
from torch.optim import Optimizer
|
|
|
|
from pytorch_lightning.core.optimizer import LightningOptimizer
|
|
from pytorch_lightning.loops.base import Loop
|
|
from pytorch_lightning.loops.closure import Closure, ClosureResult
|
|
from pytorch_lightning.loops.utilities import (
|
|
_block_parallel_sync_behavior,
|
|
_build_training_step_kwargs,
|
|
_check_training_step_output,
|
|
_process_training_step_output,
|
|
check_finite_loss,
|
|
)
|
|
from pytorch_lightning.trainer.progress import OptimizationProgress
|
|
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
|
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
|
|
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
|
from pytorch_lightning.utilities.warnings import WarningCache
|
|
|
|
|
|
class TrainingBatchLoop(Loop):
|
|
"""Runs over a single batch of data."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.accumulated_loss: Optional[Tensor] = None
|
|
self.batch_outputs: Optional[List[List[STEP_OUTPUT]]] = None
|
|
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20)
|
|
# the current split index when the batch gets split into chunks in truncated backprop through time
|
|
self.split_idx: Optional[int] = None
|
|
self.optim_progress = OptimizationProgress()
|
|
|
|
self._warning_cache: WarningCache = WarningCache()
|
|
self._hiddens: Optional[Tensor] = None
|
|
self._optimizer_freq_cumsum: Optional[int] = None
|
|
self._remaining_splits: Optional[List[Any]] = None
|
|
self._skip_backward: bool = False
|
|
|
|
@property
|
|
def done(self) -> bool:
|
|
"""Returns if all batch splits have been processed already"""
|
|
return len(self._remaining_splits) == 0
|
|
|
|
@property
|
|
def optimizer_freq_cumsum(self) -> int:
|
|
"""Returns the cumulated sum of optimizer frequencies"""
|
|
if self._optimizer_freq_cumsum is None:
|
|
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
|
|
return self._optimizer_freq_cumsum
|
|
|
|
def connect(self, **kwargs: "Loop") -> None:
|
|
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
|
|
|
|
def run(self, batch: Any, batch_idx: int) -> AttributeDict:
|
|
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks
|
|
|
|
Args:
|
|
batch: the current batch to run the train step on
|
|
batch_idx: the index of the current batch
|
|
"""
|
|
if batch is None:
|
|
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
|
|
return AttributeDict(signal=0, training_step_output=[[]])
|
|
|
|
# hook
|
|
self.trainer.logger_connector.on_batch_start()
|
|
response = self.trainer.call_hook("on_batch_start")
|
|
if response == -1:
|
|
return AttributeDict(signal=-1)
|
|
|
|
# hook
|
|
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0)
|
|
if response == -1:
|
|
return AttributeDict(signal=-1)
|
|
|
|
self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()
|
|
|
|
super().run(batch, batch_idx)
|
|
output = AttributeDict(signal=0, training_step_output=self.batch_outputs)
|
|
self.batch_outputs = None # free memory
|
|
return output
|
|
|
|
def reset(self) -> None:
|
|
"""Resets the loop state"""
|
|
self._hiddens = None
|
|
self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))]
|
|
|
|
def on_run_start(self, batch: Any, batch_idx: int):
|
|
"""Splits the data into tbptt splits
|
|
|
|
Args:
|
|
batch: the current batch to run the trainstep on
|
|
batch_idx: the index of the current batch
|
|
"""
|
|
void(batch_idx)
|
|
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))
|
|
|
|
def advance(self, batch, batch_idx):
|
|
"""Runs the train step together with optimization (if necessary) on the current batch split
|
|
|
|
Args:
|
|
batch: the current batch to run the training on (this is not the split!)
|
|
batch_idx: the index of the current batch
|
|
"""
|
|
void(batch)
|
|
split_idx, split_batch = self._remaining_splits.pop(0)
|
|
self.split_idx = split_idx
|
|
|
|
# let logger connector extract current batch size
|
|
self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch)
|
|
|
|
if self.trainer.lightning_module.automatic_optimization:
|
|
for opt_idx, optimizer in self.get_active_optimizers(batch_idx):
|
|
# handle optimization restart
|
|
if self.restarting:
|
|
if opt_idx < self.optim_progress.optimizer_idx:
|
|
continue
|
|
|
|
self.optim_progress.optimizer_idx = opt_idx
|
|
|
|
result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
|
|
if result:
|
|
self.batch_outputs[opt_idx].append(copy(result.result_collection))
|
|
else:
|
|
# in manual optimization, there is no looping over optimizers
|
|
result = self._run_optimization(batch_idx, split_batch)
|
|
if result:
|
|
self.batch_outputs[0].append(copy(result.result_collection))
|
|
|
|
def teardown(self) -> None:
|
|
# release memory
|
|
self._remaining_splits = None
|
|
|
|
def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
|
|
"""Gets the number of active optimizers based on their frequency"""
|
|
return len(self.get_active_optimizers(batch_idx))
|
|
|
|
def _run_optimization(
|
|
self,
|
|
batch_idx: int,
|
|
split_batch: Any,
|
|
opt_idx: Optional[int] = None,
|
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
) -> Optional[ClosureResult]:
|
|
"""Runs closure (train step + backward) together with optimization if necessary.
|
|
|
|
Args:
|
|
batch_idx: the index of the current batch
|
|
split_batch: the current tbptt split of the whole batch
|
|
opt_idx: the index of the current optimizer or `None` in case of manual optimization
|
|
optimizer: the current optimizer or `None` in case of manual optimization
|
|
"""
|
|
# toggle model params
|
|
self._run_optimization_start(opt_idx, optimizer)
|
|
|
|
closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)
|
|
|
|
if self.trainer.fit_loop.should_accumulate():
|
|
# For gradient accumulation
|
|
|
|
# -------------------
|
|
# calculate loss (train step + train step end)
|
|
# -------------------
|
|
# automatic_optimization: perform ddp sync only when performing optimizer_step
|
|
with _block_parallel_sync_behavior(self._trainer):
|
|
closure()
|
|
|
|
# ------------------------------
|
|
# BACKWARD PASS
|
|
# ------------------------------
|
|
# gradient update with accumulated gradients
|
|
else:
|
|
if self.trainer.lightning_module.automatic_optimization:
|
|
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
|
|
else:
|
|
closure()
|
|
|
|
result = closure.get_result()
|
|
|
|
if result:
|
|
# if no result, user decided to skip optimization
|
|
# otherwise update running loss + reset accumulated loss
|
|
self._update_running_loss(result.loss)
|
|
|
|
# untoggle model params
|
|
self._run_optimization_end(opt_idx)
|
|
return result
|
|
|
|
def _make_closure(
|
|
self,
|
|
split_batch: Any,
|
|
batch_idx: int,
|
|
opt_idx: int,
|
|
optimizer: Optimizer,
|
|
hiddens: Any,
|
|
) -> Closure:
|
|
"""
|
|
Build a closure object that captures the given arguments and runs the `training_step` function and optionally
|
|
other functions such as `backward` and `zero_grad`.
|
|
"""
|
|
step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens)
|
|
backward_fn = self._make_backward_fn(optimizer, opt_idx)
|
|
zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer)
|
|
|
|
return Closure(
|
|
step_fn=step_fn,
|
|
backward_fn=backward_fn,
|
|
zero_grad_fn=zero_grad_fn,
|
|
profiler=self.trainer.profiler,
|
|
)
|
|
|
|
def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Any) -> Callable[[], dict]:
|
|
"""Build the step function that runs the `training_step` and processes its output."""
|
|
return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens)
|
|
|
|
def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]:
|
|
"""
|
|
Build a `zero_grad` function that zeroes the gradients before back-propagation.
|
|
Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
|
|
"""
|
|
|
|
def zero_grad_fn():
|
|
self._on_before_zero_grad(optimizer)
|
|
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
|
|
|
|
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0
|
|
if (
|
|
not self._skip_backward
|
|
and self.trainer.lightning_module.automatic_optimization
|
|
and is_first_batch_to_accumulate
|
|
):
|
|
return zero_grad_fn
|
|
|
|
def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Callable[[Tensor], Tensor]]:
|
|
"""
|
|
Build a `backward` function that handles back-propagation through the output produced by the `training_step`
|
|
function. Returns ``None`` in the case backward needs to be skipped, e.g., when manual optimization is on.
|
|
"""
|
|
|
|
def backward_fn(loss: Tensor):
|
|
self.backward(loss, optimizer, opt_idx)
|
|
|
|
# check if loss or model weights are nan
|
|
if self.trainer.terminate_on_nan:
|
|
check_finite_loss(self.trainer.lightning_module, loss)
|
|
|
|
return loss
|
|
|
|
if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
|
|
return backward_fn
|
|
|
|
def _training_step(
|
|
self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor
|
|
) -> Optional[AttributeDict]:
|
|
"""Performs the actual train step with the tied hooks.
|
|
|
|
Args:
|
|
split_batch: the current tbptt split of the current batch
|
|
batch_idx: the index of the current batch
|
|
opt_idx: the index of the current optimizer
|
|
hiddens: the model's hidden state of the previous iteration
|
|
|
|
Returns:
|
|
an AttributeDict containing the loss value and the training step output.
|
|
"""
|
|
# give the PL module a result for logging
|
|
model_ref = self.trainer.lightning_module
|
|
|
|
with self.trainer.profiler.profile("model_forward"):
|
|
step_kwargs = _build_training_step_kwargs(
|
|
model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx, hiddens
|
|
)
|
|
|
|
# manually capture logged metrics
|
|
model_ref._current_fx_name = "training_step"
|
|
with self.trainer.profiler.profile("training_step"):
|
|
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
|
|
self.trainer.accelerator.post_training_step()
|
|
|
|
del step_kwargs
|
|
|
|
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
|
|
|
|
_check_training_step_output(self.trainer.lightning_module, training_step_output)
|
|
|
|
result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
|
|
if result_collection is None:
|
|
return
|
|
|
|
closure_loss = None
|
|
loss = None
|
|
if self.trainer.lightning_module.automatic_optimization:
|
|
# accumulate loss. if accumulate_grad_batches==1, no effect
|
|
closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches
|
|
# the loss will get scaled for amp. avoid any modifications to it
|
|
loss = closure_loss.detach().clone()
|
|
return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)
|
|
|
|
def _optimizer_step(
|
|
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
|
|
) -> None:
|
|
"""Performs the optimizer step and some sanity checking.
|
|
|
|
Args:
|
|
optimizer: the optimizer to perform the step with
|
|
opt_idx: the index of the current :param:`optimizer`
|
|
batch_idx: the index of the current batch
|
|
train_step_and_backward_closure: the closure function performing the train step and computing the
|
|
gradients. By default called by the optimizer (if possible)
|
|
"""
|
|
model_ref = self.trainer.lightning_module
|
|
|
|
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
|
|
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
|
|
|
|
# native amp + lbfgs is a no go right now
|
|
if using_native_amp and is_lbfgs:
|
|
raise MisconfigurationException(
|
|
"native PyTorch amp and lbfgs are not compatible."
|
|
" To request, please file a Github issue in PyTorch and tag @mcarilli"
|
|
)
|
|
|
|
# wraps into LightningOptimizer only for running step
|
|
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
|
|
|
|
self.optim_progress.optimizer.step.increment_ready()
|
|
|
|
# model hook
|
|
model_ref.optimizer_step(
|
|
self.trainer.current_epoch,
|
|
batch_idx,
|
|
optimizer,
|
|
opt_idx,
|
|
train_step_and_backward_closure,
|
|
on_tpu=(self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE),
|
|
using_native_amp=using_native_amp,
|
|
using_lbfgs=is_lbfgs,
|
|
)
|
|
|
|
self.optim_progress.optimizer.step.increment_completed()
|
|
|
|
def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
|
|
"""Calls the ``on_before_zero_grad`` hook.
|
|
|
|
Args:
|
|
optimizer: the current optimizer
|
|
"""
|
|
self.optim_progress.optimizer.zero_grad.increment_ready()
|
|
self.trainer.call_hook("on_before_zero_grad", optimizer)
|
|
self.optim_progress.optimizer.zero_grad.increment_started()
|
|
|
|
def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
|
|
"""Zeroes out all gradients of parameters optimized by the current optimizer.
|
|
|
|
Args:
|
|
batch_idx: the index of the current batch
|
|
optimizer: the current optimizer
|
|
opt_idx: the index of the current optimizer
|
|
"""
|
|
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
|
self.optim_progress.optimizer.zero_grad.increment_completed()
|
|
|
|
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
|
|
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
|
|
|
|
Args:
|
|
optimizer: the current optimizer
|
|
"""
|
|
# track gradient norms
|
|
grad_norm_dict = {}
|
|
can_log = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
|
|
should_track = float(self.trainer.track_grad_norm) > 0
|
|
if should_track and can_log:
|
|
grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm)
|
|
|
|
# clip gradients
|
|
self.trainer.accelerator.clip_gradients(
|
|
optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
|
|
)
|
|
return grad_norm_dict
|
|
|
|
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
|
|
"""Splits a single batch into a list of sequence steps for tbptt.
|
|
|
|
Args:
|
|
batch: the current batch to split
|
|
"""
|
|
tbptt_steps = self.trainer.lightning_module.truncated_bptt_steps
|
|
if tbptt_steps == 0:
|
|
return [batch]
|
|
|
|
model_ref = self.trainer.lightning_module
|
|
with self.trainer.profiler.profile("tbptt_split_batch"):
|
|
splits = model_ref.tbptt_split_batch(batch, tbptt_steps)
|
|
return splits
|
|
|
|
def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
|
|
"""Toggles the optimizer to ensure the correct one is used and prevend dangling grads.
|
|
|
|
Args:
|
|
opt_idx: the index of the optimizer to use
|
|
optimizer: the optimizer to use
|
|
|
|
"""
|
|
# make sure only the gradients of the current optimizer's parameters are calculated
|
|
# in the training step to prevent dangling gradients in multiple-optimizer setup.
|
|
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
|
|
model = self.trainer.lightning_module
|
|
model.toggle_optimizer(optimizer, opt_idx)
|
|
|
|
def _run_optimization_end(self, opt_idx: int) -> None:
|
|
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
|
|
model = self.trainer.lightning_module
|
|
model.untoggle_optimizer(opt_idx)
|
|
|
|
def backward(
|
|
self,
|
|
loss: Tensor,
|
|
optimizer: Optional[torch.optim.Optimizer],
|
|
opt_idx: Optional[int] = None,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> Tensor:
|
|
"""Performs the backward step.
|
|
|
|
Args:
|
|
loss: The loss value to back-propagate on
|
|
optimizer: Current optimizer being used. ``None`` if using manual optimization.
|
|
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
|
|
"""
|
|
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
|
|
|
|
if not self.trainer.fit_loop.should_accumulate():
|
|
# track gradients
|
|
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
|
|
if grad_norm_dict:
|
|
self.trainer.lightning_module._current_fx_name = "on_after_backward"
|
|
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
|
|
return loss
|
|
|
|
def _update_running_loss(self, current_loss: Tensor) -> None:
|
|
"""Updates the running loss value with the current value"""
|
|
if self.trainer.lightning_module.automatic_optimization:
|
|
# track total loss for logging (avoid mem leaks)
|
|
self.accumulated_loss.append(current_loss)
|
|
|
|
accumulated_loss = self.accumulated_loss.mean()
|
|
|
|
if accumulated_loss is not None:
|
|
# calculate running loss for display
|
|
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
|
|
|
|
# reset for next set of accumulated grads
|
|
self.accumulated_loss.reset()
|
|
|
|
def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]:
|
|
"""
|
|
Returns the currently active optimizers. When multiple optimizers are used with different frequencies,
|
|
only one of the optimizers is active at a time.
|
|
|
|
Returns:
|
|
A list of tuples (opt_idx, optimizer) of currently active optimizers.
|
|
"""
|
|
if not self.trainer.optimizer_frequencies:
|
|
# call training_step once per optimizer
|
|
return list(enumerate(self.trainer.optimizers))
|
|
|
|
optimizers_loop_length = self.optimizer_freq_cumsum[-1]
|
|
current_place_in_loop = batch_idx % optimizers_loop_length
|
|
|
|
# find optimzier index by looking for the first {item > current_place} in the cumsum list
|
|
opt_idx = int(np.argmax(self.optimizer_freq_cumsum > current_place_in_loop))
|
|
return [(opt_idx, self.trainer.optimizers[opt_idx])]
|