Skip to content

Conversation

@helena-intel
Copy link
Collaborator

This is #1297 updated to latest main branch.

Currently inference on Phi-3-mini and Phi-4-mini returns bad outputs (random characters) when context gets larger than about 2000 tokens. This PR, contributed by @eaidova , fixes that. This is not my code. The original PR is no longer being updated; I'm making this a new PR to make it easier to discuss and add updates.

I saw no negative impact on inference speed. I see slightly different outputs with shorter contexts on SPR (on inference with the model exported with the PR vs the model exported with main). Any suggestions to fix that would be much appreciated.

Draft PR for now, awaiting some feedback and testing, but I hope we can merge this soon.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@helena-intel helena-intel added the openvino-slow Runs OpenVINO slow tests with different versions of transformers label Oct 30, 2025
@nikita-savelyevv
Copy link
Collaborator

I see slightly different outputs with shorter contexts on SPR (on inference with the model exported with the PR vs the model exported with main).

I believe minor differences are expected on SPR. But if possible, WWB similarity should be run to see if the difference is significant or not.

@helena-intel helena-intel marked this pull request as ready for review October 31, 2025 10:24
logits_to_keep=None,
**kwargs,
):
# Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct that we have a problem when we have short and long prompts in consecutive generate calls?
We can't re-initialize inv_freqs from long_inv_freqs to short_inv_freqs and vise-versa? How this problem is solved?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline: this is handled in optimum-intel by resetting the kv-cache when the number of input tokens is equal to the long rope boundary (e.g. 4096). This is done the same way in transformers code. Tested that this works as expected in chat context with https://gist.github.com/helena-intel/b55522cda91d9d61a644f153e71f0f98 .

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @helena-intel !!



class OVPhi3ForCausalLM(OVModelForCausalLM):
def prepare_inputs_for_generation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1593 to 1648
super().__enter__()
# Call OVDecoderModelPatcher.__enter__() directly to skip Phi3ModelPatcher's longrope logic
# PhiMoE has a different rotary embedding structure, longrope is not yet supported
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to add all this modifications to PhiMoEModelPatcher? (if longrope is not yet supported then self._model.model.rotary_emb will never be set to "longrope") If we want to make sure we can raise an error in case it's ever the case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially tests failed for phi_moe, see https://github.com/huggingface/optimum-intel/actions/runs/18952102871/job/54119192964 . We should have longrope support for the MoE model too but not in this PR. I would be happy with a simpler solution to not enable longrope for the MoE model (but still have it working as it is now).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix this in a better way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a _disable_longrope property instead of the previous code.

return torch.where(seq_len <= max_pos_embeddings, short_factor, long_factor)


def long_rope(self, x, position_ids, seq_len=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings))
cos = emb.cos() * scaling_factor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is a good point. @helena-intel, please use it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +1519 to +1522
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used and should we ensure fp32 dtype also ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Added the autocast line with enabled=False.

return torch.where(seq_len <= max_pos_embeddings, short_factor, long_factor)


def long_rope(self, x, position_ids, seq_len=None):
Copy link
Collaborator

@rkazants rkazants Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@helena-intel, you actually patch this function https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L442
but I don't see that short_factor from model config is used in the patch. Please clarify it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@helena-intel, I think we need to re-write this patch more accurately to be aligned with https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L442 for longrope

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

short_factor is in the select_ext_factor function: return torch.where(seq_len <= max_pos_embeddings, short_factor, long_factor)

I agree it would be clearer to rewrite - but it is functionally working now. We see the same outputs as transformers, for both short and long context.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkazants I refactored the function and added more comments. I think it is clearer now, please review.

):
past_length = cache_position[0]
if past_length <= self.config.original_max_position_embeddings:
past_key_values = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a link to https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L522 and comment that it is aligned with phi3 for long context.
And add a comment that we reset KV cache and it means that the next step will be prefill for extended (computed so far) tokens.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the link a few lines above. The comment that was there was copied verbatim from the transformers code. I modified the second line a bit to make it clearer (transformers comment references "current failure" but it is not clear what that is).

@rkazants
Copy link
Collaborator

@helena-intel, also it is needed to create tiny-model phi3 that has small values original_max_embedding < max_embedding, for example, equal to 10 and 20. This is how we test KV cache reset and applying new scaling factors based on it. And you can easily embed this tiny model into existing tests in test_decoder and it will automatically test this scenario.

@helena-intel
Copy link
Collaborator Author

@helena-intel, also it is needed to create tiny-model phi3 that has small values original_max_embedding < max_embedding, for example, equal to 10 and 20. This is how we test KV cache reset and applying new scaling factors based on it. And you can easily embed this tiny model into existing tests in test_decoder and it will automatically test this scenario.

