Skip to content

Commit e96756c

Browse files
authored
export kv scheme (#1068)
Signed-off-by: yiliu30 <[email protected]>
1 parent 2ef7f31 commit e96756c

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

auto_round/compressors/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3096,7 +3096,6 @@ def save_quantized(
30963096
serialization_dict["autoround_version"] = __version__
30973097
if "scale_dtype" in serialization_dict.keys():
30983098
serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"])
3099-
31003099
compressed_model = save_quantized_as_format( # TODO refine the code
31013100
output_dir,
31023101
model=self.model,
@@ -3121,6 +3120,8 @@ def save_quantized(
31213120
to_quant_block_names=self.to_quant_block_names,
31223121
quant_block_list=self.quant_block_list,
31233122
device=self.device,
3123+
static_kv_dtype=self.static_kv_dtype,
3124+
static_attention_dtype=self.static_attention_dtype,
31243125
**kwargs,
31253126
)
31263127
return compressed_model

auto_round/export/export_to_llmcompressor/export_to_static_fp.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,39 @@ def pack_layer(layer_name: str, model: torch.nn.Module, data_type: str, device:
111111
set_module(model, layer_name, my_linear)
112112

113113

114+
def _construct_kv_scheme():
115+
"""Construct the default KV cache quantization scheme for FP8_STATIC export."""
116+
from compressed_tensors.quantization import ( # pylint: disable=E0401
117+
QuantizationArgs,
118+
QuantizationStrategy,
119+
QuantizationType,
120+
)
121+
122+
default_kv_scheme = QuantizationArgs(
123+
num_bits=8,
124+
type=QuantizationType.FLOAT,
125+
strategy=QuantizationStrategy.TENSOR,
126+
symmetric=True,
127+
dynamic=False,
128+
)
129+
130+
logger.warning_once(
131+
"Using default KV cache scheme: %s. "
132+
"Currently, only this KV cache scheme is supported for FP8_STATIC + FP8 KV.",
133+
repr(default_kv_scheme),
134+
)
135+
136+
return default_kv_scheme
137+
138+
139+
def _use_fp8_kv(static_kv_dtype: str | None) -> bool:
140+
"""Return True if static KV cache should use FP8."""
141+
if static_kv_dtype in ("fp8", "float8_e4m3fn"):
142+
logger.warning_once("Exporting model with static KV cache in FP8 dtype.")
143+
return True
144+
return False
145+
146+
114147
def save_quantized_as_static_fp(output_dir: str, inplace: bool = True, **kwargs) -> torch.nn.Module:
115148
"""
116149
Saves a quantized model of FP8_STATIC scheme in the llm-compressor format.
@@ -211,7 +244,7 @@ def wrapper(name):
211244
config_groups["group_0"] = scheme
212245
quantization_config = QuantizationConfig(
213246
config_groups=config_groups,
214-
kv_cache_scheme=None,
247+
kv_cache_scheme=_construct_kv_scheme() if _use_fp8_kv(kwargs.get("static_kv_dtype", None)) else None,
215248
quantization_status=QuantizationStatus.COMPRESSED,
216249
ignore=ignore,
217250
)

0 commit comments

Comments
 (0)