ref: inner train loop (intermediate step) 5/n (#3365)
This commit is contained in:
parent
c7ef5ee874
commit
38b9677638
|
@ -3,6 +3,16 @@ from typing import Any
|
|||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities import AMPType, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
import math
|
||||
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
amp = None
|
||||
|
||||
EPSILON = 1e-6
|
||||
EPSILON_FP16 = 1e-5
|
||||
|
||||
|
||||
class Accelerator(object):
|
||||
|
@ -96,3 +106,46 @@ class Accelerator(object):
|
|||
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
|
||||
model_ref = self.trainer.get_model()
|
||||
model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||
|
||||
def clip_gradients(self, optimizer):
|
||||
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
self.trainer.scaler.unscale_(optimizer)
|
||||
|
||||
# apply clip gradients
|
||||
# TODO: separate TPU case from here
|
||||
self._clip_gradients(optimizer)
|
||||
|
||||
def _clip_gradients(self, optimizer):
|
||||
# this code is a modification of torch.nn.utils.clip_grad_norm_
|
||||
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
|
||||
if self.trainer.gradient_clip_val <= 0:
|
||||
return
|
||||
|
||||
model = self.trainer.get_model()
|
||||
if self.trainer.amp_backend == AMPType.APEX:
|
||||
parameters = amp.master_params(optimizer)
|
||||
else:
|
||||
parameters = model.parameters()
|
||||
|
||||
max_norm = float(self.trainer.gradient_clip_val)
|
||||
norm_type = float(2.0)
|
||||
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
|
||||
if norm_type == math.inf:
|
||||
total_norm = max(p.grad.data.abs().max() for p in parameters)
|
||||
else:
|
||||
device = parameters[0].device
|
||||
out = torch.empty(len(parameters), device=device)
|
||||
for i, p in enumerate(parameters):
|
||||
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
|
||||
total_norm = torch.norm(out, norm_type)
|
||||
|
||||
eps = EPSILON_FP16 if self.trainer.precision == 16 else EPSILON
|
||||
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
|
||||
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
|
||||
for p in parameters:
|
||||
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
|
||||
|
|
|
@ -19,7 +19,7 @@ import torch.multiprocessing as mp
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||
|
||||
|
@ -245,3 +245,8 @@ class TPUBackend(Accelerator):
|
|||
on_tpu=True,
|
||||
using_lbfgs=is_lbfgs
|
||||
)
|
||||
|
||||
def clip_gradients(self, optimizer):
|
||||
# apply clip gradients
|
||||
# TODO: separate TPU case from here
|
||||
self._clip_gradients(optimizer)
|
||||
|
|
|
@ -298,10 +298,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
def transfer_batch_to_gpu(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def clip_gradients(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def detect_nan_tensors(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
@ -870,12 +866,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
grad_norm_dic = model.grad_norm(
|
||||
self.track_grad_norm)
|
||||
|
||||
# ------------------
|
||||
# CLIP GRADS
|
||||
# ------------------
|
||||
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
|
||||
self.scaler.unscale_(optimizer)
|
||||
self.clip_gradients(optimizer)
|
||||
# training trick
|
||||
self.accelerator_backend.clip_gradients(optimizer)
|
||||
|
||||
# ------------------
|
||||
# .STEP + ZERO_GRAD
|
||||
|
|
|
@ -42,12 +42,9 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
|
||||
# this is just a summary on variables used in this abstract class,
|
||||
# the proper values/initialisation should be done in child class
|
||||
gradient_clip_val: ...
|
||||
precision: int
|
||||
default_root_dir: str
|
||||
progress_bar_callback: ...
|
||||
on_gpu: bool
|
||||
amp_backend: AMPType
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self) -> LightningModule:
|
||||
|
@ -65,37 +62,6 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
def fit(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
def clip_gradients(self, optimizer):
|
||||
|
||||
# this code is a modification of torch.nn.utils.clip_grad_norm_
|
||||
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
|
||||
if self.gradient_clip_val <= 0:
|
||||
return
|
||||
model = self.get_model()
|
||||
if self.amp_backend == AMPType.APEX:
|
||||
parameters = amp.master_params(optimizer)
|
||||
else:
|
||||
parameters = model.parameters()
|
||||
max_norm = float(self.gradient_clip_val)
|
||||
norm_type = float(2.0)
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
if norm_type == math.inf:
|
||||
total_norm = max(p.grad.data.abs().max() for p in parameters)
|
||||
else:
|
||||
device = parameters[0].device
|
||||
out = torch.empty(len(parameters), device=device)
|
||||
for i, p in enumerate(parameters):
|
||||
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
|
||||
total_norm = torch.norm(out, norm_type)
|
||||
|
||||
eps = EPSILON_FP16 if self.precision == 16 else EPSILON
|
||||
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
|
||||
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
|
||||
for p in parameters:
|
||||
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
|
||||
|
||||
def print_nan_gradients(self) -> None:
|
||||
model = self.get_model()
|
||||
for param in model.parameters():
|
||||
|
|
Loading…
Reference in New Issue