From 38b96776380977163a6df664c2226eee8c349b93 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 5 Sep 2020 18:27:28 -0400 Subject: [PATCH] ref: inner train loop (intermediate step) 5/n (#3365) --- .../accelerators/base_backend.py | 53 +++++++++++++++++++ pytorch_lightning/accelerators/tpu_backend.py | 7 ++- pytorch_lightning/trainer/training_loop.py | 12 +---- pytorch_lightning/trainer/training_tricks.py | 34 ------------ 4 files changed, 61 insertions(+), 45 deletions(-) diff --git a/pytorch_lightning/accelerators/base_backend.py b/pytorch_lightning/accelerators/base_backend.py index a01dc3a1ad..92580a600f 100644 --- a/pytorch_lightning/accelerators/base_backend.py +++ b/pytorch_lightning/accelerators/base_backend.py @@ -3,6 +3,16 @@ from typing import Any from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +import math + + +try: + from apex import amp +except ImportError: + amp = None + +EPSILON = 1e-6 +EPSILON_FP16 = 1e-5 class Accelerator(object): @@ -96,3 +106,46 @@ class Accelerator(object): def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): model_ref = self.trainer.get_model() model_ref.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + + def clip_gradients(self, optimizer): + + if self.trainer.amp_backend == AMPType.NATIVE: + self.trainer.scaler.unscale_(optimizer) + + # apply clip gradients + # TODO: separate TPU case from here + self._clip_gradients(optimizer) + + def _clip_gradients(self, optimizer): + # this code is a modification of torch.nn.utils.clip_grad_norm_ + # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md + if self.trainer.gradient_clip_val <= 0: + return + + model = self.trainer.get_model() + if self.trainer.amp_backend == AMPType.APEX: + parameters = amp.master_params(optimizer) + else: + parameters = model.parameters() + + max_norm = float(self.trainer.gradient_clip_val) + norm_type = float(2.0) + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + if norm_type == math.inf: + total_norm = max(p.grad.data.abs().max() for p in parameters) + else: + device = parameters[0].device + out = torch.empty(len(parameters), device=device) + for i, p in enumerate(parameters): + torch.norm(p.grad.data.to(device), norm_type, out=out[i]) + total_norm = torch.norm(out, norm_type) + + eps = EPSILON_FP16 if self.trainer.precision == 16 else EPSILON + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) + clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) + for p in parameters: + p.grad.data.mul_(clip_coef.to(p.grad.data.device)) diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py index adb9c0711f..786babb957 100644 --- a/pytorch_lightning/accelerators/tpu_backend.py +++ b/pytorch_lightning/accelerators/tpu_backend.py @@ -19,7 +19,7 @@ import torch.multiprocessing as mp from pytorch_lightning import _logger as log from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.accelerators.base_backend import Accelerator @@ -245,3 +245,8 @@ class TPUBackend(Accelerator): on_tpu=True, using_lbfgs=is_lbfgs ) + + def clip_gradients(self, optimizer): + # apply clip gradients + # TODO: separate TPU case from here + self._clip_gradients(optimizer) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 564f9512d0..fee29c4f9e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -298,10 +298,6 @@ class TrainerTrainLoopMixin(ABC): def transfer_batch_to_gpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod - def clip_gradients(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod def detect_nan_tensors(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @@ -870,12 +866,8 @@ class TrainerTrainLoopMixin(ABC): grad_norm_dic = model.grad_norm( self.track_grad_norm) - # ------------------ - # CLIP GRADS - # ------------------ - if self.amp_backend == AMPType.NATIVE and not self.use_tpu: - self.scaler.unscale_(optimizer) - self.clip_gradients(optimizer) + # training trick + self.accelerator_backend.clip_gradients(optimizer) # ------------------ # .STEP + ZERO_GRAD diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 705abc6343..4f9cb7c920 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -42,12 +42,9 @@ class TrainerTrainingTricksMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class - gradient_clip_val: ... - precision: int default_root_dir: str progress_bar_callback: ... on_gpu: bool - amp_backend: AMPType @abstractmethod def get_model(self) -> LightningModule: @@ -65,37 +62,6 @@ class TrainerTrainingTricksMixin(ABC): def fit(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - def clip_gradients(self, optimizer): - - # this code is a modification of torch.nn.utils.clip_grad_norm_ - # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md - if self.gradient_clip_val <= 0: - return - model = self.get_model() - if self.amp_backend == AMPType.APEX: - parameters = amp.master_params(optimizer) - else: - parameters = model.parameters() - max_norm = float(self.gradient_clip_val) - norm_type = float(2.0) - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - if norm_type == math.inf: - total_norm = max(p.grad.data.abs().max() for p in parameters) - else: - device = parameters[0].device - out = torch.empty(len(parameters), device=device) - for i, p in enumerate(parameters): - torch.norm(p.grad.data.to(device), norm_type, out=out[i]) - total_norm = torch.norm(out, norm_type) - - eps = EPSILON_FP16 if self.precision == 16 else EPSILON - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) - clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) - for p in parameters: - p.grad.data.mul_(clip_coef.to(p.grad.data.device)) - def print_nan_gradients(self) -> None: model = self.get_model() for param in model.parameters():