Sync module states during non-fit (#17370)

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Carlos Mocholí 2023-04-15 04:35:51 +02:00 committed by GitHub
parent 9becc15ddf
commit 97a61868fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 173 additions and 193 deletions

View File

@ -243,7 +243,6 @@ utilities
combined_loader
data
deepspeed
distributed
memory
model_summary
parsing

View File

@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union
from typing import Any, Callable, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union
import torch
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
from lightning.fabric.utilities.distributed import _DatasetSamplerWrapper
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info
def _find_tensors(
@ -37,7 +39,7 @@ def _find_tensors(
# In manual_optimization, we need to call reducer prepare_for_backward.
# Note: Keep track of PyTorch DDP and update if there is a change
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
# https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L1163-L1178
def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
# `prepare_for_backward` is `DistributedDataParallel` specific.
if torch.is_grad_enabled() and model.require_backward_grad_sync:
@ -47,7 +49,7 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
args = list(_find_tensors(output)) if model.find_unused_parameters else []
args = list(_find_tensors(output)) if model.find_unused_parameters and not model.static_graph else []
reducer = cast(torch._C._distributed_c10d.Reducer, model.reducer)
reducer._rebuild_buckets() # avoids "INTERNAL ASSERT FAILED" with `find_unused_parameters=False`
reducer.prepare_for_backward(args)
@ -55,6 +57,135 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
model.require_forward_param_sync = False
def _register_ddp_comm_hook(
model: DistributedDataParallel,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[Callable] = None,
ddp_comm_wrapper: Optional[Callable] = None,
) -> None:
"""Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.
Args:
model:
DDP model
ddp_comm_state:
state is passed to the hook and can be used to maintain
and update any state information that users would like to
maintain as part of the training process. Examples: error
feedback in gradient compression, peers to communicate with
next in GossipGrad etc.
ddp_comm_hook:
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
This callable function is called once the bucket is ready. The
hook can perform whatever processing is needed and return
a Future indicating completion of any async work (ex: allreduce).
If the hook doesn't perform any communication, it can also
just return a completed Future. The Future should hold the
new value of grad bucket's tensors. Once a bucket is ready,
c10d reducer would call this hook and use the tensors returned
by the Future and copy grads to individual parameters.
ddp_comm_wrapper:
communication hook wrapper to support a communication hook such
as FP16 compression as wrapper, which could be combined with
ddp_comm_hook
Examples:
>>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP
... default_hooks as default,
... powerSGD_hook as powerSGD,
... post_localSGD_hook as post_localSGD,
... )
>>> # fp16_compress_hook for compress gradients
>>> ddp_model = ...
>>> _register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_hook=default.fp16_compress_hook,
... )
>>> # powerSGD_hook
>>> ddp_model = ...
>>> _register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_state=powerSGD.PowerSGDState(
... process_group=None,
... matrix_approximation_rank=1,
... start_powerSGD_iter=5000,
... ),
... ddp_comm_hook=powerSGD.powerSGD_hook,
... )
>>> # post_localSGD_hook
>>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP
>>> ddp_model = ...
>>> _register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... state=post_localSGD.PostLocalSGDState(
... process_group=None,
... subgroup=subgroup,
... start_localSGD_iter=1_000,
... ),
... ddp_comm_hook=post_localSGD.post_localSGD_hook,
... )
>>> # fp16_compress_wrapper combined with other communication hook
>>> ddp_model = ...
>>> _register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_state=powerSGD.PowerSGDState(
... process_group=None,
... matrix_approximation_rank=1,
... start_powerSGD_iter=5000,
... ),
... ddp_comm_hook=powerSGD.powerSGD_hook,
... ddp_comm_wrapper=default.fp16_compress_wrapper,
... )
"""
if ddp_comm_hook is None:
return
# inform mypy that ddp_comm_hook is callable
ddp_comm_hook: Callable = ddp_comm_hook
if ddp_comm_wrapper is not None:
rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
)
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
def _sync_module_states(module: torch.nn.Module) -> None:
"""Taken from https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L675-L682."""
parameters_to_ignore = (
set(module._ddp_params_and_buffers_to_ignore) # type: ignore[arg-type]
if hasattr(module, "_ddp_params_and_buffers_to_ignore")
else set()
)
from torch.distributed.distributed_c10d import _get_default_group
if not _TORCH_GREATER_EQUAL_1_12:
module_states = []
for name, param in module.named_parameters():
if name not in parameters_to_ignore:
module_states.append(param.detach())
for name, buffer in module.named_buffers():
if name not in parameters_to_ignore:
module_states.append(buffer.detach())
if len(module_states) > 0:
torch.distributed._broadcast_coalesced(_get_default_group(), module_states, 250 * 1024 * 1024, 0)
return
from torch.distributed.utils import _sync_module_states as torch_sync_module_states
torch_sync_module_states(
module,
_get_default_group(),
250 * 1024 * 1024,
src=0,
params_and_buffers_to_ignore=parameters_to_ignore,
)
class UnrepeatedDistributedSampler(DistributedSampler):
"""A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches
per process to be off-by-one from each other. This makes this sampler usable for predictions (it's

View File

@ -38,13 +38,12 @@ from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import ReduceOp
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from lightning.pytorch.overrides.distributed import prepare_for_backward
from lightning.pytorch.overrides.distributed import _register_ddp_comm_hook, _sync_module_states, prepare_for_backward
from lightning.pytorch.plugins.precision import PrecisionPlugin
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import TBroadcast
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.distributed import _register_ddp_comm_hook
from lightning.pytorch.utilities.exceptions import _augment_message
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only
from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
@ -153,25 +152,28 @@ class DDPStrategy(ParallelStrategy):
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
assert self.model is not None
self.model = self._layer_sync.apply(self.model)
if trainer_fn == TrainerFn.FITTING and self._layer_sync:
assert self.model is not None
self.model = self._layer_sync.apply(self.model)
self.setup_precision_plugin()
if trainer_fn == TrainerFn.FITTING:
# do not wrap with DDP if not fitting as there's no gradients to reduce
self.configure_ddp()
# set up optimizers after the wrapped module has been moved to the device
self.setup_optimizers(trainer)
_optimizers_to_device(self.optimizers, self.root_device)
if trainer_fn == TrainerFn.FITTING:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
self._enable_model_averaging()
else:
# we need to manually synchronize the module's states since we aren't using the DDP wrapper
assert self.model is not None
_sync_module_states(self.model)
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""

View File

@ -1,144 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities that can be used with distributed training."""
from typing import Any, Callable, Dict, Optional
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
from lightning.fabric.utilities.distributed import _distributed_available as new_distributed_available
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info
def _register_ddp_comm_hook(
model: DistributedDataParallel,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[Callable] = None,
ddp_comm_wrapper: Optional[Callable] = None,
) -> None:
"""Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.
Args:
model:
DDP model
ddp_comm_state:
state is passed to the hook and can be used to maintain
and update any state information that users would like to
maintain as part of the training process. Examples: error
feedback in gradient compression, peers to communicate with
next in GossipGrad etc.
ddp_comm_hook:
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
This callable function is called once the bucket is ready. The
hook can perform whatever processing is needed and return
a Future indicating completion of any async work (ex: allreduce).
If the hook doesn't perform any communication, it can also
just return a completed Future. The Future should hold the
new value of grad bucket's tensors. Once a bucket is ready,
c10d reducer would call this hook and use the tensors returned
by the Future and copy grads to individual parameters.
ddp_comm_wrapper:
communication hook wrapper to support a communication hook such
as FP16 compression as wrapper, which could be combined with
ddp_comm_hook
Examples:
>>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP
... default_hooks as default,
... powerSGD_hook as powerSGD,
... post_localSGD_hook as post_localSGD,
... )
>>>
>>> # fp16_compress_hook for compress gradients
>>> ddp_model = ...
>>> _register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_hook=default.fp16_compress_hook,
... )
>>>
>>> # powerSGD_hook
>>> ddp_model = ...
>>> _register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_state=powerSGD.PowerSGDState(
... process_group=None,
... matrix_approximation_rank=1,
... start_powerSGD_iter=5000,
... ),
... ddp_comm_hook=powerSGD.powerSGD_hook,
... )
>>>
>>> # post_localSGD_hook
>>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP
>>> ddp_model = ...
>>> _register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... state=post_localSGD.PostLocalSGDState(
... process_group=None,
... subgroup=subgroup,
... start_localSGD_iter=1_000,
... ),
... ddp_comm_hook=post_localSGD.post_localSGD_hook,
... )
>>>
>>> # fp16_compress_wrapper combined with other communication hook
>>> ddp_model = ...
>>> _register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_state=powerSGD.PowerSGDState(
... process_group=None,
... matrix_approximation_rank=1,
... start_powerSGD_iter=5000,
... ),
... ddp_comm_hook=powerSGD.powerSGD_hook,
... ddp_comm_wrapper=default.fp16_compress_wrapper,
... )
"""
if ddp_comm_hook is None:
return
# inform mypy that ddp_comm_hook is callable
ddp_comm_hook: Callable = ddp_comm_hook
if ddp_comm_wrapper is not None:
rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
)
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
def _broadcast_object_list(obj: Any, rank: int) -> Any:
objects = [obj if torch.distributed.get_rank() == rank else None]
torch.distributed.broadcast_object_list(objects, src=rank)
return objects[0]
# TODO: Refactor with the Strategy Collectives once finalized.
def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]:
"""This distributed utility collects dictionary state across all processes.
Args:
state: Dictionary containing the state of the current process
Returns:
states: On global rank 0, a dictionary where the primary keys are
the process rank and the values their associated states. Otherwise, returns None.
"""
if not new_distributed_available():
return {0: state}
return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())}

View File

@ -14,11 +14,39 @@
from typing import Iterable
import pytest
import torch
from torch.utils.data import BatchSampler, SequentialSampler
from lightning.fabric.utilities.data import has_len
from lightning.pytorch import seed_everything
from lightning.pytorch import LightningModule, seed_everything, Trainer
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from tests_pytorch.helpers.runif import RunIf
class MyModel(LightningModule):
def setup(self, stage: str) -> None:
self.layer = torch.nn.Linear(1, 1)
weights = self.layer.weight.item(), self.layer.bias.item()
self.rank_0_weights = self.trainer.strategy.broadcast(weights)
def test_step(self, batch, batch_idx):
current = self.layer.weight.item(), self.layer.bias.item()
assert self.rank_0_weights == current
gathered = self.all_gather(current)
# the weights have been synced
assert all(torch.all(t == t[0]) for t in gathered), gathered
@RunIf(standalone=True)
def test_params_synced_during_nonfit():
model = MyModel()
trainer = Trainer(
barebones=True,
devices=2,
accelerator="cpu",
strategy="ddp",
)
trainer.test(model, [0])
@pytest.mark.parametrize("shuffle", [False, True])

View File

@ -1,36 +0,0 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from lightning.pytorch.utilities.distributed import _collect_states_on_rank_zero
from tests_pytorch.core.test_results import spawn_launch
from tests_pytorch.helpers.runif import RunIf
def collect_states_fn(strategy):
rank = strategy.local_rank
state = {"something": torch.tensor([rank])}
collected_state = _collect_states_on_rank_zero(state)
assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}}
@RunIf(min_cuda_gpus=2, skip_windows=True)
def test_collect_states():
"""This test ensures state are properly collected across processes.
This would be used to collect dataloader states as an example.
"""
spawn_launch(collect_states_fn, [torch.device("cuda:0"), torch.device("cuda:1")])