Minor Fabric backward refactor (#17433)

This commit is contained in:
Adrian Wälchli 2023-04-21 21:36:46 +02:00 committed by GitHub
parent 0ee71d6a7a
commit 877d95f8d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -363,7 +363,7 @@ class Fabric:
# requires to attach the current `DeepSpeedEngine` for the `_FabricOptimizer.step` call.
self._strategy._deepspeed_engine = module
self._precision.backward(tensor, module, *args, **kwargs)
self._strategy.backward(tensor, module, *args, **kwargs)
def clip_gradients(
self,

View File

@ -568,10 +568,10 @@ def test_rank_properties():
def test_backward():
"""Test that backward() calls into the precision plugin."""
fabric = Fabric()
fabric._precision = Mock(spec=Precision)
fabric._strategy = Mock(spec=Precision)
loss = Mock()
fabric.backward(loss, "arg", keyword="kwarg")
fabric._precision.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
fabric._strategy.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
@RunIf(deepspeed=True, mps=False)