Added gradient clip test for native AMP (#3754)

* added gradient clip test for fp16

* pep8
This commit is contained in:
Teddy Koker 2020-10-01 01:36:34 -04:00 committed by GitHub
parent a38d108a68
commit 5ec00ccd28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 0 deletions

View File

@ -23,6 +23,7 @@ from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from tests.base import EvalModelTemplate from tests.base import EvalModelTemplate
@ -867,6 +868,36 @@ def test_gradient_clipping(tmpdir):
trainer.fit(model) trainer.fit(model)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="test requires native AMP.")
def test_gradient_clipping_fp16(tmpdir):
"""
Test gradient clipping with fp16
"""
model = EvalModelTemplate()
# test that gradient is clipped correctly
def _optimizer_step(*args, **kwargs):
parameters = model.parameters()
grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm)
trainer = Trainer(
max_steps=1,
max_epochs=1,
precision=16,
gpus=1,
gradient_clip_val=1.0,
default_root_dir=tmpdir,
)
# for the test
model.optimizer_step = _optimizer_step
model.prev_called_batch_idx = 0
trainer.fit(model)
def test_gpu_choice(tmpdir): def test_gpu_choice(tmpdir):
trainer_options = dict( trainer_options = dict(
default_root_dir=tmpdir, default_root_dir=tmpdir,