From a71be50297cb9855b8076927bee96471bb24eb6f Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 22 Sep 2021 01:56:38 -0700 Subject: [PATCH] Fix gradient accumulation for `ShardedDataParallel` (#9122) * Fix gradient accumulation for `ShardedDataParallel` * Update changelog * Update pytorch_lightning/plugins/training_type/sharded.py * add test * Update test_sharded_plugin.py * Update test_sharded_plugin.py * Update test_sharded_plugin.py --- CHANGELOG.md | 3 +++ .../plugins/training_type/sharded.py | 16 +++++++++++++++- .../plugins/training_type/sharded_spawn.py | 16 +++++++++++++++- tests/plugins/test_sharded_plugin.py | 10 ++++++++++ 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 437d57a933..48895cf606 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -363,6 +363,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554)) +- Fixed gradient accumulation for `DDPShardedPlugin` ([#9122](https://github.com/PyTorchLightning/pytorch-lightning/pull/9122)) + + ## [1.4.7] - 2021-09-14 - Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364)) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 8ff960531e..d684a34784 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from contextlib import contextmanager +from typing import Dict, Generator, Optional import torch @@ -100,6 +101,19 @@ class DDPShardedPlugin(DDPPlugin): def pre_backward(self, closure_loss: torch.Tensor) -> None: pass + @contextmanager + def block_backward_sync(self) -> Generator: + """Blocks syncing gradients behaviour on backwards pass. + + This is useful for skipping sync when accumulating gradients, reducing communication overhead + Returns: context manager with sync behaviour off + """ + if isinstance(self.model, ShardedDataParallel): + with self.model.no_sync(): + yield None + else: + yield None + def post_training_step(self): pass diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index e73fcd43cf..921f897820 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from contextlib import contextmanager +from typing import Dict, Generator, Optional import torch @@ -63,6 +64,19 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer) + @contextmanager + def block_backward_sync(self) -> Generator: + """Blocks syncing gradients behaviour on backwards pass. + + This is useful for skipping sync when accumulating gradients, reducing communication overhead + Returns: context manager with sync behaviour off + """ + if isinstance(self.model, ShardedDataParallel): + with self.model.no_sync(): + yield None + else: + yield None + @rank_zero_only def _optim_state_dict(self, optimizer): """ diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index adee6e3ba2..6926e07c32 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -309,3 +309,13 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe assert kwargs["reduce_buffer_size"] == DDPShardedPlugin._REDUCE_BUFFER_SIZE_DEFAULT else: assert kwargs["reduce_buffer_size"] == expected_buffer_size + + +@RunIf(skip_windows=True, fairscale=True) +def test_block_backward_sync(tmpdir): + plugin = DDPShardedPlugin() + model = mock.MagicMock(spec=ShardedDataParallel) + with mock.patch.object(plugin, "_model", model): + with plugin.block_backward_sync(): + pass + model.no_sync.assert_called_once()