diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 497ac6d69f..3be22f3b54 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -243,7 +243,6 @@ utilities combined_loader data deepspeed - distributed memory model_summary parsing diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index 75657b060f..1531e80e22 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -12,14 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union +from typing import Any, Callable, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union import torch from torch import Tensor -from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import BatchSampler, DistributedSampler, Sampler from lightning.fabric.utilities.distributed import _DatasetSamplerWrapper +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info def _find_tensors( @@ -37,7 +39,7 @@ def _find_tensors( # In manual_optimization, we need to call reducer prepare_for_backward. # Note: Keep track of PyTorch DDP and update if there is a change -# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 +# https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L1163-L1178 def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None: # `prepare_for_backward` is `DistributedDataParallel` specific. if torch.is_grad_enabled() and model.require_backward_grad_sync: @@ -47,7 +49,7 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None: # because we need to figure out which parameters were used during # this forward pass, to ensure we short circuit reduction for any # unused parameters. Only if `find_unused_parameters` is set. - args = list(_find_tensors(output)) if model.find_unused_parameters else [] + args = list(_find_tensors(output)) if model.find_unused_parameters and not model.static_graph else [] reducer = cast(torch._C._distributed_c10d.Reducer, model.reducer) reducer._rebuild_buckets() # avoids "INTERNAL ASSERT FAILED" with `find_unused_parameters=False` reducer.prepare_for_backward(args) @@ -55,6 +57,135 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None: model.require_forward_param_sync = False +def _register_ddp_comm_hook( + model: DistributedDataParallel, + ddp_comm_state: Optional[object] = None, + ddp_comm_hook: Optional[Callable] = None, + ddp_comm_wrapper: Optional[Callable] = None, +) -> None: + """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html. + + Args: + model: + DDP model + ddp_comm_state: + state is passed to the hook and can be used to maintain + and update any state information that users would like to + maintain as part of the training process. Examples: error + feedback in gradient compression, peers to communicate with + next in GossipGrad etc. + ddp_comm_hook: + hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future + + This callable function is called once the bucket is ready. The + hook can perform whatever processing is needed and return + a Future indicating completion of any async work (ex: allreduce). + If the hook doesn't perform any communication, it can also + just return a completed Future. The Future should hold the + new value of grad bucket's tensors. Once a bucket is ready, + c10d reducer would call this hook and use the tensors returned + by the Future and copy grads to individual parameters. + ddp_comm_wrapper: + communication hook wrapper to support a communication hook such + as FP16 compression as wrapper, which could be combined with + ddp_comm_hook + + Examples: + + >>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP + ... default_hooks as default, + ... powerSGD_hook as powerSGD, + ... post_localSGD_hook as post_localSGD, + ... ) + >>> # fp16_compress_hook for compress gradients + >>> ddp_model = ... + >>> _register_ddp_comm_hook( # doctest: +SKIP + ... model=ddp_model, + ... ddp_comm_hook=default.fp16_compress_hook, + ... ) + >>> # powerSGD_hook + >>> ddp_model = ... + >>> _register_ddp_comm_hook( # doctest: +SKIP + ... model=ddp_model, + ... ddp_comm_state=powerSGD.PowerSGDState( + ... process_group=None, + ... matrix_approximation_rank=1, + ... start_powerSGD_iter=5000, + ... ), + ... ddp_comm_hook=powerSGD.powerSGD_hook, + ... ) + >>> # post_localSGD_hook + >>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP + >>> ddp_model = ... + >>> _register_ddp_comm_hook( # doctest: +SKIP + ... model=ddp_model, + ... state=post_localSGD.PostLocalSGDState( + ... process_group=None, + ... subgroup=subgroup, + ... start_localSGD_iter=1_000, + ... ), + ... ddp_comm_hook=post_localSGD.post_localSGD_hook, + ... ) + >>> # fp16_compress_wrapper combined with other communication hook + >>> ddp_model = ... + >>> _register_ddp_comm_hook( # doctest: +SKIP + ... model=ddp_model, + ... ddp_comm_state=powerSGD.PowerSGDState( + ... process_group=None, + ... matrix_approximation_rank=1, + ... start_powerSGD_iter=5000, + ... ), + ... ddp_comm_hook=powerSGD.powerSGD_hook, + ... ddp_comm_wrapper=default.fp16_compress_wrapper, + ... ) + """ + if ddp_comm_hook is None: + return + # inform mypy that ddp_comm_hook is callable + ddp_comm_hook: Callable = ddp_comm_hook + + if ddp_comm_wrapper is not None: + rank_zero_info( + f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." + ) + ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) + + rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") + model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) + + +def _sync_module_states(module: torch.nn.Module) -> None: + """Taken from https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L675-L682.""" + parameters_to_ignore = ( + set(module._ddp_params_and_buffers_to_ignore) # type: ignore[arg-type] + if hasattr(module, "_ddp_params_and_buffers_to_ignore") + else set() + ) + from torch.distributed.distributed_c10d import _get_default_group + + if not _TORCH_GREATER_EQUAL_1_12: + module_states = [] + for name, param in module.named_parameters(): + if name not in parameters_to_ignore: + module_states.append(param.detach()) + for name, buffer in module.named_buffers(): + if name not in parameters_to_ignore: + module_states.append(buffer.detach()) + if len(module_states) > 0: + torch.distributed._broadcast_coalesced(_get_default_group(), module_states, 250 * 1024 * 1024, 0) + return + + from torch.distributed.utils import _sync_module_states as torch_sync_module_states + + torch_sync_module_states( + module, + _get_default_group(), + 250 * 1024 * 1024, + src=0, + params_and_buffers_to_ignore=parameters_to_ignore, + ) + + class UnrepeatedDistributedSampler(DistributedSampler): """A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches per process to be off-by-one from each other. This makes this sampler usable for predictions (it's diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 5b783ccdc3..dffb364449 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -38,13 +38,12 @@ from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import ReduceOp from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase -from lightning.pytorch.overrides.distributed import prepare_for_backward +from lightning.pytorch.overrides.distributed import _register_ddp_comm_hook, _sync_module_states, prepare_for_backward from lightning.pytorch.plugins.precision import PrecisionPlugin from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.strategy import TBroadcast from lightning.pytorch.trainer.states import TrainerFn -from lightning.pytorch.utilities.distributed import _register_ddp_comm_hook from lightning.pytorch.utilities.exceptions import _augment_message from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep @@ -153,25 +152,28 @@ class DDPStrategy(ParallelStrategy): # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = trainer.state.fn - if trainer_fn == TrainerFn.FITTING: - if self._layer_sync: - assert self.model is not None - self.model = self._layer_sync.apply(self.model) + if trainer_fn == TrainerFn.FITTING and self._layer_sync: + assert self.model is not None + self.model = self._layer_sync.apply(self.model) self.setup_precision_plugin() if trainer_fn == TrainerFn.FITTING: + # do not wrap with DDP if not fitting as there's no gradients to reduce self.configure_ddp() # set up optimizers after the wrapped module has been moved to the device self.setup_optimizers(trainer) _optimizers_to_device(self.optimizers, self.root_device) - if trainer_fn == TrainerFn.FITTING: import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): self._enable_model_averaging() + else: + # we need to manually synchronize the module's states since we aren't using the DDP wrapper + assert self.model is not None + _sync_module_states(self.model) def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" diff --git a/src/lightning/pytorch/utilities/distributed.py b/src/lightning/pytorch/utilities/distributed.py deleted file mode 100644 index 2e7afffb7b..0000000000 --- a/src/lightning/pytorch/utilities/distributed.py +++ /dev/null @@ -1,144 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities that can be used with distributed training.""" - -from typing import Any, Callable, Dict, Optional - -import torch -from torch.nn.parallel.distributed import DistributedDataParallel - -from lightning.fabric.utilities.distributed import _distributed_available as new_distributed_available -from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info - - -def _register_ddp_comm_hook( - model: DistributedDataParallel, - ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[Callable] = None, - ddp_comm_wrapper: Optional[Callable] = None, -) -> None: - """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html. - - Args: - model: - DDP model - ddp_comm_state: - state is passed to the hook and can be used to maintain - and update any state information that users would like to - maintain as part of the training process. Examples: error - feedback in gradient compression, peers to communicate with - next in GossipGrad etc. - ddp_comm_hook: - hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future - - This callable function is called once the bucket is ready. The - hook can perform whatever processing is needed and return - a Future indicating completion of any async work (ex: allreduce). - If the hook doesn't perform any communication, it can also - just return a completed Future. The Future should hold the - new value of grad bucket's tensors. Once a bucket is ready, - c10d reducer would call this hook and use the tensors returned - by the Future and copy grads to individual parameters. - ddp_comm_wrapper: - communication hook wrapper to support a communication hook such - as FP16 compression as wrapper, which could be combined with - ddp_comm_hook - - Examples: - - >>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP - ... default_hooks as default, - ... powerSGD_hook as powerSGD, - ... post_localSGD_hook as post_localSGD, - ... ) - >>> - >>> # fp16_compress_hook for compress gradients - >>> ddp_model = ... - >>> _register_ddp_comm_hook( # doctest: +SKIP - ... model=ddp_model, - ... ddp_comm_hook=default.fp16_compress_hook, - ... ) - >>> - >>> # powerSGD_hook - >>> ddp_model = ... - >>> _register_ddp_comm_hook( # doctest: +SKIP - ... model=ddp_model, - ... ddp_comm_state=powerSGD.PowerSGDState( - ... process_group=None, - ... matrix_approximation_rank=1, - ... start_powerSGD_iter=5000, - ... ), - ... ddp_comm_hook=powerSGD.powerSGD_hook, - ... ) - >>> - >>> # post_localSGD_hook - >>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP - >>> ddp_model = ... - >>> _register_ddp_comm_hook( # doctest: +SKIP - ... model=ddp_model, - ... state=post_localSGD.PostLocalSGDState( - ... process_group=None, - ... subgroup=subgroup, - ... start_localSGD_iter=1_000, - ... ), - ... ddp_comm_hook=post_localSGD.post_localSGD_hook, - ... ) - >>> - >>> # fp16_compress_wrapper combined with other communication hook - >>> ddp_model = ... - >>> _register_ddp_comm_hook( # doctest: +SKIP - ... model=ddp_model, - ... ddp_comm_state=powerSGD.PowerSGDState( - ... process_group=None, - ... matrix_approximation_rank=1, - ... start_powerSGD_iter=5000, - ... ), - ... ddp_comm_hook=powerSGD.powerSGD_hook, - ... ddp_comm_wrapper=default.fp16_compress_wrapper, - ... ) - """ - if ddp_comm_hook is None: - return - # inform mypy that ddp_comm_hook is callable - ddp_comm_hook: Callable = ddp_comm_hook - - if ddp_comm_wrapper is not None: - rank_zero_info( - f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." - ) - ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) - - rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") - model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) - - -def _broadcast_object_list(obj: Any, rank: int) -> Any: - objects = [obj if torch.distributed.get_rank() == rank else None] - torch.distributed.broadcast_object_list(objects, src=rank) - return objects[0] - - -# TODO: Refactor with the Strategy Collectives once finalized. -def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]: - """This distributed utility collects dictionary state across all processes. - - Args: - state: Dictionary containing the state of the current process - - Returns: - states: On global rank 0, a dictionary where the primary keys are - the process rank and the values their associated states. Otherwise, returns None. - """ - if not new_distributed_available(): - return {0: state} - return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())} diff --git a/tests/tests_pytorch/overrides/test_distributed.py b/tests/tests_pytorch/overrides/test_distributed.py index 4a1057b1b0..1189ca9c23 100644 --- a/tests/tests_pytorch/overrides/test_distributed.py +++ b/tests/tests_pytorch/overrides/test_distributed.py @@ -14,11 +14,39 @@ from typing import Iterable import pytest +import torch from torch.utils.data import BatchSampler, SequentialSampler from lightning.fabric.utilities.data import has_len -from lightning.pytorch import seed_everything +from lightning.pytorch import LightningModule, seed_everything, Trainer from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper, UnrepeatedDistributedSampler +from tests_pytorch.helpers.runif import RunIf + + +class MyModel(LightningModule): + def setup(self, stage: str) -> None: + self.layer = torch.nn.Linear(1, 1) + weights = self.layer.weight.item(), self.layer.bias.item() + self.rank_0_weights = self.trainer.strategy.broadcast(weights) + + def test_step(self, batch, batch_idx): + current = self.layer.weight.item(), self.layer.bias.item() + assert self.rank_0_weights == current + gathered = self.all_gather(current) + # the weights have been synced + assert all(torch.all(t == t[0]) for t in gathered), gathered + + +@RunIf(standalone=True) +def test_params_synced_during_nonfit(): + model = MyModel() + trainer = Trainer( + barebones=True, + devices=2, + accelerator="cpu", + strategy="ddp", + ) + trainer.test(model, [0]) @pytest.mark.parametrize("shuffle", [False, True]) diff --git a/tests/tests_pytorch/utilities/test_distributed.py b/tests/tests_pytorch/utilities/test_distributed.py deleted file mode 100644 index 15667cae50..0000000000 --- a/tests/tests_pytorch/utilities/test_distributed.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.distributed - -from lightning.pytorch.utilities.distributed import _collect_states_on_rank_zero -from tests_pytorch.core.test_results import spawn_launch -from tests_pytorch.helpers.runif import RunIf - - -def collect_states_fn(strategy): - rank = strategy.local_rank - state = {"something": torch.tensor([rank])} - collected_state = _collect_states_on_rank_zero(state) - assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} - - -@RunIf(min_cuda_gpus=2, skip_windows=True) -def test_collect_states(): - """This test ensures state are properly collected across processes. - - This would be used to collect dataloader states as an example. - """ - spawn_launch(collect_states_fn, [torch.device("cuda:0"), torch.device("cuda:1")])