Fix mypy errors attributed to `pytorch_lightning.utilities.distributed` (#13678)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Krishna Kalyan 2022-07-16 20:08:03 +02:00 committed by GitHub
parent e23756b15d
commit 2845e7565d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -84,7 +84,6 @@ module = [
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
"pytorch_lightning.utilities.distributed",
"pytorch_lightning.utilities.meta",
]
ignore_errors = "True"

View File

@ -145,6 +145,7 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un
if group is None:
group = torch.distributed.group.WORLD
op: Optional[ReduceOp]
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
op = ReduceOp.SUM
@ -174,7 +175,7 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un
class AllGatherGrad(torch.autograd.Function):
@staticmethod
def forward(
def forward( # type: ignore[override]
ctx: Any,
tensor: Tensor,
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
@ -317,7 +318,7 @@ def register_ddp_comm_hook(
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
new_rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) # type: ignore[operator]
def tpu_distributed() -> bool: