Minor Fabric backward refactor (#17433)
This commit is contained in:
parent
0ee71d6a7a
commit
877d95f8d7
|
@ -363,7 +363,7 @@ class Fabric:
|
||||||
# requires to attach the current `DeepSpeedEngine` for the `_FabricOptimizer.step` call.
|
# requires to attach the current `DeepSpeedEngine` for the `_FabricOptimizer.step` call.
|
||||||
self._strategy._deepspeed_engine = module
|
self._strategy._deepspeed_engine = module
|
||||||
|
|
||||||
self._precision.backward(tensor, module, *args, **kwargs)
|
self._strategy.backward(tensor, module, *args, **kwargs)
|
||||||
|
|
||||||
def clip_gradients(
|
def clip_gradients(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -568,10 +568,10 @@ def test_rank_properties():
|
||||||
def test_backward():
|
def test_backward():
|
||||||
"""Test that backward() calls into the precision plugin."""
|
"""Test that backward() calls into the precision plugin."""
|
||||||
fabric = Fabric()
|
fabric = Fabric()
|
||||||
fabric._precision = Mock(spec=Precision)
|
fabric._strategy = Mock(spec=Precision)
|
||||||
loss = Mock()
|
loss = Mock()
|
||||||
fabric.backward(loss, "arg", keyword="kwarg")
|
fabric.backward(loss, "arg", keyword="kwarg")
|
||||||
fabric._precision.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
|
fabric._strategy.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
|
||||||
|
|
||||||
|
|
||||||
@RunIf(deepspeed=True, mps=False)
|
@RunIf(deepspeed=True, mps=False)
|
||||||
|
|
Loading…
Reference in New Issue