Sharded Accelerator 1/n: Expose clip gradients to plugins via abstract class (#4639)
* Added abstract precision plugin to expose clip_gradients function, use within accelerator to clip gradients * Exclude model from override, keep optimizer (needed for sharded clip gradients), add override for O2 support apex * Fix doc * Applied codereview changes * Refactored clip function to encapsulate tpu changes with tpu accelerator. Default to standard clip function for vanilla torch * Pass correct grad clip val * Moved var to property * Apply code review suggestions
This commit is contained in:
parent
4a01fd048c
commit
bacabaebaf
|
@ -12,33 +12,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.utilities import AMPType, rank_zero_warn
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
import torch.distributed as torch_distrib
|
||||
from pytorch_lightning import _logger as log
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
amp = None
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import ReduceOp
|
||||
else:
|
||||
class ReduceOp:
|
||||
SUM = None
|
||||
|
||||
EPSILON = 1e-6
|
||||
EPSILON_FP16 = 1e-5
|
||||
|
||||
|
||||
class Accelerator(object):
|
||||
|
||||
|
@ -139,48 +131,22 @@ class Accelerator(object):
|
|||
model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||
|
||||
def clip_gradients(self, optimizer, clip_val=None):
|
||||
# TODO: separate TPU case from here
|
||||
self._clip_gradients(optimizer, clip_val)
|
||||
|
||||
def _clip_gradients(self, optimizer, clip_val=None):
|
||||
# use the trainer's clip val if none passed
|
||||
grad_clip_val = self.trainer.gradient_clip_val
|
||||
if clip_val is not None:
|
||||
grad_clip_val = clip_val
|
||||
grad_clip_val = float(grad_clip_val)
|
||||
|
||||
# this code is a modification of torch.nn.utils.clip_grad_norm_
|
||||
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
|
||||
if grad_clip_val <= 0:
|
||||
return
|
||||
self._clip_gradients(optimizer, grad_clip_val)
|
||||
|
||||
model = self.trainer.get_model()
|
||||
if self.trainer.amp_backend == AMPType.APEX:
|
||||
parameters = amp.master_params(optimizer)
|
||||
def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
|
||||
if self.trainer.amp_backend:
|
||||
self.trainer.precision_connector.backend.clip_gradients(grad_clip_val, optimizer, norm_type)
|
||||
else:
|
||||
parameters = model.parameters()
|
||||
|
||||
max_norm = grad_clip_val
|
||||
norm_type = float(2.0)
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
|
||||
if norm_type == math.inf:
|
||||
total_norm = max(p.grad.data.abs().max() for p in parameters)
|
||||
else:
|
||||
device = parameters[0].device
|
||||
out = torch.empty(len(parameters), device=device)
|
||||
for i, p in enumerate(parameters):
|
||||
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
|
||||
total_norm = torch.norm(out, norm_type)
|
||||
|
||||
eps = EPSILON_FP16 if self.trainer.precision == 16 else EPSILON
|
||||
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
|
||||
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
|
||||
for p in parameters:
|
||||
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
|
||||
model = self.trainer.get_model()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)
|
||||
|
||||
def on_train_epoch_end(self, outputs):
|
||||
pass
|
||||
|
@ -201,7 +167,7 @@ class Accelerator(object):
|
|||
self.trainer.optimizer_frequencies = optimizer_frequencies
|
||||
|
||||
def init_ddp_connection(
|
||||
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
|
||||
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
|
||||
) -> None:
|
||||
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
|
||||
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
|
||||
|
|
|
@ -12,12 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
|
@ -261,10 +263,27 @@ class TPUAccelerator(Accelerator):
|
|||
using_lbfgs=is_lbfgs
|
||||
)
|
||||
|
||||
def clip_gradients(self, optimizer, clip_val=None):
|
||||
# apply clip gradients
|
||||
# TODO: separate TPU case from here
|
||||
self._clip_gradients(optimizer, clip_val)
|
||||
def _clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
|
||||
# this code is a modification of torch.nn.utils.clip_grad_norm_
|
||||
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
|
||||
model = self.trainer.get_model()
|
||||
parameters = model.parameters()
|
||||
max_norm = grad_clip_val
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
|
||||
device = parameters[0].device
|
||||
out = torch.empty(len(parameters), device=device)
|
||||
for i, p in enumerate(parameters):
|
||||
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
|
||||
total_norm = torch.norm(out, norm_type)
|
||||
|
||||
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + self.norm_clipping_epsilon)
|
||||
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
|
||||
for p in parameters:
|
||||
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
|
||||
|
||||
def barrier(self, name: Optional[str] = None):
|
||||
torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}")
|
||||
|
@ -343,3 +362,7 @@ class TPUAccelerator(Accelerator):
|
|||
group: Optional[Any] = None,
|
||||
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def norm_clipping_epsilon(self):
|
||||
return 1e-6
|
||||
|
|
|
@ -11,11 +11,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Tuple
|
||||
import math
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_warn
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
|
||||
|
@ -25,7 +28,7 @@ except ImportError:
|
|||
amp = None
|
||||
|
||||
|
||||
class ApexPlugin:
|
||||
class ApexPlugin(PrecisionPlugin):
|
||||
|
||||
def __init__(self, trainer=None):
|
||||
self.trainer = trainer
|
||||
|
@ -98,3 +101,35 @@ class ApexPlugin:
|
|||
"""
|
||||
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
|
||||
return model, optimizers
|
||||
|
||||
def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
|
||||
"""
|
||||
This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights.
|
||||
This is important when setting amp_level to O2, and the master weights are in fp16.
|
||||
Args:
|
||||
grad_clip_val: Maximum norm of gradients.
|
||||
optimizer: Optimizer with gradients that will be clipped.
|
||||
norm_type: (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
"""
|
||||
model = self.trainer.get_model()
|
||||
parameters = model.parameters()
|
||||
max_norm = float(grad_clip_val)
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = [p for p in parameters if p.grad is not None]
|
||||
|
||||
if len(parameters) == 0:
|
||||
return torch.tensor(0.)
|
||||
device = parameters[0].grad.device
|
||||
total_norm = torch.norm(
|
||||
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||
clip_coef = max_norm / (total_norm + self.norm_clipping_epsilon)
|
||||
if clip_coef < 1:
|
||||
for p in parameters:
|
||||
p.grad.detach().mul_(clip_coef.to(p.grad.device))
|
||||
|
||||
@property
|
||||
def norm_clipping_epsilon(self):
|
||||
return 1e-5
|
||||
|
|
|
@ -11,11 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin
|
||||
|
||||
|
||||
class NativeAMPPlugin:
|
||||
class NativeAMPPlugin(PrecisionPlugin):
|
||||
|
||||
def __init__(self, trainer=None):
|
||||
"""
|
||||
|
@ -51,3 +55,7 @@ class NativeAMPPlugin:
|
|||
with torch.cuda.amp.autocast():
|
||||
output = fx(*args)
|
||||
return output
|
||||
|
||||
def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
|
||||
model = self.trainer.get_model()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_val, norm_type=norm_type)
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from typing import Union
|
||||
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class PrecisionPlugin(abc.ABC):
|
||||
"""
|
||||
Abstract class to extend for precision support (32/16 etc).
|
||||
|
||||
This is extended to cover any specific logic required for precision support such as AMP/APEX or sharded
|
||||
training.
|
||||
"""
|
||||
|
||||
def connect(self, model, optimizers):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, fx, args):
|
||||
raise NotImplementedError
|
||||
|
||||
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def clip_gradients(self, grad_clip_val: Union[int, float], optimizer: Optimizer, norm_type: float):
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue