This commit is contained in:
awaelchli 2024-03-09 03:38:55 +01:00
parent 096b063d6e
commit 0e619a490e
2 changed files with 14 additions and 5 deletions

View File

@ -155,6 +155,15 @@ class BitsandbytesPrecision(Precision):
def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
@property
def mixed_precision_config(self):
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
return TorchMixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None:
# There is only one key that ends with `*.weight`, the other one is the bias

View File

@ -222,8 +222,8 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
if self.mixed_precision:
return self.mixed_precision
plugin = self.precision
if isinstance(plugin, FSDPPrecision):
return plugin.mixed_precision_config
# if isinstance(plugin, FSDPPrecision):
return plugin.mixed_precision_config
return None
@property
@ -231,15 +231,15 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
def precision(self) -> FSDPPrecision:
plugin = self._precision
if plugin is not None:
assert isinstance(plugin, FSDPPrecision)
# assert isinstance(plugin, FSDPPrecision)
return plugin
return FSDPPrecision("32-true")
@precision.setter
@override
def precision(self, precision: Optional[FSDPPrecision]) -> None:
if precision is not None and not isinstance(precision, FSDPPrecision):
raise TypeError(f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision}")
# if precision is not None and not isinstance(precision, FSDPPrecision):
# raise TypeError(f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision}")
self._precision = precision
@override