Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7286,6 +7286,10 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])

# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
# ref https://github.com/ggml-org/llama.cpp/pull/17945
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])

_experts: list[dict[str, Tensor]] | None = None
Expand Down Expand Up @@ -10041,6 +10045,10 @@ def set_gguf_parameters(self):
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
yarn_params = self.hparams["yarn"]
self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"])

# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
# ref https://github.com/ggml-org/llama.cpp/pull/17945
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
Expand Down
2 changes: 1 addition & 1 deletion src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor),
attn_factor (cparams.yarn_attn_factor),
attn_factor (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)),
beta_fast (cparams.yarn_beta_fast),
beta_slow (cparams.yarn_beta_slow),
norm_eps (hparams.f_norm_eps),
Expand Down
12 changes: 12 additions & 0 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include "llama-hparams.h"

#include "ggml.h"

#include <cassert>
#include <cmath>

void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
if (dense_first) {
Expand Down Expand Up @@ -229,3 +231,13 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama

return false;
}

float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) {
GGML_ASSERT(ext_factor >= 0.0f);

if (ext_factor != 0.0f) {
attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
}

return attn_factor;
}
9 changes: 8 additions & 1 deletion src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ struct llama_hparams {
float rope_freq_base_train_swa;
float rope_freq_scale_train;
float rope_freq_scale_train_swa;

uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul = 0.0f;

Expand Down Expand Up @@ -267,7 +268,13 @@ struct llama_hparams {
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);

// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
//
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
// https://github.com/ggml-org/llama.cpp/pull/17945
static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor);
};

static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

13 changes: 4 additions & 9 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1369,9 +1369,10 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
float freq_scale) const {
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;

const auto & yarn_ext_factor = cparams.yarn_ext_factor;
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor);

const auto & n_rot = hparams.n_rot;
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
Expand All @@ -1382,12 +1383,6 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
? LLAMA_ROPE_TYPE_NEOX
: hparams.rope_type;

// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
: cparams.yarn_attn_factor;

ggml_tensor * tmp;

if (ggml_is_quantized(cur->type)) {
Expand Down
53 changes: 36 additions & 17 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// that have no expert_gating_func model parameter set
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
}
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);

if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) {
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
// cancel the factor from the convert script
hparams.rope_yarn_log_mul /= 0.1f;
}

// (optional) temperature tuning - used by mistral-large
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
Expand Down Expand Up @@ -2267,9 +2272,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);

ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);

// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
if (hparams.f_attn_temp_scale != 0.0f) {
Expand All @@ -2279,18 +2284,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}
}

// TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
// but may need further verification with other values
if (hparams.rope_yarn_log_mul != 0.0f) {
float factor = 1.0f / hparams.rope_freq_scale_train;
float mscale = 1.0f;
float mscale_all_dims = hparams.rope_yarn_log_mul;
static auto get_mscale = [](float scale, float mscale) {
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
};
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
}

switch (hparams.n_layer) {
case 26: type = LLM_TYPE_3B; break;
case 34: type = LLM_TYPE_8B; break;
Expand All @@ -2301,6 +2294,32 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: throw std::runtime_error("unsupported model architecture");
}

// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
if (hparams.rope_yarn_log_mul != 0.0f) {
const float factor = 1.0f / hparams.rope_freq_scale_train;

// note: here we assume `mscale == 1.0f`
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
float mscale = 1.0f;
const float mscale_all_dims = hparams.rope_yarn_log_mul;

// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
// special-case DEEPSEEK v2:
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
mscale = mscale_all_dims;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarification, I think for deepseek we can do mscale_all_dims = mscale_all_dims * 10.0f, because it was scaled by 0.1 before written into GGUF

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already done above on line 1642

}

static auto get_mscale = [](float scale, float mscale) {
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
};

hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're calculating 0.1f * mscale * logf(scale) + 1.0f in 4 places in the code base:

  • Here
  • yarn_attn_factor_adjust
  • inside model graph build function
  • inside the rope kernel

I think there is a way to simply it:

  • leave the yarn_attn_factor as-is (because it's also used by LLM_ARCH_GROK, so probably better not to adjust it?)
  • leave the yarn_attn_factor_adjust to adjust everything
  • remove the adjustment from deepseek2 cgraph?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In short, I think this LOC: hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims) can be redundant because:

  • The numerator get_mscale(factor, mscale) is already handled inside the rope kernel:
    mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
  • The denumerator get_mscale(factor, mscale_all_dims) will be handled by yarn_attn_factor_adjust()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The numerator get_mscale(factor, mscale) is already handled inside the rope kernel:

This numerator assumes mscale == 1.0f. I am not sure why it was added, but effectively we have to cancel it every time if we want to support mscale != 1.0f

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it's also used by LLM_ARCH_GROK, so probably better not to adjust it?

The hparams.yarn_attn_factor = get_mscale(... branch is only entered when hparams.rope_yarn_log_mul != 0.0f, so it should not trigger for GROK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah, right. that sounds good then


LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
__func__, hparams.yarn_attn_factor, mscale, mscale_all_dims);
}

pimpl->n_bytes = ml.n_bytes;

pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();
Expand Down Expand Up @@ -6806,6 +6825,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
// MRoPE (Multi-axis Rotary Position Embedding) sections
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
Expand Down Expand Up @@ -6869,7 +6889,6 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
}

if (arch == LLM_ARCH_QWEN2MOE) {
Expand Down
14 changes: 9 additions & 5 deletions src/models/deepseek2.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "models.h"



llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
Expand All @@ -20,9 +18,15 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr

// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));
const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
// And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]

// first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor
GGML_ASSERT(ext_factor >= 0.0f);
const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale));

// use the original attn_factor to pre-scale the kq_scale
const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));

ggml_tensor * cur;
ggml_tensor * inpL;
Expand Down