Remove warning on `no_backward_sync` with XLA strategy (#17761)

This commit is contained in:
Carlos Mocholí 2023-06-07 16:07:03 +02:00 committed by GitHub
parent 420eb6f248
commit f3c49b8e77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 2 deletions

View File

@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed computing the next version folder in `CSVLogger` ([#17139](https://github.com/Lightning-AI/lightning/pull/17139))
- Removed false positive warning when using `fabric.no_backward_sync` with XLA strategies ([#17761](https://github.com/Lightning-AI/lightning/pull/17761))
## [2.0.1.post0] - 2023-04-11
No changes

View File

@ -573,7 +573,7 @@ class Fabric:
"You need to set up the model first before you can call `self.no_backward_sync()`:"
" `model = self.setup(model, ...)`"
)
if not enabled or isinstance(self._strategy, SingleDeviceStrategy):
if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
context = nullcontext()
elif self._strategy._backward_sync_control is None:
rank_zero_warn(

View File

@ -637,6 +637,11 @@ def test_no_backward_sync():
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# same for XLA
fabric._strategy = Mock(spec=XLAStrategy, _backward_sync_control=MagicMock())
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# pretend that the strategy supports skipping backward sync
fabric._strategy = Mock(_backward_sync_control=MagicMock())
@ -644,7 +649,7 @@ def test_no_backward_sync():
with fabric.no_backward_sync(model, enabled=False):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# when enabld, the wrapped module gets passed down
# when enabled, the wrapped module gets passed down
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module)