From 2845e7565dbe6b765ae32870e7d2bc456529c30a Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Sat, 16 Jul 2022 20:08:03 +0200 Subject: [PATCH] Fix mypy errors attributed to `pytorch_lightning.utilities.distributed` (#13678) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pyproject.toml | 1 - src/pytorch_lightning/utilities/distributed.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0ddadd2b29..989e63122f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,6 @@ module = [ "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", - "pytorch_lightning.utilities.distributed", "pytorch_lightning.utilities.meta", ] ignore_errors = "True" diff --git a/src/pytorch_lightning/utilities/distributed.py b/src/pytorch_lightning/utilities/distributed.py index 9bc6389ae6..bc7ed3deba 100644 --- a/src/pytorch_lightning/utilities/distributed.py +++ b/src/pytorch_lightning/utilities/distributed.py @@ -145,6 +145,7 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un if group is None: group = torch.distributed.group.WORLD + op: Optional[ReduceOp] if isinstance(reduce_op, str): if reduce_op.lower() in ("avg", "mean"): op = ReduceOp.SUM @@ -174,7 +175,7 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward( + def forward( # type: ignore[override] ctx: Any, tensor: Tensor, group: Optional["torch.distributed.ProcessGroup"] = group.WORLD, @@ -317,7 +318,7 @@ def register_ddp_comm_hook( ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook) new_rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.") - model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) + model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) # type: ignore[operator] def tpu_distributed() -> bool: