From e53c4e8e6c14c92968df9bed8861e578bfe731aa Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Thu, 11 Aug 2022 22:10:05 +0100 Subject: [PATCH] Fix mypy errors attributed to `pytorch_lightning. strategies.sharded_spawn` (#14102) Co-authored-by: rohitgr7 Co-authored-by: Jirka Borovec Co-authored-by: awaelchli --- pyproject.toml | 1 - src/pytorch_lightning/overrides/base.py | 1 + src/pytorch_lightning/strategies/sharded_spawn.py | 14 +++++++++----- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b5e806bc69..9f7cc28d0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,6 @@ module = [ "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", "pytorch_lightning.strategies.sharded", - "pytorch_lightning.strategies.sharded_spawn", "pytorch_lightning.trainer.callback_hook", "pytorch_lightning.trainer.connectors.data_connector", "pytorch_lightning.trainer.supporters", diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 26c2837bda..3e9fda2f96 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -75,6 +75,7 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): trainer = pl_module._trainer if trainer is not None: + assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) if trainer.training: output = self.module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 4550e397de..882302e101 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple from torch import Tensor from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn @@ -42,7 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): def configure_ddp(self) -> None: # set up optimizers after the wrapped module has been moved to the device + assert self.lightning_module is not None self.setup_optimizers(self.lightning_module.trainer) + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model, self.optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=self.optimizers ) @@ -69,12 +72,13 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): return optimizers def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: - if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: + assert self.lightning_module + if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING: return optimizers return self._reinit_optimizers_with_oss(optimizers) - def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: + def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]: if isinstance(optimizer, OSS): optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer) @@ -93,7 +97,7 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): yield None @rank_zero_only - def _optim_state_dict(self, optimizer): + def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]: """ Retrieves state dict only on rank 0, which contains the entire optimizer state after calling :meth:`consolidate_state_dict`. @@ -112,7 +116,7 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): def pre_backward(self, closure_loss: Tensor) -> None: pass - def post_training_step(self): + def post_training_step(self) -> None: pass @classmethod