diff --git a/CHANGELOG.md b/CHANGELOG.md index e8b948b26d..266d789765 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -358,6 +358,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386)) +- Fixed collision of user argument when using ShardedDDP ([#9512](https://github.com/PyTorchLightning/pytorch-lightning/pull/9512)) + + ## [1.4.5] - 2021-08-31 - Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 817dfd65ce..8ff960531e 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -36,11 +36,14 @@ class DDPShardedPlugin(DDPPlugin): def configure_ddp(self) -> None: self._wrap_optimizers() + + if "reduce_buffer_size" not in self._ddp_kwargs: + # For multi-node training, enabling bucketing will improve performance. + self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 + self._model = ShardedDataParallel( LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers, - # For multi-node training, enabling bucketing will improve performance. - reduce_buffer_size=self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0, **self._ddp_kwargs ) setattr(self._model, "require_backward_grad_sync", False) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index f3ff6d8e5c..adee6e3ba2 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -285,3 +285,27 @@ def test_custom_kwargs_sharded(tmpdir, cls): args, kwargs = mock_sharded.call_args assert "reduce_fp16" in kwargs assert kwargs["reduce_fp16"] + + +@RunIf(skip_windows=True, fairscale=True) +@mock.patch("pytorch_lightning.plugins.DDPShardedPlugin._wrap_optimizers", autospec=True) +@pytest.mark.parametrize(["params", "expected_buffer_size"], [(dict(), 0), (dict(reduce_buffer_size=128), 128)]) +@pytest.mark.parametrize("num_nodes", [1, 2]) +def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffer_size, num_nodes): + """Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs.""" + plugin = DDPShardedPlugin(**params) + plugin.num_nodes = num_nodes + + with mock.patch.object(plugin, "_model", autospec=True): + with mock.patch( + "pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True + ) as mock_sharded: + plugin.configure_ddp() + args, kwargs = mock_sharded.call_args + assert "reduce_buffer_size" in kwargs + + if num_nodes > 1 and len(params) == 0: + # If user has not specified a buffer size and we're using multiple nodes, check to see if default is set + assert kwargs["reduce_buffer_size"] == DDPShardedPlugin._REDUCE_BUFFER_SIZE_DEFAULT + else: + assert kwargs["reduce_buffer_size"] == expected_buffer_size