Fix mypy errors in `pytorch_lightning/strategies/sharded.py` (#14184)
Co-authored-by: otaj <ota@lightning.ai>
This commit is contained in:
parent
af688dee69
commit
03f2f32445
|
@ -52,7 +52,6 @@ module = [
|
|||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
"pytorch_lightning.profilers.base",
|
||||
"pytorch_lightning.profilers.pytorch",
|
||||
"pytorch_lightning.strategies.sharded",
|
||||
"pytorch_lightning.trainer.callback_hook",
|
||||
"pytorch_lightning.trainer.supporters",
|
||||
"pytorch_lightning.trainer.trainer",
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Generator, List, Tuple, Union
|
||||
from typing import Dict, Generator, List, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
@ -20,7 +20,7 @@ from torch.optim import Optimizer
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.strategies.ddp import DDPStrategy
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
|
@ -51,10 +51,11 @@ class DDPShardedStrategy(DDPStrategy):
|
|||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
# share ddp pids to all processes
|
||||
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
|
||||
self._rank_0_will_call_children_scripts: bool = self.broadcast(self._rank_0_will_call_children_scripts)
|
||||
if self._should_run_deadlock_detection():
|
||||
self._share_information_to_prevent_deadlock()
|
||||
|
||||
assert self.accelerator is not None
|
||||
self.accelerator.setup(trainer)
|
||||
|
||||
# move the model to the correct device
|
||||
|
@ -64,6 +65,7 @@ class DDPShardedStrategy(DDPStrategy):
|
|||
trainer_fn = trainer.state.fn
|
||||
if trainer_fn == TrainerFn.FITTING:
|
||||
if self._layer_sync:
|
||||
assert self.model is not None
|
||||
self.model = self._layer_sync.apply(self.model)
|
||||
|
||||
self.setup_precision_plugin()
|
||||
|
@ -73,7 +75,9 @@ class DDPShardedStrategy(DDPStrategy):
|
|||
|
||||
def configure_ddp(self) -> None:
|
||||
self._set_ddp_kwargs()
|
||||
self.setup_optimizers(self.model.trainer)
|
||||
assert self.lightning_module is not None
|
||||
self.setup_optimizers(self.lightning_module.trainer)
|
||||
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
|
||||
self.model, self.optimizers = self._setup_model_and_optimizers(
|
||||
model=_LightningModuleWrapperBase(self.model),
|
||||
optimizers=self.optimizers,
|
||||
|
@ -97,12 +101,13 @@ class DDPShardedStrategy(DDPStrategy):
|
|||
return model, optimizers
|
||||
|
||||
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
|
||||
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
|
||||
assert self.lightning_module is not None
|
||||
if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
|
||||
return optimizers
|
||||
|
||||
return self._reinit_optimizers_with_oss(optimizers)
|
||||
|
||||
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
|
||||
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
|
||||
for x, optimizer in enumerate(optimizers):
|
||||
if isinstance(optimizer, LightningOptimizer):
|
||||
optimizer = optimizer._optimizer
|
||||
|
@ -135,7 +140,7 @@ class DDPShardedStrategy(DDPStrategy):
|
|||
else:
|
||||
yield None
|
||||
|
||||
def post_training_step(self):
|
||||
def post_training_step(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue