From e179a58397d55a0c906a3a9172f5d7c8e947fcfc Mon Sep 17 00:00:00 2001 From: Cyprien Ricque <48893621+Cyprien-Ricque@users.noreply.github.com> Date: Thu, 14 Jul 2022 11:30:37 +0200 Subject: [PATCH] fix mypy typing errors in pytorch_lightning/strategies/horovod.py (#13570) --- pyproject.toml | 1 - src/pytorch_lightning/strategies/hivemind.py | 2 ++ src/pytorch_lightning/strategies/horovod.py | 27 ++++++++++++++------ src/pytorch_lightning/utilities/types.py | 1 + 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 898699875e..5d7a4511f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,6 @@ module = [ "pytorch_lightning.strategies.ddp_spawn", "pytorch_lightning.strategies.deepspeed", "pytorch_lightning.strategies.fully_sharded", - "pytorch_lightning.strategies.horovod", "pytorch_lightning.strategies.ipu", "pytorch_lightning.strategies.parallel", "pytorch_lightning.strategies.sharded", diff --git a/src/pytorch_lightning/strategies/hivemind.py b/src/pytorch_lightning/strategies/hivemind.py index 34e2f40b2e..b274856bb6 100644 --- a/src/pytorch_lightning/strategies/hivemind.py +++ b/src/pytorch_lightning/strategies/hivemind.py @@ -310,6 +310,8 @@ class HiveMindScheduler: This code ensures that we only step when the HiveMind optimizer reaches the global step. """ + base_lrs: List[float] + def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None: # copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has # implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler` diff --git a/src/pytorch_lightning/strategies/horovod.py b/src/pytorch_lightning/strategies/horovod.py index 40fdd6f112..19075cbbb0 100644 --- a/src/pytorch_lightning/strategies/horovod.py +++ b/src/pytorch_lightning/strategies/horovod.py @@ -24,12 +24,14 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy +from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as dist_group from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.utilities.types import _LRScheduler if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -70,11 +72,11 @@ class HorovodStrategy(ParallelStrategy): return hvd.size() @property - def root_device(self): + def root_device(self) -> torch.device: return self.parallel_devices[self.local_rank] @property - def distributed_sampler_kwargs(self): + def distributed_sampler_kwargs(self) -> Dict[str, Any]: distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs @@ -95,7 +97,7 @@ class HorovodStrategy(ParallelStrategy): # no need to setup optimizers return - def _unpack_lightning_optimizer(opt): + def _unpack_lightning_optimizer(opt: Optimizer) -> Optimizer: return opt._optimizer if isinstance(opt, LightningOptimizer) else opt optimizers = self.optimizers @@ -111,8 +113,10 @@ class HorovodStrategy(ParallelStrategy): lr_scheduler_configs = self.lr_scheduler_configs for config in lr_scheduler_configs: scheduler = config.scheduler + assert isinstance(scheduler, _LRScheduler) scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] + assert self.lightning_module is not None # Horovod: broadcast parameters & optimizer state to ensure consistent initialization hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0) for optimizer in optimizers: @@ -129,27 +133,33 @@ class HorovodStrategy(ParallelStrategy): # Synchronization will be performed explicitly following backward() self._exit_stack.enter_context(optimizer.skip_synchronize()) - def barrier(self, *args, **kwargs): + def barrier(self, *args: Any, **kwargs: Any) -> None: if distributed_available(): self.join() - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: obj = hvd.broadcast_object(obj, src) return obj - def model_to_device(self): + def model_to_device(self) -> None: if self.root_device.type == "cuda": # this can potentially be removed after #8312. Not done due to lack of horovod testing torch.cuda.set_device(self.root_device) + assert self.model is not None self.model.to(self.root_device) - def join(self): + def join(self) -> None: if self.root_device.type == "cuda": hvd.join(self.local_rank) else: hvd.join() - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + def reduce( + self, + tensor: Union[Any, Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Union[Any, Tensor]: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: @@ -196,6 +206,7 @@ class HorovodStrategy(ParallelStrategy): self, optimizers: List[Optimizer], accumulate_grad_batches: int ) -> List["hvd.DistributedOptimizer"]: """Wraps optimizers to perform gradient aggregation via allreduce.""" + assert self.lightning_module is not None return [ hvd.DistributedOptimizer( opt, diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 0b10b5eebc..5010c6ff05 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -65,6 +65,7 @@ class _Stateful(Protocol): @runtime_checkable class _LRScheduler(_Stateful, Protocol): optimizer: Optimizer + base_lrs: List[float] def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None: ...