From 877d95f8d7d20134bb7a9b5cdf1a9b0d51aa82c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 21 Apr 2023 21:36:46 +0200 Subject: [PATCH] Minor Fabric backward refactor (#17433) --- src/lightning/fabric/fabric.py | 2 +- tests/tests_fabric/test_fabric.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index d9f5e9821b..47c000ac28 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -363,7 +363,7 @@ class Fabric: # requires to attach the current `DeepSpeedEngine` for the `_FabricOptimizer.step` call. self._strategy._deepspeed_engine = module - self._precision.backward(tensor, module, *args, **kwargs) + self._strategy.backward(tensor, module, *args, **kwargs) def clip_gradients( self, diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 3fded110f3..b274241819 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -568,10 +568,10 @@ def test_rank_properties(): def test_backward(): """Test that backward() calls into the precision plugin.""" fabric = Fabric() - fabric._precision = Mock(spec=Precision) + fabric._strategy = Mock(spec=Precision) loss = Mock() fabric.backward(loss, "arg", keyword="kwarg") - fabric._precision.backward.assert_called_with(loss, None, "arg", keyword="kwarg") + fabric._strategy.backward.assert_called_with(loss, None, "arg", keyword="kwarg") @RunIf(deepspeed=True, mps=False)