Improve typing for plugins (#10742)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2021-11-26 21:14:58 +01:00 committed by GitHub
parent 81a0a44d8f
commit 038c151b6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 10 additions and 11 deletions

View File

@ -86,9 +86,6 @@ module = [
"pytorch_lightning.plugins.environments.lsf_environment",
"pytorch_lightning.plugins.environments.slurm_environment",
"pytorch_lightning.plugins.environments.torchelastic_environment",
"pytorch_lightning.plugins.precision.deepspeed",
"pytorch_lightning.plugins.precision.native_amp",
"pytorch_lightning.plugins.precision.precision_plugin",
"pytorch_lightning.plugins.training_type.ddp",
"pytorch_lightning.plugins.training_type.ddp2",
"pytorch_lightning.plugins.training_type.ddp_spawn",

View File

@ -49,7 +49,9 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
deepspeed_engine: DeepSpeedEngine = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)
def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
def _run_backward(self, tensor: Tensor, model: Optional["DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None:
if model is None:
raise ValueError("Please provide the model as input to `backward`.")
model.backward(tensor, *args, **kwargs)
def optimizer_step(

View File

@ -25,9 +25,9 @@ from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _TORCH_GREATER_EQUAL_1_10:
from torch import autocast
from torch import autocast as new_autocast
else:
from torch.cuda.amp import autocast
from torch.cuda.amp import autocast as old_autocast
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
@ -62,7 +62,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
closure_loss = self.scaler.scale(closure_loss)
return super().pre_backward(model, closure_loss)
def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
super()._run_backward(tensor, model, *args, **kwargs)
@ -93,12 +93,12 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
self.scaler.step(optimizer, **kwargs)
self.scaler.update()
def autocast_context_manager(self) -> autocast:
def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]:
if _TORCH_GREATER_EQUAL_1_10:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
return autocast()
return new_autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
return old_autocast()
@contextmanager
def forward_context(self) -> Generator[None, None, None]:

View File

@ -147,7 +147,7 @@ class PrecisionPlugin(CheckpointHooks):
"""Hook to run the optimizer step."""
if isinstance(model, pl.LightningModule):
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
optimizer.step(closure=closure, **kwargs)
optimizer.step(closure=closure, **kwargs) # type: ignore[call-arg]
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1: