Support post-localSGD in Lightning DDP plugin (#8967)

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Yi Wang 2021-08-26 00:24:49 -07:00 committed by GitHub
parent 69f66fd6bb
commit 366fb39d2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 1 deletions

View File

@ -29,17 +29,21 @@ import torch
import torch.distributed import torch.distributed
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import ( from pytorch_lightning.utilities import (
_FAIRSCALE_AVAILABLE,
_HYDRA_AVAILABLE, _HYDRA_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_9,
_TORCH_GREATER_EQUAL_1_10,
rank_zero_deprecation, rank_zero_deprecation,
rank_zero_warn, rank_zero_warn,
) )
@ -53,11 +57,19 @@ from pytorch_lightning.utilities.distributed import (
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.seed import reset_seed
if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _HYDRA_AVAILABLE: if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path from hydra.utils import get_original_cwd, to_absolute_path
if _TORCH_GREATER_EQUAL_1_8: if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
if _TORCH_GREATER_EQUAL_1_10:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
import torch.distributed.algorithms.model_averaging.averagers as averagers
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -83,6 +95,7 @@ class DDPPlugin(ParallelPlugin):
ddp_comm_state: Optional[object] = None, ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None, ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None,
model_averaging_period: Optional[int] = None,
**kwargs: Union[Any, Dict[str, Any]], **kwargs: Union[Any, Dict[str, Any]],
) -> None: ) -> None:
super().__init__( super().__init__(
@ -110,6 +123,7 @@ class DDPPlugin(ParallelPlugin):
self._ddp_comm_state = ddp_comm_state self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper self._ddp_comm_wrapper = ddp_comm_wrapper
self._model_averaging_period = model_averaging_period
self._pids: Optional[List[int]] = None self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None self._sync_dir: Optional[str] = None
self.set_world_ranks() self.set_world_ranks()
@ -302,6 +316,51 @@ class DDPPlugin(ParallelPlugin):
ddp_comm_wrapper=self._ddp_comm_wrapper, ddp_comm_wrapper=self._ddp_comm_wrapper,
) )
# Post-localSDG is only available after 1.9,
# and `torch.distributed.optim` package currently is not available on Windows.
if (
_TORCH_GREATER_EQUAL_1_10
and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState)
and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
):
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
optimizers = self.lightning_module.trainer.optimizers
if self._model_averaging_period is None:
raise ValueError(
"Post-localSGD algorithm is used, " "but model averaging period is not provided to DDP plugin."
)
averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
if (
isinstance(optimizer, DistributedOptimizer)
or isinstance(optimizer, ZeroRedundancyOptimizer)
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
):
raise ValueError(
f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
)
if isinstance(optimizer, PostLocalSGDOptimizer):
continue
optim_class = type(optimizer)
post_localSGD_optimizer = PostLocalSGDOptimizer(
params=optimizer.param_groups,
optimizer_class=optim_class,
averager=averager,
**optimizer.defaults,
)
optimizers[x] = post_localSGD_optimizer
del optimizer
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
def configure_ddp(self): def configure_ddp(self):
self.pre_configure_ddp() self.pre_configure_ddp()
self._model = DistributedDataParallel( self._model = DistributedDataParallel(

View File

@ -284,12 +284,14 @@ def register_ddp_comm_hook(
.. warning :: .. warning ::
DDP communication wrapper needs pytorch version at least 1.9.0 DDP communication wrapper needs pytorch version at least 1.9.0
Post-localSGD hook needs pytorch version at least 1.9.0
Example: Example:
from torch.distributed.algorithms.ddp_comm_hooks import ( from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as default, default_hooks as default,
powerSGD_hook as powerSGD, powerSGD_hook as powerSGD,
post_localSGD_hook as post_localSGD,
) )
# fp16_compress_hook for compress gradients # fp16_compress_hook for compress gradients
@ -309,6 +311,18 @@ def register_ddp_comm_hook(
ddp_comm_hook=powerSGD.powerSGD_hook, ddp_comm_hook=powerSGD.powerSGD_hook,
) )
# post_localSGD_hook
subgroup, _ = torch.distributed.new_subgroups()
register_comm_hook(
model=ddp_model,
state=post_localSGD.PostLocalSGDState(
process_group=None,
subgroup=subgroup,
start_localSGD_iter=1_000,
),
ddp_comm_hook=post_localSGD.post_localSGD_hook,
)
# fp16_compress_wrapper combined with other communication hook # fp16_compress_wrapper combined with other communication hook
register_ddp_comm_hook( register_ddp_comm_hook(
model=ddp_model, model=ddp_model,

View File

@ -15,13 +15,15 @@ import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_10
from tests.helpers import BoringModel from tests.helpers import BoringModel
from tests.helpers.runif import RunIf from tests.helpers.runif import RunIf
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8: if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8:
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_10:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True) @RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True)
@ -108,3 +110,32 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
) )
trainer.fit(model) trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.state.finished, f"Training failed with {trainer.state}"
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, special=True)
def test_ddp_post_local_sgd_comm_hook(tmpdir):
"""Test for DDP post-localSGD hook."""
model = BoringModel()
training_type_plugin = DDPPlugin(
ddp_comm_state=post_localSGD.PostLocalSGDState(
process_group=None,
subgroup=None,
start_localSGD_iter=8,
),
ddp_comm_hook=post_localSGD.post_localSGD_hook,
model_averaging_period=4,
sync_batchnorm=True,
)
trainer = Trainer(
fast_dev_run=True,
gpus=2,
plugins=[training_type_plugin],
default_root_dir=tmpdir,
sync_batchnorm=True,
)
trainer.fit(model)
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}"