Remove the deprecated `AllGatherGrad` class (#16360)

This commit is contained in:
Carlos Mocholí 2023-01-16 16:31:30 +01:00 committed by Luca Antiga
parent cf0952b25e
commit 5d648e4d77
4 changed files with 3 additions and 47 deletions

View File

@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed the `Trainer(ipus=...)` argument
* Removed the `Trainer(num_processes=...)` argument
- Removed the deprecated `pytorch_lightning.utilities.AllGatherGrad` class ([#16360](https://github.com/Lightning-AI/lightning/pull/16360))
- Removed the deprecated `resume_from_checkpoint` Trainer argument ([#16167](https://github.com/Lightning-AI/lightning/pull/16167))
- Removed the deprecated automatic GPU selection ([#16184](https://github.com/Lightning-AI/lightning/pull/16184))

View File

@ -17,7 +17,6 @@ import numpy
from lightning_fabric.utilities import LightningEnum # noqa: F401
from lightning_fabric.utilities import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.distributed import AllGatherGrad # noqa: F401
from pytorch_lightning.utilities.enums import GradClipAlgorithmType # noqa: F401
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
from pytorch_lightning.utilities.imports import ( # noqa: F401

View File

@ -12,10 +12,9 @@
# limitations under the License.
"""Utilities that can be used with distributed training."""
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional
import torch
from torch import Tensor
from torch.nn.parallel.distributed import DistributedDataParallel
from lightning_fabric.utilities.distributed import _all_gather_ddp_if_available as new_all_gather_ddp_if_available
@ -177,40 +176,6 @@ def gather_all_tensors(*args: Any, **kwargs: Any) -> Any:
return new_gather_all_tensors(*args, **kwargs)
class AllGatherGrad(torch.autograd.Function):
"""Gathers tensors from the whole group and stacks them.
This implementation is copied from PyTorch.
.. deprecated:: v1.8.0
This function has been deprecated in v1.8.0 in favor of :func:`torch.distributed.nn.functional.all_gather` and
will be removed in v2.0.0.
"""
@staticmethod
def forward( # type: ignore[override]
ctx: Any,
tensor: Tensor,
group: Optional["torch.distributed.ProcessGroup"] = None,
) -> Tensor:
rank_zero_deprecation(
"`AllGatherGrad` has been deprecated in v1.8.0 and will be removed in v2.0.0."
" Use `torch.distributed.nn.functional.all_gather` instead.",
stacklevel=6,
)
ctx.group = group
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
gathered_tensor = torch.stack(gathered_tensor, dim=0)
return gathered_tensor
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]:
grad_output = torch.cat(grad_output)
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
return grad_output[torch.distributed.get_rank()], None
def get_default_process_group_backend_for_device(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"`pytorch_lightning.utilities.distributed.get_default_process_group_backend_for_device` has been deprecated"

View File

@ -315,16 +315,6 @@ def test_tuning_trainer_property():
trainer.tuning = True
@RunIf(skip_windows=True)
def test_v1_8_0_deprecated_all_gather_grad():
tensor1 = torch.ones(1, requires_grad=True)
with mock.patch("torch.distributed.all_gather"), mock.patch("torch.distributed.get_world_size", return_value=1):
from pytorch_lightning.utilities import AllGatherGrad
with pytest.deprecated_call(match="`AllGatherGrad` has been deprecated in v1.8"):
AllGatherGrad.apply(tensor1)
def test_v1_8_1_deprecated_rank_zero_only():
from pytorch_lightning.utilities.distributed import rank_zero_only