update BitsandBytes version (#20313)

* upggrade requiremnets.txt

* update fabric bitsandbytes linear quantization for bnb 0.44.1

* add quant_storage param

* exclude macos from bnb upgrade

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ali Alshaarawy 2024-10-02 12:18:43 -04:00 committed by GitHub
parent 66508ff4b7
commit 5dea36c5e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 13 additions and 5 deletions

View File

@ -6,4 +6,5 @@
# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods` # note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372 # shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict
bitsandbytes >=0.42.0,<0.43.0 bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32'
bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin'

View File

@ -8,4 +8,5 @@ hydra-core >=1.2.0, <1.4.0
jsonargparse[signatures] >=4.27.7, <4.28.0 jsonargparse[signatures] >=4.27.7, <4.28.0
rich >=12.3.0, <13.6.0 rich >=12.3.0, <13.6.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
bitsandbytes >=0.42.0,<0.43.0 bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32'
bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin'

View File

@ -43,7 +43,7 @@ _BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes>=0.42.0")
class BitsandbytesPrecision(Precision): class BitsandbytesPrecision(Precision):
"""Plugin for quantizing weights with `bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__. """Plugin for quantizing weights with `bitsandbytes <https://github.com/bitsandbytes-foundation/bitsandbytes>`__.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature. .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
@ -184,11 +184,15 @@ def _replace_param(
if param.device.type == "meta": if param.device.type == "meta":
if isinstance(param, bnb.nn.Params4bit): if isinstance(param, bnb.nn.Params4bit):
return bnb.nn.Params4bit( return bnb.nn.Params4bit(
data, data=data,
requires_grad=data.requires_grad, requires_grad=data.requires_grad,
quant_state=quant_state, quant_state=quant_state,
blocksize=param.blocksize,
compress_statistics=param.compress_statistics, compress_statistics=param.compress_statistics,
quant_type=param.quant_type, quant_type=param.quant_type,
quant_storage=param.quant_storage,
module=param.module,
bnb_quantized=param.bnb_quantized,
) )
return torch.nn.Parameter(data, requires_grad=data.requires_grad) return torch.nn.Parameter(data, requires_grad=data.requires_grad)
param.data = data param.data = data
@ -322,6 +326,7 @@ def _import_bitsandbytes() -> ModuleType:
return return
assert isinstance(self.weight, bnb.nn.Params4bit) assert isinstance(self.weight, bnb.nn.Params4bit)
self.weight = self.quantize(self.weight, weight, device) self.weight = self.quantize(self.weight, weight, device)
self.weight.bnb_quantized = True
@staticmethod @staticmethod
def quantize( def quantize(
@ -337,6 +342,7 @@ def _import_bitsandbytes() -> ModuleType:
blocksize=params4bit.blocksize, blocksize=params4bit.blocksize,
compress_statistics=params4bit.compress_statistics, compress_statistics=params4bit.compress_statistics,
quant_type=params4bit.quant_type, quant_type=params4bit.quant_type,
quant_storage=params4bit.quant_storage,
) )
return _replace_param(params4bit, w_4bit, quant_state) return _replace_param(params4bit, w_4bit, quant_state)

View File

@ -16,7 +16,7 @@ from lightning.pytorch.plugins.precision.precision import Precision
class BitsandbytesPrecision(Precision, FabricBNBPrecision): class BitsandbytesPrecision(Precision, FabricBNBPrecision):
"""Plugin for quantizing weights with `bitsandbytes <https://github.com/TimDettmers/bitsandbytes>`__. """Plugin for quantizing weights with `bitsandbytes <https://github.com/bitsandbytes-foundation/bitsandbytes>`__.
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature. .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.