fix mypy typing errors in pytorch_lightning/strategies/single_device.py (#13532)

* fix typing in strategies/single_device.py
* Make assert statement more explicit

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
Cyprien Ricque 2022-07-04 18:28:41 +02:00 committed by GitHub
parent cf189bdc9b
commit f116c2f72c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -74,7 +74,6 @@ module = [
"pytorch_lightning.strategies.parallel",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.strategies.single_device",
"pytorch_lightning.strategies.single_tpu",
"pytorch_lightning.strategies.tpu_spawn",
"pytorch_lightning.strategies.strategy",

View File

@ -21,7 +21,7 @@ from torch import Tensor
import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
from pytorch_lightning.utilities.types import _DEVICE
@ -66,6 +66,7 @@ class SingleDeviceStrategy(Strategy):
return self._root_device
def model_to_device(self) -> None:
assert self.model is not None, "self.model must be set before self.model.to()"
self.model.to(self.root_device)
def setup(self, trainer: pl.Trainer) -> None:
@ -76,10 +77,10 @@ class SingleDeviceStrategy(Strategy):
def is_global_zero(self) -> bool:
return True
def barrier(self, *args, **kwargs) -> None:
def barrier(self, *args: Any, **kwargs: Any) -> None:
pass
def broadcast(self, obj: object, src: int = 0) -> object:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return obj
@classmethod