Support post-localSGD in Lightning DDP plugin (#8967)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
69f66fd6bb
commit
366fb39d2e
|
@ -29,17 +29,21 @@ import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
|
||||||
|
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||||
from pytorch_lightning.distributed import LightningDistributed
|
from pytorch_lightning.distributed import LightningDistributed
|
||||||
from pytorch_lightning.overrides import LightningDistributedModule
|
from pytorch_lightning.overrides import LightningDistributedModule
|
||||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||||
|
from pytorch_lightning.trainer.states import TrainerFn
|
||||||
from pytorch_lightning.utilities import (
|
from pytorch_lightning.utilities import (
|
||||||
|
_FAIRSCALE_AVAILABLE,
|
||||||
_HYDRA_AVAILABLE,
|
_HYDRA_AVAILABLE,
|
||||||
_TORCH_GREATER_EQUAL_1_7,
|
_TORCH_GREATER_EQUAL_1_7,
|
||||||
_TORCH_GREATER_EQUAL_1_8,
|
_TORCH_GREATER_EQUAL_1_8,
|
||||||
_TORCH_GREATER_EQUAL_1_9,
|
_TORCH_GREATER_EQUAL_1_9,
|
||||||
|
_TORCH_GREATER_EQUAL_1_10,
|
||||||
rank_zero_deprecation,
|
rank_zero_deprecation,
|
||||||
rank_zero_warn,
|
rank_zero_warn,
|
||||||
)
|
)
|
||||||
|
@ -53,11 +57,19 @@ from pytorch_lightning.utilities.distributed import (
|
||||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
||||||
from pytorch_lightning.utilities.seed import reset_seed
|
from pytorch_lightning.utilities.seed import reset_seed
|
||||||
|
|
||||||
|
if _TORCH_GREATER_EQUAL_1_10:
|
||||||
|
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
|
||||||
|
|
||||||
|
if _FAIRSCALE_AVAILABLE:
|
||||||
|
from fairscale.optim import OSS
|
||||||
if _HYDRA_AVAILABLE:
|
if _HYDRA_AVAILABLE:
|
||||||
from hydra.core.hydra_config import HydraConfig
|
from hydra.core.hydra_config import HydraConfig
|
||||||
from hydra.utils import get_original_cwd, to_absolute_path
|
from hydra.utils import get_original_cwd, to_absolute_path
|
||||||
if _TORCH_GREATER_EQUAL_1_8:
|
if _TORCH_GREATER_EQUAL_1_8:
|
||||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||||
|
if _TORCH_GREATER_EQUAL_1_10:
|
||||||
|
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
||||||
|
import torch.distributed.algorithms.model_averaging.averagers as averagers
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -83,6 +95,7 @@ class DDPPlugin(ParallelPlugin):
|
||||||
ddp_comm_state: Optional[object] = None,
|
ddp_comm_state: Optional[object] = None,
|
||||||
ddp_comm_hook: Optional[callable] = None,
|
ddp_comm_hook: Optional[callable] = None,
|
||||||
ddp_comm_wrapper: Optional[callable] = None,
|
ddp_comm_wrapper: Optional[callable] = None,
|
||||||
|
model_averaging_period: Optional[int] = None,
|
||||||
**kwargs: Union[Any, Dict[str, Any]],
|
**kwargs: Union[Any, Dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -110,6 +123,7 @@ class DDPPlugin(ParallelPlugin):
|
||||||
self._ddp_comm_state = ddp_comm_state
|
self._ddp_comm_state = ddp_comm_state
|
||||||
self._ddp_comm_hook = ddp_comm_hook
|
self._ddp_comm_hook = ddp_comm_hook
|
||||||
self._ddp_comm_wrapper = ddp_comm_wrapper
|
self._ddp_comm_wrapper = ddp_comm_wrapper
|
||||||
|
self._model_averaging_period = model_averaging_period
|
||||||
self._pids: Optional[List[int]] = None
|
self._pids: Optional[List[int]] = None
|
||||||
self._sync_dir: Optional[str] = None
|
self._sync_dir: Optional[str] = None
|
||||||
self.set_world_ranks()
|
self.set_world_ranks()
|
||||||
|
@ -302,6 +316,51 @@ class DDPPlugin(ParallelPlugin):
|
||||||
ddp_comm_wrapper=self._ddp_comm_wrapper,
|
ddp_comm_wrapper=self._ddp_comm_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Post-localSDG is only available after 1.9,
|
||||||
|
# and `torch.distributed.optim` package currently is not available on Windows.
|
||||||
|
if (
|
||||||
|
_TORCH_GREATER_EQUAL_1_10
|
||||||
|
and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState)
|
||||||
|
and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
|
||||||
|
):
|
||||||
|
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
|
||||||
|
|
||||||
|
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
|
||||||
|
optimizers = self.lightning_module.trainer.optimizers
|
||||||
|
if self._model_averaging_period is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Post-localSGD algorithm is used, " "but model averaging period is not provided to DDP plugin."
|
||||||
|
)
|
||||||
|
averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
|
||||||
|
for x, optimizer in enumerate(optimizers):
|
||||||
|
if isinstance(optimizer, LightningOptimizer):
|
||||||
|
optimizer = optimizer._optimizer
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(optimizer, DistributedOptimizer)
|
||||||
|
or isinstance(optimizer, ZeroRedundancyOptimizer)
|
||||||
|
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(optimizer, PostLocalSGDOptimizer):
|
||||||
|
continue
|
||||||
|
|
||||||
|
optim_class = type(optimizer)
|
||||||
|
post_localSGD_optimizer = PostLocalSGDOptimizer(
|
||||||
|
params=optimizer.param_groups,
|
||||||
|
optimizer_class=optim_class,
|
||||||
|
averager=averager,
|
||||||
|
**optimizer.defaults,
|
||||||
|
)
|
||||||
|
optimizers[x] = post_localSGD_optimizer
|
||||||
|
del optimizer
|
||||||
|
trainer = self.lightning_module.trainer
|
||||||
|
trainer.optimizers = optimizers
|
||||||
|
trainer.convert_to_lightning_optimizers()
|
||||||
|
|
||||||
def configure_ddp(self):
|
def configure_ddp(self):
|
||||||
self.pre_configure_ddp()
|
self.pre_configure_ddp()
|
||||||
self._model = DistributedDataParallel(
|
self._model = DistributedDataParallel(
|
||||||
|
|
|
@ -284,12 +284,14 @@ def register_ddp_comm_hook(
|
||||||
|
|
||||||
.. warning ::
|
.. warning ::
|
||||||
DDP communication wrapper needs pytorch version at least 1.9.0
|
DDP communication wrapper needs pytorch version at least 1.9.0
|
||||||
|
Post-localSGD hook needs pytorch version at least 1.9.0
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
from torch.distributed.algorithms.ddp_comm_hooks import (
|
from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||||
default_hooks as default,
|
default_hooks as default,
|
||||||
powerSGD_hook as powerSGD,
|
powerSGD_hook as powerSGD,
|
||||||
|
post_localSGD_hook as post_localSGD,
|
||||||
)
|
)
|
||||||
|
|
||||||
# fp16_compress_hook for compress gradients
|
# fp16_compress_hook for compress gradients
|
||||||
|
@ -309,6 +311,18 @@ def register_ddp_comm_hook(
|
||||||
ddp_comm_hook=powerSGD.powerSGD_hook,
|
ddp_comm_hook=powerSGD.powerSGD_hook,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# post_localSGD_hook
|
||||||
|
subgroup, _ = torch.distributed.new_subgroups()
|
||||||
|
register_comm_hook(
|
||||||
|
model=ddp_model,
|
||||||
|
state=post_localSGD.PostLocalSGDState(
|
||||||
|
process_group=None,
|
||||||
|
subgroup=subgroup,
|
||||||
|
start_localSGD_iter=1_000,
|
||||||
|
),
|
||||||
|
ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||||
|
)
|
||||||
|
|
||||||
# fp16_compress_wrapper combined with other communication hook
|
# fp16_compress_wrapper combined with other communication hook
|
||||||
register_ddp_comm_hook(
|
register_ddp_comm_hook(
|
||||||
model=ddp_model,
|
model=ddp_model,
|
||||||
|
|
|
@ -15,13 +15,15 @@ import torch
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
|
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
|
||||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
|
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_10
|
||||||
from tests.helpers import BoringModel
|
from tests.helpers import BoringModel
|
||||||
from tests.helpers.runif import RunIf
|
from tests.helpers.runif import RunIf
|
||||||
|
|
||||||
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8:
|
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_8:
|
||||||
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
|
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
|
||||||
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
|
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
|
||||||
|
if torch.distributed.is_available() and _TORCH_GREATER_EQUAL_1_10:
|
||||||
|
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
||||||
|
|
||||||
|
|
||||||
@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True)
|
@RunIf(skip_windows=True, min_torch="1.9.0", min_gpus=2, special=True)
|
||||||
|
@ -108,3 +110,32 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
|
||||||
)
|
)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||||
|
|
||||||
|
|
||||||
|
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, special=True)
|
||||||
|
def test_ddp_post_local_sgd_comm_hook(tmpdir):
|
||||||
|
"""Test for DDP post-localSGD hook."""
|
||||||
|
model = BoringModel()
|
||||||
|
|
||||||
|
training_type_plugin = DDPPlugin(
|
||||||
|
ddp_comm_state=post_localSGD.PostLocalSGDState(
|
||||||
|
process_group=None,
|
||||||
|
subgroup=None,
|
||||||
|
start_localSGD_iter=8,
|
||||||
|
),
|
||||||
|
ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||||
|
model_averaging_period=4,
|
||||||
|
sync_batchnorm=True,
|
||||||
|
)
|
||||||
|
trainer = Trainer(
|
||||||
|
fast_dev_run=True,
|
||||||
|
gpus=2,
|
||||||
|
plugins=[training_type_plugin],
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
sync_batchnorm=True,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
|
||||||
|
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
|
||||||
|
assert trainer_comm_hook == expected_comm_hook
|
||||||
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||||
|
|
Loading…
Reference in New Issue