Update Gradient Clipping for TPU Accelerator (#6576)

This commit is contained in:
Kaushik B 2021-03-20 01:02:57 +05:30 committed by GitHub
parent 983a888f49
commit 87c03b1038
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 2 deletions

View File

@ -167,6 +167,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587))
- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))
## [1.2.3] - 2021-03-09
### Fixed

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
from torch.optim import Optimizer
@ -12,6 +12,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
from torch_xla._patched_functions import clip_grad_norm_
xla_clip_grad_norm_ = clip_grad_norm_
if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
@ -55,3 +58,16 @@ class TPUAccelerator(Accelerator):
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
return xm.all_gather(tensor).view(-1, *tensor.shape)
return tensor
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):
model = self.lightning_module
parameters = model.parameters()
grad_clip_val = float(clip_val)
if grad_clip_val <= 0:
return
max_norm = grad_clip_val
xla_clip_grad_norm_(parameters, max_norm, norm_type)

View File

@ -100,7 +100,6 @@ class PrecisionPlugin(Plugin):
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
"""Clips the gradients to a specific value"""
# TODO: separate TPU case from here
if clip_val is None:
return

View File

@ -355,3 +355,31 @@ def test_tpu_reduce():
assert result.item() == 8
xmp.spawn(test_reduce, nprocs=8, start_method='fork')
@pytest.mark.parametrize("clip_val", [0, 10])
@RunIf(tpu=True)
@pl_multi_process_test
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
"""
Ensure that clip gradients is only called if the value is greater than 0.
"""
tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=1,
precision=16,
limit_train_batches=4,
limit_val_batches=4,
gradient_clip_val=clip_val,
)
model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
if clip_val > 0:
mock_clip_grad_norm.assert_called()
else:
mock_clip_grad_norm.assert_not_called()