fix mypy typing errors in pytorch_lightning/strategies/horovod.py (#13570)
This commit is contained in:
parent
c1cc112b52
commit
e179a58397
|
@ -72,7 +72,6 @@ module = [
|
|||
"pytorch_lightning.strategies.ddp_spawn",
|
||||
"pytorch_lightning.strategies.deepspeed",
|
||||
"pytorch_lightning.strategies.fully_sharded",
|
||||
"pytorch_lightning.strategies.horovod",
|
||||
"pytorch_lightning.strategies.ipu",
|
||||
"pytorch_lightning.strategies.parallel",
|
||||
"pytorch_lightning.strategies.sharded",
|
||||
|
|
|
@ -310,6 +310,8 @@ class HiveMindScheduler:
|
|||
This code ensures that we only step when the HiveMind optimizer reaches the global step.
|
||||
"""
|
||||
|
||||
base_lrs: List[float]
|
||||
|
||||
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
|
||||
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
|
||||
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`
|
||||
|
|
|
@ -24,12 +24,14 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
|
|||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.distributed import group as dist_group
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
from pytorch_lightning.utilities.types import _LRScheduler
|
||||
|
||||
if _HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd
|
||||
|
@ -70,11 +72,11 @@ class HorovodStrategy(ParallelStrategy):
|
|||
return hvd.size()
|
||||
|
||||
@property
|
||||
def root_device(self):
|
||||
def root_device(self) -> torch.device:
|
||||
return self.parallel_devices[self.local_rank]
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self):
|
||||
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
|
||||
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
|
||||
return distributed_sampler_kwargs
|
||||
|
||||
|
@ -95,7 +97,7 @@ class HorovodStrategy(ParallelStrategy):
|
|||
# no need to setup optimizers
|
||||
return
|
||||
|
||||
def _unpack_lightning_optimizer(opt):
|
||||
def _unpack_lightning_optimizer(opt: Optimizer) -> Optimizer:
|
||||
return opt._optimizer if isinstance(opt, LightningOptimizer) else opt
|
||||
|
||||
optimizers = self.optimizers
|
||||
|
@ -111,8 +113,10 @@ class HorovodStrategy(ParallelStrategy):
|
|||
lr_scheduler_configs = self.lr_scheduler_configs
|
||||
for config in lr_scheduler_configs:
|
||||
scheduler = config.scheduler
|
||||
assert isinstance(scheduler, _LRScheduler)
|
||||
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]
|
||||
|
||||
assert self.lightning_module is not None
|
||||
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
|
||||
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0)
|
||||
for optimizer in optimizers:
|
||||
|
@ -129,27 +133,33 @@ class HorovodStrategy(ParallelStrategy):
|
|||
# Synchronization will be performed explicitly following backward()
|
||||
self._exit_stack.enter_context(optimizer.skip_synchronize())
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
def barrier(self, *args: Any, **kwargs: Any) -> None:
|
||||
if distributed_available():
|
||||
self.join()
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
|
||||
obj = hvd.broadcast_object(obj, src)
|
||||
return obj
|
||||
|
||||
def model_to_device(self):
|
||||
def model_to_device(self) -> None:
|
||||
if self.root_device.type == "cuda":
|
||||
# this can potentially be removed after #8312. Not done due to lack of horovod testing
|
||||
torch.cuda.set_device(self.root_device)
|
||||
assert self.model is not None
|
||||
self.model.to(self.root_device)
|
||||
|
||||
def join(self):
|
||||
def join(self) -> None:
|
||||
if self.root_device.type == "cuda":
|
||||
hvd.join(self.local_rank)
|
||||
else:
|
||||
hvd.join()
|
||||
|
||||
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
|
||||
def reduce(
|
||||
self,
|
||||
tensor: Union[Any, Tensor],
|
||||
group: Optional[Any] = None,
|
||||
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
|
||||
) -> Union[Any, Tensor]:
|
||||
"""Reduces a tensor from several distributed processes to one aggregated tensor.
|
||||
|
||||
Args:
|
||||
|
@ -196,6 +206,7 @@ class HorovodStrategy(ParallelStrategy):
|
|||
self, optimizers: List[Optimizer], accumulate_grad_batches: int
|
||||
) -> List["hvd.DistributedOptimizer"]:
|
||||
"""Wraps optimizers to perform gradient aggregation via allreduce."""
|
||||
assert self.lightning_module is not None
|
||||
return [
|
||||
hvd.DistributedOptimizer(
|
||||
opt,
|
||||
|
|
|
@ -65,6 +65,7 @@ class _Stateful(Protocol):
|
|||
@runtime_checkable
|
||||
class _LRScheduler(_Stateful, Protocol):
|
||||
optimizer: Optimizer
|
||||
base_lrs: List[float]
|
||||
|
||||
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
|
||||
...
|
||||
|
|
Loading…
Reference in New Issue