diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c1d7878e4e..9d99da26b9 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -331,7 +331,12 @@ class Accelerator: gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """clips all the optimizer parameters to the given value""" - self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm) + self.precision_plugin.clip_gradients( + optimizer, + clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, + model=self.model, + ) def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: """Hook to do something on the end of an training epoch diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index dc29a5cee4..f05fd4d54b 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union from torch import Tensor +from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl @@ -79,8 +80,9 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + model: Optional[Module] = None, ) -> None: """ - DeepSpeed handles clipping gradients via the training type plugin. + DeepSpeed handles clipping gradients internally via the training type plugin. """ pass diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index c1ea328796..f324b21732 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor @@ -104,6 +104,7 @@ class PrecisionPlugin(Plugin): optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + model: Optional[Module] = None ) -> None: """Clips the gradients""" if clip_val is None: