Added a check to validate that wrapped FSDP models are used while initializing optimizers (#15301)

Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
Rohit Gupta 2022-11-08 07:40:35 +05:30 committed by GitHub
parent 18f7f2d395
commit 0886e6352e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 248 additions and 141 deletions

View File

@ -352,6 +352,12 @@ Model layers should be wrapped in FSDP in a nested way to save peak memory and e
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
have to ``wrap`` layers manually as in the case of manual wrapping.
.. note::
While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future.
.. code-block:: python
model = BoringModel()

View File

@ -196,3 +196,9 @@ class _FairscaleBackwardSyncControl(_BackwardSyncControl):
)
with module.no_sync():
yield None
def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
from fairscale.nn.misc.flatten_params_wrapper import FlatParameter
return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])

View File

@ -0,0 +1,20 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.optim import Optimizer
def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
from torch.distributed.fsdp import FlatParameter
return any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"])

View File

@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-
- Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301))
### Changed
- From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))

View File

@ -19,7 +19,7 @@ import torch
import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE, _optimizer_has_flat_params
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.optimizer import _optimizers_to_device
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
@ -28,7 +28,6 @@ from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.types import STEP_OUTPUT
if _FAIRSCALE_AVAILABLE:
@ -191,16 +190,27 @@ class DDPFullyShardedStrategy(DDPStrategy):
self.setup_precision_plugin()
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
invalid_params_error = False
try:
super().setup_optimizers(trainer)
except ValueError as e:
if "optimizer got an empty parameter list" not in str(e):
raise
invalid_params_error = True
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
" `configure_optimizers()` hook."
)
def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
"""Wraps the model into a
:class:`~fairscale.nn.data_parallel.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
log.detail(f"setting up `Fairscale FSDP` model with device id: {self.root_device.index}.")
rank_zero_info(
"When using FairScale FSDP auto-wrap, make sure to initialize your model using trainer: "
"`torch.optim.Optimizer(self.trainer.model.parameters(), ...)`"
)
return FullyShardedDataParallel(
module=model,
process_group=self.process_group,

View File

@ -20,6 +20,7 @@ from torch import Tensor
import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.strategies.fsdp_native import _optimizer_has_flat_params
from lightning_lite.utilities.distributed import (
_get_default_process_group_backend_for_device,
_init_dist_connection,
@ -215,6 +216,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
del self.kwargs["auto_wrap_policy"]
log.detail(f"setting up FSDP model with device id: {self.root_device.index}, kwargs: {self.kwargs}")
return FullyShardedDataParallel(
module=model,
process_group=self.process_group,
@ -255,6 +257,22 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
self.setup_precision_plugin()
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
invalid_params_error = False
try:
super().setup_optimizers(trainer)
except ValueError as e:
if "optimizer got an empty parameter list" not in str(e):
raise
invalid_params_error = True
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
" `configure_optimizers()` hook."
)
def model_to_device(self) -> None:
pass

View File

@ -18,46 +18,6 @@ if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.wrap import wrap
def custom_auto_wrap_policy(
module,
recurse,
unwrapped_params: int,
min_num_params: int = int(1e8),
) -> bool:
return unwrapped_params >= 2
@RunIf(min_torch="1.12")
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
"but GPU accelerator is not used.",
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native")
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
trainer.strategy.setup_environment()
@RunIf(min_torch="1.12", min_cuda_gpus=1)
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
def test_precision_plugin_config(precision, expected):
plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda")
config = plugin.mixed_precision_config
assert config.param_dtype == expected
assert config.buffer_dtype == expected
assert config.reduce_dtype == expected
@RunIf(min_torch="1.12")
def test_fsdp_custom_mixed_precision(tmpdir):
"""Test to ensure that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = DDPFullyShardedNativeStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config
class TestFSDPModel(BoringModel):
def __init__(self):
super().__init__()
@ -154,6 +114,80 @@ class TestFSDPModelAutoWrapped(BoringModel):
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)
model_path = trainer.strategy.broadcast(model_path)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
trainer.save_checkpoint(model_path, weights_only=True)
_assert_save_equality(trainer, model_path, cls=model.__class__)
# Test entry point
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.test(ckpt_path=model_path)
# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.predict(ckpt_path=model_path)
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
# Use FullySharded to get the state dict for the sake of comparison
model_state_dict = trainer.strategy.lightning_module_state_dict()
if trainer.is_global_zero:
saved_model = cls.load_from_checkpoint(ckpt_path)
# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
assert torch.equal(ddp_param.float().cpu(), shard_param)
def custom_auto_wrap_policy(
module,
recurse,
unwrapped_params: int,
min_num_params: int = int(1e8),
) -> bool:
return unwrapped_params >= 2
@RunIf(min_torch="1.12")
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
"but GPU accelerator is not used.",
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp_native")
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
trainer.strategy.setup_environment()
@RunIf(min_torch="1.12", min_cuda_gpus=1)
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
def test_precision_plugin_config(precision, expected):
plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda")
config = plugin.mixed_precision_config
assert config.param_dtype == expected
assert config.buffer_dtype == expected
assert config.reduce_dtype == expected
@RunIf(min_torch="1.12")
def test_fsdp_custom_mixed_precision(tmpdir):
"""Test to ensure that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = DDPFullyShardedNativeStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
"""Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run."""
@ -214,35 +248,23 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir, model, stra
_run_multiple_stages(trainer, model)
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)
model_path = trainer.strategy.broadcast(model_path)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12")
def test_invalid_parameters_in_optimizer(tmpdir):
trainer = Trainer(strategy="fsdp_native", accelerator="cuda", devices=1)
trainer.save_checkpoint(model_path, weights_only=True)
class EmptyParametersModel(BoringModel):
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-2)
_assert_save_equality(trainer, model_path, cls=model.__class__)
model = EmptyParametersModel()
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
trainer.fit(model)
# Test entry point
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
class NoFlatParametersModel(BoringModel):
def configure_optimizers(self):
layer = torch.nn.Linear(4, 5)
return torch.optim.Adam(layer.parameters(), lr=1e-2)
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.test(ckpt_path=model_path)
# Predict entry point
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
trainer.predict(ckpt_path=model_path)
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
# Use FullySharded to get the state dict for the sake of comparison
model_state_dict = trainer.strategy.lightning_module_state_dict()
if trainer.is_global_zero:
saved_model = cls.load_from_checkpoint(ckpt_path)
# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
assert torch.equal(ddp_param.float().cpu(), shard_param)
model = NoFlatParametersModel()
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
trainer.fit(model)

View File

@ -18,27 +18,6 @@ if _FAIRSCALE_AVAILABLE:
from fairscale.nn import FullyShardedDataParallel, wrap
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that to raise Misconfiguration for FSDP on CPU."""
with pytest.raises(
MisconfigurationException, match="You selected strategy to be `ddp_fully_sharded`, but GPU is not available."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp")
assert isinstance(trainer.strategy, DDPFullyShardedStrategy)
trainer.strategy.setup_environment()
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
@RunIf(fairscale=True)
def test_fsdp_with_sharded_amp(cuda_count_1, tmpdir):
"""Test to ensure that plugin native amp plugin is correctly chosen when using sharded."""
trainer = Trainer(
default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", accelerator="gpu", devices=1, precision=16
)
assert isinstance(trainer.strategy, DDPFullyShardedStrategy)
assert isinstance(trainer.strategy.precision_plugin, FullyShardedNativeMixedPrecisionPlugin)
class TestFSDPModelManualWrapped(BoringModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -123,6 +102,72 @@ class TestFSDPModelAutoWrapped(BoringModel):
assert self.trainer.model.mixed_precision
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModelManualWrapped):
# Use FullySharded to get the state dict for the sake of comparison
model_state_dict = trainer.strategy.lightning_module_state_dict()
if trainer.is_global_zero:
saved_model = cls.load_from_checkpoint(ckpt_path)
# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
assert torch.equal(ddp_param.float().cpu(), shard_param)
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
trainer.save_checkpoint(model_path, weights_only=True)
_assert_save_equality(trainer, model_path, cls=model.__class__)
# Test entry point
if model.__class__ is TestFSDPModelAutoWrapped:
model = TestFSDPModelAutoWrapped()
trainer.test(model) # model is wrapped, will not call configure_shared_model
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
if model.__class__ is TestFSDPModelAutoWrapped:
model = TestFSDPModelAutoWrapped()
trainer.test(model, ckpt_path=model_path)
# Predict entry point
if model.__class__ is TestFSDPModelAutoWrapped:
model = TestFSDPModelAutoWrapped()
if model.__class__ is TestFSDPModelAutoWrapped:
model = TestFSDPModelAutoWrapped()
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
if model.__class__ is TestFSDPModelAutoWrapped:
model = TestFSDPModelAutoWrapped()
trainer.predict(model, ckpt_path=model_path)
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that to raise Misconfiguration for FSDP on CPU."""
with pytest.raises(
MisconfigurationException, match="You selected strategy to be `ddp_fully_sharded`, but GPU is not available."
):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp")
assert isinstance(trainer.strategy, DDPFullyShardedStrategy)
trainer.strategy.setup_environment()
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
@RunIf(fairscale=True)
def test_fsdp_with_sharded_amp(cuda_count_1, tmpdir):
"""Test to ensure that plugin native amp plugin is correctly chosen when using sharded."""
trainer = Trainer(
default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", accelerator="gpu", devices=1, precision=16
)
assert isinstance(trainer.strategy, DDPFullyShardedStrategy)
assert isinstance(trainer.strategy.precision_plugin, FullyShardedNativeMixedPrecisionPlugin)
@RunIf(min_cuda_gpus=1, standalone=True, fairscale=True)
def test_fully_sharded_strategy_checkpoint(tmpdir):
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
@ -171,51 +216,6 @@ def test_fully_sharded_strategy_checkpoint_multi_gpus(tmpdir, model, strategy):
_run_multiple_stages(trainer, model)
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModelManualWrapped):
# Use FullySharded to get the state dict for the sake of comparison
model_state_dict = trainer.strategy.lightning_module_state_dict()
if trainer.is_global_zero:
saved_model = cls.load_from_checkpoint(ckpt_path)
# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
assert torch.equal(ddp_param.float().cpu(), shard_param)
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
trainer.save_checkpoint(model_path, weights_only=True)
_assert_save_equality(trainer, model_path, cls=model.__class__)
# Test entry point
if model.__class__ is TestFSDPModelAutoWrapped:
model = TestFSDPModelAutoWrapped()
trainer.test(model) # model is wrapped, will not call configure_shared_model
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
if model.__class__ is TestFSDPModelAutoWrapped:
model = TestFSDPModelAutoWrapped()
trainer.test(model, ckpt_path=model_path)
# Predict entry point
if model.__class__ is TestFSDPModelAutoWrapped:
model = TestFSDPModelAutoWrapped()
if model.__class__ is TestFSDPModelAutoWrapped:
model = TestFSDPModelAutoWrapped()
trainer.predict(model) # model is wrapped, will not call `configure_sharded_model`
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
if model.__class__ is TestFSDPModelAutoWrapped:
model = TestFSDPModelAutoWrapped()
trainer.predict(model, ckpt_path=model_path)
@RunIf(min_cuda_gpus=1, standalone=True, fairscale=True)
def test_fsdp_gradient_clipping_raises(tmpdir):
"""Test to ensure that an exception is raised when clipping gradients by value with FSDP."""
@ -254,3 +254,25 @@ def test_fsdp_rewrap_limitation(tmpdir):
with pytest.raises(MisconfigurationException, match="Using the same instance of model .* not supported"):
trainer.test(model)
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, fairscale_fully_sharded=True)
def test_invalid_parameters_in_optimizer(tmpdir):
trainer = Trainer(strategy="fsdp", accelerator="gpu", devices=1)
class EmptyParametersModel(BoringModel):
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-2)
model = EmptyParametersModel()
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
trainer.fit(model)
class NoFlatParametersModel(BoringModel):
def configure_optimizers(self):
layer = torch.nn.Linear(4, 5)
return torch.optim.Adam(layer.parameters(), lr=1e-2)
model = NoFlatParametersModel()
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
trainer.fit(model)