From 4996965d11077f6bf8ab54e60e34099360ace7dc Mon Sep 17 00:00:00 2001 From: Victor Prins Date: Tue, 16 Jan 2024 12:30:56 +0100 Subject: [PATCH] Add `@override` for `src/lightning/fabric/wrappers.py` (#19292) --- src/lightning/fabric/wrappers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 6dda42e987..16611eb4c7 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -22,6 +22,7 @@ from torch import nn as nn from torch.nn.modules.module import _IncompatibleKeys from torch.optim import Optimizer from torch.utils.data import DataLoader +from typing_extensions import override from lightning.fabric.plugins import Precision from lightning.fabric.strategies import Strategy @@ -111,6 +112,7 @@ class _FabricModule(_DeviceDtypeModuleMixin): def module(self) -> nn.Module: return self._original_module or self._forward_module + @override def forward(self, *args: Any, **kwargs: Any) -> Any: """Casts all inputs to the right precision and handles autocast for operations in the module forward method.""" args, kwargs = self._precision.convert_input((args, kwargs)) @@ -129,6 +131,7 @@ class _FabricModule(_DeviceDtypeModuleMixin): def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ... + @override def state_dict( self, destination: Optional[T_destination] = None, prefix: str = "", keep_vars: bool = False ) -> Optional[Dict[str, Any]]: @@ -138,6 +141,7 @@ class _FabricModule(_DeviceDtypeModuleMixin): keep_vars=keep_vars, ) + @override def load_state_dict( # type: ignore[override] self, state_dict: Mapping[str, Any], strict: bool = True, **kwargs: Any ) -> _IncompatibleKeys: @@ -194,6 +198,7 @@ class _FabricModule(_DeviceDtypeModuleMixin): return _wrapped_method + @override def __getattr__(self, item: Any) -> Any: if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module: # Special support for `LightningModule`, to prevent bypassing DDP's forward @@ -212,6 +217,7 @@ class _FabricModule(_DeviceDtypeModuleMixin): attr = self._wrap_method_with_module_call_tracker(attr, item) return attr + @override def __setattr__(self, name: str, value: Any) -> None: if not getattr(self, "_fabric_module_initialized", False): super().__setattr__(name, value)