-
Notifications
You must be signed in to change notification settings - Fork 14.1k
models : fix the attn_factor for mistral3 graphs + improve consistency #17945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1df2e90
59b9e36
45930c9
45875df
06eb8e8
01b77b5
7320a2d
d6477e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -1635,7 +1635,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { | |||
| // that have no expert_gating_func model parameter set | ||||
| hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; | ||||
| } | ||||
| ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); | ||||
|
|
||||
| if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { | ||||
| // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] | ||||
| // cancel the factor from the convert script | ||||
| hparams.rope_yarn_log_mul /= 0.1f; | ||||
| } | ||||
|
|
||||
| // (optional) temperature tuning - used by mistral-large | ||||
| ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); | ||||
|
|
@@ -2267,9 +2272,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { | |||
| ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||||
| ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); | ||||
|
|
||||
| ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); | ||||
| ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); | ||||
| ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); | ||||
| ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); | ||||
| ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); | ||||
| ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f); | ||||
|
|
||||
| // TODO: maybe add n_attn_temp_floor_scale as a separate KV? | ||||
| if (hparams.f_attn_temp_scale != 0.0f) { | ||||
|
|
@@ -2279,18 +2284,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { | |||
| } | ||||
| } | ||||
|
|
||||
| // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f | ||||
| // but may need further verification with other values | ||||
| if (hparams.rope_yarn_log_mul != 0.0f) { | ||||
| float factor = 1.0f / hparams.rope_freq_scale_train; | ||||
| float mscale = 1.0f; | ||||
| float mscale_all_dims = hparams.rope_yarn_log_mul; | ||||
| static auto get_mscale = [](float scale, float mscale) { | ||||
| return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); | ||||
| }; | ||||
| hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); | ||||
| } | ||||
|
|
||||
| switch (hparams.n_layer) { | ||||
| case 26: type = LLM_TYPE_3B; break; | ||||
| case 34: type = LLM_TYPE_8B; break; | ||||
|
|
@@ -2301,6 +2294,32 @@ void llama_model::load_hparams(llama_model_loader & ml) { | |||
| default: throw std::runtime_error("unsupported model architecture"); | ||||
| } | ||||
|
|
||||
| // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348 | ||||
| if (hparams.rope_yarn_log_mul != 0.0f) { | ||||
| const float factor = 1.0f / hparams.rope_freq_scale_train; | ||||
|
|
||||
| // note: here we assume `mscale == 1.0f` | ||||
| // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f | ||||
| float mscale = 1.0f; | ||||
| const float mscale_all_dims = hparams.rope_yarn_log_mul; | ||||
|
|
||||
| // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] | ||||
| // special-case DEEPSEEK v2: | ||||
| // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 | ||||
| if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { | ||||
| mscale = mscale_all_dims; | ||||
| } | ||||
|
|
||||
| static auto get_mscale = [](float scale, float mscale) { | ||||
| return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); | ||||
| }; | ||||
|
|
||||
| hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're calculating
I think there is a way to simply it:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In short, I think this LOC:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This numerator assumes
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yeah, right. that sounds good then |
||||
|
|
||||
| LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n", | ||||
| __func__, hparams.yarn_attn_factor, mscale, mscale_all_dims); | ||||
| } | ||||
|
|
||||
| pimpl->n_bytes = ml.n_bytes; | ||||
|
|
||||
| pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); | ||||
|
|
@@ -6806,6 +6825,7 @@ void llama_model::print_info() const { | |||
| LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); | ||||
| LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); | ||||
| LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); | ||||
| LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul); | ||||
| LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); | ||||
| // MRoPE (Multi-axis Rotary Position Embedding) sections | ||||
| if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { | ||||
|
|
@@ -6869,7 +6889,6 @@ void llama_model::print_info() const { | |||
| LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); | ||||
| LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); | ||||
| LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); | ||||
| LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); | ||||
| } | ||||
|
|
||||
| if (arch == LLM_ARCH_QWEN2MOE) { | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarification, I think for deepseek we can do
mscale_all_dims = mscale_all_dims * 10.0f, because it was scaled by0.1before written into GGUFThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already done above on line 1642