Mark forward_module as required (#16386)
This commit is contained in:
parent
46246c3336
commit
886ad49a55
|
@ -64,6 +64,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* env_prefix
|
||||
* env_parse
|
||||
|
||||
- Mark the `forward_module` argument as required ([#16386](https://github.com/Lightning-AI/lightning/pull/16386))
|
||||
* Removed the deprecated `pl_module` argument from the distributed module wrappers
|
||||
* Removed the deprecated `pytorch_lightning.overrides.base.unwrap_lightning_module` function
|
||||
* Removed the `pytorch_lightning.overrides.distributed.LightningDistributedModule` class
|
||||
* Removed the deprecated `pytorch_lightning.overrides.fairscale.unwrap_lightning_module_sharded` function
|
||||
* Removed the `pytorch_lightning.overrides.fairscale.LightningDistributedModule` class
|
||||
|
||||
- Removed the deprecated automatic GPU selection ([#16184](https://github.com/Lightning-AI/lightning/pull/16184))
|
||||
* Removed the `Trainer(auto_select_gpus=...)` argument
|
||||
* Removed the `pytorch_lightning.tuner.auto_gpu_select.{pick_single_gpu,pick_multiple_gpus}` functions
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401
|
||||
from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401
|
|
@ -11,16 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import DataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
|
||||
class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
|
||||
|
@ -55,9 +51,7 @@ class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Mod
|
|||
|
||||
|
||||
class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
|
||||
def __init__(
|
||||
self, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]]
|
||||
) -> None:
|
||||
def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
|
||||
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
|
||||
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
|
||||
|
||||
|
@ -75,8 +69,6 @@ class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
|
|||
"`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one,"
|
||||
f" got: {forward_module.__class__.__qualname__}"
|
||||
)
|
||||
# TODO: In v2.0.0, remove the Optional type from forward_module and remove the assertion
|
||||
assert forward_module is not None
|
||||
self._forward_module = forward_module
|
||||
|
||||
# set the parameters_to_ignore from LightningModule.
|
||||
|
@ -111,47 +103,3 @@ class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
|
|||
if trainer.predicting:
|
||||
return self._forward_module.predict_step(*inputs, **kwargs)
|
||||
return self._forward_module(*inputs, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _validate_init_arguments(
|
||||
cls,
|
||||
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
|
||||
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
|
||||
) -> None:
|
||||
# TODO: In v2.0.0, remove this method and mark the forward_module init argument in all subclasses as required
|
||||
if pl_module is not None:
|
||||
rank_zero_deprecation(
|
||||
f"The argument `pl_module` in `{cls.__name__}` is deprecated in v1.8.0 and will be removed in"
|
||||
" v2.0.0. Please use `forward_module` instead."
|
||||
)
|
||||
elif forward_module is None:
|
||||
raise ValueError("Argument `forward_module` is required.")
|
||||
|
||||
|
||||
def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = False) -> "pl.LightningModule":
|
||||
"""Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module``
|
||||
attributes on the wrapper.
|
||||
|
||||
.. deprecated:: v1.8.0
|
||||
The function ``unwrap_lightning_module`` is deprecated in v1.8.0 and will be removed in v2.0.0. Access the
|
||||
``LightningModule`` directly through the strategy attribute ``Strategy.lightning_module``.
|
||||
|
||||
Raises:
|
||||
TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped
|
||||
further.
|
||||
"""
|
||||
if not _suppress_warning:
|
||||
rank_zero_deprecation(
|
||||
"The function `unwrap_lightning_module` is deprecated in v1.8.0 and will be removed in v2.0.0. Access the"
|
||||
" `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
|
||||
)
|
||||
model = wrapped_model
|
||||
if isinstance(model, (DistributedDataParallel, DataParallel)):
|
||||
model = unwrap_lightning_module(model.module)
|
||||
if isinstance(model, _LightningModuleWrapperBase):
|
||||
model = model.lightning_module
|
||||
if isinstance(model, _LightningPrecisionModuleWrapperBase):
|
||||
model = model.module
|
||||
if not isinstance(model, pl.LightningModule):
|
||||
raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.")
|
||||
return model
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import numbers
|
||||
import warnings
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
|
@ -52,23 +52,15 @@ class LightningParallelModule(_LightningModuleWrapperBase):
|
|||
)
|
||||
|
||||
Args:
|
||||
pl_module: The module to wrap. See description for `forward_module`.
|
||||
|
||||
.. deprecated:: v1.8.0
|
||||
The argument ``pl_module`` is deprecated in v1.8.0 and will be removed in v2.0.0. Please use
|
||||
``forward_module`` instead.
|
||||
|
||||
forward_module: The module to wrap. If it's not a ``LightningModule``, it must have an attribute ``.module``
|
||||
pointing to a ``LightningModule`` reference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
|
||||
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
|
||||
forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase],
|
||||
) -> None:
|
||||
self._validate_init_arguments(pl_module, forward_module)
|
||||
super().__init__(forward_module=(pl_module or forward_module))
|
||||
super().__init__(forward_module=forward_module)
|
||||
_ignore_scalar_return_in_dp()
|
||||
|
||||
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
|
||||
|
|
|
@ -12,26 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
from typing import Any, cast, Iterable, Iterator, List, Optional, Sized, Union
|
||||
from typing import Any, cast, Iterable, Iterator, List, Sized, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
|
||||
|
||||
class LightningDistributedModule(_LightningModuleWrapperBase):
|
||||
def __init__(
|
||||
self,
|
||||
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
|
||||
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
|
||||
) -> None:
|
||||
self._validate_init_arguments(pl_module, forward_module)
|
||||
super().__init__(forward_module=(pl_module or forward_module))
|
||||
|
||||
|
||||
def _find_tensors(
|
||||
|
|
|
@ -11,21 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Union
|
||||
from typing import List
|
||||
|
||||
import torch.nn as nn
|
||||
from lightning_utilities.core.imports import package_available
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.plugins import Precision
|
||||
from lightning_fabric.utilities.imports import _IS_WINDOWS
|
||||
from pytorch_lightning.overrides.base import (
|
||||
_LightningModuleWrapperBase,
|
||||
_LightningPrecisionModuleWrapperBase,
|
||||
unwrap_lightning_module,
|
||||
)
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and package_available("fairscale")
|
||||
|
||||
|
@ -35,37 +27,6 @@ else:
|
|||
OSS = object
|
||||
|
||||
|
||||
class LightningShardedDataParallel(_LightningModuleWrapperBase):
|
||||
def __init__(
|
||||
self,
|
||||
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
|
||||
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
|
||||
) -> None:
|
||||
rank_zero_deprecation(
|
||||
"PyTorch Lightning's sharded implementation using FairScale has been deprecated in v1.9.0 and will be"
|
||||
" removed in v2.0.0. You can try using the `Trainer(strategy='fsdp_native')` instead."
|
||||
" The difference is that native FSDP uses PyTorch's implementation and the current strategy uses"
|
||||
" FairScale's implementation (which was upstreamed to PyTorch). After removal, `strategy='fsdp'` will use"
|
||||
" the native version by default."
|
||||
)
|
||||
self._validate_init_arguments(pl_module, forward_module)
|
||||
super().__init__(forward_module=(pl_module or forward_module))
|
||||
|
||||
|
||||
def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule":
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
|
||||
rank_zero_deprecation(
|
||||
"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v2.0.0."
|
||||
" Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
|
||||
)
|
||||
model = wrapped_model
|
||||
if isinstance(model, ShardedDataParallel):
|
||||
model = model.module
|
||||
|
||||
return unwrap_lightning_module(model, _suppress_warning=True)
|
||||
|
||||
|
||||
def _reinit_optimizers_with_oss(optimizers: List[Optimizer], precision: Precision, num_nodes: int) -> List["OSS"]:
|
||||
for x, optimizer in enumerate(optimizers):
|
||||
if not isinstance(optimizer, OSS):
|
||||
|
|
|
@ -64,11 +64,8 @@ log = logging.getLogger(__name__)
|
|||
class LightningBaguaModule(_LightningModuleWrapperBase):
|
||||
def __init__(
|
||||
self,
|
||||
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
|
||||
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
|
||||
forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase],
|
||||
) -> None:
|
||||
self._validate_init_arguments(pl_module, forward_module)
|
||||
forward_module = pl_module or forward_module
|
||||
super().__init__(forward_module=forward_module)
|
||||
# Bagua use `bagua_module_name` to distinguish different modules
|
||||
self._bagua_module_name = f"{forward_module.__class__.__name__}{id(forward_module)}"
|
||||
|
|
|
@ -43,8 +43,7 @@ from lightning_fabric.utilities.optimizer import _optimizers_to_device
|
|||
from lightning_fabric.utilities.seed import reset_seed
|
||||
from lightning_fabric.utilities.types import ReduceOp
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
|
@ -291,7 +290,7 @@ class DDPStrategy(ParallelStrategy):
|
|||
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
|
||||
self.pre_configure_ddp()
|
||||
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
|
||||
self.model = self._setup_model(LightningDistributedModule(self.model))
|
||||
self.model = self._setup_model(_LightningModuleWrapperBase(self.model))
|
||||
self._register_ddp_hooks()
|
||||
|
||||
def determine_ddp_device_ids(self) -> Optional[List[int]]:
|
||||
|
|
|
@ -36,8 +36,7 @@ from lightning_fabric.utilities.distributed import group as _group
|
|||
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_11
|
||||
from lightning_fabric.utilities.optimizer import _optimizers_to_device
|
||||
from lightning_fabric.utilities.types import ReduceOp
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
|
||||
|
@ -211,7 +210,7 @@ class DDPSpawnStrategy(ParallelStrategy):
|
|||
def configure_ddp(self) -> None:
|
||||
self.pre_configure_ddp()
|
||||
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
|
||||
self.model = self._setup_model(LightningDistributedModule(self.model))
|
||||
self.model = self._setup_model(_LightningModuleWrapperBase(self.model))
|
||||
self._register_ddp_hooks()
|
||||
|
||||
# set up optimizers after the wrapped module has been moved to the device
|
||||
|
|
|
@ -22,7 +22,7 @@ from torch.optim.optimizer import Optimizer
|
|||
import pytorch_lightning as pl
|
||||
from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment
|
||||
from lightning_fabric.utilities.distributed import group as _group
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
|
||||
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
|
||||
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||
|
@ -123,7 +123,7 @@ class HPUParallelStrategy(DDPStrategy):
|
|||
if _TORCH_LESSER_EQUAL_1_10_2:
|
||||
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
|
||||
self._pre_configure_ddp()
|
||||
self.model = self._setup_model(LightningDistributedModule(self.model)) # type: ignore
|
||||
self.model = self._setup_model(_LightningModuleWrapperBase(self.model)) # type: ignore
|
||||
if self.root_device.type == "hpu" and self._static_graph:
|
||||
self._model._set_static_graph() # type: ignore
|
||||
self._register_ddp_hooks()
|
||||
|
|
|
@ -28,7 +28,7 @@ from lightning_fabric.plugins.environments import XLAEnvironment
|
|||
from lightning_fabric.utilities.data import has_len
|
||||
from lightning_fabric.utilities.optimizer import _optimizers_to_device
|
||||
from lightning_fabric.utilities.types import _PATH, ReduceOp
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
|
||||
|
@ -128,7 +128,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
|
|||
TPUSpawnStrategy._validate_patched_dataloaders(model)
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
|
||||
self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model))
|
||||
self.wrapped_model = xmp.MpModelWrapper(_LightningModuleWrapperBase(model))
|
||||
return super().connect(model)
|
||||
|
||||
def _configure_launcher(self) -> None:
|
||||
|
|
|
@ -17,17 +17,12 @@ from unittest import mock
|
|||
import numpy
|
||||
import pytest
|
||||
import torch
|
||||
from lightning_utilities.test.warning import no_warning_call
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
||||
from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule
|
||||
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
||||
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded
|
||||
from pytorch_lightning.demos.boring_classes import RandomDataset
|
||||
from pytorch_lightning.plugins.environments import LightningEnvironment
|
||||
from pytorch_lightning.strategies.bagua import LightningBaguaModule
|
||||
from pytorch_lightning.strategies.utils import on_colab_kaggle
|
||||
from pytorch_lightning.utilities.apply_func import (
|
||||
apply_to_collection,
|
||||
|
@ -62,51 +57,6 @@ from pytorch_lightning.utilities.distributed import (
|
|||
from pytorch_lightning.utilities.optimizer import optimizer_to_device, optimizers_to_device
|
||||
from pytorch_lightning.utilities.seed import pl_worker_init_function, reset_seed, seed_everything
|
||||
from pytorch_lightning.utilities.xla_device import inner_f, pl_multi_process, XLADeviceUtils
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wrapper_class",
|
||||
[
|
||||
LightningParallelModule,
|
||||
LightningDistributedModule,
|
||||
LightningBaguaModule,
|
||||
],
|
||||
)
|
||||
def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class):
|
||||
with no_warning_call(
|
||||
DeprecationWarning, match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8.0"
|
||||
):
|
||||
wrapper_class(BoringModel())
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8.0"
|
||||
):
|
||||
wrapper_class(pl_module=BoringModel())
|
||||
|
||||
|
||||
@RunIf(fairscale=True)
|
||||
def test_v1_10_deprecated_fairscale_pl_module_init_parameter():
|
||||
with no_warning_call(
|
||||
DeprecationWarning, match=r"The argument `pl_module` in `LightningShardedDataParallel` is deprecated in v1.8.0"
|
||||
), pytest.deprecated_call(match="FairScale has been deprecated in v1.9.0"):
|
||||
LightningShardedDataParallel(BoringModel())
|
||||
|
||||
with pytest.deprecated_call(
|
||||
match=r"The argument `pl_module` in `LightningShardedDataParallel` is deprecated in v1.8.0"
|
||||
):
|
||||
LightningShardedDataParallel(pl_module=BoringModel())
|
||||
|
||||
|
||||
def test_v1_10_deprecated_unwrap_lightning_module():
|
||||
with pytest.deprecated_call(match=r"The function `unwrap_lightning_module` is deprecated in v1.8.0"):
|
||||
unwrap_lightning_module(BoringModel())
|
||||
|
||||
|
||||
@RunIf(fairscale=True)
|
||||
def test_v1_10_deprecated_unwrap_lightning_module_sharded():
|
||||
with pytest.deprecated_call(match=r"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0"):
|
||||
unwrap_lightning_module_sharded(BoringModel())
|
||||
|
||||
|
||||
def test_v1_10_deprecated_on_colab_kaggle_func():
|
||||
|
|
|
@ -13,14 +13,9 @@
|
|||
# limitations under the License.
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import DataParallel
|
||||
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.overrides.base import (
|
||||
_LightningModuleWrapperBase,
|
||||
_LightningPrecisionModuleWrapperBase,
|
||||
unwrap_lightning_module,
|
||||
)
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wrapper_class", [_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase])
|
||||
|
@ -30,13 +25,3 @@ def test_wrapper_device_dtype(wrapper_class):
|
|||
|
||||
wrapped_model.to(dtype=torch.float16)
|
||||
assert model.dtype == torch.float16
|
||||
|
||||
|
||||
def test_unwrap_lightning_module():
|
||||
model = BoringModel()
|
||||
wrapped_model = _LightningPrecisionModuleWrapperBase(model)
|
||||
wrapped_model = _LightningModuleWrapperBase(wrapped_model)
|
||||
wrapped_model = DataParallel(wrapped_model)
|
||||
|
||||
with pytest.deprecated_call(match="The function `unwrap_lightning_module` is deprecated in v1.8.0"):
|
||||
assert unwrap_lightning_module(wrapped_model) == model
|
||||
|
|
|
@ -20,7 +20,7 @@ from torch.nn import DataParallel
|
|||
|
||||
from pytorch_lightning import LightningModule
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.overrides.data_parallel import (
|
||||
LightningParallelModule,
|
||||
python_scalar_to_tensor,
|
||||
|
@ -30,7 +30,7 @@ from pytorch_lightning.trainer.states import RunningStage
|
|||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wrapper_class", [LightningParallelModule, LightningDistributedModule])
|
||||
@pytest.mark.parametrize("wrapper_class", [LightningParallelModule, _LightningModuleWrapperBase])
|
||||
@pytest.mark.parametrize(
|
||||
"stage",
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue