Skip to content

Commit 9c02a92

Browse files
yiliu30XuehaoSun
andauthored
Add LLMC integration test (#1053)
Signed-off-by: yiliu30 <[email protected]> Co-authored-by: Sun, Xuehao <[email protected]>
1 parent e96756c commit 9c02a92

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

auto_round/compressors/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2601,7 +2601,7 @@ def _get_loss(
26012601
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
26022602
tmp_attention_mask.unsqueeze_(-1)
26032603
if self.amp:
2604-
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2604+
with autocast(device_type=str(device).split(":")[0], dtype=self.amp_dtype):
26052605
loss = mse_loss( # pylint: disable=not-callable
26062606
(output_q * tmp_attention_mask).to(torch.float32),
26072607
(current_output * tmp_attention_mask).to(torch.float32),
@@ -2614,7 +2614,7 @@ def _get_loss(
26142614

26152615
else:
26162616
if self.amp:
2617-
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2617+
with autocast(device_type=str(device).split(":")[0], dtype=self.amp_dtype):
26182618
loss = mse_loss( # pylint: disable=not-callable
26192619
output_q.to(torch.float32), current_output.to(torch.float32)
26202620
)

test/test_cpu/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ addict
22
modelscope
33
gguf
44
torchvision
5-
compressed-tensors
65
parameterized
76
numba
8-
tbb
7+
#TODO: (yiliu30) replace it with the release version
8+
llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git@7b28d78
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import pytest
2+
import torch
3+
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.autoround import AutoRoundModifier
6+
from transformers import AutoModelForCausalLM, AutoTokenizer
7+
8+
from auto_round.calib_dataset import get_dataset
9+
10+
recipe_str = """
11+
quant_stage:
12+
quant_modifiers:
13+
AutoRoundModifier:
14+
ignore: ["lm_head"]
15+
iters: 1
16+
config_groups:
17+
group_0:
18+
targets:
19+
- "Linear"
20+
input_activations: null
21+
output_activations: null
22+
weights:
23+
num_bits: 4
24+
type: "int"
25+
symmetric: true
26+
strategy: group
27+
group_size: 128
28+
"""
29+
30+
recipe_modifier_full = AutoRoundModifier(
31+
ignore=["lm_head"],
32+
iters=1,
33+
config_groups={
34+
"group_0": QuantizationScheme(
35+
targets=["Linear"],
36+
weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128),
37+
)
38+
},
39+
)
40+
41+
42+
@pytest.mark.parametrize(
43+
"recipe",
44+
[
45+
recipe_str,
46+
recipe_modifier_full,
47+
],
48+
)
49+
def test_oneshot_application(recipe, tmp_path):
50+
output = tmp_path / "oneshot_output"
51+
model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
52+
tokenizer = AutoTokenizer.from_pretrained(model)
53+
dataset = get_dataset(
54+
tokenizer=tokenizer,
55+
seqlen=16,
56+
nsamples=2,
57+
)
58+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
59+
60+
oneshot(
61+
model=model,
62+
dataset=dataset,
63+
output_dir=output,
64+
recipe=recipe,
65+
)
66+
model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device)
67+
68+
# Check that the model is quantized
69+
# for compression_config - decompress() will attach a quantization_config
70+
# to the model as we decompress right away
71+
# for quantization_config - we have CompressedLinear which will only
72+
# decompress on the forward pass and does not call decompress(). Results
73+
# in a slightly different parameter tree to access the quant config
74+
quantization_config = model_loaded.config.quantization_config.quantization_config
75+
assert quantization_config is not None
76+
77+
# check config is set properly
78+
assert "lm_head" in quantization_config.ignore
79+
assert len(quantization_config.config_groups) == 1
80+
quant_scheme = quantization_config.config_groups["group_0"]
81+
assert isinstance(quant_scheme, QuantizationScheme)
82+
83+
weight_args = quantization_config.config_groups["group_0"].weights
84+
assert isinstance(weight_args, QuantizationArgs)
85+
assert weight_args.num_bits == 4
86+
87+
# Check a specific layer is quantized
88+
targeted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj
89+
assert hasattr(targeted_linear_layer, "quantization_scheme")
90+
91+
# Check lm-head is not quantized
92+
not_targeted = model_loaded.lm_head
93+
assert not hasattr(not_targeted, "quantization_scheme")

0 commit comments

Comments
 (0)