Update Gradient Clipping for TPU Accelerator (#6576)
This commit is contained in:
parent
983a888f49
commit
87c03b1038
|
@ -167,6 +167,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587))
|
||||
|
||||
|
||||
- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))
|
||||
|
||||
|
||||
## [1.2.3] - 2021-03-09
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
@ -12,6 +12,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
|
||||
if _XLA_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
from torch_xla._patched_functions import clip_grad_norm_
|
||||
|
||||
xla_clip_grad_norm_ = clip_grad_norm_
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
@ -55,3 +58,16 @@ class TPUAccelerator(Accelerator):
|
|||
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
|
||||
return xm.all_gather(tensor).view(-1, *tensor.shape)
|
||||
return tensor
|
||||
|
||||
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):
|
||||
|
||||
model = self.lightning_module
|
||||
parameters = model.parameters()
|
||||
|
||||
grad_clip_val = float(clip_val)
|
||||
if grad_clip_val <= 0:
|
||||
return
|
||||
|
||||
max_norm = grad_clip_val
|
||||
|
||||
xla_clip_grad_norm_(parameters, max_norm, norm_type)
|
||||
|
|
|
@ -100,7 +100,6 @@ class PrecisionPlugin(Plugin):
|
|||
|
||||
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
|
||||
"""Clips the gradients to a specific value"""
|
||||
# TODO: separate TPU case from here
|
||||
if clip_val is None:
|
||||
return
|
||||
|
||||
|
|
|
@ -355,3 +355,31 @@ def test_tpu_reduce():
|
|||
assert result.item() == 8
|
||||
|
||||
xmp.spawn(test_reduce, nprocs=8, start_method='fork')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("clip_val", [0, 10])
|
||||
@RunIf(tpu=True)
|
||||
@pl_multi_process_test
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
|
||||
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
|
||||
"""
|
||||
Ensure that clip gradients is only called if the value is greater than 0.
|
||||
"""
|
||||
tutils.reset_seed()
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
tpu_cores=1,
|
||||
precision=16,
|
||||
limit_train_batches=4,
|
||||
limit_val_batches=4,
|
||||
gradient_clip_val=clip_val,
|
||||
)
|
||||
model = BoringModel()
|
||||
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
|
||||
|
||||
if clip_val > 0:
|
||||
mock_clip_grad_norm.assert_called()
|
||||
else:
|
||||
mock_clip_grad_norm.assert_not_called()
|
||||
|
|
Loading…
Reference in New Issue