Skip to content
5 changes: 5 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,16 @@ def ts_patched_forward(*args, **kwargs):
__make_16bit_traceable(model)
check_dummy_inputs_are_allowed(model, dummy_inputs)
input_info = _get_input_info(model, config, dummy_inputs)
conversion_extensions = getattr(patcher, "conversion_extensions", [])
module_extensions = getattr(patcher, "module_extensions", None)
if module_extensions is not None:
ts_decoder_kwargs["module_extensions"] = module_extensions
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
ov_model = convert_model(
ts_decoder,
example_input=dummy_inputs,
input=[(item.shape, item.type) for item in input_info],
extension=conversion_extensions,
)

ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation?
Expand Down
141 changes: 141 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
SanaTextEncoderModelPatcher,
XverseModelPatcher,
Zamba2ModelPatcher,
Qwen3NextModelPatcher,
)


Expand Down Expand Up @@ -4383,3 +4384,143 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
return common_inputs


class Qwen3NextDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
"""
Generates dummy cache_params inputs for Zamba2 architectures.
"""

SUPPORTED_INPUT_NAMES = ("cache_params",)

def __init__(
self,
task: str,
normalized_config,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
**kwargs,
)

config = normalized_config.config
self.num_full_attn_layers = config.layer_types.count("full_attention")
self.num_linear_attn_layers = config.layer_types.count("linear_attention")
self.conv_kernel_size = config.linear_conv_kernel_dim
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.num_key_value_heads = config.num_key_value_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
cache_params = []

for idx in range(self.num_linear_attn_layers):
# (batch_size, d_inner, d_conv)
d_inner = self.num_k_heads * (2 * self.head_k_dim + self.head_v_dim * self.num_v_heads // self.num_k_heads)
conv_state_shape = (
self.batch_size,
d_inner,
self.conv_kernel_size,
)
conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype)
cache_params.append(conv_state)
#num_heads = self.num_key_value_heads * (self.num_v_heads // self.num_k_heads)
num_heads = self.num_v_heads
recurrent_state_shape = (self.batch_size, num_heads, self.head_k_dim, self.head_v_dim)
recurrent_state = self.random_float_tensor(recurrent_state_shape, framework=framework, dtype=float_dtype)
cache_params.append(recurrent_state)

for idx in range(self.num_full_attn_layers):
kv_shape = (self.batch_size, self.num_key_value_heads, self.sequence_length, self.head_dim)
k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
cache_params.append(k)
cache_params.append(v)

return cache_params


@register_in_tasks_manager(
"qwen3_next",
*["text-generation", "text-generation-with-past"],
library_name="transformers",
)
class Qwen3NextOpenVINOConfig(Qwen3OpenVINOConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Qwen3NextDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = Qwen3NextDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
MIN_TRANSFORMERS_VERSION = "4.57.0"
_MODEL_PATCHER = Qwen3NextModelPatcher

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
cache_name_prefix = "cache_params.past"
else:
decoder_sequence_name = "past_sequence_length + sequence_length"
cache_name_prefix = "cache_params.present"

self.num_full_attn_layers = self._normalized_config.layer_types.count("full_attention")
self.num_linear_attn_layers = self._normalized_config.layer_types.count("linear_attention")

for i in range(self.num_linear_attn_layers):
# [batch_size, conv_kernel_size - 1, d_model]
inputs_or_outputs[f"{cache_name_prefix}.conv.{i}"] = {0: "batch_size"}
# [batch_size, d_state, d_model]
inputs_or_outputs[f"{cache_name_prefix}.ssm.{i}"] = {0: "batch_size"}

for i in range(self.num_full_attn_layers):
inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name}

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
return common_inputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
# need to override `generate_dummy_inputs` since mamba model has other states: ssm_states and conv_states
# which we separate and call them as past_ssm_states and past_conv_states
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("cache_params")]
if self.use_past_in_inputs:
input_names.extend(["cache_params"])

for input_name in input_names:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
dummy_input_gen,
input_name,
framework,
input_shapes=kwargs,
)
input_was_inserted = True
break
if not input_was_inserted:
raise RuntimeError(
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
)

return dummy_inputs
Loading