Add `@override` for `src/lightning/fabric/wrappers.py` (#19292)

This commit is contained in:
Victor Prins 2024-01-16 12:30:56 +01:00 committed by GitHub
parent 628ee0cb61
commit 4996965d11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 0 deletions

View File

@ -22,6 +22,7 @@ from torch import nn as nn
from torch.nn.modules.module import _IncompatibleKeys
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from typing_extensions import override
from lightning.fabric.plugins import Precision
from lightning.fabric.strategies import Strategy
@ -111,6 +112,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
def module(self) -> nn.Module:
return self._original_module or self._forward_module
@override
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Casts all inputs to the right precision and handles autocast for operations in the module forward method."""
args, kwargs = self._precision.convert_input((args, kwargs))
@ -129,6 +131,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
...
@override
def state_dict(
self, destination: Optional[T_destination] = None, prefix: str = "", keep_vars: bool = False
) -> Optional[Dict[str, Any]]:
@ -138,6 +141,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
keep_vars=keep_vars,
)
@override
def load_state_dict( # type: ignore[override]
self, state_dict: Mapping[str, Any], strict: bool = True, **kwargs: Any
) -> _IncompatibleKeys:
@ -194,6 +198,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
return _wrapped_method
@override
def __getattr__(self, item: Any) -> Any:
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
# Special support for `LightningModule`, to prevent bypassing DDP's forward
@ -212,6 +217,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
attr = self._wrap_method_with_module_call_tracker(attr, item)
return attr
@override
def __setattr__(self, name: str, value: Any) -> None:
if not getattr(self, "_fabric_module_initialized", False):
super().__setattr__(name, value)