Skip to content

Commit b1b60e4

Browse files
authored
Support transformers loading quantized moe model (#1067)
1 parent 9c02a92 commit b1b60e4

File tree

6 files changed

+227
-28
lines changed

6 files changed

+227
-28
lines changed

auto_round/inference/convert_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from auto_round.inference.utils import _expand_regex_config
3232
from auto_round.logger import logger
3333
from auto_round.schemes import QuantizationScheme
34+
from auto_round.special_model_handler import _handle_moe_model
3435
from auto_round.utils import (
3536
SUPPORTED_LAYER_TYPES,
3637
check_start_with_block_name,
@@ -582,6 +583,9 @@ def convert_hf_model(model: nn.Module, target_device: str = "cpu") -> tuple[nn.M
582583
elif packing_format == "auto_round:gptq":
583584
packing_format = "auto_round:auto_gptq"
584585

586+
# Preprocess model before replace layers
587+
model = _handle_moe_model(model)
588+
585589
# Replace layers with quantized versions
586590
layer_configs = get_layer_config(model, quantization_config)
587591
used_backends = _replace_by_quant_layers(model, layer_configs, backend, target_device, packing_format)

auto_round/modelling/gpt_oss.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
2020
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP
2121

22+
from auto_round.utils import unsupported_meta_device
23+
2224
__all__ = ["get_replacement_info"]
2325

2426

@@ -82,19 +84,15 @@ def __init__(self, config: GptOssConfig, original: GptOssMLP):
8284
for _ in range(E):
8385
self.experts.append(GPTOssSingleExpert(hidden_size, intermediate_size, dtype=dtype))
8486

85-
gup = original.experts.gate_up_proj # [E, H, 2I]
86-
gup_b = original.experts.gate_up_proj_bias # [E, 2I]
87-
dwn = original.experts.down_proj # [E, I, H]
88-
dwn_b = original.experts.down_proj_bias # [E, H]
89-
90-
for i, mlp in enumerate(self.experts):
91-
_update_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T)
92-
_update_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T)
93-
_update_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T)
87+
if not unsupported_meta_device(original):
88+
for i, mlp in enumerate(self.experts):
89+
_update_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T)
90+
_update_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T)
91+
_update_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T)
9492

95-
_update_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2])
96-
_update_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2])
97-
_update_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H]
93+
_update_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2])
94+
_update_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2])
95+
_update_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H]
9896

9997
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
10098
B, T, H = hidden_states.shape

auto_round/modelling/llama4.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,27 @@
2020
from transformers.modeling_utils import no_init_weights
2121
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
2222

23+
from auto_round.utils import unsupported_meta_device
24+
2325

2426
class SequentialLlama4TextExperts(torch.nn.ModuleList):
2527
def __init__(self, config, original):
2628
self.num_experts = original.gate_up_proj.shape[0]
2729
with no_init_weights():
2830
super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
29-
intermediate_size = original.down_proj.shape[1]
3031

31-
for i in range(self.num_experts):
32-
gate_up = original.gate_up_proj[i]
33-
down = original.down_proj[i]
34-
gate_proj = gate_up[:, :intermediate_size]
35-
up_proj = gate_up[:, intermediate_size:]
36-
37-
self[i].gate_proj.weight.data = gate_proj.t().contiguous()
38-
self[i].up_proj.weight.data = up_proj.t().contiguous()
39-
self[i].down_proj.weight.data = down.t().contiguous()
32+
if not unsupported_meta_device(original):
33+
intermediate_size = original.down_proj.shape[1]
34+
35+
for i in range(self.num_experts):
36+
gate_up = original.gate_up_proj[i]
37+
down = original.down_proj[i]
38+
gate_proj = gate_up[:, :intermediate_size]
39+
up_proj = gate_up[:, intermediate_size:]
40+
41+
self[i].gate_proj.weight.data.copy_(gate_proj.t())
42+
self[i].up_proj.weight.data.copy_(up_proj.t())
43+
self[i].down_proj.weight.data.copy_(down.t())
4044

4145

4246
class SequentialLlama4TextMoe(torch.nn.Module):

auto_round/special_model_handler.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import auto_round.modelling as auto_round_modelling
15-
from auto_round.utils import LazyImport, logger
15+
from auto_round.utils import LazyImport, logger, unsupported_meta_device
1616

1717
mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size
1818

