Remove the deprecated `AllGatherGrad` class (#16360)
This commit is contained in:
parent
cf0952b25e
commit
5d648e4d77
|
@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Removed the `Trainer(ipus=...)` argument
|
||||
* Removed the `Trainer(num_processes=...)` argument
|
||||
|
||||
- Removed the deprecated `pytorch_lightning.utilities.AllGatherGrad` class ([#16360](https://github.com/Lightning-AI/lightning/pull/16360))
|
||||
|
||||
- Removed the deprecated `resume_from_checkpoint` Trainer argument ([#16167](https://github.com/Lightning-AI/lightning/pull/16167))
|
||||
|
||||
- Removed the deprecated automatic GPU selection ([#16184](https://github.com/Lightning-AI/lightning/pull/16184))
|
||||
|
|
|
@ -17,7 +17,6 @@ import numpy
|
|||
|
||||
from lightning_fabric.utilities import LightningEnum # noqa: F401
|
||||
from lightning_fabric.utilities import move_data_to_device # noqa: F401
|
||||
from pytorch_lightning.utilities.distributed import AllGatherGrad # noqa: F401
|
||||
from pytorch_lightning.utilities.enums import GradClipAlgorithmType # noqa: F401
|
||||
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
|
||||
from pytorch_lightning.utilities.imports import ( # noqa: F401
|
||||
|
|
|
@ -12,10 +12,9 @@
|
|||
# limitations under the License.
|
||||
"""Utilities that can be used with distributed training."""
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from lightning_fabric.utilities.distributed import _all_gather_ddp_if_available as new_all_gather_ddp_if_available
|
||||
|
@ -177,40 +176,6 @@ def gather_all_tensors(*args: Any, **kwargs: Any) -> Any:
|
|||
return new_gather_all_tensors(*args, **kwargs)
|
||||
|
||||
|
||||
class AllGatherGrad(torch.autograd.Function):
|
||||
"""Gathers tensors from the whole group and stacks them.
|
||||
|
||||
This implementation is copied from PyTorch.
|
||||
|
||||
.. deprecated:: v1.8.0
|
||||
This function has been deprecated in v1.8.0 in favor of :func:`torch.distributed.nn.functional.all_gather` and
|
||||
will be removed in v2.0.0.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward( # type: ignore[override]
|
||||
ctx: Any,
|
||||
tensor: Tensor,
|
||||
group: Optional["torch.distributed.ProcessGroup"] = None,
|
||||
) -> Tensor:
|
||||
rank_zero_deprecation(
|
||||
"`AllGatherGrad` has been deprecated in v1.8.0 and will be removed in v2.0.0."
|
||||
" Use `torch.distributed.nn.functional.all_gather` instead.",
|
||||
stacklevel=6,
|
||||
)
|
||||
ctx.group = group
|
||||
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
|
||||
gathered_tensor = torch.stack(gathered_tensor, dim=0)
|
||||
return gathered_tensor
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]:
|
||||
grad_output = torch.cat(grad_output)
|
||||
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
|
||||
return grad_output[torch.distributed.get_rank()], None
|
||||
|
||||
|
||||
def get_default_process_group_backend_for_device(*args: Any, **kwargs: Any) -> Any:
|
||||
rank_zero_deprecation(
|
||||
"`pytorch_lightning.utilities.distributed.get_default_process_group_backend_for_device` has been deprecated"
|
||||
|
|
|
@ -315,16 +315,6 @@ def test_tuning_trainer_property():
|
|||
trainer.tuning = True
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
def test_v1_8_0_deprecated_all_gather_grad():
|
||||
tensor1 = torch.ones(1, requires_grad=True)
|
||||
with mock.patch("torch.distributed.all_gather"), mock.patch("torch.distributed.get_world_size", return_value=1):
|
||||
from pytorch_lightning.utilities import AllGatherGrad
|
||||
|
||||
with pytest.deprecated_call(match="`AllGatherGrad` has been deprecated in v1.8"):
|
||||
AllGatherGrad.apply(tensor1)
|
||||
|
||||
|
||||
def test_v1_8_1_deprecated_rank_zero_only():
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
|
|
Loading…
Reference in New Issue