From 03f2f324457cfcb1a492db46cdeeaa992687cc4b Mon Sep 17 00:00:00 2001 From: JongMok Lee Date: Sun, 28 Aug 2022 07:07:36 +0900 Subject: [PATCH] Fix mypy errors in `pytorch_lightning/strategies/sharded.py` (#14184) Co-authored-by: otaj --- pyproject.toml | 1 - src/pytorch_lightning/strategies/sharded.py | 19 ++++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 399e1e83ef..1f704e7aa2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ module = [ "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", - "pytorch_lightning.strategies.sharded", "pytorch_lightning.trainer.callback_hook", "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index 3b77bc6cee..6bf8e47022 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, List, Tuple, Union +from typing import Dict, Generator, List, Tuple from torch import Tensor from torch.nn import Module @@ -20,7 +20,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.enums import PrecisionType @@ -51,10 +51,11 @@ class DDPShardedStrategy(DDPStrategy): def setup(self, trainer: "pl.Trainer") -> None: # share ddp pids to all processes - self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) + self._rank_0_will_call_children_scripts: bool = self.broadcast(self._rank_0_will_call_children_scripts) if self._should_run_deadlock_detection(): self._share_information_to_prevent_deadlock() + assert self.accelerator is not None self.accelerator.setup(trainer) # move the model to the correct device @@ -64,6 +65,7 @@ class DDPShardedStrategy(DDPStrategy): trainer_fn = trainer.state.fn if trainer_fn == TrainerFn.FITTING: if self._layer_sync: + assert self.model is not None self.model = self._layer_sync.apply(self.model) self.setup_precision_plugin() @@ -73,7 +75,9 @@ class DDPShardedStrategy(DDPStrategy): def configure_ddp(self) -> None: self._set_ddp_kwargs() - self.setup_optimizers(self.model.trainer) + assert self.lightning_module is not None + self.setup_optimizers(self.lightning_module.trainer) + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model, self.optimizers = self._setup_model_and_optimizers( model=_LightningModuleWrapperBase(self.model), optimizers=self.optimizers, @@ -97,12 +101,13 @@ class DDPShardedStrategy(DDPStrategy): return model, optimizers def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: - if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: + assert self.lightning_module is not None + if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING: return optimizers return self._reinit_optimizers_with_oss(optimizers) - def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: + def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer @@ -135,7 +140,7 @@ class DDPShardedStrategy(DDPStrategy): else: yield None - def post_training_step(self): + def post_training_step(self) -> None: pass @classmethod