Replace PostLocalSGDOptimizer with a dedicated model averaging component (#12378)

This commit is contained in:
Danielle Pintz 2022-03-24 20:33:19 -04:00 committed by GitHub
parent ec7fa1e2d8
commit 6329be60be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 157 additions and 29 deletions

View File

@ -402,6 +402,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `Timer.on_save_checkpoint` and `Timer.on_load_checkpoint` in favor of `Timer.state_dict` and `Timer.load_state_dict` ([#11887](https://github.com/PyTorchLightning/pytorch-lightning/pull/11887))
- Replaced PostLocalSGDOptimizer with a dedicated model averaging component ([#12378](https://github.com/PyTorchLightning/pytorch-lightning/pull/12378))
### Deprecated
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))

View File

@ -749,10 +749,7 @@ Enable `FP16 Compress Hook for multi-node throughput improvement <https://pytorc
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as default,
powerSGD_hook as powerSGD,
)
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(ddp_comm_hook=default.fp16_compress_hook))
@ -816,6 +813,33 @@ Combine hooks for accumulated benefit:
)
trainer.fit(model)
When using Post-localSGD, you must also pass ``model_averaging_period`` to allow for model parameter averaging:
.. note::
Post-localSGD support requires PyTorch>=1.10.0
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import post_localSGD_hook as post_localSGD
model = MyModel()
trainer = Trainer(
gpus=4,
strategy=DDPStrategy(
ddp_comm_state=post_localSGD.PostLocalSGDState(
process_group=None,
subgroup=None,
start_localSGD_iter=8,
),
ddp_comm_hook=post_localSGD.post_localSGD_hook,
model_averaging_period=4,
),
)
trainer.fit(model)
DDP Static Graph
""""""""""""""""

View File

@ -18,12 +18,13 @@ import signal
import tempfile
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.distributed
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim.optimizer import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
@ -59,6 +60,8 @@ if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.algorithms.model_averaging.averagers import ModelAverager
log = logging.getLogger(__name__)
@ -97,6 +100,7 @@ class DDPStrategy(ParallelStrategy):
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
self._model_averaging_period = model_averaging_period
self._model_averager: Optional[ModelAverager] = None
self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None
self._rank_0_will_call_children_scripts: bool = False
@ -223,23 +227,18 @@ class DDPStrategy(ParallelStrategy):
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
self._enable_model_averaging()
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
def _enable_model_averaging(self) -> None:
# Only called when PyTorch version >= 1.10
log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")
optimizers = self.optimizers
if self._model_averaging_period is None:
raise ValueError(
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy."
)
if _TORCH_GREATER_EQUAL_1_10:
if not _IS_WINDOWS:
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.algorithms.model_averaging.averagers as averagers
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
for x, optimizer in enumerate(optimizers):
for optimizer in self.optimizers:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
@ -248,24 +247,46 @@ class DDPStrategy(ParallelStrategy):
is_distributed_optimizer
or isinstance(optimizer, ZeroRedundancyOptimizer)
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
or isinstance(optimizer, PostLocalSGDOptimizer)
):
raise ValueError(
f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
f"Currently model averaging cannot work with a distributed optimizer of type "
f"{optimizer.__class__.__name__}."
)
if isinstance(optimizer, PostLocalSGDOptimizer):
continue
self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager(
period=self._model_averaging_period, warmup_steps=self._ddp_comm_state.start_localSGD_iter
)
optim_class = type(optimizer)
post_localSGD_optimizer = PostLocalSGDOptimizer(
params=optimizer.param_groups,
optimizer_class=optim_class,
averager=averager,
**optimizer.defaults,
)
optimizers[x] = post_localSGD_optimizer
del optimizer
self.optimizers = optimizers
def optimizer_step(
self,
optimizer: Optimizer,
opt_idx: int,
closure: Callable[[], Any],
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any,
) -> Any:
"""Performs the actual optimizer step.
Args:
optimizer: the optimizer performing the step
opt_idx: index of the current optimizer
closure: closure calculating the loss value
model: reference to the model, optionally defining optimizer step related hooks
**kwargs: Any extra arguments to ``optimizer.step``
"""
optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
if not _TORCH_GREATER_EQUAL_1_10 or self._model_averager is None:
return optimizer_output
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
self._model_averager.average_parameters(iter(param))
return optimizer_output
def configure_ddp(self) -> None:
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")

View File

@ -181,7 +181,7 @@ class Strategy(ABC):
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any,
) -> Any:
"""performs the actual optimizer step.
"""Performs the actual optimizer step.
Args:
optimizer: the optimizer performing the step

View File

@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
import pytest
import torch
from pytorch_lightning import Trainer
@ -136,3 +139,80 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir):
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}"
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True)
@mock.patch("torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager.average_parameters")
def test_post_local_sgd_model_averaging(average_parameters_mock, tmpdir):
"""Test that when using DDP with post-localSGD, model averaging is called."""
model = BoringModel()
# test regular ddp does not call model averaging
trainer = Trainer(
fast_dev_run=True,
gpus=2,
strategy="ddp",
default_root_dir=tmpdir,
sync_batchnorm=True,
)
trainer.fit(model)
average_parameters_mock.assert_not_called()
# test ddp with post-localSGD does call model averaging
ddp_strategy = DDPStrategy(
ddp_comm_state=post_localSGD.PostLocalSGDState(
process_group=None,
subgroup=None,
start_localSGD_iter=8,
),
ddp_comm_hook=post_localSGD.post_localSGD_hook,
model_averaging_period=4,
)
trainer = Trainer(
fast_dev_run=True,
gpus=2,
strategy=ddp_strategy,
default_root_dir=tmpdir,
sync_batchnorm=True,
)
trainer.fit(model)
average_parameters_mock.assert_called()
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True)
@mock.patch("torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager.average_parameters")
def test_post_local_sgd_model_averaging_value_error(average_parameters_mock, tmpdir):
"""Test that when using DDP with post-localSGD a ValueError is thrown when the optmizer is
ZeroRedundancyOptimizer."""
from torch.distributed.optim import ZeroRedundancyOptimizer
class OptimizerModel(BoringModel):
def configure_optimizers(self):
return ZeroRedundancyOptimizer(params=self.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
model = OptimizerModel()
strategy = DDPStrategy(
ddp_comm_state=post_localSGD.PostLocalSGDState(
process_group=None,
subgroup=None,
start_localSGD_iter=8,
),
ddp_comm_hook=post_localSGD.post_localSGD_hook,
model_averaging_period=4,
)
trainer = Trainer(
fast_dev_run=True,
gpus=2,
strategy=strategy,
default_root_dir=tmpdir,
sync_batchnorm=True,
)
with pytest.raises(ValueError, match="Currently model averaging cannot work with a distributed optimizer"):
trainer.fit(model)
average_parameters_mock.assert_not_called()