-
Notifications
You must be signed in to change notification settings - Fork 163
Fix Phi long context issue #1504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4440904
72cd3c8
8aa5978
19feb0b
822664a
4426e18
9f0394a
c8adca6
75af74c
8185565
650a7ab
1d70db3
b018402
77ca45f
6b74fb8
6fafcce
231a9a6
fa088fe
0041491
9327d0d
4049dcf
ae45e07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1493,41 +1493,69 @@ def _phi3_self_attn_sdpa_forward( | |
| return attn_output, None, past_key_value | ||
|
|
||
|
|
||
| class Phi3ModelPatcher(OVDecoderModelPatcher): | ||
| def __enter__(self): | ||
| super().__enter__() | ||
| # Adapted from https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/models/phi3/modeling_phi3.py#L324 | ||
| # and https://github.com/huggingface/transformers/blob/v4.57.1/src/transformers/modeling_rope_utils.py#L30 | ||
| def _phi3_longrope_forward(self, x, position_ids): | ||
| """ | ||
| LongRoPE uses different scaling factors (short_factor vs long_factor) depending on whether | ||
| the sequence length exceeds the original max position embeddings used during pretraining. | ||
|
|
||
| Note: In transformers, the @dynamic_rope_update decorator replaces self.inv_freq before the forward pass. | ||
| Here we use torch.where to select between original_inv_freq and long_inv_freq and add the selection logic into the model graph. | ||
| """ | ||
| seq_len = torch.max(position_ids) + 1 | ||
| original_max_position_embeddings = ( | ||
| self.original_max_position_embeddings | ||
| if hasattr(self, "original_max_positional_embeddings") | ||
| else self.config.original_max_position_embeddings | ||
| ) | ||
|
|
||
| # currently, long RoPE can not be traced for long context support, disable it for avoid potential accuracy issues | ||
| if self._model.config.max_position_embeddings != getattr( | ||
| self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings | ||
| ): | ||
| self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings | ||
| # Slow down all frequencies by scale factor for long prompts that makes attention more stable, i.e. preserve model accuracy | ||
| # Use long_inv_freq for sequences exceeding original pretraining length, otherwise use short (default) inv_freq. | ||
| inv_freq = torch.where( | ||
| seq_len > original_max_position_embeddings, | ||
| self.long_inv_freq, | ||
| self.original_inv_freq, | ||
| ) | ||
|
|
||
| if is_transformers_version("<", "4.48.0"): | ||
| self._model.model._orig_forward = self._model.model.forward | ||
| self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model) | ||
| inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | ||
| position_ids_expanded = position_ids[:, None, :].float() | ||
|
|
||
| # Force float32 since bfloat16 loses precision on long contexts | ||
| # See https://github.com/huggingface/transformers/pull/29285 | ||
| device_type = x.device.type | ||
| device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | ||
|
Comment on lines
+1524
to
+1527
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not used and should we ensure fp32 dtype also ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Added the autocast line with enabled=False. |
||
| with torch.autocast(device_type=device_type, enabled=False): # Force float32 | ||
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | ||
| emb = torch.cat((freqs, freqs), dim=-1) | ||
| cos = emb.cos() * self.attention_scaling | ||
| sin = emb.sin() * self.attention_scaling | ||
| return cos, sin | ||
|
|
||
| # https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113 | ||
| # init inv_freq for torchscript tracing | ||
| # 4.48 transformers version phi3 fixed, but issue still visible with trust_remote_true=True (trust_remote_code has _support_sdpa = False) | ||
| for layer in self._model.model.layers: | ||
| if ( | ||
| is_torch_version(">=", "2.1.0") | ||
| and is_transformers_version("<", "4.48.0") | ||
| or not getattr(self._model, "_supports_sdpa", False) | ||
| ): | ||
| orig_self_attn_fwd = layer.self_attn.forward | ||
| layer.self_attn.forward = types.MethodType(_phi3_self_attn_sdpa_forward, layer.self_attn) | ||
| layer.self_attn._orig_forward = orig_self_attn_fwd | ||
|
|
||
| class Phi3ModelPatcher(OVDecoderModelPatcher): | ||
| def __enter__(self): | ||
| super().__enter__() | ||
| if not getattr(self, "_disable_longrope", False): | ||
| if ( | ||
| hasattr(layer.self_attn, "rotary_emb") | ||
| and getattr(layer.self_attn.rotary_emb, "inv_freq", None) is None | ||
| hasattr(self._model.model, "rotary_emb") | ||
| and getattr(self._model.model.rotary_emb, "rope_type", "default") == "longrope" | ||
| ): | ||
| rotary_emb = layer.self_attn.rotary_emb | ||
| layer.self_attn.rotary_emb.inv_freq = 1.0 / ( | ||
| rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) | ||
| long_inv_freq, _ = self._model.model.rotary_emb.rope_init_fn( | ||
| self._model.config, | ||
| torch.device("cpu"), | ||
| seq_len=self._model.config.original_max_position_embeddings + 1, | ||
| ) | ||
| self._model.model.rotary_emb.long_inv_freq = long_inv_freq | ||
| self._model.model.rotary_emb._orig_forward = self._model.model.rotary_emb.forward | ||
| self._model.model.rotary_emb.forward = types.MethodType( | ||
| _phi3_longrope_forward, self._model.model.rotary_emb | ||
| ) | ||
| elif self._model.config.max_position_embeddings != getattr( | ||
| self._model.config, "original_max_position_embeddings", self._model.config.max_position_embeddings | ||
| ): | ||
| self._orig_max_position_embeddings = self._model.config.max_position_embeddings | ||
| self._model.config.max_position_embeddings = self._model.config.original_max_position_embeddings | ||
|
|
||
| def __exit__(self, exc_type, exc_value, traceback): | ||
| super().__exit__(exc_type, exc_value, traceback) | ||
|
|
@@ -1538,6 +1566,8 @@ def __exit__(self, exc_type, exc_value, traceback): | |
| for layer in self._model.model.layers: | ||
| if hasattr(layer.self_attn, "_orig_forward"): | ||
| layer.self_attn.forward = layer.self_attn._orig_forward | ||
| if hasattr(self, "_orig_max_position_embeddings"): | ||
| self._model.config.max_position_embeddings = self._orig_max_position_embeddings | ||
|
|
||
|
|
||
| # Modified from https://github.com/huggingface/transformers/blob/v4.50.2/src/transformers/models/phimoe/modeling_phimoe.py#L756 | ||
|
|
@@ -1589,8 +1619,19 @@ def _phi_moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torc | |
|
|
||
|
|
||
| class PhiMoEModelPatcher(Phi3ModelPatcher): | ||
| def __init__( | ||
| self, | ||
| config: "OnnxConfig", | ||
| model: "PreTrainedModel", | ||
| model_kwargs: Optional[Dict[str, Any]] = None, | ||
| ): | ||
| super().__init__(config, model, model_kwargs) | ||
| # Disable longrope for Phi3-MOE | ||
| self._disable_longrope = True | ||
|
|
||
| def __enter__(self): | ||
| super().__enter__() | ||
|
|
||
| for layer in self._model.model.layers: | ||
| layer.block_sparse_moe._orig_forward = layer.block_sparse_moe.forward | ||
| layer.block_sparse_moe.forward = types.MethodType( | ||
|
|
@@ -1599,6 +1640,7 @@ def __enter__(self): | |
|
|
||
| def __exit__(self, exc_type, exc_value, traceback): | ||
| super().__exit__(exc_type, exc_value, traceback) | ||
|
|
||
| for layer in self._model.model.layers: | ||
| layer.block_sparse_moe.forward = layer.block_sparse_moe._orig_forward | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -355,11 +355,6 @@ def _export( | |
| variant=variant, | ||
| ) | ||
|
|
||
| if config.model_type == "phi3" and config.max_position_embeddings != getattr( | ||
| config, "original_max_position_embeddings", config.max_position_embeddings | ||
| ): | ||
| config.max_position_embeddings = config.original_max_position_embeddings | ||
|
|
||
| return cls._from_pretrained( | ||
| model_id=save_dir_path, | ||
| config=config, | ||
|
|
@@ -868,6 +863,8 @@ def _from_pretrained( | |
| init_cls = OVBloomForCausalLM | ||
| elif model_type == "gpt_bigcode": | ||
| init_cls = OVGPTBigCodeForCausalLM | ||
| elif model_type == "phi3": | ||
| init_cls = OVPhi3ForCausalLM | ||
| elif model_type in SSM_MODELS: | ||
| init_cls = OVModelWithMambaForCausalLM | ||
| else: | ||
|
|
@@ -912,6 +909,48 @@ def _from_pretrained( | |
| return causal_model | ||
|
|
||
|
|
||
| class OVPhi3ForCausalLM(OVModelForCausalLM): | ||
| # Adapted from https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/phi3/modeling_phi3.py#L493 | ||
| def prepare_inputs_for_generation( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would you mind adding a link to the original code
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| self, | ||
| input_ids, | ||
| past_key_values=None, | ||
| attention_mask=None, | ||
| inputs_embeds=None, | ||
| cache_position=None, | ||
| position_ids=None, | ||
| use_cache=True, | ||
| logits_to_keep=None, | ||
| **kwargs, | ||
| ): | ||
| # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Am I correct that we have a problem when we have short and long prompts in consecutive
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed offline: this is handled in optimum-intel by resetting the kv-cache when the number of input tokens is equal to the long rope boundary (e.g. 4096). This is done the same way in transformers code. Tested that this works as expected in chat context with https://gist.github.com/helena-intel/b55522cda91d9d61a644f153e71f0f98 . |
||
| # process | ||
|
|
||
| # When the first time input length reached long and short factor switching point, enforce re-compute cache | ||
| # The downside is slower inference at this single token position, however, this is better than wrong results | ||
| if ( | ||
| past_key_values | ||
| and self.config.rope_scaling | ||
| and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 | ||
| ): | ||
| past_length = cache_position[0] | ||
| if past_length <= self.config.original_max_position_embeddings: | ||
| past_key_values = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a link to
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the link a few lines above. The comment that was there was copied verbatim from the transformers code. I modified the second line a bit to make it clearer (transformers comment references "current failure" but it is not clear what that is). |
||
|
|
||
| model_inputs = super().prepare_inputs_for_generation( | ||
| input_ids=input_ids, | ||
| past_key_values=past_key_values, | ||
| attention_mask=attention_mask, | ||
| inputs_embeds=inputs_embeds, | ||
| cache_position=cache_position, | ||
| position_ids=position_ids, | ||
| use_cache=use_cache, | ||
| logits_to_keep=logits_to_keep, | ||
| **kwargs, | ||
| ) | ||
| return model_inputs | ||
|
|
||
|
|
||
| class OVBloomForCausalLM(OVModelForCausalLM): | ||
| # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation | ||
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,7 +80,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): | |
| "qwen2", | ||
| "qwen2_moe", | ||
| "arctic", | ||
| "phi3", | ||
| "gemma2", | ||
| "exaone", | ||
| "granite", | ||
|
|
@@ -91,6 +90,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): | |
|
|
||
| if is_transformers_version(">=", "4.49"): | ||
| SUPPORTED_SSM_ARCHITECTURES += ("zamba2",) | ||
| SUPPORTED_ARCHITECTURES += ("phi3", "phi3-longrope") | ||
|
|
||
| if is_transformers_version(">=", "4.54.0") and is_openvino_version(">=", "2025.4.0"): | ||
| SUPPORTED_SSM_ARCHITECTURES += ("lfm2",) | ||
|
|
@@ -184,6 +184,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): | |
| "pegasus": 2, | ||
| "qwen": 2, | ||
| "phi": 2, | ||
| "phi3-longrope": 4, | ||
| "internlm2": 4, | ||
| "falcon": 2, | ||
| "falcon-40b": 2, | ||
|
|
@@ -781,3 +782,54 @@ def test_load_with_different_dtype(self): | |
| torch.allclose(torch.Tensor(ov_logits), ref_logits, atol=5e-3), | ||
| f"values are not close for {dtype if dtype is not None else 'None'}, max diff = {torch.abs(ov_logits - ref_logits).max()}", | ||
| ) | ||
|
|
||
| def test_phi3_longrope_support(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need in this new test. Just add your model into |
||
| """Test LongRoPE support for Phi3 with inputs > 4096 tokens.""" | ||
| if is_transformers_version("<", "4.49"): | ||
| self.skipTest("Incompatible transformers version: Phi3 longrope requires transformers>=4.49") | ||
| set_seed(SEED) | ||
| model_id = "optimum-intel-internal-testing/tiny-random-phi3-longrope" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tiny model is based on phi-4-mini-instruct, which has only 1.0 for short factor: https://huggingface.co/microsoft/Phi-4-mini-instruct/blob/main/config.json#L85
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please don't mix up tiny models for phi-3 and phi4. For phi4, it should be separate
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I initially named this model tiny-phi-4 but was asked to rename. phi-4-mini-instruct is phi3 architecture, so it is not that this tiny model is not representative for the phi3 model type. I will add a tiny model that is based on a phi-3- model.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the tiny model and now used a base model that does not just use phi3 architecture as before but that also specifically has "phi-3-" in the model name. |
||
|
|
||
| transformers_model = AutoModelForCausalLM.from_pretrained(model_id) | ||
|
|
||
| # Doublecheck that model has LongRoPE support | ||
| original_max_pos = getattr(transformers_model.config, "original_max_position_embeddings", None) | ||
| self.assertIsNotNone( | ||
| original_max_pos, | ||
| f"Model {model_id} does not have original_max_position_embeddings attribute required for LongRoPE", | ||
| ) | ||
|
|
||
| set_seed(SEED) | ||
| ov_model = OVModelForCausalLM.from_pretrained( | ||
| model_id, export=True, ov_config=F32_CONFIG, device=OPENVINO_DEVICE | ||
| ) | ||
|
|
||
| # Test 1: input tokens exceed original_max_pos | ||
| # Creating model inputs with more than original max position embeddings and enough variation for varied output tokens | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please test two cases when input_ids length exceeds threshold and when only max_new_tokens exceeds threshold
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| tokens = torch.randint(high=transformers_model.config.vocab_size, size=(1, original_max_pos + 50)) | ||
| with torch.no_grad(): | ||
| transformers_outputs = transformers_model.generate(tokens, max_new_tokens=20) | ||
| ov_outputs = ov_model.generate(tokens, max_new_tokens=20) | ||
|
|
||
| self.assertTrue( | ||
| torch.equal(transformers_outputs, ov_outputs), | ||
| f"OpenVINO and PyTorch outputs do not match for LongRoPE test with inputs > original_max_pos.\n" | ||
| f"ov_outputs: {ov_outputs}\ntransformers_outputs: {transformers_outputs}", | ||
| ) | ||
|
|
||
| # Test 2: generation tokens exceed original_max_pos | ||
| # Creating model inputs with slightly less than original max position embeddings | ||
| tokens = torch.randint(high=transformers_model.config.vocab_size, size=(1, original_max_pos - 50)) | ||
| with torch.no_grad(): | ||
| transformers_outputs = transformers_model.generate(tokens, max_new_tokens=100) | ||
| ov_outputs = ov_model.generate(tokens, max_new_tokens=100) | ||
|
|
||
| self.assertTrue( | ||
| torch.equal(transformers_outputs, ov_outputs), | ||
| f"OpenVINO and PyTorch outputs do not match for LongRoPE test with cumulative context > max_pos.\n" | ||
| f"ov_outputs: {ov_outputs}\ntransformers_outputs: {transformers_outputs}", | ||
| ) | ||
|
|
||
| del transformers_model | ||
| del ov_model | ||
| gc.collect() | ||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see that
short_factoris extracted and used anyhow in this patch.Please check reference impl.: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L454C5-L454C18
We need to be aligned with HF
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
short_factor is not used explicitly because inv_freq is used directly. inv_freq is set here https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/phi3/modeling_phi3.py#L313C1-L314C69 . It calls compute_longrope_parameters in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py . During model export we will not use inputs exceeding original_max_position_embeddings, so inv_freq will be set based on short factors.
This code with explicit short and long factors is needed for transformers because they use this for inference, but for model export inv_freq will be set with short factor correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that short_factor is used in the code

Now we are not aligned with this code and can expect additional bugs in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using this code during model export, as part of model loading code. During model loading, compute_longrope_embeddings is actually called with seq_len=None, so self.inv_freq will be set with short_factor (see your screenshot: if seq_len is None, ext_factors is set according to short_factor, which is defined earlier in the function from model config, and inv_freq is than computed based on this short_factor. So, self.inv_freq is always set with short factor, self.long_inv_freq with long factor, and then in the _phi3_longrope_forward inv_freq is set to self.long_inv_freq if seq is long and else to self.inv_freq. This is aligned with transformers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I understand, it should not be executed during exporting of model. This code should be executed in run-time for each input. Otherwise, it is strange that this model is exported for some concrete seq_len.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The short_factor is used to compute inv_freq. We compute long_inv_freq in the patcher, but short/default inv_freq is computed correctly during model initialization, so self.inv_freq will already be set correctly. And then during inference, we do "if seq len > max_pos: use long_inv_freq else: use default inv_freq".
I replaced self.inv_freq with self.original_inv_freq in forward:

This is initialized here https://github.com/huggingface/transformers/blob/v4.55-release/src/transformers/models/phi3/modeling_phi3.py#L315 right after initializing self.inv_freq with short factors. We never update self.inv_freq (we override forward()) so self.original_inv_freq == self.inv_freq and self.original_inv_freq is clearer.