diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index 085d0e729d..18c8a15692 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -104,11 +104,12 @@ class ApexPlugin(PrecisionPlugin): """ This code is a modification of :meth:`torch.nn.utils.clip_grad_norm_` using a higher epsilon for fp16 weights. This is important when setting amp_level to O2, and the master weights are in fp16. + Args: grad_clip_val: Maximum norm of gradients. optimizer: Optimizer with gradients that will be clipped. norm_type: (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. + infinity norm. """ model = self.trainer.get_model() parameters = model.parameters()