Replace PostLocalSGDOptimizer with a dedicated model averaging component (#12378)
This commit is contained in:
parent
ec7fa1e2d8
commit
6329be60be
|
@ -402,6 +402,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Removed `Timer.on_save_checkpoint` and `Timer.on_load_checkpoint` in favor of `Timer.state_dict` and `Timer.load_state_dict` ([#11887](https://github.com/PyTorchLightning/pytorch-lightning/pull/11887))
|
- Removed `Timer.on_save_checkpoint` and `Timer.on_load_checkpoint` in favor of `Timer.state_dict` and `Timer.load_state_dict` ([#11887](https://github.com/PyTorchLightning/pytorch-lightning/pull/11887))
|
||||||
|
|
||||||
|
|
||||||
|
- Replaced PostLocalSGDOptimizer with a dedicated model averaging component ([#12378](https://github.com/PyTorchLightning/pytorch-lightning/pull/12378))
|
||||||
|
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
|
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
|
||||||
|
|
|
@ -749,10 +749,7 @@ Enable `FP16 Compress Hook for multi-node throughput improvement <https://pytorc
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.strategies import DDPStrategy
|
from pytorch_lightning.strategies import DDPStrategy
|
||||||
from torch.distributed.algorithms.ddp_comm_hooks import (
|
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
|
||||||
default_hooks as default,
|
|
||||||
powerSGD_hook as powerSGD,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = MyModel()
|
model = MyModel()
|
||||||
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(ddp_comm_hook=default.fp16_compress_hook))
|
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPStrategy(ddp_comm_hook=default.fp16_compress_hook))
|
||||||
|
@ -816,6 +813,33 @@ Combine hooks for accumulated benefit:
|
||||||
)
|
)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
||||||
|
When using Post-localSGD, you must also pass ``model_averaging_period`` to allow for model parameter averaging:
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Post-localSGD support requires PyTorch>=1.10.0
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
from pytorch_lightning.strategies import DDPStrategy
|
||||||
|
from torch.distributed.algorithms.ddp_comm_hooks import post_localSGD_hook as post_localSGD
|
||||||
|
|
||||||
|
model = MyModel()
|
||||||
|
trainer = Trainer(
|
||||||
|
gpus=4,
|
||||||
|
strategy=DDPStrategy(
|
||||||
|
ddp_comm_state=post_localSGD.PostLocalSGDState(
|
||||||
|
process_group=None,
|
||||||
|
subgroup=None,
|
||||||
|
start_localSGD_iter=8,
|
||||||
|
),
|
||||||
|
ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||||
|
model_averaging_period=4,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
DDP Static Graph
|
DDP Static Graph
|
||||||
""""""""""""""""
|
""""""""""""""""
|
||||||
|
|
||||||
|
|
|
@ -18,12 +18,13 @@ import signal
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||||
|
@ -59,6 +60,8 @@ if _FAIRSCALE_AVAILABLE:
|
||||||
from fairscale.optim import OSS
|
from fairscale.optim import OSS
|
||||||
if _TORCH_GREATER_EQUAL_1_8:
|
if _TORCH_GREATER_EQUAL_1_8:
|
||||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||||
|
if _TORCH_GREATER_EQUAL_1_10:
|
||||||
|
from torch.distributed.algorithms.model_averaging.averagers import ModelAverager
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -97,6 +100,7 @@ class DDPStrategy(ParallelStrategy):
|
||||||
self._ddp_comm_hook = ddp_comm_hook
|
self._ddp_comm_hook = ddp_comm_hook
|
||||||
self._ddp_comm_wrapper = ddp_comm_wrapper
|
self._ddp_comm_wrapper = ddp_comm_wrapper
|
||||||
self._model_averaging_period = model_averaging_period
|
self._model_averaging_period = model_averaging_period
|
||||||
|
self._model_averager: Optional[ModelAverager] = None
|
||||||
self._pids: Optional[List[int]] = None
|
self._pids: Optional[List[int]] = None
|
||||||
self._sync_dir: Optional[str] = None
|
self._sync_dir: Optional[str] = None
|
||||||
self._rank_0_will_call_children_scripts: bool = False
|
self._rank_0_will_call_children_scripts: bool = False
|
||||||
|
@ -223,23 +227,18 @@ class DDPStrategy(ParallelStrategy):
|
||||||
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
||||||
|
|
||||||
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
|
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
|
||||||
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
|
self._enable_model_averaging()
|
||||||
|
|
||||||
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
|
def _enable_model_averaging(self) -> None:
|
||||||
|
# Only called when PyTorch version >= 1.10
|
||||||
log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")
|
log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")
|
||||||
optimizers = self.optimizers
|
|
||||||
if self._model_averaging_period is None:
|
if self._model_averaging_period is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy."
|
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy."
|
||||||
)
|
)
|
||||||
if _TORCH_GREATER_EQUAL_1_10:
|
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
|
||||||
if not _IS_WINDOWS:
|
|
||||||
from torch.distributed.optim import DistributedOptimizer
|
|
||||||
import torch.distributed.algorithms.model_averaging.averagers as averagers
|
|
||||||
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
|
|
||||||
|
|
||||||
averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
|
for optimizer in self.optimizers:
|
||||||
for x, optimizer in enumerate(optimizers):
|
|
||||||
if isinstance(optimizer, LightningOptimizer):
|
if isinstance(optimizer, LightningOptimizer):
|
||||||
optimizer = optimizer._optimizer
|
optimizer = optimizer._optimizer
|
||||||
|
|
||||||
|
@ -248,24 +247,46 @@ class DDPStrategy(ParallelStrategy):
|
||||||
is_distributed_optimizer
|
is_distributed_optimizer
|
||||||
or isinstance(optimizer, ZeroRedundancyOptimizer)
|
or isinstance(optimizer, ZeroRedundancyOptimizer)
|
||||||
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
|
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
|
||||||
|
or isinstance(optimizer, PostLocalSGDOptimizer)
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot wrap a distributed optimizer of type {optimizer.__name__} by PostLocalSGDOptimizer."
|
f"Currently model averaging cannot work with a distributed optimizer of type "
|
||||||
|
f"{optimizer.__class__.__name__}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(optimizer, PostLocalSGDOptimizer):
|
self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager(
|
||||||
continue
|
period=self._model_averaging_period, warmup_steps=self._ddp_comm_state.start_localSGD_iter
|
||||||
|
)
|
||||||
|
|
||||||
optim_class = type(optimizer)
|
def optimizer_step(
|
||||||
post_localSGD_optimizer = PostLocalSGDOptimizer(
|
self,
|
||||||
params=optimizer.param_groups,
|
optimizer: Optimizer,
|
||||||
optimizer_class=optim_class,
|
opt_idx: int,
|
||||||
averager=averager,
|
closure: Callable[[], Any],
|
||||||
**optimizer.defaults,
|
model: Optional[Union["pl.LightningModule", Module]] = None,
|
||||||
)
|
**kwargs: Any,
|
||||||
optimizers[x] = post_localSGD_optimizer
|
) -> Any:
|
||||||
del optimizer
|
"""Performs the actual optimizer step.
|
||||||
self.optimizers = optimizers
|
|
||||||
|
Args:
|
||||||
|
optimizer: the optimizer performing the step
|
||||||
|
opt_idx: index of the current optimizer
|
||||||
|
closure: closure calculating the loss value
|
||||||
|
model: reference to the model, optionally defining optimizer step related hooks
|
||||||
|
**kwargs: Any extra arguments to ``optimizer.step``
|
||||||
|
"""
|
||||||
|
optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
|
||||||
|
|
||||||
|
if not _TORCH_GREATER_EQUAL_1_10 or self._model_averager is None:
|
||||||
|
return optimizer_output
|
||||||
|
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
for param in group["params"]:
|
||||||
|
if param.grad is None:
|
||||||
|
continue
|
||||||
|
self._model_averager.average_parameters(iter(param))
|
||||||
|
|
||||||
|
return optimizer_output
|
||||||
|
|
||||||
def configure_ddp(self) -> None:
|
def configure_ddp(self) -> None:
|
||||||
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
|
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
|
||||||
|
|
|
@ -181,7 +181,7 @@ class Strategy(ABC):
|
||||||
model: Optional[Union["pl.LightningModule", Module]] = None,
|
model: Optional[Union["pl.LightningModule", Module]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""performs the actual optimizer step.
|
"""Performs the actual optimizer step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer: the optimizer performing the step
|
optimizer: the optimizer performing the step
|
||||||
|
|
|
@ -11,6 +11,9 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
|
@ -136,3 +139,80 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir):
|
||||||
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
|
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
|
||||||
assert trainer_comm_hook == expected_comm_hook
|
assert trainer_comm_hook == expected_comm_hook
|
||||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||||
|
|
||||||
|
|
||||||
|
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True)
|
||||||
|
@mock.patch("torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager.average_parameters")
|
||||||
|
def test_post_local_sgd_model_averaging(average_parameters_mock, tmpdir):
|
||||||
|
"""Test that when using DDP with post-localSGD, model averaging is called."""
|
||||||
|
model = BoringModel()
|
||||||
|
|
||||||
|
# test regular ddp does not call model averaging
|
||||||
|
trainer = Trainer(
|
||||||
|
fast_dev_run=True,
|
||||||
|
gpus=2,
|
||||||
|
strategy="ddp",
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
sync_batchnorm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.fit(model)
|
||||||
|
average_parameters_mock.assert_not_called()
|
||||||
|
|
||||||
|
# test ddp with post-localSGD does call model averaging
|
||||||
|
ddp_strategy = DDPStrategy(
|
||||||
|
ddp_comm_state=post_localSGD.PostLocalSGDState(
|
||||||
|
process_group=None,
|
||||||
|
subgroup=None,
|
||||||
|
start_localSGD_iter=8,
|
||||||
|
),
|
||||||
|
ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||||
|
model_averaging_period=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
fast_dev_run=True,
|
||||||
|
gpus=2,
|
||||||
|
strategy=ddp_strategy,
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
sync_batchnorm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.fit(model)
|
||||||
|
average_parameters_mock.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@RunIf(skip_windows=True, min_torch="1.10.0", min_gpus=2, standalone=True)
|
||||||
|
@mock.patch("torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager.average_parameters")
|
||||||
|
def test_post_local_sgd_model_averaging_value_error(average_parameters_mock, tmpdir):
|
||||||
|
"""Test that when using DDP with post-localSGD a ValueError is thrown when the optmizer is
|
||||||
|
ZeroRedundancyOptimizer."""
|
||||||
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
|
|
||||||
|
class OptimizerModel(BoringModel):
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return ZeroRedundancyOptimizer(params=self.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
|
||||||
|
|
||||||
|
model = OptimizerModel()
|
||||||
|
strategy = DDPStrategy(
|
||||||
|
ddp_comm_state=post_localSGD.PostLocalSGDState(
|
||||||
|
process_group=None,
|
||||||
|
subgroup=None,
|
||||||
|
start_localSGD_iter=8,
|
||||||
|
),
|
||||||
|
ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||||
|
model_averaging_period=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
fast_dev_run=True,
|
||||||
|
gpus=2,
|
||||||
|
strategy=strategy,
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
sync_batchnorm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Currently model averaging cannot work with a distributed optimizer"):
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
average_parameters_mock.assert_not_called()
|
||||||
|
|
Loading…
Reference in New Issue