Supporting Adding DDP Communication Hooks (#6736)

* Fix some test errors
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* checkpoint consolidation

* Update ddp_spawn.py

* Update test_metric_result_integration.py

* Update test_results.py

* Update utils.py

* Update utils.py

* Update test_all_gather_grad.py

* Update test_all_gather_grad.py

* Update test_results.py

* Revert "Update test_results.py"

This reverts commit 9d4a2b891d.

* Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate"

This reverts commit c5053da789, reversing
changes made to 0d23d75bc9.

* Revert "Update test_all_gather_grad.py"

This reverts commit 0d23d75bc9.

* Revert "Update utils.py"

This reverts commit 70fe5da9c6.

* Revert "Update utils.py"

This reverts commit a9aae99f6e.

* Revert "Update test_results.py"

This reverts commit ea74906878.

* Revert "Update test_metric_result_integration.py"

This reverts commit bf70e431b3.

* Revert "Update ddp_spawn.py"

This reverts commit f17210183b.

* Revert "checkpoint consolidation"

This reverts commit 536c1323b0.

* Revert "Revert "checkpoint consolidation""

This reverts commit 3a9fde915a.

* Revert "Revert "Revert "checkpoint consolidation"""

This reverts commit 7a369f47e1.

* Revert "Revert "Update ddp_spawn.py""

This reverts commit 8222dc98ea.

* Revert "Revert "Update test_metric_result_integration.py""

This reverts commit 6c095b2370.

* Revert "Revert "Update test_results.py""

This reverts commit 250d0aaaa2.

* Revert "Revert "Update utils.py""

This reverts commit 8651d54d79.

* Revert "Revert "Update test_all_gather_grad.py""

This reverts commit dcdcd29731.

* modify distributed environment to make test pass

* add DDP communication hook

* remove test related setting

* remove more test related setting

* fix ddp comm hook util import issue

* comments

* one more fix for test_custom_plugin

* fix ddp spwan

* fix sgd

* address comments and add tests

* 1. add is gpu checking 2. modify test a bit 3. formatting

* formatting nit

* fix conda 3.7 1.7 issue for no torch.distributed.algorithms module

* need at least 1.8.0

* minor fix

* modify changelog

* changelog should link to PR number instead of issue number

* refine a bit on doc for register_ddp_comm_hook function, like ddp_comm_wrapper explanation and add hyperparameter for power sgd states in example usge

* move single device checking before call register_ddp_comm_hook

* formatting

* comments

* typo

* pre-commit formatting
This commit is contained in:
shuyingsunshine21 2021-04-07 04:35:57 -07:00 committed by GitHub
parent 86e1d9f759
commit 313e81638d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 309 additions and 5 deletions

View File

@ -81,6 +81,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))
- Added support for DDP communication hooks ([#6736](https://github.com/PyTorchLightning/pytorch-lightning/issues/6736))
- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677))

View File

@ -29,7 +29,12 @@ from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
from pytorch_lightning.utilities import (
_HYDRA_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
@ -37,6 +42,8 @@ from pytorch_lightning.utilities.seed import seed_everything
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
log = logging.getLogger(__name__)
@ -58,6 +65,9 @@ class DDPPlugin(ParallelPlugin):
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
@ -70,6 +80,9 @@ class DDPPlugin(ParallelPlugin):
self.task_idx = None
self.node_rank = 0
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
@property
def root_device(self):
@ -80,6 +93,10 @@ class DDPPlugin(ParallelPlugin):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs
@property
def _is_single_process_single_device(self) -> bool:
return True
def setup_environment(self):
# start the other scripts
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
@ -218,6 +235,17 @@ class DDPPlugin(ParallelPlugin):
)
self._ddp_kwargs["find_unused_parameters"] = True
def _register_ddp_hooks(self) -> None:
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device):
register_ddp_comm_hook(
model=self._model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)
def configure_ddp(self):
self.pre_configure_ddp()
self._model = DistributedDataParallel(
@ -225,6 +253,7 @@ class DDPPlugin(ParallelPlugin):
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)
self._register_ddp_hooks()
def determine_ddp_device_ids(self):
if self.root_device.type == "cpu":

