diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index c6c1867534..481a44f1c2 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -64,6 +64,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * env_prefix * env_parse +- Mark the `forward_module` argument as required ([#16386](https://github.com/Lightning-AI/lightning/pull/16386)) + * Removed the deprecated `pl_module` argument from the distributed module wrappers + * Removed the deprecated `pytorch_lightning.overrides.base.unwrap_lightning_module` function + * Removed the `pytorch_lightning.overrides.distributed.LightningDistributedModule` class + * Removed the deprecated `pytorch_lightning.overrides.fairscale.unwrap_lightning_module_sharded` function + * Removed the `pytorch_lightning.overrides.fairscale.LightningDistributedModule` class + - Removed the deprecated automatic GPU selection ([#16184](https://github.com/Lightning-AI/lightning/pull/16184)) * Removed the `Trainer(auto_select_gpus=...)` argument * Removed the `pytorch_lightning.tuner.auto_gpu_select.{pick_single_gpu,pick_multiple_gpus}` functions diff --git a/src/pytorch_lightning/overrides/__init__.py b/src/pytorch_lightning/overrides/__init__.py index ca97a63649..e69de29bb2 100644 --- a/src/pytorch_lightning/overrides/__init__.py +++ b/src/pytorch_lightning/overrides/__init__.py @@ -1,2 +0,0 @@ -from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401 -from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401 diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 9c0f6bf048..684c4fe8c1 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -11,16 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Union import torch -import torch.nn as nn -from torch.nn import DataParallel -from torch.nn.parallel import DistributedDataParallel import pytorch_lightning as pl from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module): @@ -55,9 +51,7 @@ class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Mod class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module): - def __init__( - self, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] - ) -> None: + def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. @@ -75,8 +69,6 @@ class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module): "`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one," f" got: {forward_module.__class__.__qualname__}" ) - # TODO: In v2.0.0, remove the Optional type from forward_module and remove the assertion - assert forward_module is not None self._forward_module = forward_module # set the parameters_to_ignore from LightningModule. @@ -111,47 +103,3 @@ class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module): if trainer.predicting: return self._forward_module.predict_step(*inputs, **kwargs) return self._forward_module(*inputs, **kwargs) - - @classmethod - def _validate_init_arguments( - cls, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - ) -> None: - # TODO: In v2.0.0, remove this method and mark the forward_module init argument in all subclasses as required - if pl_module is not None: - rank_zero_deprecation( - f"The argument `pl_module` in `{cls.__name__}` is deprecated in v1.8.0 and will be removed in" - " v2.0.0. Please use `forward_module` instead." - ) - elif forward_module is None: - raise ValueError("Argument `forward_module` is required.") - - -def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = False) -> "pl.LightningModule": - """Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module`` - attributes on the wrapper. - - .. deprecated:: v1.8.0 - The function ``unwrap_lightning_module`` is deprecated in v1.8.0 and will be removed in v2.0.0. Access the - ``LightningModule`` directly through the strategy attribute ``Strategy.lightning_module``. - - Raises: - TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped - further. - """ - if not _suppress_warning: - rank_zero_deprecation( - "The function `unwrap_lightning_module` is deprecated in v1.8.0 and will be removed in v2.0.0. Access the" - " `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." - ) - model = wrapped_model - if isinstance(model, (DistributedDataParallel, DataParallel)): - model = unwrap_lightning_module(model.module) - if isinstance(model, _LightningModuleWrapperBase): - model = model.lightning_module - if isinstance(model, _LightningPrecisionModuleWrapperBase): - model = model.module - if not isinstance(model, pl.LightningModule): - raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.") - return model diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index 1e70893584..eb93b29bbe 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -13,7 +13,7 @@ # limitations under the License. import numbers import warnings -from typing import Any, Optional, Union +from typing import Any, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -52,23 +52,15 @@ class LightningParallelModule(_LightningModuleWrapperBase): ) Args: - pl_module: The module to wrap. See description for `forward_module`. - - .. deprecated:: v1.8.0 - The argument ``pl_module`` is deprecated in v1.8.0 and will be removed in v2.0.0. Please use - ``forward_module`` instead. - forward_module: The module to wrap. If it's not a ``LightningModule``, it must have an attribute ``.module`` pointing to a ``LightningModule`` reference. """ def __init__( self, - forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], ) -> None: - self._validate_init_arguments(pl_module, forward_module) - super().__init__(forward_module=(pl_module or forward_module)) + super().__init__(forward_module=forward_module) _ignore_scalar_return_in_dp() def forward(self, *inputs: Any, **kwargs: Any) -> Any: diff --git a/src/pytorch_lightning/overrides/distributed.py b/src/pytorch_lightning/overrides/distributed.py index 212f7b7a41..7494d5bbc7 100644 --- a/src/pytorch_lightning/overrides/distributed.py +++ b/src/pytorch_lightning/overrides/distributed.py @@ -12,26 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, cast, Iterable, Iterator, List, Optional, Sized, Union +from typing import Any, cast, Iterable, Iterator, List, Sized, Union import torch from torch import Tensor from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler, DistributedSampler, Sampler -import pytorch_lightning as pl from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase - - -class LightningDistributedModule(_LightningModuleWrapperBase): - def __init__( - self, - forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - ) -> None: - self._validate_init_arguments(pl_module, forward_module) - super().__init__(forward_module=(pl_module or forward_module)) def _find_tensors( diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py index f818792e57..93b100f9e3 100644 --- a/src/pytorch_lightning/overrides/fairscale.py +++ b/src/pytorch_lightning/overrides/fairscale.py @@ -11,21 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union +from typing import List -import torch.nn as nn from lightning_utilities.core.imports import package_available from torch.optim import Optimizer -import pytorch_lightning as pl from lightning_fabric.plugins import Precision from lightning_fabric.utilities.imports import _IS_WINDOWS -from pytorch_lightning.overrides.base import ( - _LightningModuleWrapperBase, - _LightningPrecisionModuleWrapperBase, - unwrap_lightning_module, -) -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and package_available("fairscale") @@ -35,37 +27,6 @@ else: OSS = object -class LightningShardedDataParallel(_LightningModuleWrapperBase): - def __init__( - self, - forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - ) -> None: - rank_zero_deprecation( - "PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be" - " removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead." - " The difference is that native FSDP uses PyTorch's implementation and the current strategy uses" - " FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use" - " the native version by default." - ) - self._validate_init_arguments(pl_module, forward_module) - super().__init__(forward_module=(pl_module or forward_module)) - - -def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule": - from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel - - rank_zero_deprecation( - "The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v2.0.0." - " Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." - ) - model = wrapped_model - if isinstance(model, ShardedDataParallel): - model = model.module - - return unwrap_lightning_module(model, _suppress_warning=True) - - def _reinit_optimizers_with_oss(optimizers: List[Optimizer], precision: Precision, num_nodes: int) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if not isinstance(optimizer, OSS): diff --git a/src/pytorch_lightning/strategies/bagua.py b/src/pytorch_lightning/strategies/bagua.py index b018b197e9..5a1c3ce476 100644 --- a/src/pytorch_lightning/strategies/bagua.py +++ b/src/pytorch_lightning/strategies/bagua.py @@ -64,11 +64,8 @@ log = logging.getLogger(__name__) class LightningBaguaModule(_LightningModuleWrapperBase): def __init__( self, - forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], ) -> None: - self._validate_init_arguments(pl_module, forward_module) - forward_module = pl_module or forward_module super().__init__(forward_module=forward_module) # Bagua use `bagua_module_name` to distinguish different modules self._bagua_module_name = f"{forward_module.__class__.__name__}{id(forward_module)}" diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 6823815078..cff71f89d7 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -43,8 +43,7 @@ from lightning_fabric.utilities.optimizer import _optimizers_to_device from lightning_fabric.utilities.seed import reset_seed from lightning_fabric.utilities.types import ReduceOp from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.overrides import LightningDistributedModule -from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -291,7 +290,7 @@ class DDPStrategy(ParallelStrategy): log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel") self.pre_configure_ddp() assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) - self.model = self._setup_model(LightningDistributedModule(self.model)) + self.model = self._setup_model(_LightningModuleWrapperBase(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self) -> Optional[List[int]]: diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index e0f4ca42c1..949cdf3813 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -36,8 +36,7 @@ from lightning_fabric.utilities.distributed import group as _group from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from lightning_fabric.utilities.optimizer import _optimizers_to_device from lightning_fabric.utilities.types import ReduceOp -from pytorch_lightning.overrides import LightningDistributedModule -from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher @@ -211,7 +210,7 @@ class DDPSpawnStrategy(ParallelStrategy): def configure_ddp(self) -> None: self.pre_configure_ddp() assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) - self.model = self._setup_model(LightningDistributedModule(self.model)) + self.model = self._setup_model(_LightningModuleWrapperBase(self.model)) self._register_ddp_hooks() # set up optimizers after the wrapped module has been moved to the device diff --git a/src/pytorch_lightning/strategies/hpu_parallel.py b/src/pytorch_lightning/strategies/hpu_parallel.py index 1ef6ee1cc3..f6605b9089 100644 --- a/src/pytorch_lightning/strategies/hpu_parallel.py +++ b/src/pytorch_lightning/strategies/hpu_parallel.py @@ -22,7 +22,7 @@ from torch.optim.optimizer import Optimizer import pytorch_lightning as pl from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment from lightning_fabric.utilities.distributed import group as _group -from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO @@ -123,7 +123,7 @@ class HPUParallelStrategy(DDPStrategy): if _TORCH_LESSER_EQUAL_1_10_2: log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel") self._pre_configure_ddp() - self.model = self._setup_model(LightningDistributedModule(self.model)) # type: ignore + self.model = self._setup_model(_LightningModuleWrapperBase(self.model)) # type: ignore if self.root_device.type == "hpu" and self._static_graph: self._model._set_static_graph() # type: ignore self._register_ddp_hooks() diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 167a572181..4fd09eb6ae 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -28,7 +28,7 @@ from lightning_fabric.plugins.environments import XLAEnvironment from lightning_fabric.utilities.data import has_len from lightning_fabric.utilities.optimizer import _optimizers_to_device from lightning_fabric.utilities.types import _PATH, ReduceOp -from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy @@ -128,7 +128,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy): TPUSpawnStrategy._validate_patched_dataloaders(model) import torch_xla.distributed.xla_multiprocessing as xmp - self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) + self.wrapped_model = xmp.MpModelWrapper(_LightningModuleWrapperBase(model)) return super().connect(model) def _configure_launcher(self) -> None: diff --git a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py index a70a55ff65..31336eadd9 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py @@ -17,17 +17,12 @@ from unittest import mock import numpy import pytest import torch -from lightning_utilities.test.warning import no_warning_call from torch.utils.data import DataLoader from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin -from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset -from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule -from pytorch_lightning.overrides.base import unwrap_lightning_module -from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded +from pytorch_lightning.demos.boring_classes import RandomDataset from pytorch_lightning.plugins.environments import LightningEnvironment -from pytorch_lightning.strategies.bagua import LightningBaguaModule from pytorch_lightning.strategies.utils import on_colab_kaggle from pytorch_lightning.utilities.apply_func import ( apply_to_collection, @@ -62,51 +57,6 @@ from pytorch_lightning.utilities.distributed import ( from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device from pytorch_lightning.utilities.seed import pl_worker_init_function, reset_seed, seed_everything from pytorch_lightning.utilities.xla_device import inner_f, pl_multi_process, XLADeviceUtils -from tests_pytorch.helpers.runif import RunIf - - -@pytest.mark.parametrize( - "wrapper_class", - [ - LightningParallelModule, - LightningDistributedModule, - LightningBaguaModule, - ], -) -def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): - with no_warning_call( - DeprecationWarning, match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8.0" - ): - wrapper_class(BoringModel()) - - with pytest.deprecated_call( - match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8.0" - ): - wrapper_class(pl_module=BoringModel()) - - -@RunIf(fairscale=True) -def test_v1_10_deprecated_fairscale_pl_module_init_parameter(): - with no_warning_call( - DeprecationWarning, match=r"The argument `pl_module` in `LightningShardedDataParallel` is deprecated in v1.8.0" - ), pytest.deprecated_call(match="FairScale has been deprecated in v1.9.0"): - LightningShardedDataParallel(BoringModel()) - - with pytest.deprecated_call( - match=r"The argument `pl_module` in `LightningShardedDataParallel` is deprecated in v1.8.0" - ): - LightningShardedDataParallel(pl_module=BoringModel()) - - -def test_v1_10_deprecated_unwrap_lightning_module(): - with pytest.deprecated_call(match=r"The function `unwrap_lightning_module` is deprecated in v1.8.0"): - unwrap_lightning_module(BoringModel()) - - -@RunIf(fairscale=True) -def test_v1_10_deprecated_unwrap_lightning_module_sharded(): - with pytest.deprecated_call(match=r"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0"): - unwrap_lightning_module_sharded(BoringModel()) def test_v1_10_deprecated_on_colab_kaggle_func(): diff --git a/tests/tests_pytorch/overrides/test_base.py b/tests/tests_pytorch/overrides/test_base.py index 27d2db688d..101cf41571 100644 --- a/tests/tests_pytorch/overrides/test_base.py +++ b/tests/tests_pytorch/overrides/test_base.py @@ -13,14 +13,9 @@ # limitations under the License. import pytest import torch -from torch.nn import DataParallel from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.base import ( - _LightningModuleWrapperBase, - _LightningPrecisionModuleWrapperBase, - unwrap_lightning_module, -) +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase @pytest.mark.parametrize("wrapper_class", [_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase]) @@ -30,13 +25,3 @@ def test_wrapper_device_dtype(wrapper_class): wrapped_model.to(dtype=torch.float16) assert model.dtype == torch.float16 - - -def test_unwrap_lightning_module(): - model = BoringModel() - wrapped_model = _LightningPrecisionModuleWrapperBase(model) - wrapped_model = _LightningModuleWrapperBase(wrapped_model) - wrapped_model = DataParallel(wrapped_model) - - with pytest.deprecated_call(match="The function `unwrap_lightning_module` is deprecated in v1.8.0"): - assert unwrap_lightning_module(wrapped_model) == model diff --git a/tests/tests_pytorch/overrides/test_data_parallel.py b/tests/tests_pytorch/overrides/test_data_parallel.py index 68f625a427..64fd229991 100644 --- a/tests/tests_pytorch/overrides/test_data_parallel.py +++ b/tests/tests_pytorch/overrides/test_data_parallel.py @@ -20,7 +20,7 @@ from torch.nn import DataParallel from pytorch_lightning import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.data_parallel import ( LightningParallelModule, python_scalar_to_tensor, @@ -30,7 +30,7 @@ from pytorch_lightning.trainer.states import RunningStage from tests_pytorch.helpers.runif import RunIf -@pytest.mark.parametrize("wrapper_class", [LightningParallelModule, LightningDistributedModule]) +@pytest.mark.parametrize("wrapper_class", [LightningParallelModule, _LightningModuleWrapperBase]) @pytest.mark.parametrize( "stage", [