Sync module states during non-fit (#17370)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
9becc15ddf
commit
97a61868fb
|
@ -243,7 +243,6 @@ utilities
|
|||
combined_loader
|
||||
data
|
||||
deepspeed
|
||||
distributed
|
||||
memory
|
||||
model_summary
|
||||
parsing
|
||||
|
|
|
@ -12,14 +12,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
from typing import Any, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union
|
||||
from typing import Any, Callable, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
|
||||
|
||||
from lightning.fabric.utilities.distributed import _DatasetSamplerWrapper
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info
|
||||
|
||||
|
||||
def _find_tensors(
|
||||
|
@ -37,7 +39,7 @@ def _find_tensors(
|
|||
|
||||
# In manual_optimization, we need to call reducer prepare_for_backward.
|
||||
# Note: Keep track of PyTorch DDP and update if there is a change
|
||||
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
|
||||
# https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L1163-L1178
|
||||
def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
|
||||
# `prepare_for_backward` is `DistributedDataParallel` specific.
|
||||
if torch.is_grad_enabled() and model.require_backward_grad_sync:
|
||||
|
@ -47,7 +49,7 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
|
|||
# because we need to figure out which parameters were used during
|
||||
# this forward pass, to ensure we short circuit reduction for any
|
||||
# unused parameters. Only if `find_unused_parameters` is set.
|
||||
args = list(_find_tensors(output)) if model.find_unused_parameters else []
|
||||
args = list(_find_tensors(output)) if model.find_unused_parameters and not model.static_graph else []
|
||||
reducer = cast(torch._C._distributed_c10d.Reducer, model.reducer)
|
||||
reducer._rebuild_buckets() # avoids "INTERNAL ASSERT FAILED" with `find_unused_parameters=False`
|
||||
reducer.prepare_for_backward(args)
|
||||
|
@ -55,6 +57,135 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
|
|||
model.require_forward_param_sync = False
|
||||
|
||||
|
||||
def _register_ddp_comm_hook(
|
||||
model: DistributedDataParallel,
|
||||
ddp_comm_state: Optional[object] = None,
|
||||
ddp_comm_hook: Optional[Callable] = None,
|
||||
ddp_comm_wrapper: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.
|
||||
|
||||
Args:
|
||||
model:
|
||||
DDP model
|
||||
ddp_comm_state:
|
||||
state is passed to the hook and can be used to maintain
|
||||
and update any state information that users would like to
|
||||
maintain as part of the training process. Examples: error
|
||||
feedback in gradient compression, peers to communicate with
|
||||
next in GossipGrad etc.
|
||||
ddp_comm_hook:
|
||||
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
|
||||
|
||||
This callable function is called once the bucket is ready. The
|
||||
hook can perform whatever processing is needed and return
|
||||
a Future indicating completion of any async work (ex: allreduce).
|
||||
If the hook doesn't perform any communication, it can also
|
||||
just return a completed Future. The Future should hold the
|
||||
new value of grad bucket's tensors. Once a bucket is ready,
|
||||
c10d reducer would call this hook and use the tensors returned
|
||||
by the Future and copy grads to individual parameters.
|
||||
ddp_comm_wrapper:
|
||||
communication hook wrapper to support a communication hook such
|
||||
as FP16 compression as wrapper, which could be combined with
|
||||
ddp_comm_hook
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP
|
||||
... default_hooks as default,
|
||||
... powerSGD_hook as powerSGD,
|
||||
... post_localSGD_hook as post_localSGD,
|
||||
... )
|
||||
>>> # fp16_compress_hook for compress gradients
|
||||
>>> ddp_model = ...
|
||||
>>> _register_ddp_comm_hook( # doctest: +SKIP
|
||||
... model=ddp_model,
|
||||
... ddp_comm_hook=default.fp16_compress_hook,
|
||||
... )
|
||||
>>> # powerSGD_hook
|
||||
>>> ddp_model = ...
|
||||
>>> _register_ddp_comm_hook( # doctest: +SKIP
|
||||
... model=ddp_model,
|
||||
... ddp_comm_state=powerSGD.PowerSGDState(
|
||||
... process_group=None,
|
||||
... matrix_approximation_rank=1,
|
||||
... start_powerSGD_iter=5000,
|
||||
... ),
|
||||
... ddp_comm_hook=powerSGD.powerSGD_hook,
|
||||
... )
|
||||
>>> # post_localSGD_hook
|
||||
>>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP
|
||||
>>> ddp_model = ...
|
||||
>>> _register_ddp_comm_hook( # doctest: +SKIP
|
||||
... model=ddp_model,
|
||||
... state=post_localSGD.PostLocalSGDState(
|
||||
... process_group=None,
|
||||
... subgroup=subgroup,
|
||||
... start_localSGD_iter=1_000,
|
||||
... ),
|
||||
... ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||
... )
|
||||
>>> # fp16_compress_wrapper combined with other communication hook
|
||||
>>> ddp_model = ...
|
||||
>>> _register_ddp_comm_hook( # doctest: +SKIP
|
||||
... model=ddp_model,
|
||||
... ddp_comm_state=powerSGD.PowerSGDState(
|
||||
... process_group=None,
|
||||
... matrix_approximation_rank=1,
|
||||
... start_powerSGD_iter=5000,
|
||||
... ),
|
||||
... ddp_comm_hook=powerSGD.powerSGD_hook,
|
||||
... ddp_comm_wrapper=default.fp16_compress_wrapper,
|
||||
... )
|
||||
"""
|
||||
if ddp_comm_hook is None:
|
||||
return
|
||||
# inform mypy that ddp_comm_hook is callable
|
||||
ddp_comm_hook: Callable = ddp_comm_hook
|
||||
|
||||
if ddp_comm_wrapper is not None:
|
||||
rank_zero_info(
|
||||
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
|
||||
)
|
||||
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
|
||||
|
||||
rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
|
||||
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
|
||||
|
||||
|
||||
def _sync_module_states(module: torch.nn.Module) -> None:
|
||||
"""Taken from https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L675-L682."""
|
||||
parameters_to_ignore = (
|
||||
set(module._ddp_params_and_buffers_to_ignore) # type: ignore[arg-type]
|
||||
if hasattr(module, "_ddp_params_and_buffers_to_ignore")
|
||||
else set()
|
||||
)
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
if not _TORCH_GREATER_EQUAL_1_12:
|
||||
module_states = []
|
||||
for name, param in module.named_parameters():
|
||||
if name not in parameters_to_ignore:
|
||||
module_states.append(param.detach())
|
||||
for name, buffer in module.named_buffers():
|
||||
if name not in parameters_to_ignore:
|
||||
module_states.append(buffer.detach())
|
||||
if len(module_states) > 0:
|
||||
torch.distributed._broadcast_coalesced(_get_default_group(), module_states, 250 * 1024 * 1024, 0)
|
||||
return
|
||||
|
||||
from torch.distributed.utils import _sync_module_states as torch_sync_module_states
|
||||
|
||||
torch_sync_module_states(
|
||||
module,
|
||||
_get_default_group(),
|
||||
250 * 1024 * 1024,
|
||||
src=0,
|
||||
params_and_buffers_to_ignore=parameters_to_ignore,
|
||||
)
|
||||
|
||||
|
||||
class UnrepeatedDistributedSampler(DistributedSampler):
|
||||
"""A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches
|
||||
per process to be off-by-one from each other. This makes this sampler usable for predictions (it's
|
||||
|
|
|
@ -38,13 +38,12 @@ from lightning.fabric.utilities.seed import reset_seed
|
|||
from lightning.fabric.utilities.types import ReduceOp
|
||||
from lightning.pytorch.core.optimizer import LightningOptimizer
|
||||
from lightning.pytorch.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from lightning.pytorch.overrides.distributed import prepare_for_backward
|
||||
from lightning.pytorch.overrides.distributed import _register_ddp_comm_hook, _sync_module_states, prepare_for_backward
|
||||
from lightning.pytorch.plugins.precision import PrecisionPlugin
|
||||
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
|
||||
from lightning.pytorch.strategies.parallel import ParallelStrategy
|
||||
from lightning.pytorch.strategies.strategy import TBroadcast
|
||||
from lightning.pytorch.trainer.states import TrainerFn
|
||||
from lightning.pytorch.utilities.distributed import _register_ddp_comm_hook
|
||||
from lightning.pytorch.utilities.exceptions import _augment_message
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only
|
||||
from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
|
||||
|
@ -153,25 +152,28 @@ class DDPStrategy(ParallelStrategy):
|
|||
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
|
||||
trainer_fn = trainer.state.fn
|
||||
|
||||
if trainer_fn == TrainerFn.FITTING:
|
||||
if self._layer_sync:
|
||||
assert self.model is not None
|
||||
self.model = self._layer_sync.apply(self.model)
|
||||
if trainer_fn == TrainerFn.FITTING and self._layer_sync:
|
||||
assert self.model is not None
|
||||
self.model = self._layer_sync.apply(self.model)
|
||||
|
||||
self.setup_precision_plugin()
|
||||
|
||||
if trainer_fn == TrainerFn.FITTING:
|
||||
# do not wrap with DDP if not fitting as there's no gradients to reduce
|
||||
self.configure_ddp()
|
||||
|
||||
# set up optimizers after the wrapped module has been moved to the device
|
||||
self.setup_optimizers(trainer)
|
||||
_optimizers_to_device(self.optimizers, self.root_device)
|
||||
|
||||
if trainer_fn == TrainerFn.FITTING:
|
||||
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
||||
|
||||
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
|
||||
self._enable_model_averaging()
|
||||
else:
|
||||
# we need to manually synchronize the module's states since we aren't using the DDP wrapper
|
||||
assert self.model is not None
|
||||
_sync_module_states(self.model)
|
||||
|
||||
def _setup_model(self, model: Module) -> DistributedDataParallel:
|
||||
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
|
||||
|
|
|
@ -1,144 +0,0 @@
|
|||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities that can be used with distributed training."""
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from lightning.fabric.utilities.distributed import _distributed_available as new_distributed_available
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info
|
||||
|
||||
|
||||
def _register_ddp_comm_hook(
|
||||
model: DistributedDataParallel,
|
||||
ddp_comm_state: Optional[object] = None,
|
||||
ddp_comm_hook: Optional[Callable] = None,
|
||||
ddp_comm_wrapper: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.
|
||||
|
||||
Args:
|
||||
model:
|
||||
DDP model
|
||||
ddp_comm_state:
|
||||
state is passed to the hook and can be used to maintain
|
||||
and update any state information that users would like to
|
||||
maintain as part of the training process. Examples: error
|
||||
feedback in gradient compression, peers to communicate with
|
||||
next in GossipGrad etc.
|
||||
ddp_comm_hook:
|
||||
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
|
||||
|
||||
This callable function is called once the bucket is ready. The
|
||||
hook can perform whatever processing is needed and return
|
||||
a Future indicating completion of any async work (ex: allreduce).
|
||||
If the hook doesn't perform any communication, it can also
|
||||
just return a completed Future. The Future should hold the
|
||||
new value of grad bucket's tensors. Once a bucket is ready,
|
||||
c10d reducer would call this hook and use the tensors returned
|
||||
by the Future and copy grads to individual parameters.
|
||||
ddp_comm_wrapper:
|
||||
communication hook wrapper to support a communication hook such
|
||||
as FP16 compression as wrapper, which could be combined with
|
||||
ddp_comm_hook
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP
|
||||
... default_hooks as default,
|
||||
... powerSGD_hook as powerSGD,
|
||||
... post_localSGD_hook as post_localSGD,
|
||||
... )
|
||||
>>>
|
||||
>>> # fp16_compress_hook for compress gradients
|
||||
>>> ddp_model = ...
|
||||
>>> _register_ddp_comm_hook( # doctest: +SKIP
|
||||
... model=ddp_model,
|
||||
... ddp_comm_hook=default.fp16_compress_hook,
|
||||
... )
|
||||
>>>
|
||||
>>> # powerSGD_hook
|
||||
>>> ddp_model = ...
|
||||
>>> _register_ddp_comm_hook( # doctest: +SKIP
|
||||
... model=ddp_model,
|
||||
... ddp_comm_state=powerSGD.PowerSGDState(
|
||||
... process_group=None,
|
||||
... matrix_approximation_rank=1,
|
||||
... start_powerSGD_iter=5000,
|
||||
... ),
|
||||
... ddp_comm_hook=powerSGD.powerSGD_hook,
|
||||
... )
|
||||
>>>
|
||||
>>> # post_localSGD_hook
|
||||
>>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP
|
||||
>>> ddp_model = ...
|
||||
>>> _register_ddp_comm_hook( # doctest: +SKIP
|
||||
... model=ddp_model,
|
||||
... state=post_localSGD.PostLocalSGDState(
|
||||
... process_group=None,
|
||||
... subgroup=subgroup,
|
||||
... start_localSGD_iter=1_000,
|
||||
... ),
|
||||
... ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
||||
... )
|
||||
>>>
|
||||
>>> # fp16_compress_wrapper combined with other communication hook
|
||||
>>> ddp_model = ...
|
||||
>>> _register_ddp_comm_hook( # doctest: +SKIP
|
||||
... model=ddp_model,
|
||||
... ddp_comm_state=powerSGD.PowerSGDState(
|
||||
... process_group=None,
|
||||
... matrix_approximation_rank=1,
|
||||
... start_powerSGD_iter=5000,
|
||||
... ),
|
||||
... ddp_comm_hook=powerSGD.powerSGD_hook,
|
||||
... ddp_comm_wrapper=default.fp16_compress_wrapper,
|
||||
... )
|
||||
"""
|
||||
if ddp_comm_hook is None:
|
||||
return
|
||||
# inform mypy that ddp_comm_hook is callable
|
||||
ddp_comm_hook: Callable = ddp_comm_hook
|
||||
|
||||
if ddp_comm_wrapper is not None:
|
||||
rank_zero_info(
|
||||
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
|
||||
)
|
||||
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
|
||||
|
||||
rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
|
||||
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
|
||||
|
||||
|
||||
def _broadcast_object_list(obj: Any, rank: int) -> Any:
|
||||
objects = [obj if torch.distributed.get_rank() == rank else None]
|
||||
torch.distributed.broadcast_object_list(objects, src=rank)
|
||||
return objects[0]
|
||||
|
||||
|
||||
# TODO: Refactor with the Strategy Collectives once finalized.
|
||||
def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]:
|
||||
"""This distributed utility collects dictionary state across all processes.
|
||||
|
||||
Args:
|
||||
state: Dictionary containing the state of the current process
|
||||
|
||||
Returns:
|
||||
states: On global rank 0, a dictionary where the primary keys are
|
||||
the process rank and the values their associated states. Otherwise, returns None.
|
||||
"""
|
||||
if not new_distributed_available():
|
||||
return {0: state}
|
||||
return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())}
|
|
@ -14,11 +14,39 @@
|
|||
from typing import Iterable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import BatchSampler, SequentialSampler
|
||||
|
||||
from lightning.fabric.utilities.data import has_len
|
||||
from lightning.pytorch import seed_everything
|
||||
from lightning.pytorch import LightningModule, seed_everything, Trainer
|
||||
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
class MyModel(LightningModule):
|
||||
def setup(self, stage: str) -> None:
|
||||
self.layer = torch.nn.Linear(1, 1)
|
||||
weights = self.layer.weight.item(), self.layer.bias.item()
|
||||
self.rank_0_weights = self.trainer.strategy.broadcast(weights)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
current = self.layer.weight.item(), self.layer.bias.item()
|
||||
assert self.rank_0_weights == current
|
||||
gathered = self.all_gather(current)
|
||||
# the weights have been synced
|
||||
assert all(torch.all(t == t[0]) for t in gathered), gathered
|
||||
|
||||
|
||||
@RunIf(standalone=True)
|
||||
def test_params_synced_during_nonfit():
|
||||
model = MyModel()
|
||||
trainer = Trainer(
|
||||
barebones=True,
|
||||
devices=2,
|
||||
accelerator="cpu",
|
||||
strategy="ddp",
|
||||
)
|
||||
trainer.test(model, [0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shuffle", [False, True])
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
# Copyright The Lightning AI team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from lightning.pytorch.utilities.distributed import _collect_states_on_rank_zero
|
||||
from tests_pytorch.core.test_results import spawn_launch
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
def collect_states_fn(strategy):
|
||||
rank = strategy.local_rank
|
||||
state = {"something": torch.tensor([rank])}
|
||||
collected_state = _collect_states_on_rank_zero(state)
|
||||
assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}}
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True)
|
||||
def test_collect_states():
|
||||
"""This test ensures state are properly collected across processes.
|
||||
|
||||
This would be used to collect dataloader states as an example.
|
||||
"""
|
||||
spawn_launch(collect_states_fn, [torch.device("cuda:0"), torch.device("cuda:1")])
|
Loading…
Reference in New Issue