Minor Fabric backward refactor (#17433)
This commit is contained in:
parent
0ee71d6a7a
commit
877d95f8d7
|
@ -363,7 +363,7 @@ class Fabric:
|
|||
# requires to attach the current `DeepSpeedEngine` for the `_FabricOptimizer.step` call.
|
||||
self._strategy._deepspeed_engine = module
|
||||
|
||||
self._precision.backward(tensor, module, *args, **kwargs)
|
||||
self._strategy.backward(tensor, module, *args, **kwargs)
|
||||
|
||||
def clip_gradients(
|
||||
self,
|
||||
|
|
|
@ -568,10 +568,10 @@ def test_rank_properties():
|
|||
def test_backward():
|
||||
"""Test that backward() calls into the precision plugin."""
|
||||
fabric = Fabric()
|
||||
fabric._precision = Mock(spec=Precision)
|
||||
fabric._strategy = Mock(spec=Precision)
|
||||
loss = Mock()
|
||||
fabric.backward(loss, "arg", keyword="kwarg")
|
||||
fabric._precision.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
|
||||
fabric._strategy.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
|
||||
|
||||
|
||||
@RunIf(deepspeed=True, mps=False)
|
||||
|
|
Loading…
Reference in New Issue