233 lines
9.2 KiB
Python
233 lines
9.2 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License
|
|
import sys
|
|
from unittest.mock import Mock
|
|
|
|
import lightning.fabric
|
|
import pytest
|
|
import torch
|
|
import torch.distributed
|
|
from lightning.fabric import Fabric
|
|
from lightning.fabric.connector import _Connector
|
|
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
|
|
from lightning.fabric.utilities.init import _materialize_meta_tensors
|
|
|
|
from tests_fabric.helpers.runif import RunIf
|
|
|
|
|
|
@pytest.mark.skipif(_BITSANDBYTES_AVAILABLE, reason="bitsandbytes needs to be unavailable")
|
|
def test_bitsandbytes_plugin(monkeypatch):
|
|
module = lightning.fabric.plugins.precision.bitsandbytes
|
|
monkeypatch.setattr(module, "_BITSANDBYTES_AVAILABLE", lambda: True)
|
|
bitsandbytes_mock = Mock()
|
|
monkeypatch.setitem(sys.modules, "bitsandbytes", bitsandbytes_mock)
|
|
|
|
class ModuleMock(torch.nn.Linear):
|
|
def __init__(self, in_features, out_features, bias=True, *_, **__):
|
|
super().__init__(in_features, out_features, bias)
|
|
|
|
bitsandbytes_mock.nn.Linear8bitLt = ModuleMock
|
|
bitsandbytes_mock.nn.Linear4bit = ModuleMock
|
|
bitsandbytes_mock.nn.Params4bit = object
|
|
|
|
precision = BitsandbytesPrecision("nf4", dtype=torch.float16)
|
|
connector = _Connector(plugins=precision)
|
|
assert connector.precision is precision
|
|
assert precision.dtype == torch.float16
|
|
|
|
# same logic as in `test_default_dtype_is_restored`
|
|
assert torch.get_default_dtype() is torch.float32
|
|
with pytest.raises(RuntimeError, match="foo"), precision.module_init_context():
|
|
assert torch.get_default_dtype() is not torch.float32
|
|
raise RuntimeError("foo")
|
|
assert torch.get_default_dtype() is torch.float32
|
|
|
|
class SubModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l = torch.nn.Linear(1, 3)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(16, 48)
|
|
self.l2 = SubModule()
|
|
|
|
_NF4Linear = vars(module)["_NF4Linear"]
|
|
quantize_mock = lambda self, p, w, d: p
|
|
_NF4Linear.quantize = quantize_mock
|
|
|
|
with precision.module_init_context():
|
|
assert torch.get_default_dtype() == torch.float16
|
|
model = MyModule()
|
|
assert isinstance(model.l1, _NF4Linear)
|
|
assert isinstance(model.l2.l, _NF4Linear)
|
|
model = precision.convert_module(model)
|
|
assert model.l1.compute_dtype is precision.dtype
|
|
assert model.l2.l.compute_dtype is precision.dtype
|
|
|
|
model = MyModule()
|
|
precision.convert_module(model)
|
|
assert isinstance(model.l1, _NF4Linear)
|
|
assert isinstance(model.l2.l, _NF4Linear)
|
|
|
|
precision.ignore_modules = {"l2"}
|
|
model = MyModule()
|
|
precision.convert_module(model)
|
|
assert isinstance(model.l1, _NF4Linear)
|
|
assert isinstance(model.l2.l, torch.nn.Linear)
|
|
|
|
model = torch.nn.Conv1d(1, 1, 1)
|
|
with pytest.raises(TypeError, match="your model has no Linear"):
|
|
precision.convert_module(model)
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1)
|
|
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
|
|
@pytest.mark.parametrize(
|
|
("args", "expected"),
|
|
[
|
|
(("int8", torch.float16), torch.int8),
|
|
(("nf4", torch.bfloat16), torch.uint8),
|
|
],
|
|
)
|
|
def test_bitsandbytes_layers(args, expected):
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l = torch.nn.Linear(2, 2)
|
|
self.ln = torch.nn.LayerNorm(2)
|
|
|
|
state_dict = MyModel().state_dict()
|
|
fabric = Fabric(devices=1, plugins=BitsandbytesPrecision(*args))
|
|
with fabric.init_module():
|
|
model = MyModel()
|
|
|
|
# the model was instantiated on-device and quantized straight away
|
|
assert model.l.weight.device.type == "cuda"
|
|
assert model.l.weight.dtype == expected
|
|
# this has no impact
|
|
model = fabric.setup(model)
|
|
assert model.l.weight.device.type == "cuda"
|
|
assert model.l.weight.dtype == expected
|
|
# unquantized state dict loading still works even thought the weights are quantized
|
|
weight_before = model.l.weight.data.clone()
|
|
keys = model.load_state_dict(state_dict, strict=True)
|
|
assert not keys.missing_keys
|
|
assert not torch.equal(weight_before, model.l.weight.data)
|
|
assert model.l.weight.device.type == "cuda"
|
|
assert model.l.weight.dtype == expected
|
|
# quantized state dict can be loaded into a quantized model
|
|
quantized_state_dict = model.state_dict()
|
|
keys = model.load_state_dict(quantized_state_dict, strict=True)
|
|
assert not keys.missing_keys
|
|
# TODO: support unquantizing the state_dict so that it can be loaded into the original model
|
|
|
|
fabric = Fabric(devices=1, plugins=BitsandbytesPrecision(*args, ignore_modules={"foo"}))
|
|
with pytest.raises(RuntimeError, match="not supported"), fabric.init_module():
|
|
pass
|
|
model = MyModel()
|
|
|
|
# When ignore_modules is set, we only quantize on `setup`
|
|
assert model.l.weight.device.type == "cpu"
|
|
assert model.l.weight.dtype == torch.float32
|
|
# this quantizes now
|
|
model = fabric.setup(model)
|
|
assert model.l.weight.device.type == "cuda"
|
|
assert model.l.weight.dtype == expected
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1, min_torch="2.1")
|
|
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
|
|
@pytest.mark.parametrize(
|
|
("args", "expected"),
|
|
[
|
|
pytest.param(("int8", torch.float16), torch.int8, marks=pytest.mark.xfail(raises=NotImplementedError)),
|
|
pytest.param(("nf4", torch.bfloat16), torch.uint8, marks=RunIf(bf16_cuda=True)),
|
|
],
|
|
)
|
|
def test_bitsandbytes_layers_meta_device(args, expected, tmp_path):
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.l = torch.nn.Linear(2, 2)
|
|
self.ln = torch.nn.LayerNorm(2, bias=False)
|
|
|
|
state_dict = MyModel().state_dict()
|
|
plugin = BitsandbytesPrecision(*args)
|
|
fabric = Fabric(plugins=plugin, devices=1)
|
|
|
|
# case 1
|
|
# empty_init=True with devices=1 doesn't use meta device at the moment so set it explicitly
|
|
with fabric.init_module(empty_init=False), torch.device("meta"):
|
|
model = MyModel()
|
|
# the model was instantiated on meta and is not quantized
|
|
assert model.l.weight.device.type == "meta"
|
|
assert model.l.weight.dtype == args[1]
|
|
# materializing performs quantization
|
|
_materialize_meta_tensors(model, "cuda")
|
|
assert model.l.weight.device.type == "cuda"
|
|
assert model.l.weight.dtype == expected
|
|
# state dict loading still works even thought the weights are quantized
|
|
weight_before = model.l.weight.data.clone()
|
|
keys = model.load_state_dict(state_dict, strict=True)
|
|
assert not keys.missing_keys
|
|
assert not torch.equal(weight_before, model.l.weight.data)
|
|
assert model.l.weight.device.type == "cuda"
|
|
assert model.l.weight.dtype == expected
|
|
|
|
# case 2
|
|
with fabric.init_module(empty_init=False), torch.device("meta"):
|
|
model = MyModel()
|
|
assert model.l.weight.device.type == "meta"
|
|
assert model.l.weight.dtype == args[1]
|
|
# the model layers are already replaced, this won't do anything relevant
|
|
model = fabric.setup(model, move_to_device=False)
|
|
assert model.l.weight.device.type == "meta"
|
|
assert model.l.weight.dtype == args[1]
|
|
keys = model.load_state_dict(state_dict, strict=True) # quantizes
|
|
assert not keys.missing_keys
|
|
assert model.l.weight.device.type == "cuda"
|
|
assert model.l.weight.dtype == expected
|
|
|
|
# case 2 with an incomplete state_dict
|
|
with fabric.init_module(empty_init=False), torch.device("meta"):
|
|
model = MyModel()
|
|
assert model.l.weight.device.type == "meta"
|
|
assert model.l.weight.dtype == args[1]
|
|
partial_state_dict = {k: v for k, v in state_dict.items() if "ln" not in k}
|
|
keys = model.load_state_dict(partial_state_dict, strict=False) # quantizes
|
|
assert keys.missing_keys == ["ln.weight"]
|
|
assert model.l.weight.device.type == "cuda"
|
|
assert model.l.weight.dtype == expected
|
|
assert model.ln.weight.device.type == "meta"
|
|
assert model.ln.weight.dtype == args[1]
|
|
# now we need to materialize just for LayerNorm
|
|
_materialize_meta_tensors(model, fabric.device)
|
|
assert model.l.weight.device.type == "cuda"
|
|
assert model.l.weight.dtype == expected
|
|
assert model.ln.weight.device.type == "cuda"
|
|
assert model.ln.weight.dtype == args[1]
|
|
|
|
# test mmap and assign on a meta bnb layer
|
|
with fabric.init_module(empty_init=False), torch.device("meta"):
|
|
model = MyModel()
|
|
ckpt_path = tmp_path / "foo.ckpt"
|
|
torch.save(state_dict, ckpt_path)
|
|
torch.load(str(ckpt_path), mmap=True)
|
|
keys = model.load_state_dict(state_dict, strict=True, assign=True) # quantizes
|
|
assert not keys.missing_keys
|
|
assert model.l.weight.device.type == "cuda"
|
|
assert model.l.weight.dtype == expected
|