@@ -76,8 +76,9 @@ def _handle_moe_model(model, formats=None):
7676
from auto_round.utils import clear_memory
7777

7878
new_moe_class, convert_config, orig_cls_name = _get_moe_converter(model.config)
79-
model = model.to("cpu")
80-
clear_memory()
79+
if not unsupported_meta_device(model):
80+
model = model.to("cpu")
81+
clear_memory()
8182

8283
for name, module in tqdm(model.named_modules(), desc="Converting model"):
8384
cls_name = module.__class__.__name__
@@ -87,9 +88,6 @@ def _handle_moe_model(model, formats=None):
8788
parent = model.get_submodule(parent)
8889
setattr(parent, child, new_module)
8990

90-
logger.warning(
91-
f"{model.config.model_type} experts are converted, the quantized model can not run on transformers."
92-
)
9391
return model
9492

9593

test/test_cpu/test_moe_model.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import shutil
2+
3+
import pytest
4+
from transformers import AutoConfig, AutoTokenizer, Llama4ForConditionalGeneration
5+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM
6+
7+
from auto_round import AutoRound
8+
9+
10+
@pytest.fixture
11+
def setup_gpt_oss():
12+
"""Fixture to set up the GPT-OSS model and tokenizer."""
13+
model_name = "/tf_dataset/auto_round/models/unsloth/gpt-oss-20b-BF16"
14+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
16+
config.num_hidden_layers = 1 # Reduce layers for testing
17+
model = GptOssForCausalLM(config)
18+
output_dir = "/tmp/test_quantized_gpt_oss"
19+
return model, tokenizer, output_dir, config
20+
21+
22+
@pytest.fixture
23+
def setup_llama4():
24+
"""Fixture to set up the llama4 model and tokenizer."""
25+
model_name = "/tf_dataset/auto_round/models/meta-llama/Llama-4-Scout-17B-16E-Instruct"
26+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
27+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
28+
config.vision_config.num_hidden_layers = 2 # Reduce layers for testing
29+
config.text_config.num_hidden_layers = 2
30+
model = Llama4ForConditionalGeneration(config)
31+
output_dir = "/tmp/test_quantized_llama4"
32+
return model, tokenizer, output_dir, config
33+
34+
35+
def quantize_model(model, tokenizer, output_dir, scheme, iters=0):
36+
"""Helper function to quantize the model with the given scheme."""
37+
autoround = AutoRound(
38+
model,
39+
tokenizer,
40+
scheme=scheme,
41+
nsamples=2,
42+
iters=iters,
43+
fp_layers="self_attn,router,lm_head,mlp.gate",
44+
)
45+
quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir)
46+
return quantized_model
47+
48+
49+
def test_gptoss(setup_gpt_oss):
50+
model, tokenizer, output_dir, config = setup_gpt_oss
51+
52+
# Below parameter is set to be same as the full model
53+
# Remove it to avoid mismatch during quantized model loading
54+
delattr(model.config, "layer_types")
55+
56+
quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4")
57+
58+
# Ensure the quantized model is not None
59+
assert quantized_model is not None, "Quantized model should not be None."
60+
61+
loaded_model = GptOssForCausalLM.from_pretrained(output_dir)
62+
for n, m in quantized_model.named_modules():
63+
if m.__class__.__name__ == "QuantLinear":
64+
loaded_m = loaded_model.get_submodule(n)
65+
assert (loaded_m.weight_packed.to("cpu") == m.weight_packed.to("cpu")).all()
66+
# clean the output directory after test
67+
shutil.rmtree(output_dir, ignore_errors=True)
68+
69+
70+
def test_llama4(setup_llama4):
71+
model, tokenizer, output_dir, config = setup_llama4
72+
73+
# Below parameters are set to be same as the full model
74+
# Remove them to avoid mismatch during quantized model loading
75+
model.config.text_config.no_rope_layers = []
76+
delattr(model.config.text_config, "moe_layers")
77+
delattr(model.config.text_config, "layer_types")
78+
79+
quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4")
80+
81+
# Ensure the quantized model is not None
82+
assert quantized_model is not None, "Quantized model should not be None."
83+
84+
loaded_model = Llama4ForConditionalGeneration.from_pretrained(output_dir)
85+
for n, m in quantized_model.named_modules():
86+
if m.__class__.__name__ == "QuantLinear":
87+
loaded_m = loaded_model.get_submodule(n)
88+
assert (loaded_m.weight_packed.to("cpu") == m.weight_packed.to("cpu")).all()
89+
# clean the output directory after test
90+
shutil.rmtree(output_dir, ignore_errors=True)

