diff --git a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 870e658bfc..8c693f2975 100644 --- a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -11,10 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional + +import torch from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + +if _TORCH_GREATER_EQUAL_1_12: + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision +else: + MixedPrecision = None class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): @@ -29,3 +38,18 @@ class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): raise MisconfigurationException( f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" ) + + @property + def mixed_precision_config(self) -> Optional[MixedPrecision]: + assert MixedPrecision is not None + if self.precision == PrecisionType.HALF: + dtype = torch.float16 + elif self.precision == PrecisionType.BFLOAT: + dtype = torch.bfloat16 + else: + raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") + return MixedPrecision( + param_dtype=dtype, + reduce_dtype=dtype, + buffer_dtype=dtype, + ) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index d8cd66be2e..7528d5b959 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -23,6 +23,8 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin +from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn @@ -35,18 +37,23 @@ from pytorch_lightning.utilities.distributed import ( from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from pytorch_lightning.utilities.optimizer import optimizers_to_device +from pytorch_lightning.utilities.rank_zero import rank_zero_info from pytorch_lightning.utilities.seed import reset_seed -if _TORCH_GREATER_EQUAL_1_11: +if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.fully_sharded_data_parallel import ( BackwardPrefetch, CPUOffload, FullyShardedDataParallel, + MixedPrecision, ) from torch.distributed.fsdp.wrap import enable_wrap - +else: + MixedPrecision = None + BackwardPrefetch = None # type: ignore[misc,assignment] + CPUOffload = None # type: ignore[misc,assignment] log = logging.getLogger(__name__) @@ -56,7 +63,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): strategy_name = "fsdp_native" _registered_strategies: List[str] = [] - def __init__( # type: ignore[no-untyped-def] + def __init__( self, accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, parallel_devices: Optional[List[torch.device]] = None, @@ -64,10 +71,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, process_group_backend: Optional[str] = None, - cpu_offload=None, - backward_prefetch=None, + cpu_offload: Optional[CPUOffload] = None, + backward_prefetch: Optional[BackwardPrefetch] = None, + mixed_precision: Optional[MixedPrecision] = None, + **kwargs: Any, ) -> None: - """Strategy for Fully Sharded Data Parallel provided by torch.Distributed. + r"""Strategy for Fully Sharded Data Parallel provided by torch.Distributed. Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain @@ -84,7 +93,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): `https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html` Arguments: - cpu_offload (Optional [CPUOffload]): + cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this @@ -92,14 +101,21 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): params and grads to be on same device to work with optimizer. This API is subject to change. Default is ``None`` in which case there will be no offloading. - backward_prefetch: (Optional[BackwardPrefetch]): + backward_prefetch: This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch algorithms to help backward communication and computation overlapping. Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. + mixed_precision: + Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16` + or BF16 if ``precision=bf16`` unless a config is passed in. + This is only available in PyTorch 1.12 and later. + \**kwargs: Passed to the FSDP Context manager which will configure the FSDP class when wrapping modules. """ - if not _TORCH_GREATER_EQUAL_1_11: - raise MisconfigurationException("DDPFullyShardedNativeStrategy is supported from pytorch v1.11.0 onwards.") + if not _TORCH_GREATER_EQUAL_1_12: + raise MisconfigurationException( + "`DDPFullyShardedNativeStrategy` is supported from PyTorch v1.12.0 onwards." + ) super().__init__( accelerator=accelerator, @@ -109,16 +125,23 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): precision_plugin=precision_plugin, ) self._process_group = None - self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 - self._process_group_backend: Optional[str] = process_group_backend - self.cpu_offload: Optional[CPUOffload] = cpu_offload - self.backward_prefetch: Optional[BackwardPrefetch] = backward_prefetch + self.num_nodes = 1 + self._process_group_backend = process_group_backend + self.cpu_offload = cpu_offload + self.backward_prefetch = backward_prefetch + self.mixed_precision = mixed_precision + self._rank_0_will_call_children_scripts: bool = False + self.kwargs = kwargs @property def root_device(self) -> torch.device: assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + @property def process_group(self) -> Optional[ProcessGroup]: if self._process_group is None: @@ -130,10 +153,28 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): def process_group_backend(self) -> Optional[str]: return self._process_group_backend + @property + def mixed_precision_config(self) -> Optional[MixedPrecision]: + if self.mixed_precision: + return self.mixed_precision + plugin = self.precision_plugin + if isinstance(plugin, FullyShardedNativeMixedPrecisionPlugin): + return plugin.mixed_precision_config + + @property + def distributed_sampler_kwargs(self) -> Dict: + return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) + def setup_environment(self) -> None: + log.detail(f"{self.__class__.__name__}: setting up distributed...") reset_seed() + + # determine which process we are and world size + self.set_world_ranks() + # set warning rank rank_zero_only.rank = self.global_rank + self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None init_dist_connection(self.cluster_environment, self._process_group_backend) @@ -146,15 +187,31 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): or get_default_process_group_backend_for_device(self.root_device) ) + def set_world_ranks(self) -> None: + if self.cluster_environment is None: + return + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + rank_zero_only.rank = self.cluster_environment.global_rank() + + def _configure_launcher(self) -> None: + assert self.cluster_environment is not None + if not self.cluster_environment.creates_processes_externally: + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + self._rank_0_will_call_children_scripts = True + def setup(self, trainer: "pl.Trainer") -> None: self.accelerator.setup(trainer) + # share ddp pids to all processes + self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: assert self.model is not None self.model = self._layer_sync.apply(self.model) - if not self.cpu_offload: - self.model_to_device() + # we set the device so that optimizers can be created with distributed comms. + assert self.lightning_module is not None + self.lightning_module._device = self.root_device self.barrier() self.setup_optimizers(trainer) @@ -162,20 +219,19 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): self.setup_precision_plugin() def model_to_device(self) -> None: - # ensure we update the device type in the lightning module - assert self.lightning_module is not None - log.info(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") - self.lightning_module.to(self.root_device) + pass @contextlib.contextmanager def model_sharded_context(self) -> Generator: log.detail(f"{self.__class__.__name__}: entered model_sharded_context.") - with enable_wrap( wrapper_cls=FullyShardedDataParallel, process_group=self.process_group, cpu_offload=self.cpu_offload, backward_prefetch=self.backward_prefetch, + mixed_precision=self.mixed_precision_config, + device_id=self.root_device.index, + **self.kwargs, ): yield @@ -219,7 +275,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): return [self.root_device.index] def teardown(self) -> None: - log.info(f"{self.__class__.__name__}: tearing down strategy...") + rank_zero_info(f"{self.__class__.__name__}: tearing down strategy...") if ( self.lightning_module is not None and self.lightning_module.trainer is not None @@ -229,7 +285,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): assert self.model is not None self.model = self._layer_sync.revert(self.model) - super().teardown() + assert self.cluster_environment is not None + self.cluster_environment.teardown() + self.precision_plugin.teardown() + self.accelerator.teardown() @classmethod def get_registered_strategies(cls) -> List[str]: @@ -237,7 +296,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy): @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: - if _TORCH_GREATER_EQUAL_1_11: + if _TORCH_GREATER_EQUAL_1_12: strategy_registry.register( "fsdp_native", cls, diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index ece0e5d27b..f72aba305e 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -700,17 +700,13 @@ class AcceleratorConnector: if self._precision_flag == 16 else "Using bfloat16 Automatic Mixed Precision (AMP)" ) - if isinstance(self.strategy, DDPFullyShardedNativeStrategy): - raise MisconfigurationException( - "DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision" - ) if self._amp_type_flag == AMPType.NATIVE: device = "cpu" if self._accelerator_flag == "cpu" else "cuda" if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device) - if isinstance(self.strategy, DDPFullyShardedStrategy): + if isinstance(self.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)): return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device) return NativeMixedPrecisionPlugin(self._precision_flag, device) diff --git a/tests/tests_pytorch/accelerators/test_accelerator_connector.py b/tests/tests_pytorch/accelerators/test_accelerator_connector.py index 33911bffb0..3575b739ee 100644 --- a/tests/tests_pytorch/accelerators/test_accelerator_connector.py +++ b/tests/tests_pytorch/accelerators/test_accelerator_connector.py @@ -574,7 +574,7 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock assert trainer.strategy.local_rank == 0 -@RunIf(min_torch="1.11") +@RunIf(min_torch="1.12") def test_check_native_fsdp_strategy_and_fallback(): with pytest.raises( MisconfigurationException, @@ -584,25 +584,6 @@ def test_check_native_fsdp_strategy_and_fallback(): Trainer(accelerator="cpu", strategy="fsdp_native") -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) -@mock.patch("torch.cuda.device_count", return_value=1) -@mock.patch("torch.cuda.is_available", return_value=True) -@RunIf(min_torch="1.11") -def test_mixed_precision_support_with_native_fsdp_strategy(device_count_mock, mock_cuda_available, tmpdir): - with pytest.raises( - MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision" - ): - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - strategy="fsdp_native", - accelerator="gpu", - devices=1, - precision=16, - ) - assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy) - - @mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True) def test_unsupported_tpu_choice(mock_tpu_acc_avail): diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index b6dbff1792..1ac7ad0b66 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -1,6 +1,5 @@ import os from typing import Any, Dict, Optional -from unittest import mock import pytest import torch @@ -8,19 +7,21 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 +from pytorch_lightning.utilities.types import STEP_OUTPUT from tests_pytorch.helpers.runif import RunIf if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision from torch.distributed.fsdp.wrap import wrap -@RunIf(min_torch="1.12dev") +@RunIf(min_torch="1.12") def test_invalid_on_cpu(tmpdir): - """Test to ensure that to raise Misconfiguration for Native FSDP on CPU.""" + """Test to ensure that we raise Misconfiguration for Native FSDP on CPU.""" with pytest.raises( MisconfigurationException, match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, " @@ -31,29 +32,27 @@ def test_invalid_on_cpu(tmpdir): trainer.strategy.setup_environment() -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) -@mock.patch("torch.cuda.device_count", return_value=1) -@mock.patch("torch.cuda.is_available", return_value=True) -@RunIf(min_torch="1.12dev") -def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): - """Test to ensure that plugin native amp plugin raises Misconfiguration error.""" - with pytest.raises( - MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision" - ): - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - strategy="fsdp_native", - accelerator="gpu", - devices=1, - precision=16, - ) - assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy) +@RunIf(min_torch="1.12", min_cuda_gpus=1) +@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) +def test_precision_plugin_config(precision, expected): + plugin = FullyShardedNativeMixedPrecisionPlugin(precision=precision, device="cuda") + config = plugin.mixed_precision_config + assert config.param_dtype == expected + assert config.buffer_dtype == expected + assert config.reduce_dtype == expected + + +@RunIf(min_torch="1.12") +def test_fsdp_custom_mixed_precision(tmpdir): + """Test to ensure that passing a custom mixed precision config works.""" + config = MixedPrecision() + strategy = DDPFullyShardedNativeStrategy(mixed_precision=config) + assert strategy.mixed_precision_config == config class TestFSDPModel(BoringModel): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self): + super().__init__() self.layer: Optional[torch.nn.Module] = None def _init_model(self) -> None: @@ -79,16 +78,20 @@ class TestFSDPModel(BoringModel): def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1) - def on_train_start(self) -> None: + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: self._assert_layer_fsdp_instance() - def on_test_start(self) -> None: + def on_test_batch_end( + self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: self._assert_layer_fsdp_instance() - def on_validation_start(self) -> None: + def on_validation_batch_end( + self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: self._assert_layer_fsdp_instance() - def on_prediction_start(self) -> None: + def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None: self._assert_layer_fsdp_instance() def _assert_layer_fsdp_instance(self) -> None: @@ -101,8 +104,13 @@ class TestFSDPModel(BoringModel): assert self.layer.module[0].reshard_after_forward is True assert self.layer.module[2].reshard_after_forward is True + precision = torch.float16 if self.precision == 16 else torch.bfloat16 + assert self.layer.mixed_precision.param_dtype == precision + assert self.layer.mixed_precision.reduce_dtype == precision + assert self.layer.mixed_precision.buffer_dtype == precision -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev") + +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): """Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run.""" @@ -119,18 +127,19 @@ def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) -@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12dev") -def test_fully_sharded_native_strategy_checkpoint(tmpdir): +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") +@pytest.mark.parametrize("precision", [16, "bf16"]) +def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision): """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" model = TestFSDPModel() trainer = Trainer( - default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp_native", precision=16, max_epochs=1 + default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp_native", precision=precision, max_epochs=1 ) _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir): """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" @@ -150,7 +159,7 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir): def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): trainer.fit(model) - + model_path = trainer.strategy.broadcast(model_path) model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path trainer.save_checkpoint(model_path, weights_only=True) @@ -158,7 +167,7 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): _assert_save_equality(trainer, model_path, cls=TestFSDPModel) # Test entry point - trainer.test(model) # model is wrapped, will not call configure_shared_model + trainer.test(model) # model is wrapped, will not call `configure_sharded_model` # provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap trainer.test(ckpt_path=model_path)