Fix mypy errors attributed to `pytorch_lightning. strategies.sharded_spawn` (#14102)
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
31ecf9bfac
commit
e53c4e8e6c
|
@ -58,7 +58,6 @@ module = [
|
|||
"pytorch_lightning.profilers.base",
|
||||
"pytorch_lightning.profilers.pytorch",
|
||||
"pytorch_lightning.strategies.sharded",
|
||||
"pytorch_lightning.strategies.sharded_spawn",
|
||||
"pytorch_lightning.trainer.callback_hook",
|
||||
"pytorch_lightning.trainer.connectors.data_connector",
|
||||
"pytorch_lightning.trainer.supporters",
|
||||
|
|
|
@ -75,6 +75,7 @@ class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
|
|||
trainer = pl_module._trainer
|
||||
|
||||
if trainer is not None:
|
||||
assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
|
||||
if trainer.training:
|
||||
output = self.module.training_step(*inputs, **kwargs)
|
||||
# In manual_optimization, we need to prevent DDP reducer as
|
||||
|
|
|
@ -12,13 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
|
@ -42,7 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
|
|||
|
||||
def configure_ddp(self) -> None:
|
||||
# set up optimizers after the wrapped module has been moved to the device
|
||||
assert self.lightning_module is not None
|
||||
self.setup_optimizers(self.lightning_module.trainer)
|
||||
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
|
||||
self.model, self.optimizers = self._setup_model_and_optimizers(
|
||||
model=LightningShardedDataParallel(self.model), optimizers=self.optimizers
|
||||
)
|
||||
|
@ -69,12 +72,13 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
|
|||
return optimizers
|
||||
|
||||
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
|
||||
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
|
||||
assert self.lightning_module
|
||||
if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
|
||||
return optimizers
|
||||
|
||||
return self._reinit_optimizers_with_oss(optimizers)
|
||||
|
||||
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
|
||||
def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]:
|
||||
if isinstance(optimizer, OSS):
|
||||
optimizer.consolidate_state_dict()
|
||||
return self._optim_state_dict(optimizer)
|
||||
|
@ -93,7 +97,7 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
|
|||
yield None
|
||||
|
||||
@rank_zero_only
|
||||
def _optim_state_dict(self, optimizer):
|
||||
def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
|
||||
:meth:`consolidate_state_dict`.
|
||||
|
@ -112,7 +116,7 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
|
|||
def pre_backward(self, closure_loss: Tensor) -> None:
|
||||
pass
|
||||
|
||||
def post_training_step(self):
|
||||
def post_training_step(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue