Fix mypy errors in `pytorch_lightning/strategies/sharded.py` (#14184)

Co-authored-by: otaj <ota@lightning.ai>
This commit is contained in:
JongMok Lee 2022-08-28 07:07:36 +09:00 committed by GitHub
parent af688dee69
commit 03f2f32445
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 8 deletions

View File

@ -52,7 +52,6 @@ module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Tuple, Union
from typing import Dict, Generator, List, Tuple
from torch import Tensor
from torch.nn import Module
@ -20,7 +20,7 @@ from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import PrecisionType
@ -51,10 +51,11 @@ class DDPShardedStrategy(DDPStrategy):
def setup(self, trainer: "pl.Trainer") -> None:
# share ddp pids to all processes
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
self._rank_0_will_call_children_scripts: bool = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()
assert self.accelerator is not None
self.accelerator.setup(trainer)
# move the model to the correct device
@ -64,6 +65,7 @@ class DDPShardedStrategy(DDPStrategy):
trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
assert self.model is not None
self.model = self._layer_sync.apply(self.model)
self.setup_precision_plugin()
@ -73,7 +75,9 @@ class DDPShardedStrategy(DDPStrategy):
def configure_ddp(self) -> None:
self._set_ddp_kwargs()
self.setup_optimizers(self.model.trainer)
assert self.lightning_module is not None
self.setup_optimizers(self.lightning_module.trainer)
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
self.model, self.optimizers = self._setup_model_and_optimizers(
model=_LightningModuleWrapperBase(self.model),
optimizers=self.optimizers,
@ -97,12 +101,13 @@ class DDPShardedStrategy(DDPStrategy):
return model, optimizers
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
assert self.lightning_module is not None
if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
return optimizers
return self._reinit_optimizers_with_oss(optimizers)
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
@ -135,7 +140,7 @@ class DDPShardedStrategy(DDPStrategy):
else:
yield None
def post_training_step(self):
def post_training_step(self) -> None:
pass
@classmethod