Support individual setup of model and optimizer in Lite (#15185)

This commit is contained in:
Adrian Wälchli 2022-11-11 14:36:59 +01:00 committed by GitHub
parent 32e319af27
commit 0dfb3d28ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 315 additions and 40 deletions

View File

@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- -
- - Added `LightningLite.setup_module()` and `LightningLite.setup_optimizers()` to support strategies that need to set up the model before an optimizer can be created ([#15185](https://github.com/Lightning-AI/lightning/pull/15185))
### Changed ### Changed

View File

@ -31,7 +31,14 @@ from torch.utils.data import BatchSampler, DataLoader, DistributedSampler
from lightning_lite.plugins import Precision # avoid circular imports: # isort: split from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
from lightning_lite.accelerators.accelerator import Accelerator from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
from lightning_lite.strategies import DeepSpeedStrategy, SingleDeviceStrategy, Strategy, XLAStrategy from lightning_lite.strategies import (
DDPShardedStrategy,
DDPSpawnShardedStrategy,
DeepSpeedStrategy,
SingleDeviceStrategy,
Strategy,
XLAStrategy,
)
from lightning_lite.strategies.strategy import _Sharded, TBroadcast from lightning_lite.strategies.strategy import _Sharded, TBroadcast
from lightning_lite.utilities import move_data_to_device from lightning_lite.utilities import move_data_to_device
from lightning_lite.utilities.apply_func import convert_to_tensors from lightning_lite.utilities.apply_func import convert_to_tensors
@ -139,42 +146,100 @@ class LightningLite(ABC):
def setup( def setup(
self, self,
model: nn.Module, module: nn.Module,
*optimizers: Optimizer, *optimizers: Optimizer,
move_to_device: bool = True, move_to_device: bool = True,
) -> Any: # no specific return because the way we want our API to look does not play well with mypy ) -> Any: # no specific return because the way we want our API to look does not play well with mypy
"""Set up a model and its optimizers for accelerated training. """Set up a model and its optimizers for accelerated training.
Args: Args:
model: A model to set up module: A :class:`torch.nn.Module` to set up
*optimizers: The optimizer(s) to set up (no optimizers is also possible) *optimizers: The optimizer(s) to set up (no optimizers is also possible)
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually. and alternatively use :meth:`to_device` manually.
Returns: Returns:
The tuple of the wrapped model and list of optimizers, in the same order they were passed in. The tuple containing wrapped module and the optimizers, in the same order they were passed in.
""" """
self._validate_setup(model, optimizers) self._validate_setup(module, optimizers)
original_model = model original_module = module
model = self._precision.convert_module(model) module = self._precision.convert_module(module)
if move_to_device: if move_to_device:
model = self._move_model_to_device(model=model, optimizers=list(optimizers)) module = self._move_model_to_device(model=module, optimizers=list(optimizers))
# Let accelerator/plugin wrap and connect the models and optimizers # Let accelerator/plugin wrap and connect the models and optimizers
model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) if optimizers:
model = _LiteModule(model, self._precision, original_module=original_model) module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
module, list(optimizers)
)
else:
module = self._strategy.setup_module(module)
module = _LiteModule(module, self._precision, original_module=original_module)
# Update the _DeviceDtypeModuleMixin's device parameter # Update the _DeviceDtypeModuleMixin's device parameter
model.to(self.device if move_to_device else next(model.parameters()).device) module.to(self.device if move_to_device else next(module.parameters()).device)
optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
self._models_setup += 1 self._models_setup += 1
if optimizers: if optimizers:
# join both types in a list for API convenience # join both types in a tuple for API convenience
return [model] + optimizers return tuple((module, *optimizers))
return model return module
def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _LiteModule:
"""Set up a model for accelerated training or inference.
This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain
strategies like `FSDP` that require setting up the module before the optimizer can be created and set up.
See also :meth:`setup_optimizers`.
Args:
module: A :class:`torch.nn.Module` to set up
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
and alternatively use :meth:`to_device` manually.
Returns:
The wrapped model.
"""
self._validate_setup_module(module)
original_module = module
module = self._precision.convert_module(module)
if move_to_device:
module = self._move_model_to_device(model=module, optimizers=[])
# Let strategy wrap and connect the module alone
module = self._strategy.setup_module(module)
module = _LiteModule(module, self._precision, original_module=original_module)
# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters()).device)
self._models_setup += 1
return module
def setup_optimizers(self, *optimizers: Optimizer) -> Union[_LiteOptimizer, Tuple[_LiteOptimizer, ...]]:
"""Set up one or more optimizers for accelerated training.
Some strategies do not allow setting up model and optimizer independently. For them, you should call
``.setup(model, optimizer, ...)`` instead to jointly set them up.
Args:
*optimizers: One or more optmizers to set up.
Returns:
The wrapped optimizer(s).
"""
self._validate_setup_optimizers(optimizers)
optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers]
optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
return optimizers[0] if len(optimizers) == 1 else tuple(optimizers)
def setup_dataloaders( def setup_dataloaders(
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
@ -529,17 +594,44 @@ class LightningLite(ABC):
setattr(self, "run", partial(self._run_impl, self.run)) setattr(self, "run", partial(self._run_impl, self.run))
@staticmethod @staticmethod
def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: def _validate_setup(module: nn.Module, optimizers: Sequence[Optimizer]) -> None:
if isinstance(model, _LiteModule): if isinstance(module, _LiteModule):
raise ValueError("A model should be passed only once to the `setup` method.") raise ValueError("A model should be passed only once to the `setup` method.")
if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): if any(isinstance(opt, _LiteOptimizer) for opt in optimizers):
raise ValueError("An optimizer should be passed only once to the `setup` method.") raise ValueError("An optimizer should be passed only once to the `setup` method.")
def _validate_setup_module(self, module: nn.Module) -> None:
if isinstance(module, _LiteModule):
raise ValueError("A model should be passed only once to the `setup_module` method.")
if isinstance(self._strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)):
raise RuntimeError(
f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly"
" through `.setup(model, optimizer, ...)`. For inference, choose a different strategy, for example"
" `ddp`."
)
def _validate_setup_optimizers(self, optimizers: Sequence[Optimizer]) -> None:
if isinstance(self._strategy, (DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy)):
raise RuntimeError(
f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly"
" through `.setup(model, optimizer, ...)`."
)
if not optimizers:
raise ValueError("`setup_optimizers` requires at least one optimizer as input.")
if any(isinstance(opt, _LiteOptimizer) for opt in optimizers):
raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.")
@staticmethod @staticmethod
def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None: def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:
if not dataloaders:
raise ValueError("`setup_dataloaders` requires at least one dataloader as input.")
if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders): if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders):
raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method") raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method.")
if any(not isinstance(dl, DataLoader) for dl in dataloaders): if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")

