Fix gradient accumulation for `ShardedDataParallel` (#9122)

* Fix gradient accumulation for `ShardedDataParallel`

* Update changelog

* Update pytorch_lightning/plugins/training_type/sharded.py

* add test

* Update test_sharded_plugin.py

* Update test_sharded_plugin.py

* Update test_sharded_plugin.py
This commit is contained in:
ananthsub 2021-09-22 01:56:38 -07:00 committed by GitHub
parent 73e53e5b82
commit a71be50297
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 2 deletions

View File

@ -363,6 +363,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554))
- Fixed gradient accumulation for `DDPShardedPlugin` ([#9122](https://github.com/PyTorchLightning/pytorch-lightning/pull/9122))
## [1.4.7] - 2021-09-14
- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from contextlib import contextmanager
from typing import Dict, Generator, Optional
import torch
@ -100,6 +101,19 @@ class DDPShardedPlugin(DDPPlugin):
def pre_backward(self, closure_loss: torch.Tensor) -> None:
pass
@contextmanager
def block_backward_sync(self) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, ShardedDataParallel):
with self.model.no_sync():
yield None
else:
yield None
def post_training_step(self):
pass

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from contextlib import contextmanager
from typing import Dict, Generator, Optional
import torch
@ -63,6 +64,19 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)
@contextmanager
def block_backward_sync(self) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, ShardedDataParallel):
with self.model.no_sync():
yield None
else:
yield None
@rank_zero_only
def _optim_state_dict(self, optimizer):
"""

View File

@ -309,3 +309,13 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe
assert kwargs["reduce_buffer_size"] == DDPShardedPlugin._REDUCE_BUFFER_SIZE_DEFAULT
else:
assert kwargs["reduce_buffer_size"] == expected_buffer_size
@RunIf(skip_windows=True, fairscale=True)
def test_block_backward_sync(tmpdir):
plugin = DDPShardedPlugin()
model = mock.MagicMock(spec=ShardedDataParallel)
with mock.patch.object(plugin, "_model", model):
with plugin.block_backward_sync():
pass
model.no_sync.assert_called_once()