Skip to content

Commit 59b9e36

Browse files
committed
cont : rework attn_factor correction logic
1 parent 1df2e90 commit 59b9e36

File tree

8 files changed

+47
-23
lines changed

8 files changed

+47
-23
lines changed

convert_hf_to_gguf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7286,7 +7286,7 @@ def set_gguf_parameters(self):
72867286
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
72877287
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
72887288
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
7289-
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])
7289+
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_scaling["mscale_all_dim"])
72907290

72917291
_experts: list[dict[str, Tensor]] | None = None
72927292

@@ -10041,7 +10041,7 @@ def set_gguf_parameters(self):
1004110041
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
1004210042
yarn_params = self.hparams["yarn"]
1004310043
self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"])
10044-
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1
10044+
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
1004510045

1004610046
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
1004710047
if name.startswith("vision_") or name.startswith("patch_merger.") or "mm_projector" in name:

src/llama-graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
574574
freq_base (cparams.rope_freq_base),
575575
freq_scale (cparams.rope_freq_scale),
576576
ext_factor (cparams.yarn_ext_factor),
577-
attn_factor (cparams.yarn_attn_factor),
577+
attn_factor (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)),
578578
beta_fast (cparams.yarn_beta_fast),
579579
beta_slow (cparams.yarn_beta_slow),
580580
norm_eps (hparams.f_norm_eps),

src/llama-hparams.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,11 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama
229229

230230
return false;
231231
}
232+
233+
float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) {
234+
if (ext_factor != 0.0f) {
235+
attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
236+
}
237+
238+
return attn_factor;
239+
}

src/llama-hparams.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ struct llama_hparams {
107107
float rope_freq_base_train_swa;
108108
float rope_freq_scale_train;
109109
float rope_freq_scale_train_swa;
110+
110111
uint32_t n_ctx_orig_yarn;
111112
float rope_yarn_log_mul = 0.0f;
112113

@@ -267,7 +268,13 @@ struct llama_hparams {
267268
// TODO: think of a better place for this function
268269
// TODO: pack the SWA params in a struct?
269270
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
271+
272+
// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
273+
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
274+
//
275+
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
276+
// https://github.com/ggml-org/llama.cpp/pull/17945
277+
static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor);
270278
};
271279

272280
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
273-

src/llama-kv-cache.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,9 +1369,10 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
13691369
float freq_scale) const {
13701370
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
13711371

1372-
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
1373-
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
1374-
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
1372+
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
1373+
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
1374+
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
1375+
const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor);
13751376

13761377
const auto & n_rot = hparams.n_rot;
13771378
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
@@ -1382,12 +1383,6 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
13821383
? LLAMA_ROPE_TYPE_NEOX
13831384
: hparams.rope_type;
13841385

1385-
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
1386-
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
1387-
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
1388-
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
1389-
: cparams.yarn_attn_factor;
1390-
13911386
ggml_tensor * tmp;
13921387

13931388
if (ggml_is_quantized(cur->type)) {

src/llama-model.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,7 +1635,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
16351635
// that have no expert_gating_func model parameter set
16361636
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
16371637
}
1638-
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
1638+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);
16391639

16401640
// (optional) temperature tuning - used by mistral-large
16411641
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
@@ -2267,9 +2267,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
22672267
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
22682268
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
22692269

2270-
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
2271-
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
2272-
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
2270+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
2271+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
2272+
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);
22732273

22742274
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
22752275
if (hparams.f_attn_temp_scale != 0.0f) {
@@ -2289,6 +2289,23 @@ void llama_model::load_hparams(llama_model_loader & ml) {
22892289
default: throw std::runtime_error("unsupported model architecture");
22902290
}
22912291

2292+
// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
2293+
if (hparams.rope_yarn_log_mul != 0.0f) {
2294+
const float factor = 1.0f / hparams.rope_freq_scale_train;
2295+
2296+
const float mscale = 1.0f;
2297+
const float mscale_all_dims = hparams.rope_yarn_log_mul;
2298+
2299+
static auto get_mscale = [](float scale, float mscale) {
2300+
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
2301+
};
2302+
2303+
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
2304+
2305+
LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
2306+
__func__, hparams.yarn_attn_factor, mscale, mscale_all_dims);
2307+
}
2308+
22922309
pimpl->n_bytes = ml.n_bytes;
22932310

22942311
pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();
@@ -6794,6 +6811,7 @@ void llama_model::print_info() const {
67946811
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
67956812
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
67966813
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
6814+
LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul);
67976815
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
67986816
// MRoPE (Multi-axis Rotary Position Embedding) sections
67996817
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
@@ -6857,7 +6875,6 @@ void llama_model::print_info() const {
68576875
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
68586876
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
68596877
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
6860-
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
68616878
}
68626879

68636880
if (arch == LLM_ARCH_QWEN2MOE) {

src/models/deepseek2.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
2020

2121
// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
2222
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
23-
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
24-
const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));
25-
const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
23+
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
24+
const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));
2625

2726
ggml_tensor * cur;
2827
ggml_tensor * inpL;

src/models/mistral3.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
44
const int64_t n_embd_head = hparams.n_embd_head_v;
55

6-
const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
7-
86
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
97
GGML_ASSERT(n_embd_head == hparams.n_rot);
108

0 commit comments

Comments
 (0)