Support post-localSGD in Lightning DDP plugin (#8967)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
69f66fd6bb
commit
366fb39d2e
|
@ -29,17 +29,21 @@ import torch
|
|||
import torch.distributed
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.distributed import LightningDistributed
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import (
|
||||
_FAIRSCALE_AVAILABLE,
|
||||
_HYDRA_AVAILABLE,
|
||||
_TORCH_GREATER_EQUAL_1_7,
|
||||
_TORCH_GREATER_EQUAL_1_8,
|
||||
_TORCH_GREATER_EQUAL_1_9,
|
||||
_TORCH_GREATER_EQUAL_1_10,
|
||||
rank_zero_deprecation,
|
||||
rank_zero_warn,
|
||||
)
|
||||
|
@ -53,11 +57,19 @@ from pytorch_lightning.utilities.distributed import (
|
|||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.optim import OSS
|
||||
if _HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
||||
import torch.distributed.algorithms.model_averaging.averagers as averagers
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -83,6 +95,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
ddp_comm_state: Optional[object] = None,
|
||||
ddp_comm_hook: Optional[callable] = None,
|
||||
ddp_comm_wrapper: Optional[callable] = None,
|
||||
model_averaging_period: Optional[int] = None,
|
||||
**kwargs: Union[Any, Dict[str, Any]],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
|
@ -110,6 +123,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
self._ddp_comm_state = ddp_comm_state
|
||||
self._ddp_comm_hook = ddp_comm_hook
|
||||
self._ddp_comm_wrapper = ddp_comm_wrapper
|
||||
self._model_averaging_period = model_averaging_period
|
||||
self._pids: Optional[List[int]] = None
|
||||
self._sync_dir: Optional[str] = None
|
||||
self.set_world_ranks()
|
||||
|
@ -302,6 +316,51 @@ class DDPPlugin(ParallelPlugin):
|
|||
ddp_comm_wrapper=self._ddp_comm_wrapper,
|
||||
)
|
||||
|
||||
# Post-localSDG is only available after 1.9,
|
||||
# and `torch.distributed.optim` package currently is not available on Windows.
|
||||
if (
|
||||
_TORCH_GREATER_EQUAL_1_10
|
||||
and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState)
|
||||
and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
|
||||
):
|
||||
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
|
||||
|
||||
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
|
||||
optimizers = self.lightning_module.trainer.optimizers
|
||||
if self._model_averaging_period is None:
|
||||
raise ValueError(
|
||||
"Post-localSGD algorithm is used, " "but model averaging period is not provided to DDP plugin."
|
||||
)
|
||||
averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
|
||||
for x, optimizer in enumerate(optimizers):
|
||||
if isinstance(optimizer, LightningOptimizer):
|
||||
optimizer = optimizer._optimizer
|
||||
|
||||
if (
|
||||
isinstance(optimizer, DistributedOptimizer)
|
||||
or isinstance(optimizer, ZeroRedundancyOptimizer)
|
||||
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
|
||||
)
|
||||
|
||||
if isinstance(optimizer, PostLocalSGDOptimizer):
|
||||
continue
|
||||
|
||||
optim_class = type(optimizer)
|
||||
post_localSGD_optimizer = PostLocalSGDOptimizer(
|
||||
params=optimizer.param_groups,
|
||||
optimizer_class=optim_class,
|
||||
averager=averager,
|
||||
**optimizer.defaults,
|
||||
)
|
||||
optimizers[x] = post_localSGD_optimizer
|
||||
del optimizer
|
||||
trainer = self.lightning_module.trainer
|
||||
trainer.optimizers = optimizers
|
||||
trainer.convert_to_lightning_optimizers()
|
||||
|
||||
def configure_ddp(self):
|
||||
self.pre_configure_ddp()
|
||||
self._model = DistributedDataParallel(
|
||||
|
|
|
@ -284,12 +284,14 @@ def register_ddp_comm_hook(
|
|||
|
||||
.. warning ::
|
||||
DDP communication wrapper needs pytorch version at least 1.9.0
|
||||
Post-localSGD hook needs pytorch version at least 1.9.0
|
||||
|
||||
Example:
|
||||
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
default_hooks as default,
|
||||
powerSGD_hook as powerSGD,
|
||||
post_localSGD_hook as post_localSGD,
|
||||
)
|
||||
|
||||
# fp16_compress_hook for compress gradients
|
||||
|
@ -309,6 +311,18 @@ def register_ddp_comm_hook(
|
|||
ddp_comm_hook=powerSGD.powerSGD_hook,
|
||||
)
|
||||
|
||||
# post_localSGD_hook
|
||||
subgroup, _ = torch.distributed.new_subgroups()
|
||||
register_comm_hook(
|
||||
model=ddp_model,
|
||||
state=post_localSGD.PostLocalSGDState(
|
||||
process_group=None,
|
||||
subgroup=subgroup,
|
||||
start_localSGD_iter=1_000,
|
||||
),
|
||||
ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||
)
|
||||
|
||||
# fp16_compress_wrapper combined with other communication hook
|
||||
register_ddp_comm_hook(
|
||||
model=ddp_model,
|
||||
|
|
|
@ -15,13 +15,15 @@ import torch
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_10
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8:
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
|
||||
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_10:
|
||||
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True)
|
||||
|
@ -108,3 +110,32 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, special=True)
|
||||
def test_ddp_post_local_sgd_comm_hook(tmpdir):
|
||||
"""Test for DDP post-localSGD hook."""
|
||||
model = BoringModel()
|
||||
|
||||
training_type_plugin = DDPPlugin(
|
||||
ddp_comm_state=post_localSGD.PostLocalSGDState(
|
||||
process_group=None,
|
||||
subgroup=None,
|
||||
start_localSGD_iter=8,
|
||||
),
|
||||
ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||
model_averaging_period=4,
|
||||
sync_batchnorm=True,
|
||||
)
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
gpus=2,
|
||||
plugins=[training_type_plugin],
|
||||
default_root_dir=tmpdir,
|
||||
sync_batchnorm=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
|
||||
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
|
||||
assert trainer_comm_hook == expected_comm_hook
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
|
Loading…
Reference in New Issue