Revert removal of empty-parameters check for `configure_optimizers()` with FSDP (#18785)

This commit is contained in:
Adrian Wälchli 2023-10-12 01:36:49 -07:00 committed by GitHub
parent 20ce3aeeb4
commit 6f6c07dddf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 4 deletions

View File

@ -327,7 +327,18 @@ class FSDPStrategy(ParallelStrategy):
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
if self.kwargs.get("use_orig_params"):
return super().setup_optimizers(trainer)
if any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
invalid_params_error = False
try:
# In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access
# `self.trainer.model.parameters()` in configure_optimizers()
super().setup_optimizers(trainer)
except ValueError as ex:
if "optimizer got an empty parameter list" not in str(ex):
raise
invalid_params_error = True
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"

View File

@ -359,16 +359,22 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg):
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
def test_invalid_parameters_in_optimizer():
@pytest.mark.parametrize("use_orig_params", [None, False, True])
def test_invalid_parameters_in_optimizer(use_orig_params):
fsdp_kwargs = {}
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not None:
fsdp_kwargs = {"use_orig_params": use_orig_params}
trainer = Trainer(
strategy="fsdp",
strategy=FSDPStrategy(**fsdp_kwargs),
accelerator="cuda",
devices=1,
fast_dev_run=1,
)
error_context = (
nullcontext()
if _TORCH_GREATER_EQUAL_2_0
if _TORCH_GREATER_EQUAL_2_0 and (_TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False)
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
)
@ -385,6 +391,12 @@ def test_invalid_parameters_in_optimizer():
layer = torch.nn.Linear(4, 5)
return torch.optim.Adam(layer.parameters(), lr=1e-2)
error_context = (
nullcontext()
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not False
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
)
model = NoFlatParametersModel()
with error_context:
trainer.fit(model)

View File

@ -388,6 +388,7 @@ def test_multiple_optimizers_step(tmpdir):
def test_step_with_optimizer_closure(tmpdir):
"""Tests that `step` works with optimizer_closure."""
seed_everything(1)
class TestModel(BoringModel):
_losses = []