fix mypy typing errors in pytorch_lightning/strategies/single_device.py (#13532)
* fix typing in strategies/single_device.py * Make assert statement more explicit Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
parent
cf189bdc9b
commit
f116c2f72c
|
@ -74,7 +74,6 @@ module = [
|
|||
"pytorch_lightning.strategies.parallel",
|
||||
"pytorch_lightning.strategies.sharded",
|
||||
"pytorch_lightning.strategies.sharded_spawn",
|
||||
"pytorch_lightning.strategies.single_device",
|
||||
"pytorch_lightning.strategies.single_tpu",
|
||||
"pytorch_lightning.strategies.tpu_spawn",
|
||||
"pytorch_lightning.strategies.strategy",
|
||||
|
|
|
@ -21,7 +21,7 @@ from torch import Tensor
|
|||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
|
||||
from pytorch_lightning.utilities.types import _DEVICE
|
||||
|
||||
|
||||
|
@ -66,6 +66,7 @@ class SingleDeviceStrategy(Strategy):
|
|||
return self._root_device
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
assert self.model is not None, "self.model must be set before self.model.to()"
|
||||
self.model.to(self.root_device)
|
||||
|
||||
def setup(self, trainer: pl.Trainer) -> None:
|
||||
|
@ -76,10 +77,10 @@ class SingleDeviceStrategy(Strategy):
|
|||
def is_global_zero(self) -> bool:
|
||||
return True
|
||||
|
||||
def barrier(self, *args, **kwargs) -> None:
|
||||
def barrier(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue