Fix typing in `pl.overrides.fairscale` (#10799)

* update typing in fairscale

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-11-29 12:28:28 +01:00 committed by GitHub
parent bd3fb2e66e
commit 24fc54f07b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 3 deletions

View File

@ -78,7 +78,6 @@ module = [
"pytorch_lightning.loops.fit_loop",
"pytorch_lightning.loops.utilities",
"pytorch_lightning.overrides.distributed",
"pytorch_lightning.overrides.fairscale",
"pytorch_lightning.plugins.environments.lightning_environment",
"pytorch_lightning.plugins.environments.lsf_environment",
"pytorch_lightning.plugins.environments.slurm_environment",

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
@ -19,11 +21,11 @@ LightningShardedDataParallel = None
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
class LightningShardedDataParallel(_LightningModuleWrapperBase):
class LightningShardedDataParallel(_LightningModuleWrapperBase): # type: ignore[no-redef]
# Just do this for later docstrings
pass
def unwrap_lightning_module_sharded(wrapped_model) -> "pl.LightningModule":
def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule":
model = wrapped_model
if isinstance(model, ShardedDataParallel):
model = model.module