Minor Fabric backward refactor (#17433)

This commit is contained in:
Adrian Wälchli 2023-04-21 21:36:46 +02:00 committed by GitHub
parent 0ee71d6a7a
commit 877d95f8d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -363,7 +363,7 @@ class Fabric:
# requires to attach the current `DeepSpeedEngine` for the `_FabricOptimizer.step` call. # requires to attach the current `DeepSpeedEngine` for the `_FabricOptimizer.step` call.
self._strategy._deepspeed_engine = module self._strategy._deepspeed_engine = module
self._precision.backward(tensor, module, *args, **kwargs) self._strategy.backward(tensor, module, *args, **kwargs)
def clip_gradients( def clip_gradients(
self, self,

View File

@ -568,10 +568,10 @@ def test_rank_properties():
def test_backward(): def test_backward():
"""Test that backward() calls into the precision plugin.""" """Test that backward() calls into the precision plugin."""
fabric = Fabric() fabric = Fabric()
fabric._precision = Mock(spec=Precision) fabric._strategy = Mock(spec=Precision)
loss = Mock() loss = Mock()
fabric.backward(loss, "arg", keyword="kwarg") fabric.backward(loss, "arg", keyword="kwarg")
fabric._precision.backward.assert_called_with(loss, None, "arg", keyword="kwarg") fabric._strategy.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
@RunIf(deepspeed=True, mps=False) @RunIf(deepspeed=True, mps=False)