Improve typing for plugins (#10742)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
81a0a44d8f
commit
038c151b6e
|
@ -86,9 +86,6 @@ module = [
|
|||
"pytorch_lightning.plugins.environments.lsf_environment",
|
||||
"pytorch_lightning.plugins.environments.slurm_environment",
|
||||
"pytorch_lightning.plugins.environments.torchelastic_environment",
|
||||
"pytorch_lightning.plugins.precision.deepspeed",
|
||||
"pytorch_lightning.plugins.precision.native_amp",
|
||||
"pytorch_lightning.plugins.precision.precision_plugin",
|
||||
"pytorch_lightning.plugins.training_type.ddp",
|
||||
"pytorch_lightning.plugins.training_type.ddp2",
|
||||
"pytorch_lightning.plugins.training_type.ddp_spawn",
|
||||
|
|
|
@ -49,7 +49,9 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
|
|||
deepspeed_engine: DeepSpeedEngine = model.trainer.model
|
||||
deepspeed_engine.backward(closure_loss, *args, **kwargs)
|
||||
|
||||
def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
|
||||
def _run_backward(self, tensor: Tensor, model: Optional["DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None:
|
||||
if model is None:
|
||||
raise ValueError("Please provide the model as input to `backward`.")
|
||||
model.backward(tensor, *args, **kwargs)
|
||||
|
||||
def optimizer_step(
|
||||
|
|
|
@ -25,9 +25,9 @@ from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
from torch import autocast
|
||||
from torch import autocast as new_autocast
|
||||
else:
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import autocast as old_autocast
|
||||
|
||||
|
||||
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
|
||||
|
@ -62,7 +62,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
closure_loss = self.scaler.scale(closure_loss)
|
||||
return super().pre_backward(model, closure_loss)
|
||||
|
||||
def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
|
||||
def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
|
||||
if self.scaler is not None:
|
||||
tensor = self.scaler.scale(tensor)
|
||||
super()._run_backward(tensor, model, *args, **kwargs)
|
||||
|
@ -93,12 +93,12 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
self.scaler.step(optimizer, **kwargs)
|
||||
self.scaler.update()
|
||||
|
||||
def autocast_context_manager(self) -> autocast:
|
||||
def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]:
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
|
||||
# https://github.com/pytorch/pytorch/issues/67233
|
||||
return autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
|
||||
return autocast()
|
||||
return new_autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
|
||||
return old_autocast()
|
||||
|
||||
@contextmanager
|
||||
def forward_context(self) -> Generator[None, None, None]:
|
||||
|
|
|
@ -147,7 +147,7 @@ class PrecisionPlugin(CheckpointHooks):
|
|||
"""Hook to run the optimizer step."""
|
||||
if isinstance(model, pl.LightningModule):
|
||||
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
|
||||
optimizer.step(closure=closure, **kwargs)
|
||||
optimizer.step(closure=closure, **kwargs) # type: ignore[call-arg]
|
||||
|
||||
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
|
||||
if trainer.track_grad_norm == -1:
|
||||
|
|
Loading…
Reference in New Issue