From f116c2f72c0771acc9f8a7a7de4b09dafbc77d58 Mon Sep 17 00:00:00 2001 From: Cyprien Ricque <48893621+Cyprien-Ricque@users.noreply.github.com> Date: Mon, 4 Jul 2022 18:28:41 +0200 Subject: [PATCH] fix mypy typing errors in pytorch_lightning/strategies/single_device.py (#13532) * fix typing in strategies/single_device.py * Make assert statement more explicit Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pyproject.toml | 1 - src/pytorch_lightning/strategies/single_device.py | 7 ++++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 51781d4953..55543c9142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,6 @@ module = [ "pytorch_lightning.strategies.parallel", "pytorch_lightning.strategies.sharded", "pytorch_lightning.strategies.sharded_spawn", - "pytorch_lightning.strategies.single_device", "pytorch_lightning.strategies.single_tpu", "pytorch_lightning.strategies.tpu_spawn", "pytorch_lightning.strategies.strategy", diff --git a/src/pytorch_lightning/strategies/single_device.py b/src/pytorch_lightning/strategies/single_device.py index 82681ad423..cb436fded8 100644 --- a/src/pytorch_lightning/strategies/single_device.py +++ b/src/pytorch_lightning/strategies/single_device.py @@ -21,7 +21,7 @@ from torch import Tensor import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.strategy import Strategy +from pytorch_lightning.strategies.strategy import Strategy, TBroadcast from pytorch_lightning.utilities.types import _DEVICE @@ -66,6 +66,7 @@ class SingleDeviceStrategy(Strategy): return self._root_device def model_to_device(self) -> None: + assert self.model is not None, "self.model must be set before self.model.to()" self.model.to(self.root_device) def setup(self, trainer: pl.Trainer) -> None: @@ -76,10 +77,10 @@ class SingleDeviceStrategy(Strategy): def is_global_zero(self) -> bool: return True - def barrier(self, *args, **kwargs) -> None: + def barrier(self, *args: Any, **kwargs: Any) -> None: pass - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj @classmethod