2020-03-12 16:41:37 +00:00
|
|
|
import math
|
2020-03-19 13:24:45 +00:00
|
|
|
import sys
|
2019-12-04 15:57:32 +00:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
import torch
|
2020-03-19 13:24:45 +00:00
|
|
|
from torch import Tensor
|
2020-01-20 19:50:31 +00:00
|
|
|
|
2020-03-17 22:44:00 +00:00
|
|
|
from pytorch_lightning import _logger as log
|
2019-10-22 01:16:51 +00:00
|
|
|
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
|
|
|
|
2020-02-27 20:46:47 +00:00
|
|
|
EPSILON = 1e-6
|
|
|
|
EPSILON_FP16 = 1e-5
|
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2019-12-04 15:57:32 +00:00
|
|
|
class TrainerTrainingTricksMixin(ABC):
|
|
|
|
|
2020-02-27 21:21:14 +00:00
|
|
|
# this is just a summary on variables used in this abstract class,
|
|
|
|
# the proper values/initialisation should be done in child class
|
|
|
|
gradient_clip_val: ...
|
2020-03-19 13:24:45 +00:00
|
|
|
precision: ...
|
2019-12-04 15:57:32 +00:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def get_model(self):
|
2020-02-27 21:21:14 +00:00
|
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
2019-10-22 01:16:51 +00:00
|
|
|
|
|
|
|
def clip_gradients(self):
|
2020-02-27 20:46:47 +00:00
|
|
|
# this code is a modification of torch.nn.utils.clip_grad_norm_
|
|
|
|
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
|
2019-10-22 01:16:51 +00:00
|
|
|
if self.gradient_clip_val > 0:
|
|
|
|
model = self.get_model()
|
2020-02-27 20:46:47 +00:00
|
|
|
parameters = model.parameters()
|
|
|
|
max_norm = float(self.gradient_clip_val)
|
|
|
|
norm_type = float(2.0)
|
|
|
|
if isinstance(parameters, torch.Tensor):
|
|
|
|
parameters = [parameters]
|
|
|
|
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
|
|
|
if norm_type == math.inf:
|
|
|
|
total_norm = max(p.grad.data.abs().max() for p in parameters)
|
|
|
|
else:
|
|
|
|
device = parameters[0].device
|
|
|
|
total_norm = torch.zeros([], device=device if parameters else None)
|
|
|
|
for p in parameters:
|
|
|
|
param_norm = p.grad.data.norm(norm_type) ** norm_type
|
|
|
|
total_norm.add_(param_norm)
|
|
|
|
total_norm = (total_norm ** (1. / norm_type))
|
|
|
|
eps = EPSILON_FP16 if self.precision == 16 else EPSILON
|
|
|
|
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
|
|
|
|
for p in parameters:
|
|
|
|
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2020-03-19 13:24:45 +00:00
|
|
|
def print_nan_gradients(self) -> None:
|
2019-10-22 01:16:51 +00:00
|
|
|
model = self.get_model()
|
|
|
|
for param in model.parameters():
|
2019-12-04 12:04:58 +00:00
|
|
|
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
|
2020-02-01 20:47:58 +00:00
|
|
|
log.info(param, param.grad)
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2020-03-19 13:24:45 +00:00
|
|
|
def detect_nan_tensors(self, loss: Tensor) -> None:
|
|
|
|
model = self.get_model()
|
|
|
|
|
|
|
|
# check if loss is nan
|
|
|
|
if not torch.isfinite(loss).all():
|
|
|
|
raise ValueError(
|
|
|
|
'The loss returned in `training_step` is nan or inf.'
|
|
|
|
)
|
|
|
|
# check if a network weight is nan
|
|
|
|
for name, param in model.named_parameters():
|
|
|
|
if not torch.isfinite(param).all():
|
|
|
|
self.print_nan_gradients()
|
|
|
|
raise ValueError(
|
|
|
|
f'Detected nan and/or inf values in `{name}`.'
|
|
|
|
' Check your forward pass for numerically unstable operations.'
|
|
|
|
)
|
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
def configure_accumulated_gradients(self, accumulate_grad_batches):
|
|
|
|
if isinstance(accumulate_grad_batches, dict):
|
|
|
|
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
|
|
|
|
elif isinstance(accumulate_grad_batches, int):
|
|
|
|
schedule = {1: accumulate_grad_batches}
|
|
|
|
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
|
|
|
|
else:
|
|
|
|
raise TypeError("Gradient accumulation supports only int and dict types")
|