From f3c49b8e7736401445ee62181c20b2a2c3950e7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 7 Jun 2023 16:07:03 +0200 Subject: [PATCH] Remove warning on `no_backward_sync` with XLA strategy (#17761) --- src/lightning/fabric/CHANGELOG.md | 3 +++ src/lightning/fabric/fabric.py | 2 +- tests/tests_fabric/test_fabric.py | 7 ++++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index ad291fec16..9281b177f2 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed computing the next version folder in `CSVLogger` ([#17139](https://github.com/Lightning-AI/lightning/pull/17139)) +- Removed false positive warning when using `fabric.no_backward_sync` with XLA strategies ([#17761](https://github.com/Lightning-AI/lightning/pull/17761)) + + ## [2.0.1.post0] - 2023-04-11 No changes diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 91389b4dc0..04f331486c 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -573,7 +573,7 @@ class Fabric: "You need to set up the model first before you can call `self.no_backward_sync()`:" " `model = self.setup(model, ...)`" ) - if not enabled or isinstance(self._strategy, SingleDeviceStrategy): + if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)): context = nullcontext() elif self._strategy._backward_sync_control is None: rank_zero_warn( diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 01ff49170c..7ff7aec69a 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -637,6 +637,11 @@ def test_no_backward_sync(): with fabric.no_backward_sync(model): pass fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called() + # same for XLA + fabric._strategy = Mock(spec=XLAStrategy, _backward_sync_control=MagicMock()) + with fabric.no_backward_sync(model): + pass + fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called() # pretend that the strategy supports skipping backward sync fabric._strategy = Mock(_backward_sync_control=MagicMock()) @@ -644,7 +649,7 @@ def test_no_backward_sync(): with fabric.no_backward_sync(model, enabled=False): pass fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called() - # when enabld, the wrapped module gets passed down + # when enabled, the wrapped module gets passed down with fabric.no_backward_sync(model): pass fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module)