fix mypy typing errors in pytorch_lightning/strategies/single_tpu.py (#13534)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
Cyprien Ricque 2022-07-05 09:27:27 +02:00 committed by GitHub
parent 03cbca430d
commit 61473c2290
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 10 deletions

View File

@ -72,7 +72,6 @@ module = [
"pytorch_lightning.strategies.parallel",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.strategies.single_tpu",
"pytorch_lightning.strategies.tpu_spawn",
"pytorch_lightning.strategies.strategy",
"pytorch_lightning.profilers.advanced",

View File

@ -19,7 +19,6 @@ from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
from pytorch_lightning.utilities.model_helpers import is_overridden
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
@ -55,13 +54,10 @@ class SingleTPUStrategy(SingleDeviceStrategy):
return False
def setup(self, trainer: "pl.Trainer") -> None:
assert self.model, "self.model must be set before find_shared_parameters(self.model)"
shared_params = find_shared_parameters(self.model)
self.model_to_device()
if is_overridden("on_post_move_to_device", self.lightning_module):
self.model.on_post_move_to_device()
else:
set_shared_parameters(self.model, shared_params)
set_shared_parameters(self.model, shared_params)
super().setup(trainer)
if self.debug:
@ -70,9 +66,6 @@ class SingleTPUStrategy(SingleDeviceStrategy):
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
def model_to_device(self) -> None:
self.model.to(self.root_device)
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(