Replace PostLocalSGDOptimizer with a dedicated model averaging component (#12378)

This commit is contained in:
Danielle Pintz 2022-03-24 20:33:19 -04:00 committed by GitHub
parent ec7fa1e2d8
commit 6329be60be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 157 additions and 29 deletions

View File

@ -402,6 +402,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `Timer.on_save_checkpoint` and `Timer.on_load_checkpoint` in favor of `Timer.state_dict` and `Timer.load_state_dict` ([#11887](https://github.com/PyTorchLightning/pytorch-lightning/pull/11887)) - Removed `Timer.on_save_checkpoint` and `Timer.on_load_checkpoint` in favor of `Timer.state_dict` and `Timer.load_state_dict` ([#11887](https://github.com/PyTorchLightning/pytorch-lightning/pull/11887))
- Replaced PostLocalSGDOptimizer with a dedicated model averaging component ([#12378](https://github.com/PyTorchLightning/pytorch-lightning/pull/12378))
### Deprecated ### Deprecated
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141)) - Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))

View File

@ -749,10 +749,7 @@ Enable `FP16 Compress Hook for multi-node throughput improvement <https://pytorc
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import ( from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
default_hooks as default,
powerSGD_hook as powerSGD,
)
model = MyModel() model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(ddp_comm_hook=default.fp16_compress_hook)) trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(ddp_comm_hook=default.fp16_compress_hook))
@ -816,6 +813,33 @@ Combine hooks for accumulated benefit:
) )
trainer.fit(model) trainer.fit(model)
When using Post-localSGD, you must also pass ``model_averaging_period`` to allow for model parameter averaging:
.. note::
Post-localSGD support requires PyTorch>=1.10.0
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from torch.distributed.algorithms.ddp_comm_hooks import post_localSGD_hook as post_localSGD
model = MyModel()
trainer = Trainer(
gpus=4,
strategy=DDPStrategy(
ddp_comm_state=post_localSGD.PostLocalSGDState(
process_group=None,
subgroup=None,
start_localSGD_iter=8,
),
ddp_comm_hook=post_localSGD.post_localSGD_hook,
model_averaging_period=4,
),
)
trainer.fit(model)
DDP Static Graph DDP Static Graph
"""""""""""""""" """"""""""""""""

View File

@ -18,12 +18,13 @@ import signal
import tempfile import tempfile
import time import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
import torch.distributed import torch.distributed
from torch.nn import Module from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim.optimizer import Optimizer
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.optimizer import LightningOptimizer
@ -59,6 +60,8 @@ if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS from fairscale.optim import OSS
if _TORCH_GREATER_EQUAL_1_8: if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.algorithms.model_averaging.averagers import ModelAverager
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -97,6 +100,7 @@ class DDPStrategy(ParallelStrategy):
self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper self._ddp_comm_wrapper = ddp_comm_wrapper
self._model_averaging_period = model_averaging_period self._model_averaging_period = model_averaging_period
self._model_averager: Optional[ModelAverager] = None
self._pids: Optional[List[int]] = None self._pids: Optional[List[int]] = None
self._sync_dir: Optional[str] = None self._sync_dir: Optional[str] = None
self._rank_0_will_call_children_scripts: bool = False self._rank_0_will_call_children_scripts: bool = False
@ -223,23 +227,18 @@ class DDPStrategy(ParallelStrategy):
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) self._enable_model_averaging()
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): def _enable_model_averaging(self) -> None:
# Only called when PyTorch version >= 1.10
log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD") log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")
optimizers = self.optimizers
if self._model_averaging_period is None: if self._model_averaging_period is None:
raise ValueError( raise ValueError(
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy." "Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy."
) )
if _TORCH_GREATER_EQUAL_1_10: from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
if not _IS_WINDOWS:
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.algorithms.model_averaging.averagers as averagers
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps) for optimizer in self.optimizers:
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer): if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer optimizer = optimizer._optimizer
@ -248,24 +247,46 @@ class DDPStrategy(ParallelStrategy):
is_distributed_optimizer is_distributed_optimizer
or isinstance(optimizer, ZeroRedundancyOptimizer) or isinstance(optimizer, ZeroRedundancyOptimizer)
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)) or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
or isinstance(optimizer, PostLocalSGDOptimizer)
): ):
raise ValueError( raise ValueError(
f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer." f"Currently model averaging cannot work with a distributed optimizer of type "
f"{optimizer.__class__.__name__}."
) )
if isinstance(optimizer, PostLocalSGDOptimizer): self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager(
period=self._model_averaging_period, warmup_steps=self._ddp_comm_state.start_localSGD_iter
)
def optimizer_step(
self,
optimizer: Optimizer,
opt_idx: int,
closure: Callable[[], Any],
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any,
) -> Any:
"""Performs the actual optimizer step.
Args:
optimizer: the optimizer performing the step
opt_idx: index of the current optimizer
closure: closure calculating the loss value
model: reference to the model, optionally defining optimizer step related hooks
**kwargs: Any extra arguments to ``optimizer.step``
"""
optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
if not _TORCH_GREATER_EQUAL_1_10 or self._model_averager is None:
return optimizer_output
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue continue
self._model_averager.average_parameters(iter(param))
optim_class = type(optimizer) return optimizer_output
post_localSGD_optimizer = PostLocalSGDOptimizer(
params=optimizer.param_groups,
optimizer_class=optim_class,
averager=averager,
**optimizer.defaults,
)
optimizers[x] = post_localSGD_optimizer
del optimizer
self.optimizers = optimizers
def configure_ddp(self) -> None: def configure_ddp(self) -> None:
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel") log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")

View File

@ -181,7 +181,7 @@ class Strategy(ABC):
model: Optional[Union["pl.LightningModule", Module]] = None, model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""performs the actual optimizer step. """Performs the actual optimizer step.
Args: Args:
optimizer: the optimizer performing the step optimizer: the optimizer performing the step

View File

@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest import mock
import pytest
import torch import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
@ -136,3 +139,80 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir):
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__ expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.state.finished, f"Training failed with {trainer.state}"
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True)
@mock.patch("torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager.average_parameters")
def test_post_local_sgd_model_averaging(average_parameters_mock, tmpdir):
"""Test that when using DDP with post-localSGD, model averaging is called."""
model = BoringModel()
# test regular ddp does not call model averaging
trainer = Trainer(
fast_dev_run=True,
gpus=2,
strategy="ddp",
default_root_dir=tmpdir,
sync_batchnorm=True,
)
trainer.fit(model)
average_parameters_mock.assert_not_called()
# test ddp with post-localSGD does call model averaging
ddp_strategy = DDPStrategy(
ddp_comm_state=post_localSGD.PostLocalSGDState(
process_group=None,
subgroup=None,
start_localSGD_iter=8,
),
ddp_comm_hook=post_localSGD.post_localSGD_hook,
model_averaging_period=4,
)
trainer = Trainer(
fast_dev_run=True,
gpus=2,
strategy=ddp_strategy,
default_root_dir=tmpdir,
sync_batchnorm=True,
)
trainer.fit(model)
average_parameters_mock.assert_called()
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True)
@mock.patch("torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager.average_parameters")
def test_post_local_sgd_model_averaging_value_error(average_parameters_mock, tmpdir):
"""Test that when using DDP with post-localSGD a ValueError is thrown when the optmizer is
ZeroRedundancyOptimizer."""
from torch.distributed.optim import ZeroRedundancyOptimizer
class OptimizerModel(BoringModel):
def configure_optimizers(self):
return ZeroRedundancyOptimizer(params=self.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
model = OptimizerModel()
strategy = DDPStrategy(
ddp_comm_state=post_localSGD.PostLocalSGDState(
process_group=None,
subgroup=None,
start_localSGD_iter=8,
),
ddp_comm_hook=post_localSGD.post_localSGD_hook,
model_averaging_period=4,
)
trainer = Trainer(
fast_dev_run=True,
gpus=2,
strategy=strategy,
default_root_dir=tmpdir,
sync_batchnorm=True,
)
with pytest.raises(ValueError, match="Currently model averaging cannot work with a distributed optimizer"):
trainer.fit(model)
average_parameters_mock.assert_not_called()