fix mypy typing errors in pytorch_lightning/strategies/single_device.py (#13532)
* fix typing in strategies/single_device.py * Make assert statement more explicit Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
parent
cf189bdc9b
commit
f116c2f72c
|
@ -74,7 +74,6 @@ module = [
|
||||||
"pytorch_lightning.strategies.parallel",
|
"pytorch_lightning.strategies.parallel",
|
||||||
"pytorch_lightning.strategies.sharded",
|
"pytorch_lightning.strategies.sharded",
|
||||||
"pytorch_lightning.strategies.sharded_spawn",
|
"pytorch_lightning.strategies.sharded_spawn",
|
||||||
"pytorch_lightning.strategies.single_device",
|
|
||||||
"pytorch_lightning.strategies.single_tpu",
|
"pytorch_lightning.strategies.single_tpu",
|
||||||
"pytorch_lightning.strategies.tpu_spawn",
|
"pytorch_lightning.strategies.tpu_spawn",
|
||||||
"pytorch_lightning.strategies.strategy",
|
"pytorch_lightning.strategies.strategy",
|
||||||
|
|
|
@ -21,7 +21,7 @@ from torch import Tensor
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||||
from pytorch_lightning.strategies.strategy import Strategy
|
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
|
||||||
from pytorch_lightning.utilities.types import _DEVICE
|
from pytorch_lightning.utilities.types import _DEVICE
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,6 +66,7 @@ class SingleDeviceStrategy(Strategy):
|
||||||
return self._root_device
|
return self._root_device
|
||||||
|
|
||||||
def model_to_device(self) -> None:
|
def model_to_device(self) -> None:
|
||||||
|
assert self.model is not None, "self.model must be set before self.model.to()"
|
||||||
self.model.to(self.root_device)
|
self.model.to(self.root_device)
|
||||||
|
|
||||||
def setup(self, trainer: pl.Trainer) -> None:
|
def setup(self, trainer: pl.Trainer) -> None:
|
||||||
|
@ -76,10 +77,10 @@ class SingleDeviceStrategy(Strategy):
|
||||||
def is_global_zero(self) -> bool:
|
def is_global_zero(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def barrier(self, *args, **kwargs) -> None:
|
def barrier(self, *args: Any, **kwargs: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue