Replace PostLocalSGDOptimizer with a dedicated model averaging component (#12378)
This commit is contained in:
parent
ec7fa1e2d8
commit
6329be60be
|
@ -402,6 +402,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed `Timer.on_save_checkpoint` and `Timer.on_load_checkpoint` in favor of `Timer.state_dict` and `Timer.load_state_dict` ([#11887](https://github.com/PyTorchLightning/pytorch-lightning/pull/11887))
|
||||
|
||||
|
||||
- Replaced PostLocalSGDOptimizer with a dedicated model averaging component ([#12378](https://github.com/PyTorchLightning/pytorch-lightning/pull/12378))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
|
||||
|
|
|
@ -749,10 +749,7 @@ Enable `FP16 Compress Hook for multi-node throughput improvement <https://pytorc
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
default_hooks as default,
|
||||
powerSGD_hook as powerSGD,
|
||||
)
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
|
||||
|
||||
model = MyModel()
|
||||
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(ddp_comm_hook=default.fp16_compress_hook))
|
||||
|
@ -816,6 +813,33 @@ Combine hooks for accumulated benefit:
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
When using Post-localSGD, you must also pass ``model_averaging_period`` to allow for model parameter averaging:
|
||||
|
||||
.. note::
|
||||
Post-localSGD support requires PyTorch>=1.10.0
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import post_localSGD_hook as post_localSGD
|
||||
|
||||
model = MyModel()
|
||||
trainer = Trainer(
|
||||
gpus=4,
|
||||
strategy=DDPStrategy(
|
||||
ddp_comm_state=post_localSGD.PostLocalSGDState(
|
||||
process_group=None,
|
||||
subgroup=None,
|
||||
start_localSGD_iter=8,
|
||||
),
|
||||
ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||
model_averaging_period=4,
|
||||
),
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
DDP Static Graph
|
||||
""""""""""""""""
|
||||
|
||||
|
|
|
@ -18,12 +18,13 @@ import signal
|
|||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch.nn import Module
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
|
@ -59,6 +60,8 @@ if _FAIRSCALE_AVAILABLE:
|
|||
from fairscale.optim import OSS
|
||||
if _TORCH_GREATER_EQUAL_1_8:
|
||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
from torch.distributed.algorithms.model_averaging.averagers import ModelAverager
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -97,6 +100,7 @@ class DDPStrategy(ParallelStrategy):
|
|||
self._ddp_comm_hook = ddp_comm_hook
|
||||
self._ddp_comm_wrapper = ddp_comm_wrapper
|
||||
self._model_averaging_period = model_averaging_period
|
||||
self._model_averager: Optional[ModelAverager] = None
|
||||
self._pids: Optional[List[int]] = None
|
||||
self._sync_dir: Optional[str] = None
|
||||
self._rank_0_will_call_children_scripts: bool = False
|
||||
|
@ -223,23 +227,18 @@ class DDPStrategy(ParallelStrategy):
|
|||
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
||||
|
||||
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
|
||||
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
|
||||
self._enable_model_averaging()
|
||||
|
||||
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
|
||||
def _enable_model_averaging(self) -> None:
|
||||
# Only called when PyTorch version >= 1.10
|
||||
log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")
|
||||
optimizers = self.optimizers
|
||||
if self._model_averaging_period is None:
|
||||
raise ValueError(
|
||||
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy."
|
||||
)
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
if not _IS_WINDOWS:
|
||||
from torch.distributed.optim import DistributedOptimizer
|
||||
import torch.distributed.algorithms.model_averaging.averagers as averagers
|
||||
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
|
||||
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
|
||||
|
||||
averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
|
||||
for x, optimizer in enumerate(optimizers):
|
||||
for optimizer in self.optimizers:
|
||||
if isinstance(optimizer, LightningOptimizer):
|
||||
optimizer = optimizer._optimizer
|
||||
|
||||
|
@ -248,24 +247,46 @@ class DDPStrategy(ParallelStrategy):
|
|||
is_distributed_optimizer
|
||||
or isinstance(optimizer, ZeroRedundancyOptimizer)
|
||||
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
|
||||
or isinstance(optimizer, PostLocalSGDOptimizer)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
|
||||
f"Currently model averaging cannot work with a distributed optimizer of type "
|
||||
f"{optimizer.__class__.__name__}."
|
||||
)
|
||||
|
||||
if isinstance(optimizer, PostLocalSGDOptimizer):
|
||||
continue
|
||||
self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager(
|
||||
period=self._model_averaging_period, warmup_steps=self._ddp_comm_state.start_localSGD_iter
|
||||
)
|
||||
|
||||
optim_class = type(optimizer)
|
||||
post_localSGD_optimizer = PostLocalSGDOptimizer(
|
||||
params=optimizer.param_groups,
|
||||
optimizer_class=optim_class,
|
||||
averager=averager,
|
||||
**optimizer.defaults,
|
||||
)
|
||||
optimizers[x] = post_localSGD_optimizer
|
||||
del optimizer
|
||||
self.optimizers = optimizers
|
||||
def optimizer_step(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
opt_idx: int,
|
||||
closure: Callable[[], Any],
|
||||
model: Optional[Union["pl.LightningModule", Module]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Performs the actual optimizer step.
|
||||
|
||||
Args:
|
||||
optimizer: the optimizer performing the step
|
||||
opt_idx: index of the current optimizer
|
||||
closure: closure calculating the loss value
|
||||
model: reference to the model, optionally defining optimizer step related hooks
|
||||
**kwargs: Any extra arguments to ``optimizer.step``
|
||||
"""
|
||||
optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
|
||||
|
||||
if not _TORCH_GREATER_EQUAL_1_10 or self._model_averager is None:
|
||||
return optimizer_output
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
self._model_averager.average_parameters(iter(param))
|
||||
|
||||
return optimizer_output
|
||||
|
||||
def configure_ddp(self) -> None:
|
||||
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
|
||||
|
|
|
@ -181,7 +181,7 @@ class Strategy(ABC):
|
|||
model: Optional[Union["pl.LightningModule", Module]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""performs the actual optimizer step.
|
||||
"""Performs the actual optimizer step.
|
||||
|
||||
Args:
|
||||
optimizer: the optimizer performing the step
|
||||
|
|
|
@ -11,6 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
|
@ -136,3 +139,80 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir):
|
|||
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
|
||||
assert trainer_comm_hook == expected_comm_hook
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True)
|
||||
@mock.patch("torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager.average_parameters")
|
||||
def test_post_local_sgd_model_averaging(average_parameters_mock, tmpdir):
|
||||
"""Test that when using DDP with post-localSGD, model averaging is called."""
|
||||
model = BoringModel()
|
||||
|
||||
# test regular ddp does not call model averaging
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
gpus=2,
|
||||
strategy="ddp",
|
||||
default_root_dir=tmpdir,
|
||||
sync_batchnorm=True,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
average_parameters_mock.assert_not_called()
|
||||
|
||||
# test ddp with post-localSGD does call model averaging
|
||||
ddp_strategy = DDPStrategy(
|
||||
ddp_comm_state=post_localSGD.PostLocalSGDState(
|
||||
process_group=None,
|
||||
subgroup=None,
|
||||
start_localSGD_iter=8,
|
||||
),
|
||||
ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||
model_averaging_period=4,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
gpus=2,
|
||||
strategy=ddp_strategy,
|
||||
default_root_dir=tmpdir,
|
||||
sync_batchnorm=True,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
average_parameters_mock.assert_called()
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True)
|
||||
@mock.patch("torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager.average_parameters")
|
||||
def test_post_local_sgd_model_averaging_value_error(average_parameters_mock, tmpdir):
|
||||
"""Test that when using DDP with post-localSGD a ValueError is thrown when the optmizer is
|
||||
ZeroRedundancyOptimizer."""
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
|
||||
class OptimizerModel(BoringModel):
|
||||
def configure_optimizers(self):
|
||||
return ZeroRedundancyOptimizer(params=self.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
|
||||
|
||||
model = OptimizerModel()
|
||||
strategy = DDPStrategy(
|
||||
ddp_comm_state=post_localSGD.PostLocalSGDState(
|
||||
process_group=None,
|
||||
subgroup=None,
|
||||
start_localSGD_iter=8,
|
||||
),
|
||||
ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||
model_averaging_period=4,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
gpus=2,
|
||||
strategy=strategy,
|
||||
default_root_dir=tmpdir,
|
||||
sync_batchnorm=True,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Currently model averaging cannot work with a distributed optimizer"):
|
||||
trainer.fit(model)
|
||||
|
||||
average_parameters_mock.assert_not_called()
|
||||
|
|
Loading…
Reference in New Issue