Add back `clip_gradients(model)` (#7231)
This commit is contained in:
parent
3b36d81c03
commit
ca6c87ffbe
|
@ -331,7 +331,12 @@ class Accelerator:
|
|||
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
|
||||
) -> None:
|
||||
"""clips all the optimizer parameters to the given value"""
|
||||
self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
|
||||
self.precision_plugin.clip_gradients(
|
||||
optimizer,
|
||||
clip_val,
|
||||
gradient_clip_algorithm=gradient_clip_algorithm,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
|
||||
"""Hook to do something on the end of an training epoch
|
||||
|
|
|
@ -11,9 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -79,8 +80,9 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
|
|||
optimizer: Optimizer,
|
||||
clip_val: Union[int, float],
|
||||
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
|
||||
model: Optional[Module] = None,
|
||||
) -> None:
|
||||
"""
|
||||
DeepSpeed handles clipping gradients via the training type plugin.
|
||||
DeepSpeed handles clipping gradients internally via the training type plugin.
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -104,6 +104,7 @@ class PrecisionPlugin(Plugin):
|
|||
optimizer: Optimizer,
|
||||
clip_val: Union[int, float],
|
||||
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
|
||||
model: Optional[Module] = None
|
||||
) -> None:
|
||||
"""Clips the gradients"""
|
||||
if clip_val is None:
|
||||
|
|
Loading…
Reference in New Issue