fix mypy typing errors in pytorch_lightning/strategies/single_tpu.py (#13534)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
03cbca430d
commit
61473c2290
|
@ -72,7 +72,6 @@ module = [
|
|||
"pytorch_lightning.strategies.parallel",
|
||||
"pytorch_lightning.strategies.sharded",
|
||||
"pytorch_lightning.strategies.sharded_spawn",
|
||||
"pytorch_lightning.strategies.single_tpu",
|
||||
"pytorch_lightning.strategies.tpu_spawn",
|
||||
"pytorch_lightning.strategies.strategy",
|
||||
"pytorch_lightning.profilers.advanced",
|
||||
|
|
|
@ -19,7 +19,6 @@ from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
|||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
@ -55,13 +54,10 @@ class SingleTPUStrategy(SingleDeviceStrategy):
|
|||
return False
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
assert self.model, "self.model must be set before find_shared_parameters(self.model)"
|
||||
shared_params = find_shared_parameters(self.model)
|
||||
self.model_to_device()
|
||||
if is_overridden("on_post_move_to_device", self.lightning_module):
|
||||
self.model.on_post_move_to_device()
|
||||
else:
|
||||
set_shared_parameters(self.model, shared_params)
|
||||
|
||||
set_shared_parameters(self.model, shared_params)
|
||||
super().setup(trainer)
|
||||
|
||||
if self.debug:
|
||||
|
@ -70,9 +66,6 @@ class SingleTPUStrategy(SingleDeviceStrategy):
|
|||
self.tpu_local_core_rank = xm.get_local_ordinal()
|
||||
self.tpu_global_core_rank = xm.get_ordinal()
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
self.model.to(self.root_device)
|
||||
|
||||
@classmethod
|
||||
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||
strategy_registry.register(
|
||||
|
|
Loading…
Reference in New Issue