Handle collision of user argument when using ShardedDDP (#9512)

* Handle collision of user argument

* Add CHANGELOG.md
This commit is contained in:
Sean Naren 2021-09-14 13:20:36 +01:00 committed by GitHub
parent c784092013
commit adaa2347f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 2 deletions

View File

@ -358,6 +358,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386))
- Fixed collision of user argument when using ShardedDDP ([#9512](https://github.com/PyTorchLightning/pytorch-lightning/pull/9512))
## [1.4.5] - 2021-08-31
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))

View File

@ -36,11 +36,14 @@ class DDPShardedPlugin(DDPPlugin):
def configure_ddp(self) -> None:
self._wrap_optimizers()
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
sharded_optimizer=self.lightning_module.trainer.optimizers,
# For multi-node training, enabling bucketing will improve performance.
reduce_buffer_size=self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0,
**self._ddp_kwargs
)
setattr(self._model, "require_backward_grad_sync", False)

View File

@ -285,3 +285,27 @@ def test_custom_kwargs_sharded(tmpdir, cls):
args, kwargs = mock_sharded.call_args
assert "reduce_fp16" in kwargs
assert kwargs["reduce_fp16"]
@RunIf(skip_windows=True, fairscale=True)
@mock.patch("pytorch_lightning.plugins.DDPShardedPlugin._wrap_optimizers", autospec=True)
@pytest.mark.parametrize(["params", "expected_buffer_size"], [(dict(), 0), (dict(reduce_buffer_size=128), 128)])
@pytest.mark.parametrize("num_nodes", [1, 2])
def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffer_size, num_nodes):
"""Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs."""
plugin = DDPShardedPlugin(**params)
plugin.num_nodes = num_nodes
with mock.patch.object(plugin, "_model", autospec=True):
with mock.patch(
"pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True
) as mock_sharded:
plugin.configure_ddp()
args, kwargs = mock_sharded.call_args
assert "reduce_buffer_size" in kwargs
if num_nodes > 1 and len(params) == 0:
# If user has not specified a buffer size and we're using multiple nodes, check to see if default is set
assert kwargs["reduce_buffer_size"] == DDPShardedPlugin._REDUCE_BUFFER_SIZE_DEFAULT
else:
assert kwargs["reduce_buffer_size"] == expected_buffer_size