update BitsandBytes version (#20313)
* upggrade requiremnets.txt * update fabric bitsandbytes linear quantization for bnb 0.44.1 * add quant_storage param * exclude macos from bnb upgrade * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
66508ff4b7
commit
5dea36c5e2
|
@ -6,4 +6,5 @@
|
||||||
# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
|
# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
|
||||||
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
|
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
|
||||||
deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict
|
deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict
|
||||||
bitsandbytes >=0.42.0,<0.43.0
|
bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32'
|
||||||
|
bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin'
|
||||||
|
|
|
@ -8,4 +8,5 @@ hydra-core >=1.2.0, <1.4.0
|
||||||
jsonargparse[signatures] >=4.27.7, <4.28.0
|
jsonargparse[signatures] >=4.27.7, <4.28.0
|
||||||
rich >=12.3.0, <13.6.0
|
rich >=12.3.0, <13.6.0
|
||||||
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
|
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
|
||||||
bitsandbytes >=0.42.0,<0.43.0
|
bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32'
|
||||||
|
bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin'
|
||||||
|
|
|
@ -43,7 +43,7 @@ _BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes>=0.42.0")
|
||||||
|
|
||||||
|
|
||||||
class BitsandbytesPrecision(Precision):
|
class BitsandbytesPrecision(Precision):
|
||||||
"""Plugin for quantizing weights with `bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__.
|
"""Plugin for quantizing weights with `bitsandbytes <https://github.com/bitsandbytes-foundation/bitsandbytes>`__.
|
||||||
|
|
||||||
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
|
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
|
||||||
|
|
||||||
|
@ -184,11 +184,15 @@ def _replace_param(
|
||||||
if param.device.type == "meta":
|
if param.device.type == "meta":
|
||||||
if isinstance(param, bnb.nn.Params4bit):
|
if isinstance(param, bnb.nn.Params4bit):
|
||||||
return bnb.nn.Params4bit(
|
return bnb.nn.Params4bit(
|
||||||
data,
|
data=data,
|
||||||
requires_grad=data.requires_grad,
|
requires_grad=data.requires_grad,
|
||||||
quant_state=quant_state,
|
quant_state=quant_state,
|
||||||
|
blocksize=param.blocksize,
|
||||||
compress_statistics=param.compress_statistics,
|
compress_statistics=param.compress_statistics,
|
||||||
quant_type=param.quant_type,
|
quant_type=param.quant_type,
|
||||||
|
quant_storage=param.quant_storage,
|
||||||
|
module=param.module,
|
||||||
|
bnb_quantized=param.bnb_quantized,
|
||||||
)
|
)
|
||||||
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
|
return torch.nn.Parameter(data, requires_grad=data.requires_grad)
|
||||||
param.data = data
|
param.data = data
|
||||||
|
@ -322,6 +326,7 @@ def _import_bitsandbytes() -> ModuleType:
|
||||||
return
|
return
|
||||||
assert isinstance(self.weight, bnb.nn.Params4bit)
|
assert isinstance(self.weight, bnb.nn.Params4bit)
|
||||||
self.weight = self.quantize(self.weight, weight, device)
|
self.weight = self.quantize(self.weight, weight, device)
|
||||||
|
self.weight.bnb_quantized = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def quantize(
|
def quantize(
|
||||||
|
@ -337,6 +342,7 @@ def _import_bitsandbytes() -> ModuleType:
|
||||||
blocksize=params4bit.blocksize,
|
blocksize=params4bit.blocksize,
|
||||||
compress_statistics=params4bit.compress_statistics,
|
compress_statistics=params4bit.compress_statistics,
|
||||||
quant_type=params4bit.quant_type,
|
quant_type=params4bit.quant_type,
|
||||||
|
quant_storage=params4bit.quant_storage,
|
||||||
)
|
)
|
||||||
return _replace_param(params4bit, w_4bit, quant_state)
|
return _replace_param(params4bit, w_4bit, quant_state)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from lightning.pytorch.plugins.precision.precision import Precision
|
||||||
|
|
||||||
|
|
||||||
class BitsandbytesPrecision(Precision, FabricBNBPrecision):
|
class BitsandbytesPrecision(Precision, FabricBNBPrecision):
|
||||||
"""Plugin for quantizing weights with `bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__.
|
"""Plugin for quantizing weights with `bitsandbytes <https://github.com/bitsandbytes-foundation/bitsandbytes>`__.
|
||||||
|
|
||||||
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
|
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue