Supporting Adding DDP Communication Hooks (#6736)
* Fix some test errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * checkpoint consolidation * Update ddp_spawn.py * Update test_metric_result_integration.py * Update test_results.py * Update utils.py * Update utils.py * Update test_all_gather_grad.py * Update test_all_gather_grad.py * Update test_results.py * Revert "Update test_results.py" This reverts commit9d4a2b891d
. * Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate" This reverts commitc5053da789
, reversing changes made to0d23d75bc9
. * Revert "Update test_all_gather_grad.py" This reverts commit0d23d75bc9
. * Revert "Update utils.py" This reverts commit70fe5da9c6
. * Revert "Update utils.py" This reverts commita9aae99f6e
. * Revert "Update test_results.py" This reverts commitea74906878
. * Revert "Update test_metric_result_integration.py" This reverts commitbf70e431b3
. * Revert "Update ddp_spawn.py" This reverts commitf17210183b
. * Revert "checkpoint consolidation" This reverts commit536c1323b0
. * Revert "Revert "checkpoint consolidation"" This reverts commit3a9fde915a
. * Revert "Revert "Revert "checkpoint consolidation""" This reverts commit7a369f47e1
. * Revert "Revert "Update ddp_spawn.py"" This reverts commit8222dc98ea
. * Revert "Revert "Update test_metric_result_integration.py"" This reverts commit6c095b2370
. * Revert "Revert "Update test_results.py"" This reverts commit250d0aaaa2
. * Revert "Revert "Update utils.py"" This reverts commit8651d54d79
. * Revert "Revert "Update test_all_gather_grad.py"" This reverts commitdcdcd29731
. * modify distributed environment to make test pass * add DDP communication hook * remove test related setting * remove more test related setting * fix ddp comm hook util import issue * comments * one more fix for test_custom_plugin * fix ddp spwan * fix sgd * address comments and add tests * 1. add is gpu checking 2. modify test a bit 3. formatting * formatting nit * fix conda 3.7 1.7 issue for no torch.distributed.algorithms module * need at least 1.8.0 * minor fix * modify changelog * changelog should link to PR number instead of issue number * refine a bit on doc for register_ddp_comm_hook function, like ddp_comm_wrapper explanation and add hyperparameter for power sgd states in example usge * move single device checking before call register_ddp_comm_hook * formatting * comments * typo * pre-commit formatting
This commit is contained in:
parent
86e1d9f759
commit
313e81638d
|
@ -81,6 +81,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))
|
||||
|
||||
- Added support for DDP communication hooks ([#6736](https://github.com/PyTorchLightning/pytorch-lightning/issues/6736))
|
||||
|
||||
- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677))
|
||||
|
||||
|
|
|
@ -29,7 +29,12 @@ from pytorch_lightning.overrides import LightningDistributedModule
|
|||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
|
||||
from pytorch_lightning.utilities import (
|
||||
_HYDRA_AVAILABLE,
|
||||
_TORCH_GREATER_EQUAL_1_7,
|
||||
_TORCH_GREATER_EQUAL_1_8,
|
||||
rank_zero_warn,
|
||||
)
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
@ -37,6 +42,8 @@ from pytorch_lightning.utilities.seed import seed_everything
|
|||
if _HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -58,6 +65,9 @@ class DDPPlugin(ParallelPlugin):
|
|||
num_nodes: int = 1,
|
||||
cluster_environment: ClusterEnvironment = None,
|
||||
sync_batchnorm: bool = False,
|
||||
ddp_comm_state: Optional[object] = None,
|
||||
ddp_comm_hook: Optional[callable] = None,
|
||||
ddp_comm_wrapper: Optional[callable] = None,
|
||||
**kwargs: Union[Any, Dict[str, Any]],
|
||||
) -> None:
|
||||
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
|
||||
|
@ -70,6 +80,9 @@ class DDPPlugin(ParallelPlugin):
|
|||
self.task_idx = None
|
||||
self.node_rank = 0
|
||||
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
|
||||
self._ddp_comm_state = ddp_comm_state
|
||||
self._ddp_comm_hook = ddp_comm_hook
|
||||
self._ddp_comm_wrapper = ddp_comm_wrapper
|
||||
|
||||
@property
|
||||
def root_device(self):
|
||||
|
@ -80,6 +93,10 @@ class DDPPlugin(ParallelPlugin):
|
|||
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
|
||||
return distributed_sampler_kwargs
|
||||
|
||||
@property
|
||||
def _is_single_process_single_device(self) -> bool:
|
||||
return True
|
||||
|
||||
def setup_environment(self):
|
||||
# start the other scripts
|
||||
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
|
||||
|
@ -218,6 +235,17 @@ class DDPPlugin(ParallelPlugin):
|
|||
)
|
||||
self._ddp_kwargs["find_unused_parameters"] = True
|
||||
|
||||
def _register_ddp_hooks(self) -> None:
|
||||
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
|
||||
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
|
||||
if (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device):
|
||||
register_ddp_comm_hook(
|
||||
model=self._model,
|
||||
ddp_comm_state=self._ddp_comm_state,
|
||||
ddp_comm_hook=self._ddp_comm_hook,
|
||||
ddp_comm_wrapper=self._ddp_comm_wrapper,
|
||||
)
|
||||
|
||||
def configure_ddp(self):
|
||||
self.pre_configure_ddp()
|
||||
self._model = DistributedDataParallel(
|
||||
|
@ -225,6 +253,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
device_ids=self.determine_ddp_device_ids(),
|
||||
**self._ddp_kwargs,
|
||||
)
|
||||
self._register_ddp_hooks()
|
||||
|
||||
def determine_ddp_device_ids(self):
|
||||
if self.root_device.type == "cpu":
|
||||
|
|
|
@ -59,6 +59,10 @@ class DDP2Plugin(DDPPlugin):
|
|||
distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank)
|
||||
return distributed_sampler_kwargs
|
||||
|
||||
@property
|
||||
def _is_single_process_single_device(self) -> bool:
|
||||
return False
|
||||
|
||||
def set_world_ranks(self):
|
||||
self.local_rank = self.task_idx
|
||||
self.node_rank = self.cluster_environment.node_rank()
|
||||
|
|
|
@ -28,12 +28,15 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward
|
|||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -47,6 +50,9 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
num_nodes: int = 1,
|
||||
cluster_environment: ClusterEnvironment = None,
|
||||
sync_batchnorm: bool = False,
|
||||
ddp_comm_state: Optional[object] = None,
|
||||
ddp_comm_hook: Optional[callable] = None,
|
||||
ddp_comm_wrapper: Optional[callable] = None,
|
||||
**kwargs: Union[Any, Dict[str, Any]],
|
||||
):
|
||||
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
|
||||
|
@ -54,9 +60,12 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.sync_batchnorm = sync_batchnorm
|
||||
self._ddp_kwargs = kwargs
|
||||
self.dist = LightningDistributed()
|
||||
self.num_processes = len(parallel_devices)
|
||||
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
|
||||
self.node_rank = 0
|
||||
self.mp_queue = None
|
||||
self._ddp_comm_state = ddp_comm_state
|
||||
self._ddp_comm_hook = ddp_comm_hook
|
||||
self._ddp_comm_wrapper = ddp_comm_wrapper
|
||||
|
||||
def __getstate__(self):
|
||||
""" Makes this plugin pickleable without destroying the queue in the current process. """
|
||||
|
@ -76,9 +85,12 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
|
||||
return distributed_sampler_kwargs
|
||||
|
||||
@property
|
||||
def _is_single_process_single_device(self):
|
||||
return True
|
||||
|
||||
def setup(self, model):
|
||||
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
|
||||
|
||||
# pass in a state q
|
||||
smp = mp.get_context("spawn")
|
||||
self.mp_queue = smp.SimpleQueue()
|
||||
|
@ -181,6 +193,17 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
)
|
||||
self._ddp_kwargs["find_unused_parameters"] = True
|
||||
|
||||
def _register_ddp_hooks(self) -> None:
|
||||
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
|
||||
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
|
||||
if (_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device):
|
||||
register_ddp_comm_hook(
|
||||
model=self._model,
|
||||
ddp_comm_state=self._ddp_comm_state,
|
||||
ddp_comm_hook=self._ddp_comm_hook,
|
||||
ddp_comm_wrapper=self._ddp_comm_wrapper,
|
||||
)
|
||||
|
||||
def configure_ddp(self):
|
||||
self.pre_configure_ddp()
|
||||
self._model = DistributedDataParallel(
|
||||
|
@ -188,6 +211,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
device_ids=self.determine_ddp_device_ids(),
|
||||
**self._ddp_kwargs,
|
||||
)
|
||||
self._register_ddp_hooks()
|
||||
|
||||
def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
|
||||
# TODO: this code is duplicated in DDP and DDPSpawn, make this a function
|
||||
|
|
|
@ -47,6 +47,8 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
|
|||
_RPC_AVAILABLE,
|
||||
_TORCH_GREATER_EQUAL_1_6,
|
||||
_TORCH_GREATER_EQUAL_1_7,
|
||||
_TORCH_GREATER_EQUAL_1_8,
|
||||
_TORCH_GREATER_EQUAL_1_9,
|
||||
_TORCH_LOWER_EQUAL_1_4,
|
||||
_TORCH_QUANTIZE_AVAILABLE,
|
||||
_TORCHTEXT_AVAILABLE,
|
||||
|
|
|
@ -17,9 +17,15 @@ import os
|
|||
import warnings
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Optional, Union
|
||||
from pytorch_lightning.utilities.imports import (
|
||||
_TORCH_GREATER_EQUAL_1_8,
|
||||
_TORCH_GREATER_EQUAL_1_9,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
if torch.distributed.is_available():
|
||||
|
@ -208,3 +214,107 @@ def all_gather_ddp_if_available(
|
|||
with torch.no_grad():
|
||||
return AllGatherGrad.apply(tensor, group)
|
||||
return tensor
|
||||
|
||||
|
||||
def register_ddp_comm_hook(
|
||||
model: DistributedDataParallel,
|
||||
ddp_comm_state: Optional[object] = None,
|
||||
ddp_comm_hook: Optional[callable] = None,
|
||||
ddp_comm_wrapper: Optional[callable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Function to register communication hook for DDP model
|
||||
https://pytorch.org/docs/master/ddp_comm_hooks.html
|
||||
|
||||
Args:
|
||||
model:
|
||||
DDP model
|
||||
ddp_comm_state:
|
||||
state is passed to the hook and can be used to maintain
|
||||
and update any state information that users would like to
|
||||
maintain as part of the training process. Examples: error
|
||||
feedback in gradient compression, peers to communicate with
|
||||
next in GossipGrad etc.
|
||||
ddp_comm_hook:
|
||||
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
|
||||
|
||||
This callable function is called once the bucket is ready. The
|
||||
hook can perform whatever processing is needed and return
|
||||
a Future indicating completion of any async work (ex: allreduce).
|
||||
If the hook doesn't perform any communication, it can also
|
||||
just return a completed Future. The Future should hold the
|
||||
new value of grad bucket's tensors. Once a bucket is ready,
|
||||
c10d reducer would call this hook and use the tensors returned
|
||||
by the Future and copy grads to individual parameters.
|
||||
ddp_comm_wrapper:
|
||||
communication hook wraper to support a communication hook such
|
||||
as FP16 compression as wrapper, which could be combined with
|
||||
ddp_comm_hook
|
||||
|
||||
.. warning ::
|
||||
DDP communication hook needs pytorch version at least 1.8.0
|
||||
|
||||
.. warning ::
|
||||
DDP communication wrapper needs pytorch version at least 1.9.0
|
||||
|
||||
Example:
|
||||
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
default_hooks as default,
|
||||
powerSGD_hook as powerSGD,
|
||||
)
|
||||
|
||||
# fp16_compress_hook for compress gradients
|
||||
register_ddp_comm_hook(
|
||||
model=ddp_model,
|
||||
ddp_comm_hook=default.fp16_compress_hook,
|
||||
)
|
||||
|
||||
# powerSGD_hook
|
||||
register_ddp_comm_hook(
|
||||
model=ddp_model,
|
||||
ddp_comm_state=powerSGD.PowerSGDState(
|
||||
process_group=None,
|
||||
matrix_approximation_rank=1,
|
||||
start_powerSGD_iter=5000,
|
||||
),
|
||||
ddp_comm_hook=powerSGD.powerSGD_hook,
|
||||
)
|
||||
|
||||
# fp16_compress_wrapper combined with other communication hook
|
||||
register_ddp_comm_hook(
|
||||
model=ddp_model,
|
||||
ddp_comm_state=powerSGD.PowerSGDState(
|
||||
process_group=None,
|
||||
matrix_approximation_rank=1,
|
||||
start_powerSGD_iter=5000,
|
||||
),
|
||||
ddp_comm_hook=powerSGD.powerSGD_hook,
|
||||
ddp_comm_wrapper=default.fp16_compress_wrapper,
|
||||
)
|
||||
"""
|
||||
if not _TORCH_GREATER_EQUAL_1_8:
|
||||
rank_zero_warn(
|
||||
"Not registering DDP comm hook. "
|
||||
"To use communication hooks, please use pytorch>=1.8.0."
|
||||
)
|
||||
return
|
||||
if ddp_comm_hook is None:
|
||||
return
|
||||
if ddp_comm_wrapper is not None:
|
||||
if not _TORCH_GREATER_EQUAL_1_9:
|
||||
rank_zero_warn(
|
||||
"Not applying DDP comm wrapper. "
|
||||
"To use communication wrapper, please use pytorch>=1.9.0."
|
||||
)
|
||||
else:
|
||||
rank_zero_info(
|
||||
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
|
||||
)
|
||||
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
|
||||
|
||||
rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
|
||||
model.register_comm_hook(
|
||||
state=ddp_comm_state,
|
||||
hook=ddp_comm_hook,
|
||||
)
|
||||
|
|
|
@ -69,6 +69,7 @@ _TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
|
|||
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
|
||||
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
|
||||
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
|
||||
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
|
||||
|
||||
_APEX_AVAILABLE = _module_available("apex.amp")
|
||||
_BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||
|
|
|
@ -18,7 +18,6 @@ from tests.helpers.runif import RunIf
|
|||
|
||||
|
||||
class CustomParallelPlugin(DDPPlugin):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Set to None so it will be overwritten by the accelerator connector.
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8:
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
default_hooks as default,
|
||||
powerSGD_hook as powerSGD,
|
||||
)
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True)
|
||||
def test_ddp_fp16_compress_comm_hook(tmpdir):
|
||||
"""Test for DDP FP16 compress hook."""
|
||||
model = BoringModel()
|
||||
training_type_plugin = DDPPlugin(
|
||||
ddp_comm_hook=default.fp16_compress_hook,
|
||||
sync_batchnorm=True,
|
||||
)
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
gpus=2,
|
||||
plugins=[training_type_plugin],
|
||||
default_root_dir=tmpdir,
|
||||
sync_batchnorm=True,
|
||||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer_comm_hook = (
|
||||
trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
|
||||
)
|
||||
expected_comm_hook = default.fp16_compress_hook.__qualname__
|
||||
assert trainer_comm_hook == expected_comm_hook
|
||||
assert (
|
||||
trainer.state == TrainerState.FINISHED
|
||||
), f"Training failed with {trainer.state}"
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True)
|
||||
def test_ddp_sgd_comm_hook(tmpdir):
|
||||
"""Test for DDP FP16 compress hook."""
|
||||
model = BoringModel()
|
||||
training_type_plugin = DDPPlugin(
|
||||
ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
|
||||
ddp_comm_hook=powerSGD.powerSGD_hook,
|
||||
sync_batchnorm=True,
|
||||
)
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
gpus=2,
|
||||
plugins=[training_type_plugin],
|
||||
default_root_dir=tmpdir,
|
||||
sync_batchnorm=True,
|
||||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer_comm_hook = (
|
||||
trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
|
||||
)
|
||||
expected_comm_hook = powerSGD.powerSGD_hook.__qualname__
|
||||
assert trainer_comm_hook == expected_comm_hook
|
||||
assert (
|
||||
trainer.state == TrainerState.FINISHED
|
||||
), f"Training failed with {trainer.state}"
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True)
|
||||
def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
|
||||
"""Test for DDP FP16 compress wrapper for SGD hook."""
|
||||
model = BoringModel()
|
||||
training_type_plugin = DDPPlugin(
|
||||
ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
|
||||
ddp_comm_hook=powerSGD.powerSGD_hook,
|
||||
ddp_comm_wrapper=default.fp16_compress_wrapper,
|
||||
sync_batchnorm=True,
|
||||
)
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
gpus=2,
|
||||
plugins=[training_type_plugin],
|
||||
default_root_dir=tmpdir,
|
||||
sync_batchnorm=True,
|
||||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer_comm_hook = (
|
||||
trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
|
||||
)
|
||||
expected_comm_hook = default.fp16_compress_wrapper(
|
||||
powerSGD.powerSGD_hook
|
||||
).__qualname__
|
||||
assert trainer_comm_hook == expected_comm_hook
|
||||
assert (
|
||||
trainer.state == TrainerState.FINISHED
|
||||
), f"Training failed with {trainer.state}"
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, min_torch="1.8.0", min_gpus=2, special=True)
|
||||
def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
|
||||
"""Test for DDP Spawn FP16 compress hook."""
|
||||
model = BoringModel()
|
||||
training_type_plugin = DDPSpawnPlugin(
|
||||
ddp_comm_hook=default.fp16_compress_hook,
|
||||
sync_batchnorm=True,
|
||||
)
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
gpus=2,
|
||||
plugins=[training_type_plugin],
|
||||
default_root_dir=tmpdir,
|
||||
sync_batchnorm=True,
|
||||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert (
|
||||
trainer.state == TrainerState.FINISHED
|
||||
), f"Training failed with {trainer.state}"
|
Loading…
Reference in New Issue