From 0e619a490e7f2c94cdf03eeae1c594eda7f3b9a3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 9 Mar 2024 03:38:55 +0100 Subject: [PATCH] hack --- src/lightning/fabric/plugins/precision/bitsandbytes.py | 9 +++++++++ src/lightning/fabric/strategies/fsdp.py | 10 +++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 12a0ac3998..a41f387b2f 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -155,6 +155,15 @@ class BitsandbytesPrecision(Precision): def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) + @property + def mixed_precision_config(self): + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision + return TorchMixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, + ) + def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None: # There is only one key that ends with `*.weight`, the other one is the bias diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index ed89629f72..bf85d6c4c5 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -222,8 +222,8 @@ class FSDPStrategy(ParallelStrategy, _Sharded): if self.mixed_precision: return self.mixed_precision plugin = self.precision - if isinstance(plugin, FSDPPrecision): - return plugin.mixed_precision_config + # if isinstance(plugin, FSDPPrecision): + return plugin.mixed_precision_config return None @property @@ -231,15 +231,15 @@ class FSDPStrategy(ParallelStrategy, _Sharded): def precision(self) -> FSDPPrecision: plugin = self._precision if plugin is not None: - assert isinstance(plugin, FSDPPrecision) + # assert isinstance(plugin, FSDPPrecision) return plugin return FSDPPrecision("32-true") @precision.setter @override def precision(self, precision: Optional[FSDPPrecision]) -> None: - if precision is not None and not isinstance(precision, FSDPPrecision): - raise TypeError(f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision}") + # if precision is not None and not isinstance(precision, FSDPPrecision): + # raise TypeError(f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision}") self._precision = precision @override