Fix mypy errors attributed to `pytorch_lightning.utilities.distributed` (#13678)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
e23756b15d
commit
2845e7565d
|
@ -84,7 +84,6 @@ module = [
|
|||
"pytorch_lightning.tuner.batch_size_scaling",
|
||||
"pytorch_lightning.utilities.auto_restart",
|
||||
"pytorch_lightning.utilities.data",
|
||||
"pytorch_lightning.utilities.distributed",
|
||||
"pytorch_lightning.utilities.meta",
|
||||
]
|
||||
ignore_errors = "True"
|
||||
|
|
|
@ -145,6 +145,7 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un
|
|||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
op: Optional[ReduceOp]
|
||||
if isinstance(reduce_op, str):
|
||||
if reduce_op.lower() in ("avg", "mean"):
|
||||
op = ReduceOp.SUM
|
||||
|
@ -174,7 +175,7 @@ def sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Un
|
|||
|
||||
class AllGatherGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
def forward( # type: ignore[override]
|
||||
ctx: Any,
|
||||
tensor: Tensor,
|
||||
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
|
||||
|
@ -317,7 +318,7 @@ def register_ddp_comm_hook(
|
|||
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
|
||||
|
||||
new_rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
|
||||
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
|
||||
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook) # type: ignore[operator]
|
||||
|
||||
|
||||
def tpu_distributed() -> bool:
|
||||
|
|
Loading…
Reference in New Issue