Handle collision of user argument when using ShardedDDP (#9512)
* Handle collision of user argument * Add CHANGELOG.md
This commit is contained in:
parent
c784092013
commit
adaa2347f1
|
@ -358,6 +358,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386))
|
||||
|
||||
|
||||
- Fixed collision of user argument when using ShardedDDP ([#9512](https://github.com/PyTorchLightning/pytorch-lightning/pull/9512))
|
||||
|
||||
|
||||
## [1.4.5] - 2021-08-31
|
||||
|
||||
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
|
||||
|
|
|
@ -36,11 +36,14 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
|
||||
def configure_ddp(self) -> None:
|
||||
self._wrap_optimizers()
|
||||
|
||||
if "reduce_buffer_size" not in self._ddp_kwargs:
|
||||
# For multi-node training, enabling bucketing will improve performance.
|
||||
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
|
||||
|
||||
self._model = ShardedDataParallel(
|
||||
LightningShardedDataParallel(self.model),
|
||||
sharded_optimizer=self.lightning_module.trainer.optimizers,
|
||||
# For multi-node training, enabling bucketing will improve performance.
|
||||
reduce_buffer_size=self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0,
|
||||
**self._ddp_kwargs
|
||||
)
|
||||
setattr(self._model, "require_backward_grad_sync", False)
|
||||
|
|
|
@ -285,3 +285,27 @@ def test_custom_kwargs_sharded(tmpdir, cls):
|
|||
args, kwargs = mock_sharded.call_args
|
||||
assert "reduce_fp16" in kwargs
|
||||
assert kwargs["reduce_fp16"]
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, fairscale=True)
|
||||
@mock.patch("pytorch_lightning.plugins.DDPShardedPlugin._wrap_optimizers", autospec=True)
|
||||
@pytest.mark.parametrize(["params", "expected_buffer_size"], [(dict(), 0), (dict(reduce_buffer_size=128), 128)])
|
||||
@pytest.mark.parametrize("num_nodes", [1, 2])
|
||||
def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffer_size, num_nodes):
|
||||
"""Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs."""
|
||||
plugin = DDPShardedPlugin(**params)
|
||||
plugin.num_nodes = num_nodes
|
||||
|
||||
with mock.patch.object(plugin, "_model", autospec=True):
|
||||
with mock.patch(
|
||||
"pytorch_lightning.plugins.training_type.sharded.ShardedDataParallel", autospec=True
|
||||
) as mock_sharded:
|
||||
plugin.configure_ddp()
|
||||
args, kwargs = mock_sharded.call_args
|
||||
assert "reduce_buffer_size" in kwargs
|
||||
|
||||
if num_nodes > 1 and len(params) == 0:
|
||||
# If user has not specified a buffer size and we're using multiple nodes, check to see if default is set
|
||||
assert kwargs["reduce_buffer_size"] == DDPShardedPlugin._REDUCE_BUFFER_SIZE_DEFAULT
|
||||
else:
|
||||
assert kwargs["reduce_buffer_size"] == expected_buffer_size
|
||||
|
|
Loading…
Reference in New Issue