View File

@ -59,6 +59,10 @@ class DDP2Plugin(DDPPlugin):
distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank)
return distributed_sampler_kwargs
@property
def _is_single_process_single_device(self) -> bool:
return False
def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()

View File

@ -28,12 +28,15 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.seed import seed_everything
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
log = logging.getLogger(__name__)
@ -47,6 +50,9 @@ class DDPSpawnPlugin(ParallelPlugin):
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Union[Any, Dict[str, Any]],
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
@ -54,9 +60,12 @@ class DDPSpawnPlugin(ParallelPlugin):
self.sync_batchnorm = sync_batchnorm
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices)
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.node_rank = 0
self.mp_queue = None
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
def __getstate__(self):
""" Makes this plugin pickleable without destroying the queue in the current process. """
@ -76,9 +85,12 @@ class DDPSpawnPlugin(ParallelPlugin):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs
@property
def _is_single_process_single_device(self):
return True
def setup(self, model):
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
# pass in a state q
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()
@ -181,6 +193,17 @@ class DDPSpawnPlugin(ParallelPlugin):
)
self._ddp_kwargs["find_unused_parameters"] = True
def _register_ddp_hooks(self) -> None:
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device):
register_ddp_comm_hook(
model=self._model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)
def configure_ddp(self):
self.pre_configure_ddp()
self._model = DistributedDataParallel(
@ -188,6 +211,7 @@ class DDPSpawnPlugin(ParallelPlugin):
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)
self._register_ddp_hooks()
def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
# TODO: this code is duplicated in DDP and DDPSpawn, make this a function

View File

@ -47,6 +47,8 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
_RPC_AVAILABLE,
_TORCH_GREATER_EQUAL_1_6,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
_TORCH_LOWER_EQUAL_1_4,
_TORCH_QUANTIZE_AVAILABLE,
_TORCHTEXT_AVAILABLE,

View File

@ -17,9 +17,15 @@ import os
import warnings
from functools import partial, wraps
from typing import Any, Optional, Union
from pytorch_lightning.utilities.imports import (
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
)
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
log = logging.getLogger(__name__)
if torch.distributed.is_available():
@ -208,3 +214,107 @@ def all_gather_ddp_if_available(
with torch.no_grad():
return AllGatherGrad.apply(tensor, group)
return tensor
def register_ddp_comm_hook(
model: DistributedDataParallel,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
) -> None:
"""
Function to register communication hook for DDP model
https://pytorch.org/docs/master/ddp_comm_hooks.html
Args:
model:
DDP model
ddp_comm_state:
state is passed to the hook and can be used to maintain
and update any state information that users would like to
maintain as part of the training process. Examples: error
feedback in gradient compression, peers to communicate with
next in GossipGrad etc.
ddp_comm_hook:
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
This callable function is called once the bucket is ready. The
hook can perform whatever processing is needed and return
a Future indicating completion of any async work (ex: allreduce).
If the hook doesn't perform any communication, it can also
just return a completed Future. The Future should hold the
new value of grad bucket's tensors. Once a bucket is ready,
c10d reducer would call this hook and use the tensors returned
by the Future and copy grads to individual parameters.
ddp_comm_wrapper:
communication hook wraper to support a communication hook such
as FP16 compression as wrapper, which could be combined with
ddp_comm_hook
.. warning ::
DDP communication hook needs pytorch version at least 1.8.0
.. warning ::
DDP communication wrapper needs pytorch version at least 1.9.0
Example:
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as default,
powerSGD_hook as powerSGD,
)
# fp16_compress_hook for compress gradients
register_ddp_comm_hook(
model=ddp_model,
ddp_comm_hook=default.fp16_compress_hook,
)
# powerSGD_hook
register_ddp_comm_hook(
model=ddp_model,
ddp_comm_state=powerSGD.PowerSGDState(
process_group=None,
matrix_approximation_rank=1,
start_powerSGD_iter=5000,
),
ddp_comm_hook=powerSGD.powerSGD_hook,
)
# fp16_compress_wrapper combined with other communication hook
register_ddp_comm_hook(
model=ddp_model,
ddp_comm_state=powerSGD.PowerSGDState(
process_group=None,
matrix_approximation_rank=1,
start_powerSGD_iter=5000,
),
ddp_comm_hook=powerSGD.powerSGD_hook,
ddp_comm_wrapper=default.fp16_compress_wrapper,
)
"""
if not _TORCH_GREATER_EQUAL_1_8:
rank_zero_warn(
"Not registering DDP comm hook. "
"To use communication hooks, please use pytorch>=1.8.0."
)
return
if ddp_comm_hook is None:
return
if ddp_comm_wrapper is not None:
if not _TORCH_GREATER_EQUAL_1_9:
rank_zero_warn(
"Not applying DDP comm wrapper. "
"To use communication wrapper, please use pytorch>=1.9.0."
)
else:
rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
)
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(
state=ddp_comm_state,
hook=ddp_comm_hook,
)

