Fix gradient clipping (#1438)
* Fix gradient clipping * Relax accuracy constraint
This commit is contained in:
parent
b2707c9b2e
commit
8dd9b80d7a
|
@ -41,7 +41,7 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
total_norm = torch.zeros([], device=device if parameters else None)
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type) ** norm_type
|
||||
total_norm.add_(param_norm)
|
||||
total_norm.add_(param_norm)
|
||||
total_norm = (total_norm ** (1. / norm_type))
|
||||
eps = EPSILON_FP16 if self.precision == 16 else EPSILON
|
||||
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
|
||||
|
|
|
@ -658,3 +658,30 @@ def test_trainer_interrupted_flag(tmpdir):
|
|||
assert not trainer.interrupted
|
||||
trainer.fit(model)
|
||||
assert trainer.interrupted
|
||||
|
||||
|
||||
def test_gradient_clipping(tmpdir):
|
||||
"""
|
||||
Test gradient clipping
|
||||
"""
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
# test that gradient is clipped correctly
|
||||
def _optimizer_step(*args, **kwargs):
|
||||
parameters = model.parameters()
|
||||
grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
|
||||
assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm)
|
||||
|
||||
trainer = Trainer(max_steps=1,
|
||||
max_epochs=1,
|
||||
gradient_clip_val=1.0,
|
||||
default_save_path=tmpdir)
|
||||
|
||||
# for the test
|
||||
model.optimizer_step = _optimizer_step
|
||||
model.prev_called_batch_idx = 0
|
||||
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue