[FIX] Native FSDP precision + tests (#12985)
This commit is contained in:
parent
c028ff3b95
commit
d78698528d
|
@ -11,10 +11,19 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
|
||||
else:
|
||||
MixedPrecision = None
|
||||
|
||||
|
||||
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
|
||||
|
@ -29,3 +38,18 @@ class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
|
|||
raise MisconfigurationException(
|
||||
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
|
||||
)
|
||||
|
||||
@property
|
||||
def mixed_precision_config(self) -> Optional[MixedPrecision]:
|
||||
assert MixedPrecision is not None
|
||||
if self.precision == PrecisionType.HALF:
|
||||
dtype = torch.float16
|
||||
elif self.precision == PrecisionType.BFLOAT:
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
|
||||
return MixedPrecision(
|
||||
param_dtype=dtype,
|
||||
reduce_dtype=dtype,
|
||||
buffer_dtype=dtype,
|
||||
)
|
||||
|
|
|
@ -23,6 +23,8 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
|
||||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
|
@ -35,18 +37,23 @@ from pytorch_lightning.utilities.distributed import (
|
|||
from pytorch_lightning.utilities.distributed import group as _group
|
||||
from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_11:
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
BackwardPrefetch,
|
||||
CPUOffload,
|
||||
FullyShardedDataParallel,
|
||||
MixedPrecision,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import enable_wrap
|
||||
|
||||
else:
|
||||
MixedPrecision = None
|
||||
BackwardPrefetch = None # type: ignore[misc,assignment]
|
||||
CPUOffload = None # type: ignore[misc,assignment]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +63,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
strategy_name = "fsdp_native"
|
||||
_registered_strategies: List[str] = []
|
||||
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
def __init__(
|
||||
self,
|
||||
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
|
||||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
|
@ -64,10 +71,12 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
process_group_backend: Optional[str] = None,
|
||||
cpu_offload=None,
|
||||
backward_prefetch=None,
|
||||
cpu_offload: Optional[CPUOffload] = None,
|
||||
backward_prefetch: Optional[BackwardPrefetch] = None,
|
||||
mixed_precision: Optional[MixedPrecision] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Strategy for Fully Sharded Data Parallel provided by torch.Distributed.
|
||||
r"""Strategy for Fully Sharded Data Parallel provided by torch.Distributed.
|
||||
|
||||
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
|
||||
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
|
||||
|
@ -84,7 +93,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
`https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html`
|
||||
|
||||
Arguments:
|
||||
cpu_offload (Optional [CPUOffload]):
|
||||
cpu_offload:
|
||||
CPU offloading config. Currently, only parameter and gradient CPU
|
||||
offload is supported. It can be enabled via passing in
|
||||
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
|
||||
|
@ -92,14 +101,21 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
params and grads to be on same device to work with optimizer. This
|
||||
API is subject to change. Default is ``None`` in which case there
|
||||
will be no offloading.
|
||||
backward_prefetch: (Optional[BackwardPrefetch]):
|
||||
backward_prefetch:
|
||||
This is an experimental feature that is subject to change in the
|
||||
the near future. It allows users to enable two different backward_prefetch
|
||||
algorithms to help backward communication and computation overlapping.
|
||||
Pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
|
||||
mixed_precision:
|
||||
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`
|
||||
or BF16 if ``precision=bf16`` unless a config is passed in.
|
||||
This is only available in PyTorch 1.12 and later.
|
||||
\**kwargs: Passed to the FSDP Context manager which will configure the FSDP class when wrapping modules.
|
||||
"""
|
||||
if not _TORCH_GREATER_EQUAL_1_11:
|
||||
raise MisconfigurationException("DDPFullyShardedNativeStrategy is supported from pytorch v1.11.0 onwards.")
|
||||
if not _TORCH_GREATER_EQUAL_1_12:
|
||||
raise MisconfigurationException(
|
||||
"`DDPFullyShardedNativeStrategy` is supported from PyTorch v1.12.0 onwards."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
accelerator=accelerator,
|
||||
|
@ -109,16 +125,23 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
precision_plugin=precision_plugin,
|
||||
)
|
||||
self._process_group = None
|
||||
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
|
||||
self._process_group_backend: Optional[str] = process_group_backend
|
||||
self.cpu_offload: Optional[CPUOffload] = cpu_offload
|
||||
self.backward_prefetch: Optional[BackwardPrefetch] = backward_prefetch
|
||||
self.num_nodes = 1
|
||||
self._process_group_backend = process_group_backend
|
||||
self.cpu_offload = cpu_offload
|
||||
self.backward_prefetch = backward_prefetch
|
||||
self.mixed_precision = mixed_precision
|
||||
self._rank_0_will_call_children_scripts: bool = False
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
def root_device(self) -> torch.device:
|
||||
assert self.parallel_devices is not None
|
||||
return self.parallel_devices[self.local_rank]
|
||||
|
||||
@property
|
||||
def num_processes(self) -> int:
|
||||
return len(self.parallel_devices) if self.parallel_devices is not None else 0
|
||||
|
||||
@property
|
||||
def process_group(self) -> Optional[ProcessGroup]:
|
||||
if self._process_group is None:
|
||||
|
@ -130,10 +153,28 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
def process_group_backend(self) -> Optional[str]:
|
||||
return self._process_group_backend
|
||||
|
||||
@property
|
||||
def mixed_precision_config(self) -> Optional[MixedPrecision]:
|
||||
if self.mixed_precision:
|
||||
return self.mixed_precision
|
||||
plugin = self.precision_plugin
|
||||
if isinstance(plugin, FullyShardedNativeMixedPrecisionPlugin):
|
||||
return plugin.mixed_precision_config
|
||||
|
||||
@property
|
||||
def distributed_sampler_kwargs(self) -> Dict:
|
||||
return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
|
||||
|
||||
def setup_environment(self) -> None:
|
||||
log.detail(f"{self.__class__.__name__}: setting up distributed...")
|
||||
reset_seed()
|
||||
|
||||
# determine which process we are and world size
|
||||
self.set_world_ranks()
|
||||
|
||||
# set warning rank
|
||||
rank_zero_only.rank = self.global_rank
|
||||
|
||||
self._process_group_backend = self._get_process_group_backend()
|
||||
assert self.cluster_environment is not None
|
||||
init_dist_connection(self.cluster_environment, self._process_group_backend)
|
||||
|
@ -146,15 +187,31 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
or get_default_process_group_backend_for_device(self.root_device)
|
||||
)
|
||||
|
||||
def set_world_ranks(self) -> None:
|
||||
if self.cluster_environment is None:
|
||||
return
|
||||
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
|
||||
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
|
||||
rank_zero_only.rank = self.cluster_environment.global_rank()
|
||||
|
||||
def _configure_launcher(self) -> None:
|
||||
assert self.cluster_environment is not None
|
||||
if not self.cluster_environment.creates_processes_externally:
|
||||
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
|
||||
self._rank_0_will_call_children_scripts = True
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.accelerator.setup(trainer)
|
||||
# share ddp pids to all processes
|
||||
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
|
||||
|
||||
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
|
||||
assert self.model is not None
|
||||
self.model = self._layer_sync.apply(self.model)
|
||||
|
||||
if not self.cpu_offload:
|
||||
self.model_to_device()
|
||||
# we set the device so that optimizers can be created with distributed comms.
|
||||
assert self.lightning_module is not None
|
||||
self.lightning_module._device = self.root_device
|
||||
|
||||
self.barrier()
|
||||
self.setup_optimizers(trainer)
|
||||
|
@ -162,20 +219,19 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
self.setup_precision_plugin()
|
||||
|
||||
def model_to_device(self) -> None:
|
||||
# ensure we update the device type in the lightning module
|
||||
assert self.lightning_module is not None
|
||||
log.info(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
|
||||
self.lightning_module.to(self.root_device)
|
||||
pass
|
||||
|
||||
@contextlib.contextmanager
|
||||
def model_sharded_context(self) -> Generator:
|
||||
log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")
|
||||
|
||||
with enable_wrap(
|
||||
wrapper_cls=FullyShardedDataParallel,
|
||||
process_group=self.process_group,
|
||||
cpu_offload=self.cpu_offload,
|
||||
backward_prefetch=self.backward_prefetch,
|
||||
mixed_precision=self.mixed_precision_config,
|
||||
device_id=self.root_device.index,
|
||||
**self.kwargs,
|
||||
):
|
||||
yield
|
||||
|
||||
|
@ -219,7 +275,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
return [self.root_device.index]
|
||||
|
||||
def teardown(self) -> None:
|
||||
log.info(f"{self.__class__.__name__}: tearing down strategy...")
|
||||
rank_zero_info(f"{self.__class__.__name__}: tearing down strategy...")
|
||||
if (
|
||||
self.lightning_module is not None
|
||||
and self.lightning_module.trainer is not None
|
||||
|
@ -229,7 +285,10 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
assert self.model is not None
|
||||
self.model = self._layer_sync.revert(self.model)
|
||||
|
||||
super().teardown()
|
||||
assert self.cluster_environment is not None
|
||||
self.cluster_environment.teardown()
|
||||
self.precision_plugin.teardown()
|
||||
self.accelerator.teardown()
|
||||
|
||||
@classmethod
|
||||
def get_registered_strategies(cls) -> List[str]:
|
||||
|
@ -237,7 +296,7 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
|
||||
@classmethod
|
||||
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||
if _TORCH_GREATER_EQUAL_1_11:
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
strategy_registry.register(
|
||||
"fsdp_native",
|
||||
cls,
|
||||
|
|
|
@ -700,17 +700,13 @@ class AcceleratorConnector:
|
|||
if self._precision_flag == 16
|
||||
else "Using bfloat16 Automatic Mixed Precision (AMP)"
|
||||
)
|
||||
if isinstance(self.strategy, DDPFullyShardedNativeStrategy):
|
||||
raise MisconfigurationException(
|
||||
"DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision"
|
||||
)
|
||||
|
||||
if self._amp_type_flag == AMPType.NATIVE:
|
||||
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
|
||||
|
||||
if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)):
|
||||
return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
|
||||
if isinstance(self.strategy, DDPFullyShardedStrategy):
|
||||
if isinstance(self.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)):
|
||||
return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
|
||||
return NativeMixedPrecisionPlugin(self._precision_flag, device)
|
||||
|
||||
|
|
|
@ -574,7 +574,7 @@ def test_strategy_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock
|
|||
assert trainer.strategy.local_rank == 0
|
||||
|
||||
|
||||
@RunIf(min_torch="1.11")
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_check_native_fsdp_strategy_and_fallback():
|
||||
with pytest.raises(
|
||||
MisconfigurationException,
|
||||
|
@ -584,25 +584,6 @@ def test_check_native_fsdp_strategy_and_fallback():
|
|||
Trainer(accelerator="cpu", strategy="fsdp_native")
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
|
||||
@mock.patch("torch.cuda.device_count", return_value=1)
|
||||
@mock.patch("torch.cuda.is_available", return_value=True)
|
||||
@RunIf(min_torch="1.11")
|
||||
def test_mixed_precision_support_with_native_fsdp_strategy(device_count_mock, mock_cuda_available, tmpdir):
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision"
|
||||
):
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
strategy="fsdp_native",
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
precision=16,
|
||||
)
|
||||
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
|
||||
def test_unsupported_tpu_choice(mock_tpu_acc_avail):
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -8,19 +7,21 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
|
||||
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel, MixedPrecision
|
||||
from torch.distributed.fsdp.wrap import wrap
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12dev")
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_invalid_on_cpu(tmpdir):
|
||||
"""Test to ensure that to raise Misconfiguration for Native FSDP on CPU."""
|
||||
"""Test to ensure that we raise Misconfiguration for Native FSDP on CPU."""
|
||||
with pytest.raises(
|
||||
MisconfigurationException,
|
||||
match=f"You selected strategy to be `{DDPFullyShardedNativeStrategy.strategy_name}`, "
|
||||
|
@ -31,29 +32,27 @@ def test_invalid_on_cpu(tmpdir):
|
|||
trainer.strategy.setup_environment()
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})
|
||||
@mock.patch("torch.cuda.device_count", return_value=1)
|
||||
@mock.patch("torch.cuda.is_available", return_value=True)
|
||||
@RunIf(min_torch="1.12dev")
|
||||
def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir):
|
||||
"""Test to ensure that plugin native amp plugin raises Misconfiguration error."""
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="DDPFullyShardedNativeStrategy currently doesn't support Mixed Precision"
|
||||
):
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
strategy="fsdp_native",
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
precision=16,
|
||||
)
|
||||
assert isinstance(trainer.strategy, DDPFullyShardedNativeStrategy)
|
||||
@RunIf(min_torch="1.12", min_cuda_gpus=1)
|
||||
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
|
||||
def test_precision_plugin_config(precision, expected):
|
||||
plugin = FullyShardedNativeMixedPrecisionPlugin(precision=precision, device="cuda")
|
||||
config = plugin.mixed_precision_config
|
||||
assert config.param_dtype == expected
|
||||
assert config.buffer_dtype == expected
|
||||
assert config.reduce_dtype == expected
|
||||
|
||||
|
||||
@RunIf(min_torch="1.12")
|
||||
def test_fsdp_custom_mixed_precision(tmpdir):
|
||||
"""Test to ensure that passing a custom mixed precision config works."""
|
||||
config = MixedPrecision()
|
||||
strategy = DDPFullyShardedNativeStrategy(mixed_precision=config)
|
||||
assert strategy.mixed_precision_config == config
|
||||
|
||||
|
||||
class TestFSDPModel(BoringModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer: Optional[torch.nn.Module] = None
|
||||
|
||||
def _init_model(self) -> None:
|
||||
|
@ -79,16 +78,20 @@ class TestFSDPModel(BoringModel):
|
|||
def configure_optimizers(self):
|
||||
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
|
||||
def on_train_start(self) -> None:
|
||||
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def on_test_start(self) -> None:
|
||||
def on_test_batch_end(
|
||||
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def on_validation_start(self) -> None:
|
||||
def on_validation_batch_end(
|
||||
self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def on_prediction_start(self) -> None:
|
||||
def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
self._assert_layer_fsdp_instance()
|
||||
|
||||
def _assert_layer_fsdp_instance(self) -> None:
|
||||
|
@ -101,8 +104,13 @@ class TestFSDPModel(BoringModel):
|
|||
assert self.layer.module[0].reshard_after_forward is True
|
||||
assert self.layer.module[2].reshard_after_forward is True
|
||||
|
||||
precision = torch.float16 if self.precision == 16 else torch.bfloat16
|
||||
assert self.layer.mixed_precision.param_dtype == precision
|
||||
assert self.layer.mixed_precision.reduce_dtype == precision
|
||||
assert self.layer.mixed_precision.buffer_dtype == precision
|
||||
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev")
|
||||
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
|
||||
def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
|
||||
"""Test to ensure that sync_batchnorm works when using fsdp_native and GPU, and all stages can be run."""
|
||||
|
||||
|
@ -119,18 +127,19 @@ def test_fully_sharded_native_strategy_sync_batchnorm(tmpdir):
|
|||
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12dev")
|
||||
def test_fully_sharded_native_strategy_checkpoint(tmpdir):
|
||||
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12")
|
||||
@pytest.mark.parametrize("precision", [16, "bf16"])
|
||||
def test_fully_sharded_native_strategy_checkpoint(tmpdir, precision):
|
||||
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
|
||||
|
||||
model = TestFSDPModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp_native", precision=16, max_epochs=1
|
||||
default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp_native", precision=precision, max_epochs=1
|
||||
)
|
||||
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12dev")
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
|
||||
def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir):
|
||||
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
|
||||
|
||||
|
@ -150,7 +159,7 @@ def test_fully_sharded_native_strategy_checkpoint_multi_gpus(tmpdir):
|
|||
|
||||
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
|
||||
trainer.fit(model)
|
||||
|
||||
model_path = trainer.strategy.broadcast(model_path)
|
||||
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
|
||||
|
||||
trainer.save_checkpoint(model_path, weights_only=True)
|
||||
|
@ -158,7 +167,7 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
|
|||
_assert_save_equality(trainer, model_path, cls=TestFSDPModel)
|
||||
|
||||
# Test entry point
|
||||
trainer.test(model) # model is wrapped, will not call configure_shared_model
|
||||
trainer.test(model) # model is wrapped, will not call `configure_sharded_model`
|
||||
|
||||
# provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap
|
||||
trainer.test(ckpt_path=model_path)
|
||||
|
|
Loading…
Reference in New Issue