From 0dfb3d28ce858e5d709cba468b374a3d41329655 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 11 Nov 2022 14:36:59 +0100 Subject: [PATCH] Support individual setup of model and optimizer in Lite (#15185) --- src/lightning_lite/CHANGELOG.md | 2 +- src/lightning_lite/lite.py | 126 +++++++++++++++--- src/lightning_lite/strategies/deepspeed.py | 31 +++-- src/lightning_lite/strategies/fairscale.py | 29 ++++ src/lightning_lite/strategies/strategy.py | 8 +- tests/tests_lite/strategies/test_deepspeed.py | 35 +++++ tests/tests_lite/strategies/test_fairscale.py | 16 +++ tests/tests_lite/test_lite.py | 108 +++++++++++++-- 8 files changed, 315 insertions(+), 40 deletions(-) diff --git a/src/lightning_lite/CHANGELOG.md b/src/lightning_lite/CHANGELOG.md index ba128a6793..b7c64c315a 100644 --- a/src/lightning_lite/CHANGELOG.md +++ b/src/lightning_lite/CHANGELOG.md @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -- +- Added `LightningLite.setup_module()` and `LightningLite.setup_optimizers()` to support strategies that need to set up the model before an optimizer can be created ([#15185](https://github.com/Lightning-AI/lightning/pull/15185)) ### Changed diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index b058dd446f..55c76ed227 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -31,7 +31,14 @@ from torch.utils.data import BatchSampler, DataLoader, DistributedSampler from lightning_lite.plugins import Precision # avoid circular imports: # isort: split from lightning_lite.accelerators.accelerator import Accelerator from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT -from lightning_lite.strategies import DeepSpeedStrategy, SingleDeviceStrategy, Strategy, XLAStrategy +from lightning_lite.strategies import ( + DDPShardedStrategy, + DDPSpawnShardedStrategy, + DeepSpeedStrategy, + SingleDeviceStrategy, + Strategy, + XLAStrategy, +) from lightning_lite.strategies.strategy import _Sharded, TBroadcast from lightning_lite.utilities import move_data_to_device from lightning_lite.utilities.apply_func import convert_to_tensors @@ -139,42 +146,100 @@ class LightningLite(ABC): def setup( self, - model: nn.Module, + module: nn.Module, *optimizers: Optimizer, move_to_device: bool = True, ) -> Any: # no specific return because the way we want our API to look does not play well with mypy """Set up a model and its optimizers for accelerated training. Args: - model: A model to set up + module: A :class:`torch.nn.Module` to set up *optimizers: The optimizer(s) to set up (no optimizers is also possible) move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually. Returns: - The tuple of the wrapped model and list of optimizers, in the same order they were passed in. + The tuple containing wrapped module and the optimizers, in the same order they were passed in. """ - self._validate_setup(model, optimizers) - original_model = model + self._validate_setup(module, optimizers) + original_module = module - model = self._precision.convert_module(model) + module = self._precision.convert_module(module) if move_to_device: - model = self._move_model_to_device(model=model, optimizers=list(optimizers)) + module = self._move_model_to_device(model=module, optimizers=list(optimizers)) # Let accelerator/plugin wrap and connect the models and optimizers - model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) - model = _LiteModule(model, self._precision, original_module=original_model) + if optimizers: + module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment] + module, list(optimizers) + ) + else: + module = self._strategy.setup_module(module) + + module = _LiteModule(module, self._precision, original_module=original_module) # Update the _DeviceDtypeModuleMixin's device parameter - model.to(self.device if move_to_device else next(model.parameters()).device) + module.to(self.device if move_to_device else next(module.parameters()).device) optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] + self._models_setup += 1 + if optimizers: - # join both types in a list for API convenience - return [model] + optimizers - return model + # join both types in a tuple for API convenience + return tuple((module, *optimizers)) + return module + + def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _LiteModule: + """Set up a model for accelerated training or inference. + + This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain + strategies like `FSDP` that require setting up the module before the optimizer can be created and set up. + See also :meth:`setup_optimizers`. + + Args: + module: A :class:`torch.nn.Module` to set up + move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` + and alternatively use :meth:`to_device` manually. + + Returns: + The wrapped model. + """ + self._validate_setup_module(module) + original_module = module + + module = self._precision.convert_module(module) + + if move_to_device: + module = self._move_model_to_device(model=module, optimizers=[]) + + # Let strategy wrap and connect the module alone + module = self._strategy.setup_module(module) + module = _LiteModule(module, self._precision, original_module=original_module) + + # Update the _DeviceDtypeModuleMixin's device parameter + module.to(self.device if move_to_device else next(module.parameters()).device) + + self._models_setup += 1 + return module + + def setup_optimizers(self, *optimizers: Optimizer) -> Union[_LiteOptimizer, Tuple[_LiteOptimizer, ...]]: + """Set up one or more optimizers for accelerated training. + + Some strategies do not allow setting up model and optimizer independently. For them, you should call + ``.setup(model, optimizer, ...)`` instead to jointly set them up. + + Args: + *optimizers: One or more optmizers to set up. + + Returns: + The wrapped optimizer(s). + """ + self._validate_setup_optimizers(optimizers) + optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] + optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] + return optimizers[0] if len(optimizers) == 1 else tuple(optimizers) def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True @@ -529,17 +594,44 @@ class LightningLite(ABC): setattr(self, "run", partial(self._run_impl, self.run)) @staticmethod - def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: - if isinstance(model, _LiteModule): + def _validate_setup(module: nn.Module, optimizers: Sequence[Optimizer]) -> None: + if isinstance(module, _LiteModule): raise ValueError("A model should be passed only once to the `setup` method.") if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): raise ValueError("An optimizer should be passed only once to the `setup` method.") + def _validate_setup_module(self, module: nn.Module) -> None: + if isinstance(module, _LiteModule): + raise ValueError("A model should be passed only once to the `setup_module` method.") + + if isinstance(self._strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): + raise RuntimeError( + f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly" + " through `.setup(model, optimizer, ...)`. For inference, choose a different strategy, for example" + " `ddp`." + ) + + def _validate_setup_optimizers(self, optimizers: Sequence[Optimizer]) -> None: + if isinstance(self._strategy, (DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy)): + raise RuntimeError( + f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly" + " through `.setup(model, optimizer, ...)`." + ) + + if not optimizers: + raise ValueError("`setup_optimizers` requires at least one optimizer as input.") + + if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): + raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.") + @staticmethod def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None: + if not dataloaders: + raise ValueError("`setup_dataloaders` requires at least one dataloader as input.") + if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders): - raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method") + raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method.") if any(not isinstance(dl, DataLoader) for dl in dataloaders): raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") diff --git a/src/lightning_lite/strategies/deepspeed.py b/src/lightning_lite/strategies/deepspeed.py index 835305ef43..57920aa8a9 100644 --- a/src/lightning_lite/strategies/deepspeed.py +++ b/src/lightning_lite/strategies/deepspeed.py @@ -35,7 +35,7 @@ from lightning_lite.utilities.distributed import log from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.rank_zero import rank_zero_info from lightning_lite.utilities.seed import reset_seed -from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau +from lightning_lite.utilities.types import _PATH _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") if TYPE_CHECKING and _DEEPSPEED_AVAILABLE: @@ -305,11 +305,11 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): return self._deepspeed_engine def setup_module_and_optimizers( - self, model: Module, optimizers: List[Optimizer] + self, module: Module, optimizers: List[Optimizer] ) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]: - """Setup a model and multiple optimizers together. + """Set up a model and multiple optimizers together. - Currently only a single optimizer is supported. + Currently, only a single optimizer is supported. Return: The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single @@ -321,10 +321,25 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): f" Got {len(optimizers)} optimizers instead." ) - self._deepspeed_engine, optimizer = self._setup_module_and_optimizer(model, optimizers[0]) + self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0]) self._set_deepspeed_activation_checkpointing() return self._deepspeed_engine, [optimizer] + def setup_module(self, module: Module) -> "deepspeed.DeepSpeedEngine": + """Set up a module for inference (no optimizers). + + For training, see :meth:`setup_module_and_optimizers`. + """ + self._deepspeed_engine, _ = self._initialize_engine(module) + return self._deepspeed_engine + + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Optimizers can only be set up jointly with the model in this strategy. + + Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together. + """ + raise NotImplementedError(self._err_msg_joint_setup_required()) + @contextmanager def module_sharded_context(self) -> Generator[None, None, None]: # Current limitation in Lite: The config needs to be fully determined at the time of calling the @@ -401,11 +416,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): offload_optimizer_device="nvme", ) - def _setup_module_and_optimizer( + def _initialize_engine( self, model: Module, - optimizer: Optional[Optimizer], - lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None, + optimizer: Optional[Optimizer] = None, ) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. @@ -420,7 +434,6 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): model=model, model_parameters=model_parameters, optimizer=optimizer, - lr_scheduler=lr_scheduler, dist_init_required=False, ) return deepspeed_engine, deepspeed_optimizer diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index 12895bcee5..93e6957e69 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -18,6 +18,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple import torch from lightning_utilities.core.imports import module_available from torch.nn import Module +from torch.nn.parallel import DistributedDataParallel from torch.optim import Optimizer from lightning_lite.accelerators import Accelerator @@ -89,6 +90,20 @@ class DDPShardedStrategy(DDPStrategy): model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers + def setup_module(self, module: Module) -> DistributedDataParallel: + """Setting up the module without optimizers in this strategy is not supported. + + Please use :meth:`setup_module_and_optimizers` instead. + """ + raise NotImplementedError(self._err_msg_joint_setup_required()) + + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Optimizers can only be set up jointly with the model in this strategy. + + Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer(s) together. + """ + raise NotImplementedError(self._err_msg_joint_setup_required()) + @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( @@ -153,6 +168,20 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers + def setup_module(self, module: Module) -> DistributedDataParallel: + """Setting up the module without optimizers in this strategy is not supported. + + Please use :meth:`setup_module_and_optimizers` instead. + """ + raise NotImplementedError(self._err_msg_joint_setup_required()) + + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Optimizers can only be set up jointly with the model in this strategy. + + Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer(s) together. + """ + raise NotImplementedError(self._err_msg_joint_setup_required()) + @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( diff --git a/src/lightning_lite/strategies/strategy.py b/src/lightning_lite/strategies/strategy.py index 0021894188..d013753473 100644 --- a/src/lightning_lite/strategies/strategy.py +++ b/src/lightning_lite/strategies/strategy.py @@ -118,7 +118,7 @@ class Strategy(ABC): """Set up a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will - call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs. + call :meth:`setup_module` and :meth:`setup_optimizer` on the inputs. """ module = self.setup_module(module) optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers] @@ -288,6 +288,12 @@ class Strategy(ABC): def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None: pass + def _err_msg_joint_setup_required(self) -> str: + return ( + f"The `{type(self).__name__}` does not support setting up the module and optimizer(s) independently." + " Please call `setup_module_and_optimizers(model, [optimizer, ...])` to jointly set them up." + ) + class _BackwardSyncControl(ABC): """Interface for any :class:`Strategy` that wants to offer a functionality to enable or disable gradient diff --git a/tests/tests_lite/strategies/test_deepspeed.py b/tests/tests_lite/strategies/test_deepspeed.py index bea2096013..f1cac3fd4e 100644 --- a/tests/tests_lite/strategies/test_deepspeed.py +++ b/tests/tests_lite/strategies/test_deepspeed.py @@ -13,8 +13,12 @@ # limitations under the License. import json import os +from re import escape +from unittest import mock +from unittest.mock import ANY, Mock import pytest +import torch from tests_lite.helpers.runif import RunIf from lightning_lite.accelerators import CPUAccelerator @@ -116,3 +120,34 @@ def test_deepspeed_config_zero_offload(deepspeed_zero_config): deepspeed_zero_config["zero_optimization"]["offload_optimizer"] = False strategy = DeepSpeedStrategy(config=deepspeed_zero_config) assert strategy.config["zero_optimization"]["offload_optimizer"] is False + + +@RunIf(deepspeed=True) +@mock.patch("lightning_lite.strategies.deepspeed.deepspeed.initialize") +def test_deepspeed_setup_module(init_mock): + """Test that the DeepSpeed strategy can set up the model for inference (no optimizer required).""" + model = Mock() + model.parameters.return_value = [] + strategy = DeepSpeedStrategy() + strategy.parallel_devices = [torch.device("cuda", 1)] + init_mock.return_value = [Mock()] * 4 # mock to make tuple unpacking work + + strategy.setup_module(model) + init_mock.assert_called_with( + args=ANY, + config=strategy.config, + model=model, + model_parameters=ANY, + optimizer=None, + dist_init_required=False, + ) + + +@RunIf(deepspeed=True) +def test_deepspeed_requires_joint_setup(): + """Test that the DeepSpeed strategy does not support setting up model and optimizer independently.""" + strategy = DeepSpeedStrategy() + with pytest.raises( + NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently") + ): + strategy.setup_optimizer(Mock()) diff --git a/tests/tests_lite/strategies/test_fairscale.py b/tests/tests_lite/strategies/test_fairscale.py index 31857e0da1..4029402e19 100644 --- a/tests/tests_lite/strategies/test_fairscale.py +++ b/tests/tests_lite/strategies/test_fairscale.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from re import escape from unittest import mock from unittest.mock import MagicMock, Mock @@ -89,3 +90,18 @@ def test_fairscale_no_backward_sync(cls): pass module.no_sync.assert_called_once() + + +@pytest.mark.parametrize("cls", [DDPShardedStrategy, DDPSpawnShardedStrategy]) +def test_fairscale_requires_joint_setup(cls): + """Test that the fairscale sharded strategy does not support setting up model and optimizer independently.""" + strategy = cls() + with pytest.raises( + NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently") + ): + strategy.setup_module(Mock()) + + with pytest.raises( + NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently") + ): + strategy.setup_optimizer(Mock()) diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 0070e0893d..ba28837ff5 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -27,7 +27,16 @@ from torch.utils.data import DataLoader, DistributedSampler, Sampler from lightning_lite.lite import LightningLite from lightning_lite.plugins import Precision -from lightning_lite.strategies import DDPStrategy, ParallelStrategy, SingleDeviceStrategy, Strategy +from lightning_lite.strategies import ( + DDPShardedStrategy, + DDPSpawnShardedStrategy, + DDPStrategy, + DeepSpeedStrategy, + ParallelStrategy, + SingleDeviceStrategy, + Strategy, + XLAStrategy, +) from lightning_lite.strategies.strategy import _Sharded from lightning_lite.utilities import _StrategyType from lightning_lite.utilities.exceptions import MisconfigurationException @@ -72,11 +81,13 @@ def test_run_input_output(): @mock.patch("lightning_lite.strategies.ddp.DistributedDataParallel") -def test_setup_model(ddp_mock): +@pytest.mark.parametrize("setup_method", ["setup", "setup_module"]) +def test_setup_module(ddp_mock, setup_method): """Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model.""" lite = EmptyLite(accelerator="cpu", strategy="ddp", devices=2) model = nn.Linear(1, 2) - lite_model = lite.setup(model) + setup_method = getattr(lite, setup_method) + lite_model = setup_method(model) ddp_mock.assert_called_with(module=model, device_ids=ANY) assert lite_model.module == model assert lite_model.weight is model.weight @@ -95,7 +106,8 @@ def test_setup_model(ddp_mock): ], ) @pytest.mark.parametrize("move_to_device", [True, False]) -def test_setup_model_move_to_device(move_to_device, accelerator, initial_device, target_device): +@pytest.mark.parametrize("setup_method", ["setup", "setup_module"]) +def test_setup_module_move_to_device(setup_method, move_to_device, accelerator, initial_device, target_device): """Test that `move_to_device` leads to parameters being moved to the correct device and that the device attributes on the wrapper are updated.""" initial_device = torch.device(initial_device) @@ -105,7 +117,8 @@ def test_setup_model_move_to_device(move_to_device, accelerator, initial_device, lite = EmptyLite(accelerator=accelerator, devices=1) model = nn.Linear(1, 2) model.to(initial_device) - lite_model = lite.setup(model, move_to_device=move_to_device) + setup_method = getattr(lite, setup_method) + lite_model = setup_method(model, move_to_device=move_to_device) # all parameters on the expected device assert all(param.device == expected_device for param in model.parameters()) @@ -117,7 +130,8 @@ def test_setup_model_move_to_device(move_to_device, accelerator, initial_device, @RunIf(min_cuda_gpus=1) @pytest.mark.parametrize("move_to_device", [True, False]) -def test_setup_model_parameters_on_different_devices(move_to_device): +@pytest.mark.parametrize("setup_method", ["setup", "setup_module"]) +def test_setup_module_parameters_on_different_devices(setup_method, move_to_device): """Test that a warning is emitted when model parameters are on a different device prior to calling `setup()`.""" device0 = torch.device("cpu") @@ -129,9 +143,11 @@ def test_setup_model_parameters_on_different_devices(move_to_device): module1 = nn.Linear(1, 2).to(device1) model = nn.Sequential(module0, module1) + setup_method = getattr(lite, setup_method) + if move_to_device: with pytest.warns(PossibleUserWarning, match="has parameters on different devices"): - lite_model = lite.setup(model, move_to_device=move_to_device) + lite_model = setup_method(model, move_to_device=move_to_device) # both have the same device now assert lite_model.device == device1 @@ -139,11 +155,11 @@ def test_setup_model_parameters_on_different_devices(move_to_device): assert module1.weight.device == module1.bias.device == device1 else: with no_warning_call(expected_warning=PossibleUserWarning, match="has parameters on different devices"): - lite.setup(model, move_to_device=move_to_device) + setup_method(model, move_to_device=move_to_device) -def test_setup_optimizers(): - """Test that setup_optimizers can handle no optimizers, one optimizer, or multiple optimizers.""" +def test_setup_module_and_optimizers(): + """Test that `setup()` can handle no optimizers, one optimizer, or multiple optimizers.""" lite = EmptyLite() model = nn.Linear(1, 2) optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1) @@ -171,8 +187,28 @@ def test_setup_optimizers(): assert lite_optimizer1.optimizer is optimizer1 +def test_setup_optimizers(): + """Test that `setup_optimizers()` can handle one or more optimizers.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1) + optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1) + + # single optimizer + lite_optimizer = lite.setup_optimizers(optimizer0) + assert isinstance(lite_optimizer, _LiteOptimizer) + assert lite_optimizer.optimizer is optimizer0 + + # multiple optimizers + lite_optimizer0, lite_optimizer1 = lite.setup_optimizers(optimizer0, optimizer1) + assert isinstance(lite_optimizer0, _LiteOptimizer) + assert isinstance(lite_optimizer1, _LiteOptimizer) + assert lite_optimizer0.optimizer is optimizer0 + assert lite_optimizer1.optimizer is optimizer1 + + def test_setup_twice_fails(): - """Test that calling setup with a model or optimizer that is already wrapped fails.""" + """Test that calling `setup` with a model or optimizer that is already wrapped fails.""" lite = EmptyLite() model = nn.Linear(1, 2) optimizer = torch.optim.Adam(model.parameters()) @@ -186,6 +222,49 @@ def test_setup_twice_fails(): lite.setup(model, lite_optimizer) +def test_setup_module_twice_fails(): + """Test that calling `setup_module` with a model that is already wrapped fails.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + + lite_model = lite.setup_module(model) + with pytest.raises(ValueError, match="A model should be passed only once to the"): + lite.setup_module(lite_model) + + +def test_setup_optimizers_twice_fails(): + """Test that calling `setup_module` with a model that is already wrapped fails.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer = torch.optim.Adam(model.parameters()) + + lite_optimizer = lite.setup_optimizers(optimizer) + with pytest.raises(ValueError, match="An optimizer should be passed only once to"): + lite.setup_optimizers(lite_optimizer) + + +@pytest.mark.parametrize("strategy_cls", [DDPShardedStrategy, DDPSpawnShardedStrategy]) +def test_setup_module_not_supported(strategy_cls): + """Test that `setup_module` validates the strategy supports setting up model and optimizers independently.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + lite._strategy = Mock(spec=strategy_cls) + with pytest.raises(RuntimeError, match=escape("requires the model and optimizer(s) to be set up jointly through")): + lite.setup_module(model) + + +@pytest.mark.parametrize("strategy_cls", [DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy]) +def test_setup_optimizers_not_supported(strategy_cls): + """Test that `setup_optimizers` validates the strategy supports setting up model and optimizers + independently.""" + lite = EmptyLite() + model = nn.Linear(1, 2) + optimizer = torch.optim.Adam(model.parameters()) + lite._strategy = Mock(spec=strategy_cls) + with pytest.raises(RuntimeError, match=escape("requires the model and optimizer(s) to be set up jointly through")): + lite.setup_optimizers(optimizer) + + def test_setup_tracks_num_models(): """Test that setup() tracks how many times it has setup a model.""" lite = EmptyLite() @@ -199,10 +278,15 @@ def test_setup_tracks_num_models(): lite.setup(model, optimizer) assert lite._models_setup == 2 + lite.setup_module(model) + assert lite._models_setup == 3 -def test_setup_dataloaders_unsupported_type(): + +def test_setup_dataloaders_unsupported_input(): """Test that the setup_dataloaders method fails when provided with non-DataLoader objects.""" lite = EmptyLite() + with pytest.raises(ValueError, match="`setup_dataloaders` requires at least one dataloader"): + lite.setup_dataloaders() with pytest.raises(TypeError, match="Only PyTorch DataLoader are currently supported"): lite.setup_dataloaders(range(2)) # type: ignore