From 24fc54f07bb11e1955f31c1e3d3034806660f37f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Nov 2021 12:28:28 +0100 Subject: [PATCH] Fix typing in `pl.overrides.fairscale` (#10799) * update typing in fairscale * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 1 - pytorch_lightning/overrides/fairscale.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0315bb0373..54586addf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,6 @@ module = [ "pytorch_lightning.loops.fit_loop", "pytorch_lightning.loops.utilities", "pytorch_lightning.overrides.distributed", - "pytorch_lightning.overrides.fairscale", "pytorch_lightning.plugins.environments.lightning_environment", "pytorch_lightning.plugins.environments.lsf_environment", "pytorch_lightning.plugins.environments.slurm_environment", diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 7860377b1a..c33bed6090 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch.nn as nn + import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE @@ -19,11 +21,11 @@ LightningShardedDataParallel = None if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel - class LightningShardedDataParallel(_LightningModuleWrapperBase): + class LightningShardedDataParallel(_LightningModuleWrapperBase): # type: ignore[no-redef] # Just do this for later docstrings pass - def unwrap_lightning_module_sharded(wrapped_model) -> "pl.LightningModule": + def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule": model = wrapped_model if isinstance(model, ShardedDataParallel): model = model.module