lightning/pytorch_lightning/plugins/precision/precision_plugin.py

150 lines
5.2 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks
from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.types import _PARAMETERS
class PrecisionPlugin(Plugin, CheckpointHooks):
"""
Base class for all plugins handling the precision-specific parts of the training.
The class attribute precision must be overwritten in child classes.
The default value reflects fp32 training.
"""
precision: Union[str, int] = 32
def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
"""
The master params of the model. Returns the plain model params here.
Maybe different in other precision plugins.
"""
for group in optimizer.param_groups:
for p in group["params"]:
yield p
def connect(
self,
model: Module,
optimizers: List[Optimizer],
lr_schedulers: List[Any],
) -> Tuple[Module, List[Optimizer], List[Any]]:
"""Connects this plugin to the accelerator and the training process"""
return model, optimizers, lr_schedulers
def pre_backward(
self,
model: 'pl.LightningModule',
closure_loss: Tensor,
) -> Tensor:
"""Run before precision plugin executes backward
Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
"""
model.trainer.call_hook("on_before_backward", closure_loss)
return closure_loss
def backward(
self,
model: 'pl.LightningModule',
closure_loss: Tensor,
optimizer: Optional[Optimizer],
*args: Any,
**kwargs: Any,
) -> None:
"""Performs the actual backpropagation
Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: current optimizer being used. ``None`` if using manual optimization
"""
# do backward pass
if model is not None and isinstance(model, pl.LightningModule):
model.backward(closure_loss, optimizer, *args, **kwargs)
else:
closure_loss.backward(*args, **kwargs)
def post_backward(
self,
model: 'pl.LightningModule',
closure_loss: Tensor,
) -> Tensor:
"""Run after precision plugin executes backward
Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
"""
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
model.trainer.call_hook("on_after_backward")
return closure_loss
def pre_optimizer_step(
self,
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
return True
def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
model: Optional[Module] = None
) -> None:
"""Clips the gradients"""
if clip_val is None:
return
clip_val = float(clip_val)
if clip_val <= 0:
return
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
self.clip_grad_by_value(optimizer, clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
# TODO: there should be a mechanism to set `norm_type`
self.clip_grad_by_norm(optimizer, clip_val)
def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by value"""
parameters = self.master_params(optimizer)
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)
def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by norm"""
parameters = self.master_params(optimizer)
torch.nn.utils.clip_grad_norm_(parameters, clip_val)