Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
774cf23
initial commit for branch
ddh0 Dec 11, 2025
5ab4ff7
simplify constants
ddh0 Dec 11, 2025
66e2d17
Merge branch 'ggml-org:master' into power-law-sampler
ddh0 Dec 11, 2025
88fb0f3
add params to `struct common_params_sampling`, add reference to PR
ddh0 Dec 11, 2025
374bfd4
explicitly clamp `min_target` and `max_target` to `[0.0, 1.0]`
ddh0 Dec 11, 2025
ffe1639
add args, rename `queue_size` -> `window_size`
ddh0 Dec 11, 2025
4959878
improved comments
ddh0 Dec 11, 2025
f3457a8
minor
ddh0 Dec 11, 2025
9316959
remove old unused code from algorithm
ddh0 Dec 11, 2025
b3aea57
minor
ddh0 Dec 11, 2025
cd7de7c
add power law case to `common_sampler_init`, add sampler name mappings
ddh0 Dec 11, 2025
534cb4f
clarify behaviour when `window_size = 0`
ddh0 Dec 11, 2025
dcada03
add missing enums
ddh0 Dec 11, 2025
2d62bbe
remove `target_range` param, make `target == 1` no-op, cleanup code
ddh0 Dec 12, 2025
5c78b79
oops, straggler
ddh0 Dec 12, 2025
53380c1
add missing parameters in `server-task.cpp`
ddh0 Dec 13, 2025
94cb883
copy from author
ddh0 Dec 13, 2025
0a19a3f
remove old debug log, style nit
ddh0 Dec 13, 2025
824bb3a
fix compiler warning, add commented-out logging per token
ddh0 Dec 13, 2025
1879fc6
Merge branch 'ggml-org:master' into power-law-sampler
ddh0 Dec 13, 2025
67a7336
Merge branch 'ggml-org:master' into power-law-sampler
ddh0 Dec 13, 2025
a96ddd7
re-write + change parameters + simplify
ddh0 Dec 14, 2025
b8a9626
oops forgot args.cpp
ddh0 Dec 14, 2025
965bcc9
fix leftover `window_size`
ddh0 Dec 14, 2025
d1e5c60
add missing values to `common_params_sampling::print()`
ddh0 Dec 14, 2025
9613c48
with logging
ddh0 Dec 14, 2025
2a3f579
does this fix it?
ddh0 Dec 14, 2025
ec54fe5
no, but does this?
ddh0 Dec 14, 2025
667b70f
update default decay
ddh0 Dec 14, 2025
36b526d
Merge branch 'master' into power-law-sampler
ddh0 Dec 14, 2025
6934780
optimize
ddh0 Dec 14, 2025
f5d0872
fix bad merge
ddh0 Dec 15, 2025
493bf30
silence `missing initializer for member`
ddh0 Dec 15, 2025
6854325
update default decay to 0.9
ddh0 Dec 15, 2025
b5ed673
fix logging
ddh0 Dec 15, 2025
4e28eb2
format (double)
ddh0 Dec 15, 2025
1c58e9a
add power law to the new `samplers` vector
ddh0 Dec 15, 2025
4e04bd1
log sampler init values
ddh0 Dec 15, 2025
6e66095
Merge branch 'ggml-org:master' into power-law-sampler
ddh0 Dec 15, 2025
9c50b57
improve logging messages in llama_sampler_power_law
ddh0 Dec 15, 2025
0344068
remove extraneous logging
ddh0 Dec 15, 2025
1c2d2e9
simplify target computation
ddh0 Dec 16, 2025
85b6e52
Merge branch 'ggml-org:master' into power-law-sampler
ddh0 Dec 16, 2025
fcb5129
remove debug logging, explicitly clamp params at init
ddh0 Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1560,6 +1560,24 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_sparam());
add_opt(common_arg(
{"--power-law-target"}, "N",
string_format("power law sampler: select tokens near this probability (valid range 0.0 "
"to 1.0; <0 = disabled) (default: %.2f)\n"
"[(more info)]""(https://github.com/ggml-org/llama.cpp/pull/17927)",
(double)params.sampling.power_law_target),
[](common_params & params, const std::string & value) {
params.sampling.power_law_target = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--power-law-decay"}, "N",
string_format("decay rate for target adaptation over time. lower values -> faster but less stable adaptation.\n"
"(valid range 0.0 to 1.0; ≤0 = no adaptation) (default: %.2f)", (double)params.sampling.power_law_decay),
[](common_params & params, const std::string & value) {
params.sampling.power_law_decay = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dynatemp-range"}, "N",
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
Expand Down
53 changes: 28 additions & 25 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
COMMON_SAMPLER_TYPE_POWER_LAW = 12,
};

// dimensionality reduction methods, used by cvector-generator
Expand Down Expand Up @@ -164,32 +165,34 @@ enum common_params_sampling_config : uint64_t {
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler

int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f;// -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
float power_law_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
float power_law_decay = 0.90f; // decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float top_n_sigma = -1.00f; // -1.0 = disabled
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;

uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
Expand Down
22 changes: 17 additions & 5 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ std::string common_params_sampling::print() const {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, power_law_target = %.3f, power_law_decay = %.3f",
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
mirostat, mirostat_eta, mirostat_tau);
mirostat, mirostat_eta, mirostat_tau, power_law_target, power_law_decay);

return std::string(result);
}
Expand Down Expand Up @@ -241,6 +241,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}

if (params.mirostat == 0) {
// if this flag is set, we will not need to add `dist` at the end of the sampler chain
bool has_distribution_sampler = false;

for (const auto & cnstr : params.samplers) {
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
Expand All @@ -250,7 +253,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}

samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
Expand Down Expand Up @@ -281,12 +283,18 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_PENALTIES:
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
case COMMON_SAMPLER_TYPE_POWER_LAW:
has_distribution_sampler = true;
samplers.push_back(llama_sampler_init_power_law (params.power_law_target, params.power_law_decay, params.seed));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}

samplers.push_back(llama_sampler_init_dist(params.seed));
// only add `dist` to the end of the chain if no other distribution samplers were added
if (!has_distribution_sampler) {
samplers.push_back(llama_sampler_init_dist(params.seed));
}
} else if (params.mirostat == 1) {
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
Expand Down Expand Up @@ -553,6 +561,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
case COMMON_SAMPLER_TYPE_POWER_LAW: return 'w';
default : return '?';
}
}
Expand All @@ -569,6 +578,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
case COMMON_SAMPLER_TYPE_POWER_LAW: return "power_law";
default : return "";
}
}
Expand All @@ -585,6 +595,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
{ "power_law", COMMON_SAMPLER_TYPE_POWER_LAW },
};

// since samplers names are written multiple ways
Expand All @@ -600,6 +611,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "power-law", COMMON_SAMPLER_TYPE_POWER_LAW },
};

std::vector<common_sampler_type> samplers;
Expand Down
23 changes: 23 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,29 @@ extern "C" {
const char ** seq_breakers,
size_t num_breakers);

/// power-law
///
/// this sampler implements a power law probability transformation with adaptive
/// target tracking. it reshapes token probability distributions to favor tokens near a
/// configurable target probability, rather than always selecting from the highest probability
/// candidates. it is ideal for creative, unpredictable text generation.
///
/// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
/// rather than just transforming logits. therefore it must always be the last sampler in the
/// sampler chain.
///
/// minimal truncation before this sampler is recommended.
///
/// @param target select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled)
/// @param decay decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation)
///
/// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
/// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)
LLAMA_API struct llama_sampler * llama_sampler_init_power_law(
float target,
float decay,
uint32_t seed);

LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,
Expand Down
144 changes: 144 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,150 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
return result;
}

// power-law
//
// this sampler implements a power law probability transformation with adaptive
// target tracking. it reshapes token probability distributions to favor tokens near a
// configurable target probability, rather than always selecting from the highest probability
// candidates. it is ideal for creative, unpredictable text generation.
//
// this sampler is like `greedy`, `dist`, and `mirostat` in that it actually selects a token ID
// rather than just transforming logits. therefore it must always be the last sampler in the
// sampler chain.
//
// minimal truncation before this sampler is recommended.
//
// ref: https://github.com/MrJackSpade/llama.cpp/tree/master (original impl)
// ref: https://github.com/ggml-org/llama.cpp/pull/17927 (llama.cpp PR)