Yes, added model yesterday (https://huggingface.co/optimum-intel-internal-testing/tiny-random-phi-4-mini-instruct) and just added a test that fails in main branch and passes with this PR.

f"values are not close for {dtype if dtype is not None else 'None'}, max diff = {torch.abs(ov_logits - ref_logits).max()}",
)

def test_phi3_longrope_support(self):
Copy link
Collaborator

@rkazants rkazants Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need in this new test. Just add your model into SUPPORTED_ARCHITECTURES above. All requited testing will be activated. You also need to add model id into util_tests.py

@rkazants
Copy link
Collaborator

rkazants commented Nov 14, 2025

Yes, added model yesterday (https://huggingface.co/optimum-intel-internal-testing/tiny-random-phi-4-mini-instruct) and just added a test that fails in main branch and passes with this PR.

Why phi4? Also, please initialize to some small original_max_position_embeddings to value I suggested. Just name it as tiny-random-phi3, make sure that rope_type is longrope.

@helena-intel
Copy link
Collaborator Author

helena-intel commented Nov 17, 2025

Why phi4?

I wanted to use a model that is being used by people who reported this issue, and I figured it would be useful to have a phi-4 tiny model too. I can change it if needed.

Just name it as tiny-random-phi3

So it should replace the existing model? https://github.com/huggingface/optimum-intel/blob/main/tests/openvino/utils_tests.py#L152

no need in this new test. Just add your model into SUPPORTED_ARCHITECTURES above. All requited testing will be activated.

I think it's useful to test both short and long context because it is also relevant to know if short context starts failing. And long context should be tested with prompts above the threshold value so if we rely on existing tests we should always remember that the generic model input needs to exceed the long context threshold. If someone changes the existing "same output as transformers" test, or the tiny model, the test may miss issues.

Also, please initialize to some small original_max_position_embeddings to value I suggested. Just name it as tiny-random-phi3, make sure that rope_type is longrope.

I will look into that. Values probably need to be a bit higher, but can be lower than default. We can't just set the values to 10 and 20, the model is sensitive to parameters and it's easy to get collapsing outputs or differences between PyTorch and OpenVINO.

- Explicitly disable torch.autocast to ensure float32 precision
- Add sources for adapted code
- Use self.attention_scaling instead of manual computation
- Save and restore original _orig_max_position_embeddings
- Modify F32_CONFIG to use EXECUTION_MODE_HINT
Exclude longrope for phi3-moe with _disable_longrope
- Add more comments
- Remove superfluous select_ext_factor function
- Rename long_rope to _phi3_longrope_forward for clarity
)

# Creating model inputs with more than original max position embeddings and enough variation for varied output tokens
tokens = torch.as_tensor(list(tokenizer.get_vocab().values())[: original_max_pos + 50]).unsqueeze(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tokenizer is not really needed here, you can use torch.randint with model.config.vocab_size
also shouldn't we test staring with less than max position embeddings and generating enough to surpass it (to trigger cache re-computation)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed test to use randint, and now we test both scenarios, where input tokens exceeds original_max_pos and where generation tokens exceeds it.

Comment on lines 30 to 31
# With this config, inference runs in f32 and optimizations that may influence accuracy are disabled
F32_CONFIG = {"EXECUTION_MODE_HINT": "ACCURACY"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we wanna change it for all models ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not change it globally. Let have it only for phi3 models.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted

model_id, export=True, ov_config=F32_CONFIG, device=OPENVINO_DEVICE
)

# Creating model inputs with more than original max position embeddings and enough variation for varied output tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please test two cases when input_ids length exceeds threshold and when only max_new_tokens exceeds threshold

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def test_phi3_longrope_support(self):
"""Test LongRoPE support for Phi3 with inputs > 4096 tokens."""
set_seed(SEED)
model_id = "optimum-intel-internal-testing/tiny-random-phi-4-mini-instruct"
Copy link
Collaborator

@rkazants rkazants Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please change model card id. Now it is quite confusing with phi-4 but this is not phi-4

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

- rename tiny model to phi3
- add test for cumulative context
- revert F32_CONFIG change
- Set MIN_TRANSFORMERS_VERSION to 4.49 for Phi3
- Remove code specific for transformers<4.49
- Disable trust-remote-code for Phi3
@helena-intel helena-intel removed the openvino-slow Runs OpenVINO slow tests with different versions of transformers label Nov 21, 2025
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Thanks for the awesome fix !

Copy link
Collaborator

@rkazants rkazants left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IlyasMoutawwakil, please let me review before merge. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants