[FIX] Native FSDP precision + tests (#12985)

This commit is contained in:
Sean Naren 2022-07-20 12:32:35 +01:00 committed by GitHub
parent c028ff3b95
commit d78698528d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 155 additions and 86 deletions

View File

@ -11,10 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional
import torch
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
else:
MixedPrecision = None
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
@ -29,3 +38,18 @@ class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)
@property
def mixed_precision_config(self) -> Optional[MixedPrecision]:
assert MixedPrecision is not None
if self.precision == PrecisionType.HALF:
dtype = torch.float16
elif self.precision == PrecisionType.BFLOAT:
dtype = torch.bfloat16
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
return MixedPrecision(
param_dtype=dtype,
reduce_dtype=dtype,
buffer_dtype=dtype,
)

View File

@ -23,6 +23,8 @@ import pytorch_lightning as pl
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
@ -35,18 +37,23 @@ from pytorch_lightning.utilities.distributed import (
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info
from pytorch_lightning.utilities.seed import reset_seed
if _TORCH_GREATER_EQUAL_1_11:
if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
FullyShardedDataParallel,
MixedPrecision,
)
from torch.distributed.fsdp.wrap import enable_wrap
else:
MixedPrecision = None
BackwardPrefetch = None # type: ignore[misc,assignment]
CPUOffload = None # type: ignore[misc,assignment]
log = logging.getLogger(__name__)
@ -56,7 +63,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
strategy_name = "fsdp_native"
_registered_strategies: List[str] = []
def __init__( # type: ignore[no-untyped-def]
def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
@ -64,10 +71,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
process_group_backend: Optional[str] = None,
cpu_offload=None,
backward_prefetch=None,
cpu_offload: Optional[CPUOffload] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None,
**kwargs: Any,
) -> None:
"""Strategy for Fully Sharded Data Parallel provided by torch.Distributed.
r"""Strategy for Fully Sharded Data Parallel provided by torch.Distributed.
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
@ -84,7 +93,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
`https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html`
Arguments:
cpu_offload (Optional [CPUOffload]):
cpu_offload:
CPU offloading config. Currently, only parameter and gradient CPU
offload is supported. It can be enabled via passing in
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
@ -92,14 +101,21 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
params and grads to be on same device to work with optimizer. This
API is subject to change. Default is ``None`` in which case there
will be no offloading.
backward_prefetch: (Optional[BackwardPrefetch]):
backward_prefetch:
This is an experimental feature that is subject to change in the
the near future. It allows users to enable two different backward_prefetch
algorithms to help backward communication and computation overlapping.
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
mixed_precision:
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`
or BF16 if ``precision=bf16`` unless a config is passed in.
This is only available in PyTorch 1.12 and later.
\**kwargs: Passed to the FSDP Context manager which will configure the FSDP class when wrapping modules.
"""
if not _TORCH_GREATER_EQUAL_1_11:
raise MisconfigurationException("DDPFullyShardedNativeStrategy is supported from pytorch v1.11.0 onwards.")
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException(
"`DDPFullyShardedNativeStrategy` is supported from PyTorch v1.12.0 onwards."
)
super().__init__(
accelerator=accelerator,
@ -109,16 +125,23 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
precision_plugin=precision_plugin,
)
self._process_group = None
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
self._process_group_backend: Optional[str] = process_group_backend
self.cpu_offload: Optional[CPUOffload] = cpu_offload
self.backward_prefetch: Optional[BackwardPrefetch] = backward_prefetch
self.num_nodes = 1
self._process_group_backend = process_group_backend
self.cpu_offload = cpu_offload
self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision
self._rank_0_will_call_children_scripts: bool = False
self.kwargs = kwargs
@property
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]
@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0
@property
def process_group(self) -> Optional[ProcessGroup]:
if self._process_group is None:
@ -130,10 +153,28 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend
@property
def mixed_precision_config(self) -> Optional[MixedPrecision]:
if self.mixed_precision:
return self.mixed_precision
plugin = self.precision_plugin
if isinstance(plugin, FullyShardedNativeMixedPrecisionPlugin):
return plugin.mixed_precision_config
@property
def distributed_sampler_kwargs(self) -> Dict:
return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
def setup_environment(self) -> None:
log.detail(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()
# determine which process we are and world size
self.set_world_ranks()
# set warning rank
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
init_dist_connection(self.cluster_environment, self._process_group_backend)
@ -146,15 +187,31 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
or get_default_process_group_backend_for_device(self.root_device)
)
def set_world_ranks(self) -> None:
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
self._rank_0_will_call_children_scripts = True
def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
# share ddp pids to all processes
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
assert self.model is not None
self.model = self._layer_sync.apply(self.model)
if not self.cpu_offload:
self.model_to_device()
# we set the device so that optimizers can be created with distributed comms.
assert self.lightning_module is not None
self.lightning_module._device = self.root_device
self.barrier()
self.setup_optimizers(trainer)
@ -162,20 +219,19 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
self.setup_precision_plugin()
def model_to_device(self) -> None:
# ensure we update the device type in the lightning module
assert self.lightning_module is not None
log.info(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
self.lightning_module.to(self.root_device)
pass
@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")
with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
backward_prefetch=self.backward_prefetch,
mixed_precision=self.mixed_precision_config,
device_id=self.root_device.index,
**self.kwargs,
):
yield
@ -219,7 +275,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
return [self.root_device.index]
def teardown(self) -> None:
log.info(f"{self.__class__.__name__}: tearing down strategy...")
rank_zero_info(f"{self.__class__.__name__}: tearing down strategy...")
if (
self.lightning_module is not None
and self.lightning_module.trainer is not None
@ -229,7 +285,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
assert self.model is not None
self.model = self._layer_sync.revert(self.model)
super().teardown()
assert self.cluster_environment is not None
self.cluster_environment.teardown()
self.precision_plugin.teardown()
self.accelerator.teardown()
@classmethod
def get_registered_strategies(cls) -> List[str]:
@ -237,7 +296,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
if _TORCH_GREATER_EQUAL_1_11:
if _TORCH_GREATER_EQUAL_1_12:
strategy_registry.register(
"fsdp_native",
cls,

View File

@ -700,17 +700,13 @@ class AcceleratorConnector:
if self._precision_flag == 16
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
if isinstance(self.strategy, DDPFullyShardedNativeStrategy):
raise MisconfigurationException(
"DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision"
)
if self._amp_type_flag == AMPType.NATIVE:
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)):
return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
if isinstance(self.strategy, DDPFullyShardedStrategy):
if isinstance(self.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)):
return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
return NativeMixedPrecisionPlugin(self._precision_flag, device)

View File

@ -574,7 +574,7 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
assert trainer.strategy.local_rank == 0
@RunIf(min_torch="1.11")
@RunIf(min_torch="1.12")
def test_check_native_fsdp_strategy_and_fallback():
with pytest.raises(
MisconfigurationException,
@ -584,25 +584,6 @@ def test_check_native_fsdp_strategy_and_fallback():
Trainer(accelerator="cpu", strategy="fsdp_native")
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
@mock.patch("torch.cuda.device_count", return_value=1)
@mock.patch("torch.cuda.is_available", return_value=True)
@RunIf(min_torch="1.11")
def test_mixed_precision_support_with_native_fsdp_strategy(device_count_mock, mock_cuda_available, tmpdir):
with pytest.raises(
MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision"
):
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
strategy="fsdp_native",
accelerator="gpu",
devices=1,
precision=16,
)
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
def test_unsupported_tpu_choice(mock_tpu_acc_avail):

View File

@ -1,6 +1,5 @@
import os
from typing import Any, Dict, Optional
from unittest import mock
import pytest
import torch
@ -8,19 +7,21 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from pytorch_lightning.utilities.types import STEP_OUTPUT
from tests_pytorch.helpers.runif import RunIf
if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import wrap
@RunIf(min_torch="1.12dev")
@RunIf(min_torch="1.12")
def test_invalid_on_cpu(tmpdir):
"""Test to ensure that to raise Misconfiguration for Native FSDP on CPU."""
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
@ -31,29 +32,27 @@ def test_invalid_on_cpu(tmpdir):
trainer.strategy.setup_environment()
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
@mock.patch("torch.cuda.device_count", return_value=1)
@mock.patch("torch.cuda.is_available", return_value=True)
@RunIf(min_torch="1.12dev")
def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir):
"""Test to ensure that plugin native amp plugin raises Misconfiguration error."""
with pytest.raises(
MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision"
):
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
strategy="fsdp_native",
accelerator="gpu",
devices=1,
precision=16,
)
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
@RunIf(min_torch="1.12", min_cuda_gpus=1)
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
def test_precision_plugin_config(precision, expected):
plugin = FullyShardedNativeMixedPrecisionPlugin(precision=precision, device="cuda")
config = plugin.mixed_precision_config
assert config.param_dtype == expected
assert config.buffer_dtype == expected
assert config.reduce_dtype == expected
@RunIf(min_torch="1.12")
def test_fsdp_custom_mixed_precision(tmpdir):
"""Test to ensure that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = DDPFullyShardedNativeStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config
class TestFSDPModel(BoringModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self):
super().__init__()
self.layer: Optional[torch.nn.Module] = None
def _init_model(self) -> None:
@ -79,16 +78,20 @@ class TestFSDPModel(BoringModel):
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def on_train_start(self) -> None:
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
self._assert_layer_fsdp_instance()
def on_test_start(self) -> None:
def on_test_batch_end(
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
self._assert_layer_fsdp_instance()
def on_validation_start(self) -> None:
def on_validation_batch_end(
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
self._assert_layer_fsdp_instance()
def on_prediction_start(self) -> None:
def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
self._assert_layer_fsdp_instance()
def _assert_layer_fsdp_instance(self) -> None:
@ -101,8 +104,13 @@ class TestFSDPModel(BoringModel):
assert self.layer.module[0].reshard_after_forward is True
assert self.layer.module[2].reshard_after_forward is True
precision = torch.float16 if self.precision == 16 else torch.bfloat16
assert self.layer.mixed_precision.param_dtype == precision
assert self.layer.mixed_precision.reduce_dtype == precision
assert self.layer.mixed_precision.buffer_dtype == precision
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
"""Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run."""
@ -119,18 +127,19 @@ def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12dev")
def test_fully_sharded_native_strategy_checkpoint(tmpdir):
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize("precision", [16, "bf16"])
def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision):
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
model = TestFSDPModel()
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp_native", precision=16, max_epochs=1
default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp_native", precision=precision, max_epochs=1
)
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir):
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
@ -150,7 +159,7 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir):
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.fit(model)
model_path = trainer.strategy.broadcast(model_path)
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
trainer.save_checkpoint(model_path, weights_only=True)
@ -158,7 +167,7 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
_assert_save_equality(trainer, model_path, cls=TestFSDPModel)
# Test entry point
trainer.test(model) # model is wrapped, will not call configure_shared_model
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
# provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap
trainer.test(ckpt_path=model_path)