test/test_cuda/test_moe_model.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import shutil
2+
3+
import pytest
4+
import torch
5+
from transformers import AutoConfig, AutoTokenizer, Llama4ForConditionalGeneration
6+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM
7+
8+
from auto_round import AutoRound
9+
10+
11+
@pytest.fixture
12+
def setup_gpt_oss():
13+
"""Fixture to set up the GPT-OSS model and tokenizer."""
14+
model_name = "/models/gpt-oss-20b-BF16"
15+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
16+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
17+
config.num_hidden_layers = 1 # Reduce layers for testing
18+
model = GptOssForCausalLM(config)
19+
output_dir = "test_quantized_gpt_oss"
20+
return model, tokenizer, output_dir, config
21+
22+
23+
@pytest.fixture
24+
def setup_llama4():
25+
"""Fixture to set up the llama4 model and tokenizer."""
26+
model_name = "/dataset/Llama-4-Scout-17B-16E-Instruct"
27+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
28+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
29+
config.vision_config.num_hidden_layers = 2 # Reduce layers for testing
30+
config.text_config.num_hidden_layers = 2
31+
model = Llama4ForConditionalGeneration(config)
32+
output_dir = "test_quantized_llama4"
33+
return model, tokenizer, output_dir, config
34+
35+
36+
def quantize_model(model, tokenizer, output_dir, scheme, iters=0):
37+
"""Helper function to quantize the model with the given scheme."""
38+
autoround = AutoRound(
39+
model,
40+
tokenizer,
41+
scheme=scheme,
42+
nsamples=2,
43+
iters=iters,
44+
fp_layers="self_attn,router,lm_head,mlp.gate",
45+
)
46+
quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir)
47+
return quantized_model
48+
49+
50+
def test_gptoss(setup_gpt_oss):
51+
model, tokenizer, output_dir, config = setup_gpt_oss
52+
53+
# Below parameter is set to be same as the full model
54+
# Remove it to avoid mismatch during quantized model loading
55+
delattr(model.config, "layer_types")
56+
57+
quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4")
58+
59+
# Ensure the quantized model is not None
60+
assert quantized_model is not None, "Quantized model should not be None."
61+
62+
loaded_model = GptOssForCausalLM.from_pretrained(output_dir)
63+
quantized_model.to("cuda")
64+
loaded_model.to("cuda")
65+
for n, m in quantized_model.named_modules():
66+
if m.__class__.__name__ == "QuantLinear":
67+
loaded_m = loaded_model.get_submodule(n)
68+
assert (loaded_m.weight_packed == m.weight_packed).all()
69+
70+
inp = torch.randint(0, 100, (1, 64)).to("cuda")
71+
with torch.inference_mode():
72+
loaded_out = loaded_model(inp)
73+
74+
# clean the output directory after test
75+
shutil.rmtree(output_dir, ignore_errors=True)
76+
77+
78+
def test_llama4(setup_llama4):
79+
model, tokenizer, output_dir, config = setup_llama4
80+
81+
# Below parameters are set to be same as the full model
82+
# Remove them to avoid mismatch during quantized model loading
83+
model.config.text_config.no_rope_layers = []
84+
delattr(model.config.text_config, "moe_layers")
85+
delattr(model.config.text_config, "layer_types")
86+
87+
quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4")
88+
89+
# Ensure the quantized model is not None
90+
assert quantized_model is not None, "Quantized model should not be None."
91+
92+
loaded_model = Llama4ForConditionalGeneration.from_pretrained(output_dir)
93+
quantized_model.to("cuda")
94+
loaded_model.to("cuda")
95+
for n, m in quantized_model.named_modules():
96+
if m.__class__.__name__ == "QuantLinear":
97+
loaded_m = loaded_model.get_submodule(n)
98+
assert (loaded_m.weight_packed == m.weight_packed).all()
99+
100+
inp = torch.randint(0, 100, (1, 64)).to("cuda")
101+
with torch.inference_mode():
102+
loaded_out = loaded_model(inp)
103+
104+
# clean the output directory after test
105+
shutil.rmtree(output_dir, ignore_errors=True)

0 commit comments

Comments
 (0)