struct llama_sampler_power_law {

// the desired average probability for selected tokens (0.0 to 1.0)
// higher values favor more probable tokens (more deterministic)
// lower values favor less probable tokens (more creative)
// negative values disable Power Law sampling (sample from distribution as-is)
const float target;

// controls how quickly history influence fades (0.0 to 0.99)
// lower values = faster adaptation, more reactive to recent tokens
// higher values = slower adaptation, more stable over time
// effective history length ≈ 1/(1-decay) tokens
// examples: decay=0.5 → ~2 tokens, decay=0.9 → ~10, decay=0.95 → ~20
// internally clamped to <= 0.99 to prevent unbounded accumulation
const float decay;

const uint32_t seed;
std::mt19937 rng;

// historical token probabilities weighted by recency
float weighted_sum;
// sum of weights, converges to 1/(1-decay)
float total_weight;
// used to store original token probabilities (needed for history update after selection)
std::vector<float> original_probs;
};

// transformation constants
static constexpr float DISTRIBUTION_WIDTH = 0.3f;
static constexpr float PEAK_LOGIT_VALUE = 5.0f;
static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH;

static const char * llama_sampler_power_law_name(const struct llama_sampler * /*smpl*/) {
return "power-law";
}

static void llama_sampler_power_law_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_power_law *) smpl->ctx;

if (ctx->target < 0.0f) {
// no-op: just sample from the distribution as-is
llama_sampler_softmax_impl(cur_p, false);
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
return;
}

// softmax and store the original probabilities
llama_sampler_softmax_impl(cur_p, false);
ctx->original_probs.resize(cur_p->size);
for (size_t i = 0; i < cur_p->size; ++i) {
ctx->original_probs[i] = cur_p->data[i].p;
}

// compute the adapted target probability for the current sampling step
float computed_target = std::clamp(
ctx->total_weight == 0.0f ? ctx->target : 2.0f * ctx->target - (ctx->weighted_sum / ctx->total_weight),
0.0f, 1.0f
);

// power law transform
for (size_t i = 0; i < cur_p->size; ++i) {
float dist = (cur_p->data[i].p - computed_target) * INV_WIDTH;
cur_p->data[i].logit = PEAK_LOGIT_VALUE / (1.0f + dist * dist);
}

llama_sampler_softmax_impl(cur_p, false);

// sample from transformed distribution
const int idx = llama_sample_dist(cur_p, ctx->rng);
cur_p->selected = idx;

// update running history with the original probability of the selected token
ctx->weighted_sum = ctx->original_probs[idx] + ctx->decay * ctx->weighted_sum;
ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; // history fades over time
}

static void llama_sampler_power_law_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_power_law *) smpl->ctx;
ctx->weighted_sum = 0.0f;
ctx->total_weight = 0.0f;
}

static struct llama_sampler * llama_sampler_power_law_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_power_law *) smpl->ctx;
auto * result = llama_sampler_init_power_law(ctx->target, ctx->decay, ctx->seed);
auto * result_ctx = (llama_sampler_power_law *) result->ctx;

result_ctx->rng = ctx->rng;
result_ctx->weighted_sum = ctx->weighted_sum;
result_ctx->total_weight = ctx->total_weight;
result_ctx->original_probs.reserve(ctx->original_probs.capacity());

return result;
}

static void llama_sampler_power_law_free(struct llama_sampler * smpl) {
delete (llama_sampler_power_law *) smpl->ctx;
}

static struct llama_sampler_i llama_sampler_power_law_i = {
/* .name = */ llama_sampler_power_law_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_power_law_apply,
/* .reset = */ llama_sampler_power_law_reset,
/* .clone = */ llama_sampler_power_law_clone,
/* .free = */ llama_sampler_power_law_free,
};

struct llama_sampler * llama_sampler_init_power_law(
float target,
float decay,
uint32_t seed
) {
auto seed_cur = get_rng_seed(seed);
return llama_sampler_init(
/* .iface = */ &llama_sampler_power_law_i,
/* .ctx = */ new llama_sampler_power_law {
/* .target = */ std::clamp(target, 0.0f, 1.0f),
/* .decay = */ std::clamp(decay, 0.0f, 0.99f),
/* .seed = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .weighted_sum = */ 0.0f,
/* .total_weight = */ 0.0f,
/* .original_probs = */ {},
}
);
}

// logit-bias

struct llama_sampler_logit_bias {
Expand Down
2 changes: 2 additions & 0 deletions tools/server/server-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ task_params server_task::params_from_json_cmpl(
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
params.sampling.power_law_target = json_value(data, "power_law_target", defaults.sampling.power_law_target);
params.sampling.power_law_decay = json_value(data, "power_law_decay", defaults.sampling.power_law_decay);
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
Expand Down
Loading