View File

@ -69,6 +69,7 @@ _TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')

View File

@ -18,7 +18,6 @@ from tests.helpers.runif import RunIf
class CustomParallelPlugin(DDPPlugin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Set to None so it will be overwritten by the accelerator connector.

View File

@ -0,0 +1,134 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8:
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as default,
powerSGD_hook as powerSGD,
)
@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True)
def test_ddp_fp16_compress_comm_hook(tmpdir):
"""Test for DDP FP16 compress hook."""
model = BoringModel()
training_type_plugin = DDPPlugin(
ddp_comm_hook=default.fp16_compress_hook,
sync_batchnorm=True,
)
trainer = Trainer(
max_epochs=1,
gpus=2,
plugins=[training_type_plugin],
default_root_dir=tmpdir,
sync_batchnorm=True,
fast_dev_run=True,
)
trainer.fit(model)
trainer_comm_hook = (
trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
)
expected_comm_hook = default.fp16_compress_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert (
trainer.state == TrainerState.FINISHED
), f"Training failed with {trainer.state}"
@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True)
def test_ddp_sgd_comm_hook(tmpdir):
"""Test for DDP FP16 compress hook."""
model = BoringModel()
training_type_plugin = DDPPlugin(
ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
ddp_comm_hook=powerSGD.powerSGD_hook,
sync_batchnorm=True,
)
trainer = Trainer(
max_epochs=1,
gpus=2,
plugins=[training_type_plugin],
default_root_dir=tmpdir,
sync_batchnorm=True,
fast_dev_run=True,
)
trainer.fit(model)
trainer_comm_hook = (
trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
)
expected_comm_hook = powerSGD.powerSGD_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert (
trainer.state == TrainerState.FINISHED
), f"Training failed with {trainer.state}"
@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True)
def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
"""Test for DDP FP16 compress wrapper for SGD hook."""
model = BoringModel()
training_type_plugin = DDPPlugin(
ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
ddp_comm_hook=powerSGD.powerSGD_hook,
ddp_comm_wrapper=default.fp16_compress_wrapper,
sync_batchnorm=True,
)
trainer = Trainer(
max_epochs=1,
gpus=2,
plugins=[training_type_plugin],
default_root_dir=tmpdir,
sync_batchnorm=True,
fast_dev_run=True,
)
trainer.fit(model)
trainer_comm_hook = (
trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
)
expected_comm_hook = default.fp16_compress_wrapper(
powerSGD.powerSGD_hook
).__qualname__
assert trainer_comm_hook == expected_comm_hook
assert (
trainer.state == TrainerState.FINISHED
), f"Training failed with {trainer.state}"
@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True)
def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
"""Test for DDP Spawn FP16 compress hook."""
model = BoringModel()
training_type_plugin = DDPSpawnPlugin(
ddp_comm_hook=default.fp16_compress_hook,
sync_batchnorm=True,
)
trainer = Trainer(
max_epochs=1,
gpus=2,
plugins=[training_type_plugin],
default_root_dir=tmpdir,
sync_batchnorm=True,
fast_dev_run=True,
)
trainer.fit(model)
assert (
trainer.state == TrainerState.FINISHED
), f"Training failed with {trainer.state}"