fix mypy typing errors in pytorch_lightning/strategies/horovod.py (#13570)

This commit is contained in:
Cyprien Ricque 2022-07-14 11:30:37 +02:00 committed by GitHub
parent c1cc112b52
commit e179a58397
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 9 deletions

View File

@ -72,7 +72,6 @@ module = [
"pytorch_lightning.strategies.ddp_spawn",
"pytorch_lightning.strategies.deepspeed",
"pytorch_lightning.strategies.fully_sharded",
"pytorch_lightning.strategies.horovod",
"pytorch_lightning.strategies.ipu",
"pytorch_lightning.strategies.parallel",
"pytorch_lightning.strategies.sharded",

View File

@ -310,6 +310,8 @@ class HiveMindScheduler:
This code ensures that we only step when the HiveMind optimizer reaches the global step.
"""
base_lrs: List[float]
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`

View File

@ -24,12 +24,14 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities.types import _LRScheduler
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
@ -70,11 +72,11 @@ class HorovodStrategy(ParallelStrategy):
return hvd.size()
@property
def root_device(self):
def root_device(self) -> torch.device:
return self.parallel_devices[self.local_rank]
@property
def distributed_sampler_kwargs(self):
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs
@ -95,7 +97,7 @@ class HorovodStrategy(ParallelStrategy):
# no need to setup optimizers
return
def _unpack_lightning_optimizer(opt):
def _unpack_lightning_optimizer(opt: Optimizer) -> Optimizer:
return opt._optimizer if isinstance(opt, LightningOptimizer) else opt
optimizers = self.optimizers
@ -111,8 +113,10 @@ class HorovodStrategy(ParallelStrategy):
lr_scheduler_configs = self.lr_scheduler_configs
for config in lr_scheduler_configs:
scheduler = config.scheduler
assert isinstance(scheduler, _LRScheduler)
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]
assert self.lightning_module is not None
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0)
for optimizer in optimizers:
@ -129,27 +133,33 @@ class HorovodStrategy(ParallelStrategy):
# Synchronization will be performed explicitly following backward()
self._exit_stack.enter_context(optimizer.skip_synchronize())
def barrier(self, *args, **kwargs):
def barrier(self, *args: Any, **kwargs: Any) -> None:
if distributed_available():
self.join()
def broadcast(self, obj: object, src: int = 0) -> object:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
obj = hvd.broadcast_object(obj, src)
return obj
def model_to_device(self):
def model_to_device(self) -> None:
if self.root_device.type == "cuda":
# this can potentially be removed after #8312. Not done due to lack of horovod testing
torch.cuda.set_device(self.root_device)
assert self.model is not None
self.model.to(self.root_device)
def join(self):
def join(self) -> None:
if self.root_device.type == "cuda":
hvd.join(self.local_rank)
else:
hvd.join()
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
def reduce(
self,
tensor: Union[Any, Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Union[Any, Tensor]:
"""Reduces a tensor from several distributed processes to one aggregated tensor.
Args:
@ -196,6 +206,7 @@ class HorovodStrategy(ParallelStrategy):
self, optimizers: List[Optimizer], accumulate_grad_batches: int
) -> List["hvd.DistributedOptimizer"]:
"""Wraps optimizers to perform gradient aggregation via allreduce."""
assert self.lightning_module is not None
return [
hvd.DistributedOptimizer(
opt,

View File

@ -65,6 +65,7 @@ class _Stateful(Protocol):
@runtime_checkable
class _LRScheduler(_Stateful, Protocol):
optimizer: Optimizer
base_lrs: List[float]
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
...