diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 251e88353c..199c653287 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -327,7 +327,18 @@ class FSDPStrategy(ParallelStrategy): def setup_optimizers(self, trainer: "pl.Trainer") -> None: if self.kwargs.get("use_orig_params"): return super().setup_optimizers(trainer) - if any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers): + + invalid_params_error = False + try: + # In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access + # `self.trainer.model.parameters()` in configure_optimizers() + super().setup_optimizers(trainer) + except ValueError as ex: + if "optimizer got an empty parameter list" not in str(ex): + raise + invalid_params_error = True + + if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers): # We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True` raise ValueError( "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the" diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 25a9e122b6..4a5cbc9fad 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -359,16 +359,22 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg): @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True) -def test_invalid_parameters_in_optimizer(): +@pytest.mark.parametrize("use_orig_params", [None, False, True]) +def test_invalid_parameters_in_optimizer(use_orig_params): + fsdp_kwargs = {} + if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not None: + fsdp_kwargs = {"use_orig_params": use_orig_params} + trainer = Trainer( - strategy="fsdp", + strategy=FSDPStrategy(**fsdp_kwargs), accelerator="cuda", devices=1, fast_dev_run=1, ) + error_context = ( nullcontext() - if _TORCH_GREATER_EQUAL_2_0 + if _TORCH_GREATER_EQUAL_2_0 and (_TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False) else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters") ) @@ -385,6 +391,12 @@ def test_invalid_parameters_in_optimizer(): layer = torch.nn.Linear(4, 5) return torch.optim.Adam(layer.parameters(), lr=1e-2) + error_context = ( + nullcontext() + if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not False + else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters") + ) + model = NoFlatParametersModel() with error_context: trainer.fit(model) diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index 12348ac707..8476b07bf2 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -388,6 +388,7 @@ def test_multiple_optimizers_step(tmpdir): def test_step_with_optimizer_closure(tmpdir): """Tests that `step` works with optimizer_closure.""" + seed_everything(1) class TestModel(BoringModel): _losses = []