-
Notifications
You must be signed in to change notification settings - Fork 14.1k
models : fix the attn_factor for mistral3 graphs + improve consistency #17945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
0ca55b6 to
45930c9
Compare
convert_hf_to_gguf.py
Outdated
| yarn_params = self.hparams["yarn"] | ||
| self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"]) | ||
| self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1 | ||
| self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the old version of this was incorrect. AFAICT this affects Mistral 3 Large models. Since it's difficult to test these and given that they are recent and likely haven't proliferated much, I think it's best to accept this breaking change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the old version should be correct because I copied from this LOC of DeepseekV2:
llama.cpp/convert_hf_to_gguf.py
Line 7289 in a81a569
| self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) |
Also note that, the current LOC is used by MistralMoeModel which is based on DeepseekV2Model, it is different from normal MistralModel which is based on LlamaModel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise, it's probably better to use another metadata key in mistral3, I would suggest {arch}.rope.scaling.yarn_mscale_all_dims
And can default it to 1.0 to avoid breaking changes for existing GGUF
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, I missed that Mistral 3 Large implements DS2 arch. Reverted the change and added a comment
| // special-case DEEPSEEK v2: | ||
| // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 | ||
| if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { | ||
| mscale = mscale_all_dims; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarification, I think for deepseek we can do mscale_all_dims = mscale_all_dims * 10.0f, because it was scaled by 0.1 before written into GGUF
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already done above on line 1642
| return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); | ||
| }; | ||
|
|
||
| hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're calculating 0.1f * mscale * logf(scale) + 1.0f in 4 places in the code base:
- Here
yarn_attn_factor_adjust- inside model graph build function
- inside the rope kernel
I think there is a way to simply it:
- leave the
yarn_attn_factoras-is (because it's also used byLLM_ARCH_GROK, so probably better not to adjust it?) - leave the
yarn_attn_factor_adjustto adjust everything - remove the adjustment from deepseek2 cgraph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In short, I think this LOC: hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims) can be redundant because:
- The numerator
get_mscale(factor, mscale)is already handled inside the rope kernel:llama.cpp/ggml/src/ggml-cuda/rope.cu
Line 34 in c33a58b
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); - The denumerator
get_mscale(factor, mscale_all_dims)will be handled byyarn_attn_factor_adjust()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The numerator get_mscale(factor, mscale) is already handled inside the rope kernel:
This numerator assumes mscale == 1.0f. I am not sure why it was added, but effectively we have to cancel it every time if we want to support mscale != 1.0f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because it's also used by LLM_ARCH_GROK, so probably better not to adjust it?
The hparams.yarn_attn_factor = get_mscale(... branch is only entered when hparams.rope_yarn_log_mul != 0.0f, so it should not trigger for GROK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yeah, right. that sounds good then
ggerganov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic is still quite messy, but at least I hope we are handling YaRN correctly now in all cases. @ngxson Thanks for the help and investigation of this issue.
Tested that logprobs with Ministral 3 and Devstral 2 are improved and DS2 are the same as before. Not sure which other models would be affected. Plan to merge this later today.
|
Is this supposed to affect Devstral Small 2 too? Getting improved generation now that this PR was merged; just wondering whether I'm just hitting a statistical anomaly or if this PR helped. |
cont #17644
https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
yarn_ext_factor != 0.0f, instead of on a case-by-case basisllama_hparams::yarn_attn_factor_adjust()and reuse both inllama_graphandllama_kv_cache::build_rope_shift()for consistencyhparams.rope_yarn_log_mulparameter (a.k.a.mscale_all_dims)hparams.rope_yarn_log_mulwith0.1in the convert script for Deepseek v2 for backwards compatibility. Negate it during model load. (see[TAG_DEEPSEEK2_YARN_LOG_MUL_FIX])