Add `@override` for `src/lightning/fabric/wrappers.py` (#19292)
This commit is contained in:
parent
628ee0cb61
commit
4996965d11
|
@ -22,6 +22,7 @@ from torch import nn as nn
|
|||
from torch.nn.modules.module import _IncompatibleKeys
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.plugins import Precision
|
||||
from lightning.fabric.strategies import Strategy
|
||||
|
@ -111,6 +112,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
|
|||
def module(self) -> nn.Module:
|
||||
return self._original_module or self._forward_module
|
||||
|
||||
@override
|
||||
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Casts all inputs to the right precision and handles autocast for operations in the module forward method."""
|
||||
args, kwargs = self._precision.convert_input((args, kwargs))
|
||||
|
@ -129,6 +131,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
|
|||
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
@override
|
||||
def state_dict(
|
||||
self, destination: Optional[T_destination] = None, prefix: str = "", keep_vars: bool = False
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
|
@ -138,6 +141,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
|
|||
keep_vars=keep_vars,
|
||||
)
|
||||
|
||||
@override
|
||||
def load_state_dict( # type: ignore[override]
|
||||
self, state_dict: Mapping[str, Any], strict: bool = True, **kwargs: Any
|
||||
) -> _IncompatibleKeys:
|
||||
|
@ -194,6 +198,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
|
|||
|
||||
return _wrapped_method
|
||||
|
||||
@override
|
||||
def __getattr__(self, item: Any) -> Any:
|
||||
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
|
||||
# Special support for `LightningModule`, to prevent bypassing DDP's forward
|
||||
|
@ -212,6 +217,7 @@ class _FabricModule(_DeviceDtypeModuleMixin):
|
|||
attr = self._wrap_method_with_module_call_tracker(attr, item)
|
||||
return attr
|
||||
|
||||
@override
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if not getattr(self, "_fabric_module_initialized", False):
|
||||
super().__setattr__(name, value)
|
||||
|
|
Loading…
Reference in New Issue