Added a check to validate that wrapped FSDP models are used while initializing optimizers (#15301)
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
18f7f2d395
commit
0886e6352e
|
@ -352,6 +352,12 @@ Model layers should be wrapped in FSDP in a nested way to save peak memory and e
|
|||
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
|
||||
have to ``wrap`` layers manually as in the case of manual wrapping.
|
||||
|
||||
.. note::
|
||||
While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else
|
||||
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
|
||||
``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = BoringModel()
|
||||
|
|
|
@ -196,3 +196,9 @@ class _FairscaleBackwardSyncControl(_BackwardSyncControl):
|
|||
)
|
||||
with module.no_sync():
|
||||
yield None
|
||||
|
||||
|
||||
def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
|
||||
from fairscale.nn.misc.flatten_params_wrapper import FlatParameter
|
||||
|
||||
return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
|
||||
from torch.distributed.fsdp import FlatParameter
|
||||
|
||||
return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])
|
|
@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
-
|
||||
|
||||
|
||||
- Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))
|
||||
|
|
|
@ -19,7 +19,7 @@ import torch
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
|
||||
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
|
||||
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE, _optimizer_has_flat_params
|
||||
from lightning_lite.utilities.enums import PrecisionType
|
||||
from lightning_lite.utilities.optimizer import _optimizers_to_device
|
||||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
|
@ -28,7 +28,6 @@ from pytorch_lightning.strategies.ddp import DDPStrategy
|
|||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
|
@ -191,16 +190,27 @@ class DDPFullyShardedStrategy(DDPStrategy):
|
|||
|
||||
self.setup_precision_plugin()
|
||||
|
||||
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
|
||||
invalid_params_error = False
|
||||
try:
|
||||
super().setup_optimizers(trainer)
|
||||
except ValueError as e:
|
||||
if "optimizer got an empty parameter list" not in str(e):
|
||||
raise
|
||||
invalid_params_error = True
|
||||
|
||||
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
|
||||
raise ValueError(
|
||||
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
|
||||
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
|
||||
" `configure_optimizers()` hook."
|
||||
)
|
||||
|
||||
def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
|
||||
"""Wraps the model into a
|
||||
:class:`~fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
|
||||
log.detail(f"setting up `Fairscale FSDP` model with device id: {self.root_device.index}.")
|
||||
|
||||
rank_zero_info(
|
||||
"When using FairScale FSDP auto-wrap, make sure to initialize your model using trainer: "
|
||||
"`torch.optim.Optimizer(self.trainer.model.parameters(), ...)`"
|
||||
)
|
||||
|
||||
return FullyShardedDataParallel(
|
||||
module=model,
|
||||
process_group=self.process_group,
|
||||
|
|
|
@ -20,6 +20,7 @@ from torch import Tensor
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
|
||||
from lightning_lite.strategies.fsdp_native import _optimizer_has_flat_params
|
||||
from lightning_lite.utilities.distributed import (
|
||||
_get_default_process_group_backend_for_device,
|
||||
_init_dist_connection,
|
||||
|
@ -215,6 +216,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
del self.kwargs["auto_wrap_policy"]
|
||||
|
||||
log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
|
||||
|
||||
return FullyShardedDataParallel(
|
||||
module=model,
|
||||
process_group=self.process_group,
|
||||
|
@ -255,6 +257,22 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
|
||||
self.setup_precision_plugin()
|
||||
|
||||
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
|
||||
invalid_params_error = False
|
||||
try:
|
||||
super().setup_optimizers(trainer)
|
||||
except ValueError as e:
|
||||
if "optimizer got an empty parameter list" not in str(e):
|
||||
raise
|
||||
invalid_params_error = True
|
||||
|
||||
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
|
||||
raise ValueError(
|
||||
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
|
||||
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
|
||||
" `configure_optimizers()` hook."
|
||||
)
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
@ -18,46 +18,6 @@ if _TORCH_GREATER_EQUAL_1_12:
|
|||
from torch.distributed.fsdp.wrap import wrap
|
||||
|
||||
|
||||
def custom_auto_wrap_policy(
|
||||
module,
|
||||
recurse,
|
||||
unwrapped_params: int,
|
||||
min_num_params: int = int(1e8),
|
||||
) -> bool:
|
||||
return unwrapped_params >= 2
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_invalid_on_cpu(tmpdir):
|
||||
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
|
||||
with pytest.raises(
|
||||
MisconfigurationException,
|
||||
match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
|
||||
"but GPU accelerator is not used.",
|
||||
):
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native")
|
||||
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
|
||||
trainer.strategy.setup_environment()
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12", min_cuda_gpus=1)
|
||||
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
|
||||
def test_precision_plugin_config(precision, expected):
|
||||
plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda")
|
||||
config = plugin.mixed_precision_config
|
||||
assert config.param_dtype == expected
|
||||
assert config.buffer_dtype == expected
|
||||
assert config.reduce_dtype == expected
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_fsdp_custom_mixed_precision(tmpdir):
|
||||
"""Test to ensure that passing a custom mixed precision config works."""
|
||||
config = MixedPrecision()
|
||||
strategy = DDPFullyShardedNativeStrategy(mixed_precision=config)
|
||||
assert strategy.mixed_precision_config == config
|
||||
|
||||
|
||||
class TestFSDPModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -154,6 +114,80 @@ class TestFSDPModelAutoWrapped(BoringModel):
|
|||
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision
|
||||
|
||||
|
||||
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
|
||||
trainer.fit(model)
|
||||
model_path = trainer.strategy.broadcast(model_path)
|
||||
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
|
||||
|
||||
trainer.save_checkpoint(model_path, weights_only=True)
|
||||
|
||||
_assert_save_equality(trainer, model_path, cls=model.__class__)
|
||||
|
||||
# Test entry point
|
||||
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
||||
trainer.test(ckpt_path=model_path)
|
||||
|
||||
# Predict entry point
|
||||
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
||||
trainer.predict(ckpt_path=model_path)
|
||||
|
||||
|
||||
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
|
||||
# Use FullySharded to get the state dict for the sake of comparison
|
||||
model_state_dict = trainer.strategy.lightning_module_state_dict()
|
||||
|
||||
if trainer.is_global_zero:
|
||||
saved_model = cls.load_from_checkpoint(ckpt_path)
|
||||
|
||||
# Assert model parameters are identical after loading
|
||||
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
|
||||
assert torch.equal(ddp_param.float().cpu(), shard_param)
|
||||
|
||||
|
||||
def custom_auto_wrap_policy(
|
||||
module,
|
||||
recurse,
|
||||
unwrapped_params: int,
|
||||
min_num_params: int = int(1e8),
|
||||
) -> bool:
|
||||
return unwrapped_params >= 2
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_invalid_on_cpu(tmpdir):
|
||||
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
|
||||
with pytest.raises(
|
||||
MisconfigurationException,
|
||||
match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
|
||||
"but GPU accelerator is not used.",
|
||||
):
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native")
|
||||
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
|
||||
trainer.strategy.setup_environment()
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12", min_cuda_gpus=1)
|
||||
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
|
||||
def test_precision_plugin_config(precision, expected):
|
||||
plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda")
|
||||
config = plugin.mixed_precision_config
|
||||
assert config.param_dtype == expected
|
||||
assert config.buffer_dtype == expected
|
||||
assert config.reduce_dtype == expected
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_fsdp_custom_mixed_precision(tmpdir):
|
||||
"""Test to ensure that passing a custom mixed precision config works."""
|
||||
config = MixedPrecision()
|
||||
strategy = DDPFullyShardedNativeStrategy(mixed_precision=config)
|
||||
assert strategy.mixed_precision_config == config
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
|
||||
def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
|
||||
"""Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run."""
|
||||
|
@ -214,35 +248,23 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, stra
|
|||
_run_multiple_stages(trainer, model)
|
||||
|
||||
|
||||
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
|
||||
trainer.fit(model)
|
||||
model_path = trainer.strategy.broadcast(model_path)
|
||||
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
|
||||
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12")
|
||||
def test_invalid_parameters_in_optimizer(tmpdir):
|
||||
trainer = Trainer(strategy="fsdp_native", accelerator="cuda", devices=1)
|
||||
|
||||
trainer.save_checkpoint(model_path, weights_only=True)
|
||||
class EmptyParametersModel(BoringModel):
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-2)
|
||||
|
||||
_assert_save_equality(trainer, model_path, cls=model.__class__)
|
||||
model = EmptyParametersModel()
|
||||
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
|
||||
trainer.fit(model)
|
||||
|
||||
# Test entry point
|
||||
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
|
||||
class NoFlatParametersModel(BoringModel):
|
||||
def configure_optimizers(self):
|
||||
layer = torch.nn.Linear(4, 5)
|
||||
return torch.optim.Adam(layer.parameters(), lr=1e-2)
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
||||
trainer.test(ckpt_path=model_path)
|
||||
|
||||
# Predict entry point
|
||||
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
||||
trainer.predict(ckpt_path=model_path)
|
||||
|
||||
|
||||
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
|
||||
# Use FullySharded to get the state dict for the sake of comparison
|
||||
model_state_dict = trainer.strategy.lightning_module_state_dict()
|
||||
|
||||
if trainer.is_global_zero:
|
||||
saved_model = cls.load_from_checkpoint(ckpt_path)
|
||||
|
||||
# Assert model parameters are identical after loading
|
||||
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
|
||||
assert torch.equal(ddp_param.float().cpu(), shard_param)
|
||||
model = NoFlatParametersModel()
|
||||
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -18,27 +18,6 @@ if _FAIRSCALE_AVAILABLE:
|
|||
from fairscale.nn import FullyShardedDataParallel, wrap
|
||||
|
||||
|
||||
def test_invalid_on_cpu(tmpdir):
|
||||
"""Test to ensure that to raise Misconfiguration for FSDP on CPU."""
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="You selected strategy to be `ddp_fully_sharded`, but GPU is not available."
|
||||
):
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp")
|
||||
assert isinstance(trainer.strategy, DDPFullyShardedStrategy)
|
||||
trainer.strategy.setup_environment()
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
|
||||
@RunIf(fairscale=True)
|
||||
def test_fsdp_with_sharded_amp(cuda_count_1, tmpdir):
|
||||
"""Test to ensure that plugin native amp plugin is correctly chosen when using sharded."""
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", accelerator="gpu", devices=1, precision=16
|
||||
)
|
||||
assert isinstance(trainer.strategy, DDPFullyShardedStrategy)
|
||||
assert isinstance(trainer.strategy.precision_plugin, FullyShardedNativeMixedPrecisionPlugin)
|
||||
|
||||
|
||||
class TestFSDPModelManualWrapped(BoringModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -123,6 +102,72 @@ class TestFSDPModelAutoWrapped(BoringModel):
|
|||
assert self.trainer.model.mixed_precision
|
||||
|
||||
|
||||
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModelManualWrapped):
|
||||
# Use FullySharded to get the state dict for the sake of comparison
|
||||
model_state_dict = trainer.strategy.lightning_module_state_dict()
|
||||
|
||||
if trainer.is_global_zero:
|
||||
saved_model = cls.load_from_checkpoint(ckpt_path)
|
||||
|
||||
# Assert model parameters are identical after loading
|
||||
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
|
||||
assert torch.equal(ddp_param.float().cpu(), shard_param)
|
||||
|
||||
|
||||
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
|
||||
trainer.fit(model)
|
||||
|
||||
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
|
||||
|
||||
trainer.save_checkpoint(model_path, weights_only=True)
|
||||
|
||||
_assert_save_equality(trainer, model_path, cls=model.__class__)
|
||||
|
||||
# Test entry point
|
||||
if model.__class__ is TestFSDPModelAutoWrapped:
|
||||
model = TestFSDPModelAutoWrapped()
|
||||
trainer.test(model) # model is wrapped, will not call configure_shared_model
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
||||
if model.__class__ is TestFSDPModelAutoWrapped:
|
||||
model = TestFSDPModelAutoWrapped()
|
||||
trainer.test(model, ckpt_path=model_path)
|
||||
|
||||
# Predict entry point
|
||||
if model.__class__ is TestFSDPModelAutoWrapped:
|
||||
model = TestFSDPModelAutoWrapped()
|
||||
|
||||
if model.__class__ is TestFSDPModelAutoWrapped:
|
||||
model = TestFSDPModelAutoWrapped()
|
||||
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
||||
if model.__class__ is TestFSDPModelAutoWrapped:
|
||||
model = TestFSDPModelAutoWrapped()
|
||||
trainer.predict(model, ckpt_path=model_path)
|
||||
|
||||
|
||||
def test_invalid_on_cpu(tmpdir):
|
||||
"""Test to ensure that to raise Misconfiguration for FSDP on CPU."""
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="You selected strategy to be `ddp_fully_sharded`, but GPU is not available."
|
||||
):
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp")
|
||||
assert isinstance(trainer.strategy, DDPFullyShardedStrategy)
|
||||
trainer.strategy.setup_environment()
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
|
||||
@RunIf(fairscale=True)
|
||||
def test_fsdp_with_sharded_amp(cuda_count_1, tmpdir):
|
||||
"""Test to ensure that plugin native amp plugin is correctly chosen when using sharded."""
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", accelerator="gpu", devices=1, precision=16
|
||||
)
|
||||
assert isinstance(trainer.strategy, DDPFullyShardedStrategy)
|
||||
assert isinstance(trainer.strategy.precision_plugin, FullyShardedNativeMixedPrecisionPlugin)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, standalone=True, fairscale=True)
|
||||
def test_fully_sharded_strategy_checkpoint(tmpdir):
|
||||
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
|
||||
|
@ -171,51 +216,6 @@ def test_fully_sharded_strategy_checkpoint_multi_gpus(tmpdir, model, strategy):
|
|||
_run_multiple_stages(trainer, model)
|
||||
|
||||
|
||||
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModelManualWrapped):
|
||||
# Use FullySharded to get the state dict for the sake of comparison
|
||||
model_state_dict = trainer.strategy.lightning_module_state_dict()
|
||||
|
||||
if trainer.is_global_zero:
|
||||
saved_model = cls.load_from_checkpoint(ckpt_path)
|
||||
|
||||
# Assert model parameters are identical after loading
|
||||
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
|
||||
assert torch.equal(ddp_param.float().cpu(), shard_param)
|
||||
|
||||
|
||||
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
|
||||
trainer.fit(model)
|
||||
|
||||
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
|
||||
|
||||
trainer.save_checkpoint(model_path, weights_only=True)
|
||||
|
||||
_assert_save_equality(trainer, model_path, cls=model.__class__)
|
||||
|
||||
# Test entry point
|
||||
if model.__class__ is TestFSDPModelAutoWrapped:
|
||||
model = TestFSDPModelAutoWrapped()
|
||||
trainer.test(model) # model is wrapped, will not call configure_shared_model
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
||||
if model.__class__ is TestFSDPModelAutoWrapped:
|
||||
model = TestFSDPModelAutoWrapped()
|
||||
trainer.test(model, ckpt_path=model_path)
|
||||
|
||||
# Predict entry point
|
||||
if model.__class__ is TestFSDPModelAutoWrapped:
|
||||
model = TestFSDPModelAutoWrapped()
|
||||
|
||||
if model.__class__ is TestFSDPModelAutoWrapped:
|
||||
model = TestFSDPModelAutoWrapped()
|
||||
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
||||
if model.__class__ is TestFSDPModelAutoWrapped:
|
||||
model = TestFSDPModelAutoWrapped()
|
||||
trainer.predict(model, ckpt_path=model_path)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, standalone=True, fairscale=True)
|
||||
def test_fsdp_gradient_clipping_raises(tmpdir):
|
||||
"""Test to ensure that an exception is raised when clipping gradients by value with FSDP."""
|
||||
|
@ -254,3 +254,25 @@ def test_fsdp_rewrap_limitation(tmpdir):
|
|||
|
||||
with pytest.raises(MisconfigurationException, match="Using the same instance of model .* not supported"):
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True)
|
||||
def test_invalid_parameters_in_optimizer(tmpdir):
|
||||
trainer = Trainer(strategy="fsdp", accelerator="gpu", devices=1)
|
||||
|
||||
class EmptyParametersModel(BoringModel):
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-2)
|
||||
|
||||
model = EmptyParametersModel()
|
||||
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
|
||||
trainer.fit(model)
|
||||
|
||||
class NoFlatParametersModel(BoringModel):
|
||||
def configure_optimizers(self):
|
||||
layer = torch.nn.Linear(4, 5)
|
||||
return torch.optim.Adam(layer.parameters(), lr=1e-2)
|
||||
|
||||
model = NoFlatParametersModel()
|
||||
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue