From 1df2e90860ff2627980e46f59d470bfb6e5b98d8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 11 Dec 2025 21:41:20 +0200 Subject: [PATCH 1/8] models : fix the attn_factor for mistral3 graphs --- src/llama-model.cpp | 12 ------------ src/models/mistral3.cpp | 2 ++ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index fc337b045eb..aa51e902e6c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2279,18 +2279,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } - // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f - // but may need further verification with other values - if (hparams.rope_yarn_log_mul != 0.0f) { - float factor = 1.0f / hparams.rope_freq_scale_train; - float mscale = 1.0f; - float mscale_all_dims = hparams.rope_yarn_log_mul; - static auto get_mscale = [](float scale, float mscale) { - return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); - }; - hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); - } - switch (hparams.n_layer) { case 26: type = LLM_TYPE_3B; break; case 34: type = LLM_TYPE_8B; break; diff --git a/src/models/mistral3.cpp b/src/models/mistral3.cpp index 0b672235911..b62acd19167 100644 --- a/src/models/mistral3.cpp +++ b/src/models/mistral3.cpp @@ -3,6 +3,8 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; + const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); From 59b9e36f872df37454c1a01a5dd4ea69ede65832 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Dec 2025 11:37:29 +0200 Subject: [PATCH 2/8] cont : rework attn_factor correction logic --- convert_hf_to_gguf.py | 4 ++-- src/llama-graph.cpp | 2 +- src/llama-hparams.cpp | 8 ++++++++ src/llama-hparams.h | 9 ++++++++- src/llama-kv-cache.cpp | 13 ++++--------- src/llama-model.cpp | 27 ++++++++++++++++++++++----- src/models/deepseek2.cpp | 5 ++--- src/models/mistral3.cpp | 2 -- 8 files changed, 47 insertions(+), 23 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 867bc90531c..3a532c2e0a5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7286,7 +7286,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) - self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_scaling["mscale_all_dim"]) _experts: list[dict[str, Tensor]] | None = None @@ -10041,7 +10041,7 @@ def set_gguf_parameters(self): MistralModel.set_mistral_config(self.gguf_writer, self.hparams) yarn_params = self.hparams["yarn"] self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"]) - self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1 + self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name: diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 6cf9a883a6e..a1a32494b75 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -574,7 +574,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : freq_base (cparams.rope_freq_base), freq_scale (cparams.rope_freq_scale), ext_factor (cparams.yarn_ext_factor), - attn_factor (cparams.yarn_attn_factor), + attn_factor (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)), beta_fast (cparams.yarn_beta_fast), beta_slow (cparams.yarn_beta_slow), norm_eps (hparams.f_norm_eps), diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 8cdbaf69fc0..7aa517bfe9b 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -229,3 +229,11 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama return false; } + +float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) { + if (ext_factor != 0.0f) { + attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + } + + return attn_factor; +} diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 6eff334a5fd..c9960e91697 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -107,6 +107,7 @@ struct llama_hparams { float rope_freq_base_train_swa; float rope_freq_scale_train; float rope_freq_scale_train_swa; + uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; @@ -267,7 +268,13 @@ struct llama_hparams { // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + + // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor: + // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544 + // + // ref: https://github.com/ggml-org/llama.cpp/discussions/7416 + // https://github.com/ggml-org/llama.cpp/pull/17945 + static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor); }; static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); - diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3e02bd62977..8f94c8820ce 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1369,9 +1369,10 @@ ggml_tensor * llama_kv_cache::build_rope_shift( float freq_scale) const { const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_beta_fast = cparams.yarn_beta_fast; - const auto & yarn_beta_slow = cparams.yarn_beta_slow; + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; + const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor); const auto & n_rot = hparams.n_rot; const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE @@ -1382,12 +1383,6 @@ ggml_tensor * llama_kv_cache::build_rope_shift( ? LLAMA_ROPE_TYPE_NEOX : hparams.rope_type; - // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. - // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 - ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) - : cparams.yarn_attn_factor; - ggml_tensor * tmp; if (ggml_is_quantized(cur->type)) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index aa51e902e6c..3ca183e7f51 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1635,7 +1635,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // that have no expert_gating_func model parameter set hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; } - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f); // (optional) temperature tuning - used by mistral-large ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); @@ -2267,9 +2267,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f); // TODO: maybe add n_attn_temp_floor_scale as a separate KV? if (hparams.f_attn_temp_scale != 0.0f) { @@ -2289,6 +2289,23 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: throw std::runtime_error("unsupported model architecture"); } + // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348 + if (hparams.rope_yarn_log_mul != 0.0f) { + const float factor = 1.0f / hparams.rope_freq_scale_train; + + const float mscale = 1.0f; + const float mscale_all_dims = hparams.rope_yarn_log_mul; + + static auto get_mscale = [](float scale, float mscale) { + return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); + }; + + hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); + + LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n", + __func__, hparams.yarn_attn_factor, mscale, mscale_all_dims); + } + pimpl->n_bytes = ml.n_bytes; pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); @@ -6794,6 +6811,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); // MRoPE (Multi-axis Rotary Position Embedding) sections if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { @@ -6857,7 +6875,6 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); } if (arch == LLM_ARCH_QWEN2MOE) { diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index dbaa8297be9..1b31ce45240 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -20,9 +20,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); - const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); - const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/mistral3.cpp b/src/models/mistral3.cpp index b62acd19167..0b672235911 100644 --- a/src/models/mistral3.cpp +++ b/src/models/mistral3.cpp @@ -3,8 +3,6 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); From 45930c97dc587ed2395407b3c1d95e4762522ff6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Dec 2025 11:42:00 +0200 Subject: [PATCH 3/8] cont : make deepseek2 consistent --- convert_hf_to_gguf.py | 6 +++++- src/llama-model.cpp | 7 ++++++- src/models/deepseek2.cpp | 3 ++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3a532c2e0a5..714da7eabf0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7286,7 +7286,11 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) - self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_scaling["mscale_all_dim"]) + + # [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + # note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul + # ref https://github.com/ggml-org/llama.cpp/pull/17945 + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) _experts: list[dict[str, Tensor]] | None = None diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3ca183e7f51..5eb8993d68a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1635,7 +1635,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { // that have no expert_gating_func model parameter set hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; } - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f); + + if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // cancel the factor from the convert script + hparams.rope_yarn_log_mul /= 0.1f; + } // (optional) temperature tuning - used by mistral-large ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 1b31ce45240..59c10db01a9 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -20,7 +20,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + const float mscale = attn_factor * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); ggml_tensor * cur; From 45875df231227d31363e7a32e5a5f5453294fe09 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Dec 2025 11:52:17 +0200 Subject: [PATCH 4/8] cont : add TODO --- src/llama-model.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5eb8993d68a..dcd2dad29f6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2298,6 +2298,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (hparams.rope_yarn_log_mul != 0.0f) { const float factor = 1.0f / hparams.rope_freq_scale_train; + // note: here we assume `mscale == 1.0f` + // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f const float mscale = 1.0f; const float mscale_all_dims = hparams.rope_yarn_log_mul; From 06eb8e868df6aa34fd2d5a60aecf94006640239b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Dec 2025 12:09:43 +0200 Subject: [PATCH 5/8] cont : special-case DSv2 --- src/llama-model.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index dcd2dad29f6..e4808b1e1eb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2300,9 +2300,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { // note: here we assume `mscale == 1.0f` // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f - const float mscale = 1.0f; + float mscale = 1.0f; const float mscale_all_dims = hparams.rope_yarn_log_mul; + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // special-case DEEPSEEK v2: + // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 + if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { + mscale = mscale_all_dims; + } + static auto get_mscale = [](float scale, float mscale) { return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); }; From 01b77b5793989e0a1508a2fe3757d4e743a7bbd1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Dec 2025 12:39:05 +0200 Subject: [PATCH 6/8] cont : revert Mistral 3 Large changes --- convert_hf_to_gguf.py | 6 +++++- src/llama-hparams.cpp | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 714da7eabf0..151608d56b8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -10045,7 +10045,11 @@ def set_gguf_parameters(self): MistralModel.set_mistral_config(self.gguf_writer, self.hparams) yarn_params = self.hparams["yarn"] self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"]) - self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim + + # [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + # note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul + # ref https://github.com/ggml-org/llama.cpp/pull/17945 + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1 def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name: diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 7aa517bfe9b..2b12b9c4923 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -1,7 +1,9 @@ #include "llama-hparams.h" #include "ggml.h" + #include +#include void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { if (dense_first) { From 7320a2dc136918082c74654f4a857a710bb801f5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Dec 2025 13:00:41 +0200 Subject: [PATCH 7/8] cont : fix DS2 to use the original attn_factor --- src/llama-hparams.cpp | 2 ++ src/models/deepseek2.cpp | 9 ++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 2b12b9c4923..277d0bcfd3c 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -233,6 +233,8 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama } float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) { + GGML_ASSERT(ext_factor >= 0.0f); + if (ext_factor != 0.0f) { attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); } diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 59c10db01a9..dd47168dffe 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B @@ -21,7 +19,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] - const float mscale = attn_factor * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + + // first cancel the adjustment from llama_hparams::yarn_attn_factor to get the original attn_factor + GGML_ASSERT(ext_factor >= 0.0f); + const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale)); + + const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); ggml_tensor * cur; From d6477e1431a6e69a4b5d62dc30b0cb8528d5fb02 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Dec 2025 17:12:08 +0200 Subject: [PATCH 8/8] cont : minor comments [no ci] --- src/models/deepseek2.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index dd47168dffe..49382874baa 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -20,10 +20,11 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] - // first cancel the adjustment from llama_hparams::yarn_attn_factor to get the original attn_factor + // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor GGML_ASSERT(ext_factor >= 0.0f); const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale)); + // use the original attn_factor to pre-scale the kq_scale const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));