hack
This commit is contained in:
parent
096b063d6e
commit
0e619a490e
|
@ -155,6 +155,15 @@ class BitsandbytesPrecision(Precision):
|
|||
def convert_output(self, data: Any) -> Any:
|
||||
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
|
||||
|
||||
@property
|
||||
def mixed_precision_config(self):
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
|
||||
return TorchMixedPrecision(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.bfloat16,
|
||||
buffer_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
|
||||
def _quantize_on_load_hook(quantize_fn: Callable[[torch.Tensor], None], state_dict: OrderedDict, *_: Any) -> None:
|
||||
# There is only one key that ends with `*.weight`, the other one is the bias
|
||||
|
|
|
@ -222,8 +222,8 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
if self.mixed_precision:
|
||||
return self.mixed_precision
|
||||
plugin = self.precision
|
||||
if isinstance(plugin, FSDPPrecision):
|
||||
return plugin.mixed_precision_config
|
||||
# if isinstance(plugin, FSDPPrecision):
|
||||
return plugin.mixed_precision_config
|
||||
return None
|
||||
|
||||
@property
|
||||
|
@ -231,15 +231,15 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
def precision(self) -> FSDPPrecision:
|
||||
plugin = self._precision
|
||||
if plugin is not None:
|
||||
assert isinstance(plugin, FSDPPrecision)
|
||||
# assert isinstance(plugin, FSDPPrecision)
|
||||
return plugin
|
||||
return FSDPPrecision("32-true")
|
||||
|
||||
@precision.setter
|
||||
@override
|
||||
def precision(self, precision: Optional[FSDPPrecision]) -> None:
|
||||
if precision is not None and not isinstance(precision, FSDPPrecision):
|
||||
raise TypeError(f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision}")
|
||||
# if precision is not None and not isinstance(precision, FSDPPrecision):
|
||||
# raise TypeError(f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision}")
|
||||
self._precision = precision
|
||||
|
||||
@override
|
||||
|
|
Loading…
Reference in New Issue