Revert removal of empty-parameters check for `configure_optimizers()` with FSDP (#18785)
This commit is contained in:
parent
20ce3aeeb4
commit
6f6c07dddf
|
@ -327,7 +327,18 @@ class FSDPStrategy(ParallelStrategy):
|
|||
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
|
||||
if self.kwargs.get("use_orig_params"):
|
||||
return super().setup_optimizers(trainer)
|
||||
if any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
|
||||
|
||||
invalid_params_error = False
|
||||
try:
|
||||
# In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access
|
||||
# `self.trainer.model.parameters()` in configure_optimizers()
|
||||
super().setup_optimizers(trainer)
|
||||
except ValueError as ex:
|
||||
if "optimizer got an empty parameter list" not in str(ex):
|
||||
raise
|
||||
invalid_params_error = True
|
||||
|
||||
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
|
||||
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
|
||||
raise ValueError(
|
||||
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
|
||||
|
|
|
@ -359,16 +359,22 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg):
|
|||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
|
||||
def test_invalid_parameters_in_optimizer():
|
||||
@pytest.mark.parametrize("use_orig_params", [None, False, True])
|
||||
def test_invalid_parameters_in_optimizer(use_orig_params):
|
||||
fsdp_kwargs = {}
|
||||
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not None:
|
||||
fsdp_kwargs = {"use_orig_params": use_orig_params}
|
||||
|
||||
trainer = Trainer(
|
||||
strategy="fsdp",
|
||||
strategy=FSDPStrategy(**fsdp_kwargs),
|
||||
accelerator="cuda",
|
||||
devices=1,
|
||||
fast_dev_run=1,
|
||||
)
|
||||
|
||||
error_context = (
|
||||
nullcontext()
|
||||
if _TORCH_GREATER_EQUAL_2_0
|
||||
if _TORCH_GREATER_EQUAL_2_0 and (_TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False)
|
||||
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
|
||||
)
|
||||
|
||||
|
@ -385,6 +391,12 @@ def test_invalid_parameters_in_optimizer():
|
|||
layer = torch.nn.Linear(4, 5)
|
||||
return torch.optim.Adam(layer.parameters(), lr=1e-2)
|
||||
|
||||
error_context = (
|
||||
nullcontext()
|
||||
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not False
|
||||
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
|
||||
)
|
||||
|
||||
model = NoFlatParametersModel()
|
||||
with error_context:
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -388,6 +388,7 @@ def test_multiple_optimizers_step(tmpdir):
|
|||
|
||||
def test_step_with_optimizer_closure(tmpdir):
|
||||
"""Tests that `step` works with optimizer_closure."""
|
||||
seed_everything(1)
|
||||
|
||||
class TestModel(BoringModel):
|
||||
_losses = []
|
||||
|
|
Loading…
Reference in New Issue