Fix gradient accumulation for `ShardedDataParallel` (#9122)
* Fix gradient accumulation for `ShardedDataParallel` * Update changelog * Update pytorch_lightning/plugins/training_type/sharded.py * add test * Update test_sharded_plugin.py * Update test_sharded_plugin.py * Update test_sharded_plugin.py
This commit is contained in:
parent
73e53e5b82
commit
a71be50297
|
@ -363,6 +363,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554))
|
||||
|
||||
|
||||
- Fixed gradient accumulation for `DDPShardedPlugin` ([#9122](https://github.com/PyTorchLightning/pytorch-lightning/pull/9122))
|
||||
|
||||
|
||||
## [1.4.7] - 2021-09-14
|
||||
|
||||
- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Optional
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Generator, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -100,6 +101,19 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
def pre_backward(self, closure_loss: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def block_backward_sync(self) -> Generator:
|
||||
"""Blocks syncing gradients behaviour on backwards pass.
|
||||
|
||||
This is useful for skipping sync when accumulating gradients, reducing communication overhead
|
||||
Returns: context manager with sync behaviour off
|
||||
"""
|
||||
if isinstance(self.model, ShardedDataParallel):
|
||||
with self.model.no_sync():
|
||||
yield None
|
||||
else:
|
||||
yield None
|
||||
|
||||
def post_training_step(self):
|
||||
pass
|
||||
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Optional
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Generator, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -63,6 +64,19 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
optimizer.consolidate_state_dict()
|
||||
return self._optim_state_dict(optimizer)
|
||||
|
||||
@contextmanager
|
||||
def block_backward_sync(self) -> Generator:
|
||||
"""Blocks syncing gradients behaviour on backwards pass.
|
||||
|
||||
This is useful for skipping sync when accumulating gradients, reducing communication overhead
|
||||
Returns: context manager with sync behaviour off
|
||||
"""
|
||||
if isinstance(self.model, ShardedDataParallel):
|
||||
with self.model.no_sync():
|
||||
yield None
|
||||
else:
|
||||
yield None
|
||||
|
||||
@rank_zero_only
|
||||
def _optim_state_dict(self, optimizer):
|
||||
"""
|
||||
|
|
|
@ -309,3 +309,13 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe
|
|||
assert kwargs["reduce_buffer_size"] == DDPShardedPlugin._REDUCE_BUFFER_SIZE_DEFAULT
|
||||
else:
|
||||
assert kwargs["reduce_buffer_size"] == expected_buffer_size
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, fairscale=True)
|
||||
def test_block_backward_sync(tmpdir):
|
||||
plugin = DDPShardedPlugin()
|
||||
model = mock.MagicMock(spec=ShardedDataParallel)
|
||||
with mock.patch.object(plugin, "_model", model):
|
||||
with plugin.block_backward_sync():
|
||||
pass
|
||||
model.no_sync.assert_called_once()
|
||||
|
|
Loading…
Reference in New Issue