Fix gradient clipping (#1438)

* Fix gradient clipping

* Relax accuracy constraint
This commit is contained in:
Alex Sergeev 2020-04-09 18:08:28 -07:00 committed by GitHub
parent b2707c9b2e
commit 8dd9b80d7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 1 deletions

View File

@ -41,7 +41,7 @@ class TrainerTrainingTricksMixin(ABC):
total_norm = torch.zeros([], device=device if parameters else None)
for p in parameters:
param_norm = p.grad.data.norm(norm_type) ** norm_type
total_norm.add_(param_norm)
total_norm.add_(param_norm)
total_norm = (total_norm ** (1. / norm_type))
eps = EPSILON_FP16 if self.precision == 16 else EPSILON
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)

View File

@ -658,3 +658,30 @@ def test_trainer_interrupted_flag(tmpdir):
assert not trainer.interrupted
trainer.fit(model)
assert trainer.interrupted
def test_gradient_clipping(tmpdir):
"""
Test gradient clipping
"""
tutils.reset_seed()
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
# test that gradient is clipped correctly
def _optimizer_step(*args, **kwargs):
parameters = model.parameters()
grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm)
trainer = Trainer(max_steps=1,
max_epochs=1,
gradient_clip_val=1.0,
default_save_path=tmpdir)
# for the test
model.optimizer_step = _optimizer_step
model.prev_called_batch_idx = 0
trainer.fit(model)