Skip to content

Conversation

@ggerganov
Copy link
Member

@ggerganov ggerganov commented Dec 11, 2025

cont #17644

  • Fix the adjustment of the RoPE attention factor based on:

https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348

  • Apply the adjustment for all models that use yarn_ext_factor != 0.0f, instead of on a case-by-case basis
  • Extract the logic in a common function: llama_hparams::yarn_attn_factor_adjust() and reuse both in llama_graph and llama_kv_cache::build_rope_shift() for consistency
  • Normalize the meaning of the hparams.rope_yarn_log_mul parameter (a.k.a. mscale_all_dims)
  • Keep the pre-scaling of hparams.rope_yarn_log_mul with 0.1 in the convert script for Deepseek v2 for backwards compatibility. Negate it during model load. (see [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX])

@ggerganov ggerganov force-pushed the gg/mistral-fix-attn-factor branch from 0ca55b6 to 45930c9 Compare December 12, 2025 09:49
@ggerganov ggerganov marked this pull request as ready for review December 12, 2025 09:56
@ggerganov ggerganov requested a review from CISC as a code owner December 12, 2025 09:56
@ggerganov ggerganov requested a review from ngxson December 12, 2025 09:56
yarn_params = self.hparams["yarn"]
self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"])
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the old version of this was incorrect. AFAICT this affects Mistral 3 Large models. Since it's difficult to test these and given that they are recent and likely haven't proliferated much, I think it's best to accept this breaking change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the old version should be correct because I copied from this LOC of DeepseekV2:

self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])

Also note that, the current LOC is used by MistralMoeModel which is based on DeepseekV2Model, it is different from normal MistralModel which is based on LlamaModel

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, it's probably better to use another metadata key in mistral3, I would suggest {arch}.rope.scaling.yarn_mscale_all_dims

And can default it to 1.0 to avoid breaking changes for existing GGUF

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, I missed that Mistral 3 Large implements DS2 arch. Reverted the change and added a comment

@github-actions github-actions bot added the python python script changes label Dec 12, 2025
@ggerganov ggerganov changed the title models : fix the attn_factor for mistral3 graphs models : fix the attn_factor for mistral3 graphs + improve consistency Dec 12, 2025
// special-case DEEPSEEK v2:
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
mscale = mscale_all_dims;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarification, I think for deepseek we can do mscale_all_dims = mscale_all_dims * 10.0f, because it was scaled by 0.1 before written into GGUF

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already done above on line 1642

return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
};

hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're calculating 0.1f * mscale * logf(scale) + 1.0f in 4 places in the code base:

  • Here
  • yarn_attn_factor_adjust
  • inside model graph build function
  • inside the rope kernel

I think there is a way to simply it:

  • leave the yarn_attn_factor as-is (because it's also used by LLM_ARCH_GROK, so probably better not to adjust it?)
  • leave the yarn_attn_factor_adjust to adjust everything
  • remove the adjustment from deepseek2 cgraph?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In short, I think this LOC: hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims) can be redundant because:

  • The numerator get_mscale(factor, mscale) is already handled inside the rope kernel:
    mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
  • The denumerator get_mscale(factor, mscale_all_dims) will be handled by yarn_attn_factor_adjust()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The numerator get_mscale(factor, mscale) is already handled inside the rope kernel:

This numerator assumes mscale == 1.0f. I am not sure why it was added, but effectively we have to cancel it every time if we want to support mscale != 1.0f

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it's also used by LLM_ARCH_GROK, so probably better not to adjust it?

The hparams.yarn_attn_factor = get_mscale(... branch is only entered when hparams.rope_yarn_log_mul != 0.0f, so it should not trigger for GROK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah, right. that sounds good then

Copy link
Member Author

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is still quite messy, but at least I hope we are handling YaRN correctly now in all cases. @ngxson Thanks for the help and investigation of this issue.

Tested that logprobs with Ministral 3 and Devstral 2 are improved and DS2 are the same as before. Not sure which other models would be affected. Plan to merge this later today.

@ggerganov ggerganov merged commit 7bed317 into master Dec 12, 2025
1 check passed
@ggerganov ggerganov deleted the gg/mistral-fix-attn-factor branch December 12, 2025 15:12
@Nindaleth
Copy link
Contributor

Is this supposed to affect Devstral Small 2 too?

Getting improved generation now that this PR was merged; just wondering whether I'm just hitting a statistical anomaly or if this PR helped.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants