Support individual setup of model and optimizer in Lite (#15185)
This commit is contained in:
parent
32e319af27
commit
0dfb3d28ce
|
@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
-
|
||||
|
||||
-
|
||||
- Added `LightningLite.setup_module()` and `LightningLite.setup_optimizers()` to support strategies that need to set up the model before an optimizer can be created ([#15185](https://github.com/Lightning-AI/lightning/pull/15185))
|
||||
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -31,7 +31,14 @@ from torch.utils.data import BatchSampler, DataLoader, DistributedSampler
|
|||
from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
|
||||
from lightning_lite.accelerators.accelerator import Accelerator
|
||||
from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
|
||||
from lightning_lite.strategies import DeepSpeedStrategy, SingleDeviceStrategy, Strategy, XLAStrategy
|
||||
from lightning_lite.strategies import (
|
||||
DDPShardedStrategy,
|
||||
DDPSpawnShardedStrategy,
|
||||
DeepSpeedStrategy,
|
||||
SingleDeviceStrategy,
|
||||
Strategy,
|
||||
XLAStrategy,
|
||||
)
|
||||
from lightning_lite.strategies.strategy import _Sharded, TBroadcast
|
||||
from lightning_lite.utilities import move_data_to_device
|
||||
from lightning_lite.utilities.apply_func import convert_to_tensors
|
||||
|
@ -139,42 +146,100 @@ class LightningLite(ABC):
|
|||
|
||||
def setup(
|
||||
self,
|
||||
model: nn.Module,
|
||||
module: nn.Module,
|
||||
*optimizers: Optimizer,
|
||||
move_to_device: bool = True,
|
||||
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
|
||||
"""Set up a model and its optimizers for accelerated training.
|
||||
|
||||
Args:
|
||||
model: A model to set up
|
||||
module: A :class:`torch.nn.Module` to set up
|
||||
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
|
||||
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
|
||||
and alternatively use :meth:`to_device` manually.
|
||||
|
||||
Returns:
|
||||
The tuple of the wrapped model and list of optimizers, in the same order they were passed in.
|
||||
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
|
||||
"""
|
||||
self._validate_setup(model, optimizers)
|
||||
original_model = model
|
||||
self._validate_setup(module, optimizers)
|
||||
original_module = module
|
||||
|
||||
model = self._precision.convert_module(model)
|
||||
module = self._precision.convert_module(module)
|
||||
|
||||
if move_to_device:
|
||||
model = self._move_model_to_device(model=model, optimizers=list(optimizers))
|
||||
module = self._move_model_to_device(model=module, optimizers=list(optimizers))
|
||||
|
||||
# Let accelerator/plugin wrap and connect the models and optimizers
|
||||
model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers))
|
||||
model = _LiteModule(model, self._precision, original_module=original_model)
|
||||
if optimizers:
|
||||
module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
|
||||
module, list(optimizers)
|
||||
)
|
||||
else:
|
||||
module = self._strategy.setup_module(module)
|
||||
|
||||
module = _LiteModule(module, self._precision, original_module=original_module)
|
||||
|
||||
# Update the _DeviceDtypeModuleMixin's device parameter
|
||||
model.to(self.device if move_to_device else next(model.parameters()).device)
|
||||
module.to(self.device if move_to_device else next(module.parameters()).device)
|
||||
|
||||
optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
|
||||
|
||||
self._models_setup += 1
|
||||
|
||||
if optimizers:
|
||||
# join both types in a list for API convenience
|
||||
return [model] + optimizers
|
||||
return model
|
||||
# join both types in a tuple for API convenience
|
||||
return tuple((module, *optimizers))
|
||||
return module
|
||||
|
||||
def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _LiteModule:
|
||||
"""Set up a model for accelerated training or inference.
|
||||
|
||||
This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain
|
||||
strategies like `FSDP` that require setting up the module before the optimizer can be created and set up.
|
||||
See also :meth:`setup_optimizers`.
|
||||
|
||||
Args:
|
||||
module: A :class:`torch.nn.Module` to set up
|
||||
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
|
||||
and alternatively use :meth:`to_device` manually.
|
||||
|
||||
Returns:
|
||||
The wrapped model.
|
||||
"""
|
||||
self._validate_setup_module(module)
|
||||
original_module = module
|
||||
|
||||
module = self._precision.convert_module(module)
|
||||
|
||||
if move_to_device:
|
||||
module = self._move_model_to_device(model=module, optimizers=[])
|
||||
|
||||
# Let strategy wrap and connect the module alone
|
||||
module = self._strategy.setup_module(module)
|
||||
module = _LiteModule(module, self._precision, original_module=original_module)
|
||||
|
||||
# Update the _DeviceDtypeModuleMixin's device parameter
|
||||
module.to(self.device if move_to_device else next(module.parameters()).device)
|
||||
|
||||
self._models_setup += 1
|
||||
return module
|
||||
|
||||
def setup_optimizers(self, *optimizers: Optimizer) -> Union[_LiteOptimizer, Tuple[_LiteOptimizer, ...]]:
|
||||
"""Set up one or more optimizers for accelerated training.
|
||||
|
||||
Some strategies do not allow setting up model and optimizer independently. For them, you should call
|
||||
``.setup(model, optimizer, ...)`` instead to jointly set them up.
|
||||
|
||||
Args:
|
||||
*optimizers: One or more optmizers to set up.
|
||||
|
||||
Returns:
|
||||
The wrapped optimizer(s).
|
||||
"""
|
||||
self._validate_setup_optimizers(optimizers)
|
||||
optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers]
|
||||
optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
|
||||
return optimizers[0] if len(optimizers) == 1 else tuple(optimizers)
|
||||
|
||||
def setup_dataloaders(
|
||||
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
|
||||
|
@ -529,17 +594,44 @@ class LightningLite(ABC):
|
|||
setattr(self, "run", partial(self._run_impl, self.run))
|
||||
|
||||
@staticmethod
|
||||
def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None:
|
||||
if isinstance(model, _LiteModule):
|
||||
def _validate_setup(module: nn.Module, optimizers: Sequence[Optimizer]) -> None:
|
||||
if isinstance(module, _LiteModule):
|
||||
raise ValueError("A model should be passed only once to the `setup` method.")
|
||||
|
||||
if any(isinstance(opt, _LiteOptimizer) for opt in optimizers):
|
||||
raise ValueError("An optimizer should be passed only once to the `setup` method.")
|
||||
|
||||
def _validate_setup_module(self, module: nn.Module) -> None:
|
||||
if isinstance(module, _LiteModule):
|
||||
raise ValueError("A model should be passed only once to the `setup_module` method.")
|
||||
|
||||
if isinstance(self._strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)):
|
||||
raise RuntimeError(
|
||||
f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly"
|
||||
" through `.setup(model, optimizer, ...)`. For inference, choose a different strategy, for example"
|
||||
" `ddp`."
|
||||
)
|
||||
|
||||
def _validate_setup_optimizers(self, optimizers: Sequence[Optimizer]) -> None:
|
||||
if isinstance(self._strategy, (DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy)):
|
||||
raise RuntimeError(
|
||||
f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly"
|
||||
" through `.setup(model, optimizer, ...)`."
|
||||
)
|
||||
|
||||
if not optimizers:
|
||||
raise ValueError("`setup_optimizers` requires at least one optimizer as input.")
|
||||
|
||||
if any(isinstance(opt, _LiteOptimizer) for opt in optimizers):
|
||||
raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.")
|
||||
|
||||
@staticmethod
|
||||
def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:
|
||||
if not dataloaders:
|
||||
raise ValueError("`setup_dataloaders` requires at least one dataloader as input.")
|
||||
|
||||
if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders):
|
||||
raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method")
|
||||
raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method.")
|
||||
|
||||
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
|
||||
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
|
||||
|
|
|
@ -35,7 +35,7 @@ from lightning_lite.utilities.distributed import log
|
|||
from lightning_lite.utilities.enums import AMPType, PrecisionType
|
||||
from lightning_lite.utilities.rank_zero import rank_zero_info
|
||||
from lightning_lite.utilities.seed import reset_seed
|
||||
from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
|
||||
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
|
||||
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
|
||||
|
@ -305,11 +305,11 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
return self._deepspeed_engine
|
||||
|
||||
def setup_module_and_optimizers(
|
||||
self, model: Module, optimizers: List[Optimizer]
|
||||
self, module: Module, optimizers: List[Optimizer]
|
||||
) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]:
|
||||
"""Set up a model and multiple optimizers together.
|
||||
|
||||
Currently only a single optimizer is supported.
|
||||
Currently, only a single optimizer is supported.
|
||||
|
||||
Return:
|
||||
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
|
||||
|
@ -321,10 +321,25 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
f" Got {len(optimizers)} optimizers instead."
|
||||
)
|
||||
|
||||
self._deepspeed_engine, optimizer = self._setup_module_and_optimizer(model, optimizers[0])
|
||||
self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0])
|
||||
self._set_deepspeed_activation_checkpointing()
|
||||
return self._deepspeed_engine, [optimizer]
|
||||
|
||||
def setup_module(self, module: Module) -> "deepspeed.DeepSpeedEngine":
|
||||
"""Set up a module for inference (no optimizers).
|
||||
|
||||
For training, see :meth:`setup_module_and_optimizers`.
|
||||
"""
|
||||
self._deepspeed_engine, _ = self._initialize_engine(module)
|
||||
return self._deepspeed_engine
|
||||
|
||||
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||
"""Optimizers can only be set up jointly with the model in this strategy.
|
||||
|
||||
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together.
|
||||
"""
|
||||
raise NotImplementedError(self._err_msg_joint_setup_required())
|
||||
|
||||
@contextmanager
|
||||
def module_sharded_context(self) -> Generator[None, None, None]:
|
||||
# Current limitation in Lite: The config needs to be fully determined at the time of calling the
|
||||
|
@ -401,11 +416,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
offload_optimizer_device="nvme",
|
||||
)
|
||||
|
||||
def _setup_module_and_optimizer(
|
||||
def _initialize_engine(
|
||||
self,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer],
|
||||
lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
|
||||
"""Initialize one model and one optimizer with an optional learning rate scheduler.
|
||||
|
||||
|
@ -420,7 +434,6 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
model=model,
|
||||
model_parameters=model_parameters,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
dist_init_required=False,
|
||||
)
|
||||
return deepspeed_engine, deepspeed_optimizer
|
||||
|
|
|
@ -18,6 +18,7 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
|
|||
import torch
|
||||
from lightning_utilities.core.imports import module_available
|
||||
from torch.nn import Module
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lightning_lite.accelerators import Accelerator
|
||||
|
@ -89,6 +90,20 @@ class DDPShardedStrategy(DDPStrategy):
|
|||
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
|
||||
return model, optimizers
|
||||
|
||||
def setup_module(self, module: Module) -> DistributedDataParallel:
|
||||
"""Setting up the module without optimizers in this strategy is not supported.
|
||||
|
||||
Please use :meth:`setup_module_and_optimizers` instead.
|
||||
"""
|
||||
raise NotImplementedError(self._err_msg_joint_setup_required())
|
||||
|
||||
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||
"""Optimizers can only be set up jointly with the model in this strategy.
|
||||
|
||||
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer(s) together.
|
||||
"""
|
||||
raise NotImplementedError(self._err_msg_joint_setup_required())
|
||||
|
||||
@classmethod
|
||||
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||
strategy_registry.register(
|
||||
|
@ -153,6 +168,20 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
|
|||
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
|
||||
return model, optimizers
|
||||
|
||||
def setup_module(self, module: Module) -> DistributedDataParallel:
|
||||
"""Setting up the module without optimizers in this strategy is not supported.
|
||||
|
||||
Please use :meth:`setup_module_and_optimizers` instead.
|
||||
"""
|
||||
raise NotImplementedError(self._err_msg_joint_setup_required())
|
||||
|
||||
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
||||
"""Optimizers can only be set up jointly with the model in this strategy.
|
||||
|
||||
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer(s) together.
|
||||
"""
|
||||
raise NotImplementedError(self._err_msg_joint_setup_required())
|
||||
|
||||
@classmethod
|
||||
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||
strategy_registry.register(
|
||||
|
|
|
@ -118,7 +118,7 @@ class Strategy(ABC):
|
|||
"""Set up a model and multiple optimizers together.
|
||||
|
||||
The returned objects are expected to be in the same order they were passed in. The default implementation will
|
||||
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs.
|
||||
call :meth:`setup_module` and :meth:`setup_optimizer` on the inputs.
|
||||
"""
|
||||
module = self.setup_module(module)
|
||||
optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers]
|
||||
|
@ -288,6 +288,12 @@ class Strategy(ABC):
|
|||
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def _err_msg_joint_setup_required(self) -> str:
|
||||
return (
|
||||
f"The `{type(self).__name__}` does not support setting up the module and optimizer(s) independently."
|
||||
" Please call `setup_module_and_optimizers(model, [optimizer, ...])` to jointly set them up."
|
||||
)
|
||||
|
||||
|
||||
class _BackwardSyncControl(ABC):
|
||||
"""Interface for any :class:`Strategy` that wants to offer a functionality to enable or disable gradient
|
||||
|
|
|
@ -13,8 +13,12 @@
|
|||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
from re import escape
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from tests_lite.helpers.runif import RunIf
|
||||
|
||||
from lightning_lite.accelerators import CPUAccelerator
|
||||
|
@ -116,3 +120,34 @@ def test_deepspeed_config_zero_offload(deepspeed_zero_config):
|
|||
deepspeed_zero_config["zero_optimization"]["offload_optimizer"] = False
|
||||
strategy = DeepSpeedStrategy(config=deepspeed_zero_config)
|
||||
assert strategy.config["zero_optimization"]["offload_optimizer"] is False
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
@mock.patch("lightning_lite.strategies.deepspeed.deepspeed.initialize")
|
||||
def test_deepspeed_setup_module(init_mock):
|
||||
"""Test that the DeepSpeed strategy can set up the model for inference (no optimizer required)."""
|
||||
model = Mock()
|
||||
model.parameters.return_value = []
|
||||
strategy = DeepSpeedStrategy()
|
||||
strategy.parallel_devices = [torch.device("cuda", 1)]
|
||||
init_mock.return_value = [Mock()] * 4 # mock to make tuple unpacking work
|
||||
|
||||
strategy.setup_module(model)
|
||||
init_mock.assert_called_with(
|
||||
args=ANY,
|
||||
config=strategy.config,
|
||||
model=model,
|
||||
model_parameters=ANY,
|
||||
optimizer=None,
|
||||
dist_init_required=False,
|
||||
)
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_deepspeed_requires_joint_setup():
|
||||
"""Test that the DeepSpeed strategy does not support setting up model and optimizer independently."""
|
||||
strategy = DeepSpeedStrategy()
|
||||
with pytest.raises(
|
||||
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
|
||||
):
|
||||
strategy.setup_optimizer(Mock())
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from re import escape
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
|
@ -89,3 +90,18 @@ def test_fairscale_no_backward_sync(cls):
|
|||
pass
|
||||
|
||||
module.no_sync.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", [DDPShardedStrategy, DDPSpawnShardedStrategy])
|
||||
def test_fairscale_requires_joint_setup(cls):
|
||||
"""Test that the fairscale sharded strategy does not support setting up model and optimizer independently."""
|
||||
strategy = cls()
|
||||
with pytest.raises(
|
||||
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
|
||||
):
|
||||
strategy.setup_module(Mock())
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
|
||||
):
|
||||
strategy.setup_optimizer(Mock())
|
||||
|
|
|
@ -27,7 +27,16 @@ from torch.utils.data import DataLoader, DistributedSampler, Sampler
|
|||
|
||||
from lightning_lite.lite import LightningLite
|
||||
from lightning_lite.plugins import Precision
|
||||
from lightning_lite.strategies import DDPStrategy, ParallelStrategy, SingleDeviceStrategy, Strategy
|
||||
from lightning_lite.strategies import (
|
||||
DDPShardedStrategy,
|
||||
DDPSpawnShardedStrategy,
|
||||
DDPStrategy,
|
||||
DeepSpeedStrategy,
|
||||
ParallelStrategy,
|
||||
SingleDeviceStrategy,
|
||||
Strategy,
|
||||
XLAStrategy,
|
||||
)
|
||||
from lightning_lite.strategies.strategy import _Sharded
|
||||
from lightning_lite.utilities import _StrategyType
|
||||
from lightning_lite.utilities.exceptions import MisconfigurationException
|
||||
|
@ -72,11 +81,13 @@ def test_run_input_output():
|
|||
|
||||
|
||||
@mock.patch("lightning_lite.strategies.ddp.DistributedDataParallel")
|
||||
def test_setup_model(ddp_mock):
|
||||
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
|
||||
def test_setup_module(ddp_mock, setup_method):
|
||||
"""Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model."""
|
||||
lite = EmptyLite(accelerator="cpu", strategy="ddp", devices=2)
|
||||
model = nn.Linear(1, 2)
|
||||
lite_model = lite.setup(model)
|
||||
setup_method = getattr(lite, setup_method)
|
||||
lite_model = setup_method(model)
|
||||
ddp_mock.assert_called_with(module=model, device_ids=ANY)
|
||||
assert lite_model.module == model
|
||||
assert lite_model.weight is model.weight
|
||||
|
@ -95,7 +106,8 @@ def test_setup_model(ddp_mock):
|
|||
],
|
||||
)
|
||||
@pytest.mark.parametrize("move_to_device", [True, False])
|
||||
def test_setup_model_move_to_device(move_to_device, accelerator, initial_device, target_device):
|
||||
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
|
||||
def test_setup_module_move_to_device(setup_method, move_to_device, accelerator, initial_device, target_device):
|
||||
"""Test that `move_to_device` leads to parameters being moved to the correct device and that the device
|
||||
attributes on the wrapper are updated."""
|
||||
initial_device = torch.device(initial_device)
|
||||
|
@ -105,7 +117,8 @@ def test_setup_model_move_to_device(move_to_device, accelerator, initial_device,
|
|||
lite = EmptyLite(accelerator=accelerator, devices=1)
|
||||
model = nn.Linear(1, 2)
|
||||
model.to(initial_device)
|
||||
lite_model = lite.setup(model, move_to_device=move_to_device)
|
||||
setup_method = getattr(lite, setup_method)
|
||||
lite_model = setup_method(model, move_to_device=move_to_device)
|
||||
|
||||
# all parameters on the expected device
|
||||
assert all(param.device == expected_device for param in model.parameters())
|
||||
|
@ -117,7 +130,8 @@ def test_setup_model_move_to_device(move_to_device, accelerator, initial_device,
|
|||
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
@pytest.mark.parametrize("move_to_device", [True, False])
|
||||
def test_setup_model_parameters_on_different_devices(move_to_device):
|
||||
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
|
||||
def test_setup_module_parameters_on_different_devices(setup_method, move_to_device):
|
||||
"""Test that a warning is emitted when model parameters are on a different device prior to calling
|
||||
`setup()`."""
|
||||
device0 = torch.device("cpu")
|
||||
|
@ -129,9 +143,11 @@ def test_setup_model_parameters_on_different_devices(move_to_device):
|
|||
module1 = nn.Linear(1, 2).to(device1)
|
||||
model = nn.Sequential(module0, module1)
|
||||
|
||||
setup_method = getattr(lite, setup_method)
|
||||
|
||||
if move_to_device:
|
||||
with pytest.warns(PossibleUserWarning, match="has parameters on different devices"):
|
||||
lite_model = lite.setup(model, move_to_device=move_to_device)
|
||||
lite_model = setup_method(model, move_to_device=move_to_device)
|
||||
|
||||
# both have the same device now
|
||||
assert lite_model.device == device1
|
||||
|
@ -139,11 +155,11 @@ def test_setup_model_parameters_on_different_devices(move_to_device):
|
|||
assert module1.weight.device == module1.bias.device == device1
|
||||
else:
|
||||
with no_warning_call(expected_warning=PossibleUserWarning, match="has parameters on different devices"):
|
||||
lite.setup(model, move_to_device=move_to_device)
|
||||
setup_method(model, move_to_device=move_to_device)
|
||||
|
||||
|
||||
def test_setup_optimizers():
|
||||
"""Test that setup_optimizers can handle no optimizers, one optimizer, or multiple optimizers."""
|
||||
def test_setup_module_and_optimizers():
|
||||
"""Test that `setup()` can handle no optimizers, one optimizer, or multiple optimizers."""
|
||||
lite = EmptyLite()
|
||||
model = nn.Linear(1, 2)
|
||||
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
|
@ -171,8 +187,28 @@ def test_setup_optimizers():
|
|||
assert lite_optimizer1.optimizer is optimizer1
|
||||
|
||||
|
||||
def test_setup_optimizers():
|
||||
"""Test that `setup_optimizers()` can handle one or more optimizers."""
|
||||
lite = EmptyLite()
|
||||
model = nn.Linear(1, 2)
|
||||
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
|
||||
|
||||
# single optimizer
|
||||
lite_optimizer = lite.setup_optimizers(optimizer0)
|
||||
assert isinstance(lite_optimizer, _LiteOptimizer)
|
||||
assert lite_optimizer.optimizer is optimizer0
|
||||
|
||||
# multiple optimizers
|
||||
lite_optimizer0, lite_optimizer1 = lite.setup_optimizers(optimizer0, optimizer1)
|
||||
assert isinstance(lite_optimizer0, _LiteOptimizer)
|
||||
assert isinstance(lite_optimizer1, _LiteOptimizer)
|
||||
assert lite_optimizer0.optimizer is optimizer0
|
||||
assert lite_optimizer1.optimizer is optimizer1
|
||||
|
||||
|
||||
def test_setup_twice_fails():
|
||||
"""Test that calling setup with a model or optimizer that is already wrapped fails."""
|
||||
"""Test that calling `setup` with a model or optimizer that is already wrapped fails."""
|
||||
lite = EmptyLite()
|
||||
model = nn.Linear(1, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
|
@ -186,6 +222,49 @@ def test_setup_twice_fails():
|
|||
lite.setup(model, lite_optimizer)
|
||||
|
||||
|
||||
def test_setup_module_twice_fails():
|
||||
"""Test that calling `setup_module` with a model that is already wrapped fails."""
|
||||
lite = EmptyLite()
|
||||
model = nn.Linear(1, 2)
|
||||
|
||||
lite_model = lite.setup_module(model)
|
||||
with pytest.raises(ValueError, match="A model should be passed only once to the"):
|
||||
lite.setup_module(lite_model)
|
||||
|
||||
|
||||
def test_setup_optimizers_twice_fails():
|
||||
"""Test that calling `setup_module` with a model that is already wrapped fails."""
|
||||
lite = EmptyLite()
|
||||
model = nn.Linear(1, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
|
||||
lite_optimizer = lite.setup_optimizers(optimizer)
|
||||
with pytest.raises(ValueError, match="An optimizer should be passed only once to"):
|
||||
lite.setup_optimizers(lite_optimizer)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy_cls", [DDPShardedStrategy, DDPSpawnShardedStrategy])
|
||||
def test_setup_module_not_supported(strategy_cls):
|
||||
"""Test that `setup_module` validates the strategy supports setting up model and optimizers independently."""
|
||||
lite = EmptyLite()
|
||||
model = nn.Linear(1, 2)
|
||||
lite._strategy = Mock(spec=strategy_cls)
|
||||
with pytest.raises(RuntimeError, match=escape("requires the model and optimizer(s) to be set up jointly through")):
|
||||
lite.setup_module(model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy_cls", [DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy])
|
||||
def test_setup_optimizers_not_supported(strategy_cls):
|
||||
"""Test that `setup_optimizers` validates the strategy supports setting up model and optimizers
|
||||
independently."""
|
||||
lite = EmptyLite()
|
||||
model = nn.Linear(1, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
lite._strategy = Mock(spec=strategy_cls)
|
||||
with pytest.raises(RuntimeError, match=escape("requires the model and optimizer(s) to be set up jointly through")):
|
||||
lite.setup_optimizers(optimizer)
|
||||
|
||||
|
||||
def test_setup_tracks_num_models():
|
||||
"""Test that setup() tracks how many times it has setup a model."""
|
||||
lite = EmptyLite()
|
||||
|
@ -199,10 +278,15 @@ def test_setup_tracks_num_models():
|
|||
lite.setup(model, optimizer)
|
||||
assert lite._models_setup == 2
|
||||
|
||||
lite.setup_module(model)
|
||||
assert lite._models_setup == 3
|
||||
|
||||
def test_setup_dataloaders_unsupported_type():
|
||||
|
||||
def test_setup_dataloaders_unsupported_input():
|
||||
"""Test that the setup_dataloaders method fails when provided with non-DataLoader objects."""
|
||||
lite = EmptyLite()
|
||||
with pytest.raises(ValueError, match="`setup_dataloaders` requires at least one dataloader"):
|
||||
lite.setup_dataloaders()
|
||||
with pytest.raises(TypeError, match="Only PyTorch DataLoader are currently supported"):
|
||||
lite.setup_dataloaders(range(2)) # type: ignore
|
||||
|
||||
|
|
Loading…
Reference in New Issue