Fix mypy errors attributed to `pytorch_lightning. strategies.sharded_spawn` (#14102)

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
Krishna Kalyan 2022-08-11 22:10:05 +01:00 committed by GitHub
parent 31ecf9bfac
commit e53c4e8e6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 6 deletions

View File

@ -58,7 +58,6 @@ module = [
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.supporters",

View File

@ -75,6 +75,7 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
trainer = pl_module._trainer
if trainer is not None:
assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
if trainer.training:
output = self.module.training_step(*inputs, **kwargs)
# In manual_optimization, we need to prevent DDP reducer as

View File

@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.states import TrainerFn
@ -42,7 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
def configure_ddp(self) -> None:
# set up optimizers after the wrapped module has been moved to the device
assert self.lightning_module is not None
self.setup_optimizers(self.lightning_module.trainer)
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
self.model, self.optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model), optimizers=self.optimizers
)
@ -69,12 +72,13 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
return optimizers
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
assert self.lightning_module
if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
return optimizers
return self._reinit_optimizers_with_oss(optimizers)
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]:
if isinstance(optimizer, OSS):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)
@ -93,7 +97,7 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
yield None
@rank_zero_only
def _optim_state_dict(self, optimizer):
def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]:
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
@ -112,7 +116,7 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
def pre_backward(self, closure_loss: Tensor) -> None:
pass
def post_training_step(self):
def post_training_step(self) -> None:
pass
@classmethod