View File

@ -35,7 +35,7 @@ from lightning_lite.utilities.distributed import log
from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.rank_zero import rank_zero_info from lightning_lite.utilities.rank_zero import rank_zero_info
from lightning_lite.utilities.seed import reset_seed from lightning_lite.utilities.seed import reset_seed
from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau from lightning_lite.utilities.types import _PATH
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE: if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
@ -305,11 +305,11 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
return self._deepspeed_engine return self._deepspeed_engine
def setup_module_and_optimizers( def setup_module_and_optimizers(
self, model: Module, optimizers: List[Optimizer] self, module: Module, optimizers: List[Optimizer]
) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]: ) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]:
"""Setup a model and multiple optimizers together. """Set up a model and multiple optimizers together.
Currently only a single optimizer is supported. Currently, only a single optimizer is supported.
Return: Return:
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
@ -321,10 +321,25 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
f" Got {len(optimizers)} optimizers instead." f" Got {len(optimizers)} optimizers instead."
) )
self._deepspeed_engine, optimizer = self._setup_module_and_optimizer(model, optimizers[0]) self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0])
self._set_deepspeed_activation_checkpointing() self._set_deepspeed_activation_checkpointing()
return self._deepspeed_engine, [optimizer] return self._deepspeed_engine, [optimizer]
def setup_module(self, module: Module) -> "deepspeed.DeepSpeedEngine":
"""Set up a module for inference (no optimizers).
For training, see :meth:`setup_module_and_optimizers`.
"""
self._deepspeed_engine, _ = self._initialize_engine(module)
return self._deepspeed_engine
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Optimizers can only be set up jointly with the model in this strategy.
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together.
"""
raise NotImplementedError(self._err_msg_joint_setup_required())
@contextmanager @contextmanager
def module_sharded_context(self) -> Generator[None, None, None]: def module_sharded_context(self) -> Generator[None, None, None]:
# Current limitation in Lite: The config needs to be fully determined at the time of calling the # Current limitation in Lite: The config needs to be fully determined at the time of calling the
@ -401,11 +416,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
offload_optimizer_device="nvme", offload_optimizer_device="nvme",
) )
def _setup_module_and_optimizer( def _initialize_engine(
self, self,
model: Module, model: Module,
optimizer: Optional[Optimizer], optimizer: Optional[Optimizer] = None,
lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None,
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]: ) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
"""Initialize one model and one optimizer with an optional learning rate scheduler. """Initialize one model and one optimizer with an optional learning rate scheduler.
@ -420,7 +434,6 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
model=model, model=model,
model_parameters=model_parameters, model_parameters=model_parameters,
optimizer=optimizer, optimizer=optimizer,
lr_scheduler=lr_scheduler,
dist_init_required=False, dist_init_required=False,
) )
return deepspeed_engine, deepspeed_optimizer return deepspeed_engine, deepspeed_optimizer

View File

@ -18,6 +18,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
import torch import torch
from lightning_utilities.core.imports import module_available from lightning_utilities.core.imports import module_available
from torch.nn import Module from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer from torch.optim import Optimizer
from lightning_lite.accelerators import Accelerator from lightning_lite.accelerators import Accelerator
@ -89,6 +90,20 @@ class DDPShardedStrategy(DDPStrategy):
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers return model, optimizers
def setup_module(self, module: Module) -> DistributedDataParallel:
"""Setting up the module without optimizers in this strategy is not supported.
Please use :meth:`setup_module_and_optimizers` instead.
"""
raise NotImplementedError(self._err_msg_joint_setup_required())
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Optimizers can only be set up jointly with the model in this strategy.
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer(s) together.
"""
raise NotImplementedError(self._err_msg_joint_setup_required())
@classmethod @classmethod
def register_strategies(cls, strategy_registry: Dict) -> None: def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register( strategy_registry.register(
@ -153,6 +168,20 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers return model, optimizers
def setup_module(self, module: Module) -> DistributedDataParallel:
"""Setting up the module without optimizers in this strategy is not supported.
Please use :meth:`setup_module_and_optimizers` instead.
"""
raise NotImplementedError(self._err_msg_joint_setup_required())
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Optimizers can only be set up jointly with the model in this strategy.
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer(s) together.
"""
raise NotImplementedError(self._err_msg_joint_setup_required())
@classmethod @classmethod
def register_strategies(cls, strategy_registry: Dict) -> None: def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register( strategy_registry.register(

View File

@ -118,7 +118,7 @@ class Strategy(ABC):
"""Set up a model and multiple optimizers together. """Set up a model and multiple optimizers together.
The returned objects are expected to be in the same order they were passed in. The default implementation will The returned objects are expected to be in the same order they were passed in. The default implementation will
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs. call :meth:`setup_module` and :meth:`setup_optimizer` on the inputs.
""" """
module = self.setup_module(module) module = self.setup_module(module)
optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers] optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers]
@ -288,6 +288,12 @@ class Strategy(ABC):
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None: def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:
pass pass
def _err_msg_joint_setup_required(self) -> str:
return (
f"The `{type(self).__name__}` does not support setting up the module and optimizer(s) independently."
" Please call `setup_module_and_optimizers(model, [optimizer, ...])` to jointly set them up."
)
class _BackwardSyncControl(ABC): class _BackwardSyncControl(ABC):
"""Interface for any :class:`Strategy` that wants to offer a functionality to enable or disable gradient """Interface for any :class:`Strategy` that wants to offer a functionality to enable or disable gradient

View File

@ -13,8 +13,12 @@
# limitations under the License. # limitations under the License.
import json import json
import os import os
from re import escape
from unittest import mock
from unittest.mock import ANY, Mock
import pytest import pytest
import torch
from tests_lite.helpers.runif import RunIf from tests_lite.helpers.runif import RunIf
from lightning_lite.accelerators import CPUAccelerator from lightning_lite.accelerators import CPUAccelerator
@ -116,3 +120,34 @@ def test_deepspeed_config_zero_offload(deepspeed_zero_config):
deepspeed_zero_config["zero_optimization"]["offload_optimizer"] = False deepspeed_zero_config["zero_optimization"]["offload_optimizer"] = False
strategy = DeepSpeedStrategy(config=deepspeed_zero_config) strategy = DeepSpeedStrategy(config=deepspeed_zero_config)
assert strategy.config["zero_optimization"]["offload_optimizer"] is False assert strategy.config["zero_optimization"]["offload_optimizer"] is False
@RunIf(deepspeed=True)
@mock.patch("lightning_lite.strategies.deepspeed.deepspeed.initialize")
def test_deepspeed_setup_module(init_mock):
"""Test that the DeepSpeed strategy can set up the model for inference (no optimizer required)."""
model = Mock()
model.parameters.return_value = []
strategy = DeepSpeedStrategy()
strategy.parallel_devices = [torch.device("cuda", 1)]
init_mock.return_value = [Mock()] * 4 # mock to make tuple unpacking work
strategy.setup_module(model)
init_mock.assert_called_with(
args=ANY,
config=strategy.config,
model=model,
model_parameters=ANY,
optimizer=None,
dist_init_required=False,
)
@RunIf(deepspeed=True)
def test_deepspeed_requires_joint_setup():
"""Test that the DeepSpeed strategy does not support setting up model and optimizer independently."""
strategy = DeepSpeedStrategy()
with pytest.raises(
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
):
strategy.setup_optimizer(Mock())

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from re import escape
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, Mock from unittest.mock import MagicMock, Mock
@ -89,3 +90,18 @@ def test_fairscale_no_backward_sync(cls):
pass pass
module.no_sync.assert_called_once() module.no_sync.assert_called_once()
@pytest.mark.parametrize("cls", [DDPShardedStrategy, DDPSpawnShardedStrategy])
def test_fairscale_requires_joint_setup(cls):
"""Test that the fairscale sharded strategy does not support setting up model and optimizer independently."""
strategy = cls()
with pytest.raises(
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
):
strategy.setup_module(Mock())
with pytest.raises(
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
):
strategy.setup_optimizer(Mock())

View File

@ -27,7 +27,16 @@ from torch.utils.data import DataLoader, DistributedSampler, Sampler
from lightning_lite.lite import LightningLite from lightning_lite.lite import LightningLite
from lightning_lite.plugins import Precision from lightning_lite.plugins import Precision
from lightning_lite.strategies import DDPStrategy, ParallelStrategy, SingleDeviceStrategy, Strategy from lightning_lite.strategies import (
DDPShardedStrategy,
DDPSpawnShardedStrategy,
DDPStrategy,
DeepSpeedStrategy,
ParallelStrategy,
SingleDeviceStrategy,
Strategy,
XLAStrategy,
)
from lightning_lite.strategies.strategy import _Sharded from lightning_lite.strategies.strategy import _Sharded
from lightning_lite.utilities import _StrategyType from lightning_lite.utilities import _StrategyType
from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.exceptions import MisconfigurationException
@ -72,11 +81,13 @@ def test_run_input_output():
@mock.patch("lightning_lite.strategies.ddp.DistributedDataParallel") @mock.patch("lightning_lite.strategies.ddp.DistributedDataParallel")
def test_setup_model(ddp_mock): @pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_module(ddp_mock, setup_method):
"""Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model.""" """Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model."""
lite = EmptyLite(accelerator="cpu", strategy="ddp", devices=2) lite = EmptyLite(accelerator="cpu", strategy="ddp", devices=2)
model = nn.Linear(1, 2) model = nn.Linear(1, 2)
lite_model = lite.setup(model) setup_method = getattr(lite, setup_method)
lite_model = setup_method(model)
ddp_mock.assert_called_with(module=model, device_ids=ANY) ddp_mock.assert_called_with(module=model, device_ids=ANY)
assert lite_model.module == model assert lite_model.module == model
assert lite_model.weight is model.weight assert lite_model.weight is model.weight
@ -95,7 +106,8 @@ def test_setup_model(ddp_mock):
], ],
) )
@pytest.mark.parametrize("move_to_device", [True, False]) @pytest.mark.parametrize("move_to_device", [True, False])
def test_setup_model_move_to_device(move_to_device, accelerator, initial_device, target_device): @pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_module_move_to_device(setup_method, move_to_device, accelerator, initial_device, target_device):
"""Test that `move_to_device` leads to parameters being moved to the correct device and that the device """Test that `move_to_device` leads to parameters being moved to the correct device and that the device
attributes on the wrapper are updated.""" attributes on the wrapper are updated."""
initial_device = torch.device(initial_device) initial_device = torch.device(initial_device)
@ -105,7 +117,8 @@ def test_setup_model_move_to_device(move_to_device, accelerator, initial_device,
lite = EmptyLite(accelerator=accelerator, devices=1) lite = EmptyLite(accelerator=accelerator, devices=1)
model = nn.Linear(1, 2) model = nn.Linear(1, 2)
model.to(initial_device) model.to(initial_device)
lite_model = lite.setup(model, move_to_device=move_to_device) setup_method = getattr(lite, setup_method)
lite_model = setup_method(model, move_to_device=move_to_device)
# all parameters on the expected device # all parameters on the expected device
assert all(param.device == expected_device for param in model.parameters()) assert all(param.device == expected_device for param in model.parameters())
@ -117,7 +130,8 @@ def test_setup_model_move_to_device(move_to_device, accelerator, initial_device,
@RunIf(min_cuda_gpus=1) @RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("move_to_device", [True, False]) @pytest.mark.parametrize("move_to_device", [True, False])
def test_setup_model_parameters_on_different_devices(move_to_device): @pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_module_parameters_on_different_devices(setup_method, move_to_device):
"""Test that a warning is emitted when model parameters are on a different device prior to calling """Test that a warning is emitted when model parameters are on a different device prior to calling
`setup()`.""" `setup()`."""
device0 = torch.device("cpu") device0 = torch.device("cpu")
@ -129,9 +143,11 @@ def test_setup_model_parameters_on_different_devices(move_to_device):
module1 = nn.Linear(1, 2).to(device1) module1 = nn.Linear(1, 2).to(device1)
model = nn.Sequential(module0, module1) model = nn.Sequential(module0, module1)
setup_method = getattr(lite, setup_method)
if move_to_device: if move_to_device:
with pytest.warns(PossibleUserWarning, match="has parameters on different devices"): with pytest.warns(PossibleUserWarning, match="has parameters on different devices"):
lite_model = lite.setup(model, move_to_device=move_to_device) lite_model = setup_method(model, move_to_device=move_to_device)
# both have the same device now # both have the same device now
assert lite_model.device == device1 assert lite_model.device == device1
@ -139,11 +155,11 @@ def test_setup_model_parameters_on_different_devices(move_to_device):
assert module1.weight.device == module1.bias.device == device1 assert module1.weight.device == module1.bias.device == device1
else: else:
with no_warning_call(expected_warning=PossibleUserWarning, match="has parameters on different devices"): with no_warning_call(expected_warning=PossibleUserWarning, match="has parameters on different devices"):
lite.setup(model, move_to_device=move_to_device) setup_method(model, move_to_device=move_to_device)
def test_setup_optimizers(): def test_setup_module_and_optimizers():
"""Test that setup_optimizers can handle no optimizers, one optimizer, or multiple optimizers.""" """Test that `setup()` can handle no optimizers, one optimizer, or multiple optimizers."""
lite = EmptyLite() lite = EmptyLite()
model = nn.Linear(1, 2) model = nn.Linear(1, 2)
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1) optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
@ -171,8 +187,28 @@ def test_setup_optimizers():
assert lite_optimizer1.optimizer is optimizer1 assert lite_optimizer1.optimizer is optimizer1
def test_setup_optimizers():
"""Test that `setup_optimizers()` can handle one or more optimizers."""
lite = EmptyLite()
model = nn.Linear(1, 2)
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
# single optimizer
lite_optimizer = lite.setup_optimizers(optimizer0)
assert isinstance(lite_optimizer, _LiteOptimizer)
assert lite_optimizer.optimizer is optimizer0
# multiple optimizers
lite_optimizer0, lite_optimizer1 = lite.setup_optimizers(optimizer0, optimizer1)
assert isinstance(lite_optimizer0, _LiteOptimizer)
assert isinstance(lite_optimizer1, _LiteOptimizer)
assert lite_optimizer0.optimizer is optimizer0
assert lite_optimizer1.optimizer is optimizer1
def test_setup_twice_fails(): def test_setup_twice_fails():
"""Test that calling setup with a model or optimizer that is already wrapped fails.""" """Test that calling `setup` with a model or optimizer that is already wrapped fails."""
lite = EmptyLite() lite = EmptyLite()
model = nn.Linear(1, 2) model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters()) optimizer = torch.optim.Adam(model.parameters())
@ -186,6 +222,49 @@ def test_setup_twice_fails():
lite.setup(model, lite_optimizer) lite.setup(model, lite_optimizer)
def test_setup_module_twice_fails():
"""Test that calling `setup_module` with a model that is already wrapped fails."""
lite = EmptyLite()
model = nn.Linear(1, 2)
lite_model = lite.setup_module(model)
with pytest.raises(ValueError, match="A model should be passed only once to the"):
lite.setup_module(lite_model)
def test_setup_optimizers_twice_fails():
"""Test that calling `setup_module` with a model that is already wrapped fails."""
lite = EmptyLite()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())
lite_optimizer = lite.setup_optimizers(optimizer)
with pytest.raises(ValueError, match="An optimizer should be passed only once to"):
lite.setup_optimizers(lite_optimizer)
@pytest.mark.parametrize("strategy_cls", [DDPShardedStrategy, DDPSpawnShardedStrategy])
def test_setup_module_not_supported(strategy_cls):
"""Test that `setup_module` validates the strategy supports setting up model and optimizers independently."""
lite = EmptyLite()
model = nn.Linear(1, 2)
lite._strategy = Mock(spec=strategy_cls)
with pytest.raises(RuntimeError, match=escape("requires the model and optimizer(s) to be set up jointly through")):
lite.setup_module(model)
@pytest.mark.parametrize("strategy_cls", [DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy])
def test_setup_optimizers_not_supported(strategy_cls):
"""Test that `setup_optimizers` validates the strategy supports setting up model and optimizers
independently."""
lite = EmptyLite()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())
lite._strategy = Mock(spec=strategy_cls)
with pytest.raises(RuntimeError, match=escape("requires the model and optimizer(s) to be set up jointly through")):
lite.setup_optimizers(optimizer)
def test_setup_tracks_num_models(): def test_setup_tracks_num_models():
"""Test that setup() tracks how many times it has setup a model.""" """Test that setup() tracks how many times it has setup a model."""
lite = EmptyLite() lite = EmptyLite()
@ -199,10 +278,15 @@ def test_setup_tracks_num_models():
lite.setup(model, optimizer) lite.setup(model, optimizer)
assert lite._models_setup == 2 assert lite._models_setup == 2
lite.setup_module(model)
assert lite._models_setup == 3
def test_setup_dataloaders_unsupported_type():
def test_setup_dataloaders_unsupported_input():
"""Test that the setup_dataloaders method fails when provided with non-DataLoader objects.""" """Test that the setup_dataloaders method fails when provided with non-DataLoader objects."""
lite = EmptyLite() lite = EmptyLite()
with pytest.raises(ValueError, match="`setup_dataloaders` requires at least one dataloader"):
lite.setup_dataloaders()
with pytest.raises(TypeError, match="Only PyTorch DataLoader are currently supported"): with pytest.raises(TypeError, match="Only PyTorch DataLoader are currently supported"):
lite.setup_dataloaders(range(2)) # type: ignore lite.setup_dataloaders(range(2)) # type: ignore