Skip to content

Commit 4084e7f

Browse files
authored
Passes long and short factors for phi3+ models using longrope (#3375)
In the canonical HF implementation of Phi3+ models, the longrope embedding leverages both the long and short factors depending on sequence length. This can be seen here: https://github.com/huggingface/transformers/blob/7b325cd573e40bbb12951b8446176c96e8b1afaa/src/transformers/modeling_rope_utils.py#L521 To achieve this in MLC, we need to pass both the long and short factors to the KV Cache creation. The TVM side of this patch is here: apache/tvm#18422
1 parent 7b15b19 commit 4084e7f

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

python/mlc_llm/model/phi3/phi3_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ def __init__(self, config: Phi3Config) -> None:
238238
self.rope_scaling = config.rope_scaling
239239
self.rope_theta = config.position_embedding_base
240240
self.rope_ext_factors = (
241-
config.rope_scaling["long_factor"] if config.rope_scaling is not None else None
241+
(config.rope_scaling["long_factor"] + config.rope_scaling["short_factor"])
242+
if config.rope_scaling is not None
243+
else None
242244
)
243245
self.tensor_parallel_shards = config.tensor_parallel_shards
244246
self.partial_rotary_factor = config.partial_rotary_factor

python/mlc_llm/model/phi3v/phi3v_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def __init__(self, config: Phi3VConfig) -> None:
143143
self.rope_scaling = config.rope_scaling
144144
self.rope_theta = config.position_embedding_base
145145
self.rope_ext_factors = (
146-
config.rope_scaling["long_factor"] if config.rope_scaling is not None else None
146+
(config.rope_scaling["long_factor"] + config.rope_scaling["short_factor"])
147+
if config.rope_scaling is not None
148+
else None
147149
)
148150
self.tensor_parallel_shards = config.tensor_parallel_shards
149151
self.dtype = "float32"

0 commit comments

Comments
 (0)