From 7ee3c35a231227b50d44266abed83cf7768eb0a2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 11 Dec 2025 15:00:49 +0200 Subject: [PATCH 1/3] common : refactor common_sampler + grammar logic changes --- common/arg.cpp | 2 +- common/common.cpp | 190 ++++++++++++------ common/common.h | 29 ++- common/sampling.cpp | 183 +++++++++-------- common/sampling.h | 17 +- common/speculative.cpp | 2 +- examples/batched/batched.cpp | 28 ++- examples/embedding/embedding.cpp | 6 +- examples/eval-callback/eval-callback.cpp | 6 +- examples/lookahead/lookahead.cpp | 6 +- examples/lookup/lookup-create.cpp | 8 +- examples/lookup/lookup-stats.cpp | 8 +- examples/lookup/lookup.cpp | 6 +- examples/parallel/parallel.cpp | 6 +- examples/retrieval/retrieval.cpp | 6 +- examples/save-load-state/save-load-state.cpp | 6 +- .../speculative-simple/speculative-simple.cpp | 12 +- examples/speculative/speculative.cpp | 16 +- examples/training/finetune.cpp | 17 +- tools/completion/completion.cpp | 20 +- tools/cvector-generator/cvector-generator.cpp | 6 +- tools/imatrix/imatrix.cpp | 6 +- tools/mtmd/mtmd-cli.cpp | 6 +- tools/perplexity/perplexity.cpp | 6 +- tools/server/server-context.cpp | 47 ++--- tools/tts/tts.cpp | 12 +- 26 files changed, 367 insertions(+), 290 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index a31dcbc689c..9f0fc52c61a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1359,7 +1359,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.top_k = value; params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K; } - ).set_sparam()); + ).set_sparam().set_env("LLAMA_ARG_TOP_K")); add_opt(common_arg( {"--top-p"}, "N", string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), diff --git a/common/common.cpp b/common/common.cpp index 0497f90a280..b76dfa10eab 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1013,31 +1013,40 @@ bool tty_can_use_colors() { // Model utils // -static inline void common_init_sampler_from_model( +// TODO: move to common/sampling +static void common_init_sampler_from_model( const llama_model * model, common_params_sampling & sparams) { const uint64_t config = sparams.user_sampling_config; auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) { - if (config & user_config) return; + if (config & user_config) { + return; + } char buf[64] = {0}; if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { char * end = nullptr; int32_t v = strtol(buf, &end, 10); - if (end && end != buf) dst = v; + if (end && end != buf) { + dst = v; + } } }; auto get_float = [&](const char * key, float & dst, uint64_t user_config) { - if (config & user_config) return; + if (config & user_config) { + return; + } char buf[128] = {0}; if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { char * end = nullptr; float v = strtof(buf, &end); - if (end && end != buf) dst = v; + if (end && end != buf) { + dst = v; + } } }; @@ -1065,31 +1074,122 @@ static inline void common_init_sampler_from_model( get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA); } -struct common_init_result common_init_from_params(common_params & params) { - common_init_result iparams; - auto mparams = common_model_params_to_llama(params); +struct common_init_result::impl { + impl() = default; + ~impl() = default; + + llama_model_ptr model; + llama_context_ptr context; + + std::vector lora; + + std::vector samplers; +}; + +common_init_result::common_init_result(common_params & params) : + pimpl(new impl{}) { + const auto mparams = common_model_params_to_llama(params); llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); if (model == NULL) { - LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", - __func__, params.model.path.c_str()); - return iparams; + return; } - common_init_sampler_from_model(model, params.sampling); + pimpl->model.reset(model); const llama_vocab * vocab = llama_model_get_vocab(model); + // updates params.sampling + // TODO: fix naming + common_init_sampler_from_model(model, params.sampling); + auto cparams = common_context_params_to_llama(params); + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sampling.ignore_eos = false; + } + + // initialize once + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { + LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY); + params.sampling.logit_bias_eog.push_back({i, -INFINITY}); + } + } + + if (params.sampling.ignore_eos) { + // add EOG biases to the active set of logit biases + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); + } + + //if (params.sampling.penalty_last_n == -1) { + // LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + // params.sampling.penalty_last_n = llama_n_ctx(lctx); + //} + + //if (params.sampling.dry_penalty_last_n == -1) { + // LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); + //} + + pimpl->samplers.resize(cparams.n_seq_max); + + for (int i = 0; i < (int) cparams.n_seq_max; ++i) { + pimpl->samplers[i].reset(common_sampler_init(model, params.sampling)); + } + llama_context * lctx = llama_init_from_model(model, cparams); + if (lctx == NULL) { + LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", + __func__, params.model.path.c_str()); + return; + } + + pimpl->context.reset(lctx); +} + +llama_model * common_init_result::model() { + return pimpl->model.get(); +} + +llama_context * common_init_result::context() { + return pimpl->context.get(); +} + +common_sampler * common_init_result::sampler(llama_seq_id seq_id) { + return pimpl->samplers[seq_id].get(); +} + +std::vector & common_init_result::lora() { + return pimpl->lora; +} + +void common_init_result::free_context() { + pimpl->context.reset(); +} + +common_init_result_ptr common_init_from_params(common_params & params) { + common_init_result_ptr res(new common_init_result(params)); + + llama_model * model = res->model(); + if (model == NULL) { + LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", + __func__, params.model.path.c_str()); + return res; + } + + llama_context * lctx = res->context(); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", __func__, params.model.path.c_str()); - llama_model_free(model); - return iparams; + return res; } + const llama_vocab * vocab = llama_model_get_vocab(model); + if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); params.ctx_shift = false; @@ -1101,10 +1201,7 @@ struct common_init_result common_init_from_params(common_params & params) { const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { - llama_free(lctx); - llama_model_free(model); - - return iparams; + return res; } int err = llama_apply_adapter_cvec( @@ -1115,10 +1212,7 @@ struct common_init_result common_init_from_params(common_params & params) { params.control_vector_layer_start, params.control_vector_layer_end); if (err) { - llama_free(lctx); - llama_model_free(model); - - return iparams; + return res; } } @@ -1142,10 +1236,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (!ok) { - llama_free(lctx); - llama_model_free(model); - - return iparams; + return res; } } @@ -1155,9 +1246,7 @@ struct common_init_result common_init_from_params(common_params & params) { lora.reset(llama_adapter_lora_init(model, la.path.c_str())); if (lora == nullptr) { LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); - llama_free(lctx); - llama_model_free(model); - return iparams; + return res; } char buf[1024]; @@ -1166,43 +1255,13 @@ struct common_init_result common_init_from_params(common_params & params) { la.task_name = buf; llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); la.prompt_prefix = buf; - iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters + res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters } if (!params.lora_init_without_apply) { common_set_adapter_lora(lctx, params.lora_adapters); } - if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); - params.sampling.ignore_eos = false; - } - - // initialize once - for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { - if (llama_vocab_is_eog(vocab, i)) { - LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); - params.sampling.logit_bias_eog.push_back({i, -INFINITY}); - } - } - - if (params.sampling.ignore_eos) { - // add EOG biases to the active set of logit biases - params.sampling.logit_bias.insert( - params.sampling.logit_bias.end(), - params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); - } - - if (params.sampling.penalty_last_n == -1) { - LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); - params.sampling.penalty_last_n = llama_n_ctx(lctx); - } - - if (params.sampling.dry_penalty_last_n == -1) { - LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); - params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); - } - if (params.warmup) { LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); @@ -1241,12 +1300,11 @@ struct common_init_result common_init_from_params(common_params & params) { llama_set_warmup(lctx, false); } - iparams.model.reset(model); - iparams.context.reset(lctx); - - return iparams; + return res; } +common_init_result::~common_init_result() = default; + std::string get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. @@ -1255,7 +1313,9 @@ std::string get_model_endpoint() { std::string model_endpoint = "https://huggingface.co/"; if (endpoint_env) { model_endpoint = endpoint_env; - if (model_endpoint.back() != '/') model_endpoint += '/'; + if (model_endpoint.back() != '/') { + model_endpoint += '/'; + } } return model_endpoint; } diff --git a/common/common.h b/common/common.h index 2fd83f0cf9c..4edb74b7066 100644 --- a/common/common.h +++ b/common/common.h @@ -195,7 +195,6 @@ struct common_params_sampling { std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY - std::vector samplers = { COMMON_SAMPLER_TYPE_PENALTIES, COMMON_SAMPLER_TYPE_DRY, @@ -216,6 +215,10 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + bool has_logit_bias() const { + return !logit_bias.empty(); + } + // print the parameters into a string std::string print() const; }; @@ -669,15 +672,29 @@ bool tty_can_use_colors(); // Model utils // -// note: defines object's lifetime +struct common_sampler; + +// note: defines the model, context, samplers, ets. lifetimes struct common_init_result { - llama_model_ptr model; - llama_context_ptr context; + common_init_result(common_params & params); + ~common_init_result(); - std::vector lora; + llama_model * model(); + llama_context * context(); + common_sampler * sampler(llama_seq_id seq_id); + + std::vector & lora(); + + void free_context(); + +private: + struct impl; + std::unique_ptr pimpl; }; -struct common_init_result common_init_from_params(common_params & params); +using common_init_result_ptr = std::unique_ptr; + +common_init_result_ptr common_init_from_params(common_params & params); struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params); diff --git a/common/sampling.cpp b/common/sampling.cpp index 7a6b7be1e0e..6935d84e226 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -104,9 +104,10 @@ struct ring_buffer { struct common_sampler { common_params_sampling params; - struct llama_sampler * grmr; struct llama_sampler * chain; + bool grammar; + ring_buffer prev; std::vector cur; @@ -116,7 +117,6 @@ struct common_sampler { void reset() { prev.clear(); - llama_sampler_reset(grmr); llama_sampler_reset(chain); } @@ -167,10 +167,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co lparams.no_perf = params.no_perf; - struct llama_sampler * grmr; + llama_sampler * chain = llama_sampler_chain_init(lparams); + + bool grammar = false; + std::vector samplers; + if (params.grammar.compare(0, 11, "%llguidance") == 0) { #ifdef LLAMA_USE_LLGUIDANCE - grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); + samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str())); + grammar = true; #else GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE @@ -217,30 +222,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co trigger_patterns_c.push_back(regex.c_str()); } - grmr = params.grammar_lazy - ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", - trigger_patterns_c.data(), trigger_patterns_c.size(), - trigger_tokens.data(), trigger_tokens.size()) - : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); - if (!grmr) { - return nullptr; + if (!params.grammar.empty()) { + if (params.grammar_lazy) { + samplers.push_back( + llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", + trigger_patterns_c.data(), trigger_patterns_c.size(), + trigger_tokens.data(), trigger_tokens.size())); + } else { + samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root")); + } + + grammar = true; } } - auto * result = new common_sampler { - /* .params = */ params, - /* .grmr = */ grmr, - /* .chain = */ llama_sampler_chain_init(lparams), - /* .prev = */ ring_buffer(std::max(32, params.n_prev)), - /* .cur = */ {}, - /* .cur_p = */ {}, - }; - - llama_sampler_chain_add(result->chain, - llama_sampler_init_logit_bias( - llama_vocab_n_tokens(vocab), - params.logit_bias.size(), - params.logit_bias.data())); + if (params.has_logit_bias()) { + samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); + } if (params.mirostat == 0) { for (const auto & cnstr : params.samplers) { @@ -253,58 +251,70 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co c_breakers.push_back(str.c_str()); } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + samplers.push_back(llama_sampler_init_top_k (params.top_k)); break; case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); + samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma)); break; case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); break; case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); + samplers.push_back(llama_sampler_init_infill (vocab)); break; case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); break; default: GGML_ASSERT(false && "unknown sampler type"); } } - llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + + samplers.push_back(llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + samplers.push_back(llama_sampler_init_temp(params.temp)); + samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); + samplers.push_back(llama_sampler_init_temp(params.temp)); + samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); } else { GGML_ASSERT(false && "unknown mirostat version"); } + for (auto * smpl : samplers) { + llama_sampler_chain_add(chain, smpl); + } + + auto * result = new common_sampler { + /* .params = */ params, + /* .chain = */ chain, + /* .grammar = */ grammar, + /* .prev = */ ring_buffer(std::max(32, params.n_prev)), + /* .cur = */ {}, + /* .cur_p = */ {}, + }; + return result; } void common_sampler_free(struct common_sampler * gsmpl) { if (gsmpl) { - llama_sampler_free(gsmpl->grmr); - llama_sampler_free(gsmpl->chain); delete gsmpl; @@ -314,11 +324,24 @@ void common_sampler_free(struct common_sampler * gsmpl) { void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { const auto tm = gsmpl->tm(); - if (accept_grammar) { - llama_sampler_accept(gsmpl->grmr, token); - } + if (gsmpl->grammar) { + const int n_smpl = llama_sampler_chain_n(gsmpl->chain); - llama_sampler_accept(gsmpl->chain, token); + for (int i = 0; i < n_smpl; i++) { + auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); + + // the grammar sampler is always the first one + if (i == 0) { + if (accept_grammar) { + llama_sampler_accept(smpl, token); + } + } else { + llama_sampler_accept(smpl, token); + } + } + } else { + llama_sampler_accept(gsmpl->chain, token); + } gsmpl->prev.push_back(token); } @@ -329,12 +352,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) { struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { return new common_sampler { - /* .params = */ gsmpl->params, - /* .grmr = */ llama_sampler_clone(gsmpl->grmr), - /* .chain = */ llama_sampler_clone(gsmpl->chain), - /* .prev = */ gsmpl->prev, - /* .cur = */ gsmpl->cur, - /* .cur_p = */ gsmpl->cur_p, + /* .params = */ gsmpl->params, + /* .chain = */ llama_sampler_clone(gsmpl->chain), + /* .grammar = */ gsmpl->grammar, + /* .prev = */ gsmpl->prev, + /* .cur = */ gsmpl->cur, + /* .cur_p = */ gsmpl->cur_p, }; } @@ -383,58 +406,33 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } } -llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { +struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) { + return gsmpl->chain; +} + +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) { llama_synchronize(ctx); // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations const auto tm = gsmpl->tm(); - gsmpl->set_logits(ctx, idx); + llama_token id = LLAMA_TOKEN_NULL; - auto & grmr = gsmpl->grmr; auto & chain = gsmpl->chain; auto & cur_p = gsmpl->cur_p; // initialized by set_logits - if (grammar_first) { - llama_sampler_apply(grmr, &cur_p); - } + gsmpl->set_logits(ctx, idx); llama_sampler_apply(chain, &cur_p); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); - const llama_token id = cur_p.data[cur_p.selected].id; - - if (grammar_first) { - return id; - } - - // check if it the sampled token fits the grammar - { - llama_token_data single_token_data = { id, 1.0f, 0.0f }; - llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false }; - - llama_sampler_apply(grmr, &single_token_data_array); - - const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; - if (is_valid) { - return id; - } - } - - // resampling: - // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain - gsmpl->set_logits(ctx, idx); - - llama_sampler_apply(grmr, &cur_p); - llama_sampler_apply(chain, &cur_p); - - GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration"); + id = cur_p.data[cur_p.selected].id; - return cur_p.data[cur_p.selected].id; + return id; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -442,7 +440,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample size_t i = 0; for (; i < draft.size(); i++) { - const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]); common_sampler_accept(gsmpl, id, true); @@ -454,7 +452,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample } if (i == draft.size()) { - const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]); common_sampler_accept(gsmpl, id, true); @@ -464,13 +462,13 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample return result; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; } - return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft); } uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { @@ -515,7 +513,8 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) { for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); - result += std::string("-> ") + llama_sampler_name(smpl) + " "; + result += std::string("-> "); + result += std::string(llama_sampler_name(smpl)) + " "; } return result; diff --git a/common/sampling.h b/common/sampling.h index e198eecda38..ace5d3d020b 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -48,6 +48,8 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); // arguments can be nullptr to skip printing void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); +struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); + // extended sampling implementation: // // - set logits @@ -55,10 +57,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -// if grammar_first is true, the grammar is applied before the samplers (slower) -// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar -// -llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx); // generalized version of common_sampler_sample // @@ -76,10 +75,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft); // assume idxs == [ 0, 1, 2, ..., draft.size() ] -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); @@ -107,3 +106,9 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); + +struct common_sampler_deleter { + void operator()(common_sampler * s) { common_sampler_free(s); } +}; + +typedef std::unique_ptr common_sampler_ptr; diff --git a/common/speculative.cpp b/common/speculative.cpp index 3e83b0964c8..1e12383ae6b 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft( for (int i = 0; i < params.n_draft; ++i) { common_batch_clear(batch); - common_sampler_sample(smpl, ctx_dft, 0, true); + common_sampler_sample(smpl, ctx_dft, 0); const auto * cur_p = common_sampler_get_candidates(smpl, true); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 1a5de5928a5..a2bf95c3fd8 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "sampling.h" #include #include @@ -64,17 +65,21 @@ int main(int argc, char ** argv) { ctx_params.n_ctx = n_kv_req; ctx_params.n_batch = std::max(n_predict, n_parallel); - llama_context * ctx = llama_init_from_model(model, ctx_params); - auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = false; - llama_sampler * smpl = llama_sampler_chain_init(sparams); + std::vector sampler_configs; - llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); - llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); - llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); - llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); + for (int32_t i = 0; i < n_parallel; ++i) { + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); + } + + llama_context * ctx = llama_init_from_model(model, ctx_params); if (ctx == NULL) { LOG_ERR("%s: error: failed to create the llama_context\n" , __func__); @@ -173,7 +178,7 @@ int main(int argc, char ** argv) { continue; } - const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); + const llama_token new_token_id = llama_sampler_sample(sampler_configs[i], ctx, i_batch[i]); // is it an end of generation? -> mark the stream as finished if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { @@ -229,14 +234,17 @@ int main(int argc, char ** argv) { __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); LOG("\n"); - llama_perf_sampler_print(smpl); + llama_perf_sampler_print(sampler_configs[0]); llama_perf_context_print(ctx); fprintf(stderr, "\n"); llama_batch_free(batch); - llama_sampler_free(smpl); + for (auto & sampler_config : sampler_configs) { + llama_sampler_free(sampler_config); + } + llama_free(ctx); llama_model_free(model); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index fe91b308cdc..81111e81b2c 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -131,10 +131,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 80c693ce61c..408338f1afc 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -202,10 +202,10 @@ int main(int argc, char ** argv) { params.warmup = false; // init - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == nullptr || ctx == nullptr) { LOG_ERR("%s : failed to init\n", __func__); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 1e26d8221b8..f54cfdd77f2 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -55,10 +55,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the target model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); auto * mem = llama_get_memory(ctx); diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index 3da45ed9e03..bb94a8fe06d 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -18,16 +18,16 @@ int main(int argc, char ** argv){ llama_numa_init(params.numa); // load the model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model_ptr & model = llama_init.model; - llama_context_ptr & ctx = llama_init.context; + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); GGML_ASSERT(model != nullptr); // tokenize the prompt std::vector inp; - inp = common_tokenize(ctx.get(), params.prompt, true, true); + inp = common_tokenize(ctx, params.prompt, true, true); fprintf(stderr, "%s: tokenization done\n", __func__); common_ngram_cache ngram_cache; diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index fcb289abe0e..135f6fcab95 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -28,13 +28,13 @@ int main(int argc, char ** argv){ llama_numa_init(params.numa); // load the model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_context_ptr & ctx = llama_init.context; + llama_context * ctx = llama_init->context(); // tokenize the prompt std::vector inp; - inp = common_tokenize(ctx.get(), params.prompt, true, true); + inp = common_tokenize(ctx, params.prompt, true, true); common_ngram_cache ngram_cache_context; common_ngram_cache ngram_cache_dynamic; @@ -65,7 +65,7 @@ int main(int argc, char ** argv){ } const int n_input = inp.size(); - const int n_ctx = llama_n_ctx(ctx.get()); + const int n_ctx = llama_n_ctx(ctx); int n_drafted = 0; int n_accept = 0; diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 2bfa26b55f0..27f159940a4 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -29,10 +29,10 @@ int main(int argc, char ** argv){ llama_numa_init(params.numa); // load the model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); const llama_vocab * vocab = llama_model_get_vocab(model); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index e48f48fc322..c92173ae291 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -192,10 +192,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the target model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); auto * mem = llama_get_memory(ctx); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 042e12c2bf8..2c2143ad108 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -149,10 +149,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 4cd3071f762..39d4464663d 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -34,10 +34,10 @@ int main(int argc, char ** argv) { std::string result2; // init - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == nullptr || ctx == nullptr) { fprintf(stderr, "%s : failed to init\n", __func__); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a8e53f28eb5..a5bd5c9125e 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -40,10 +40,10 @@ int main(int argc, char ** argv) { llama_context * ctx_dft = NULL; // load the target model - common_init_result llama_init_tgt = common_init_from_params(params); + auto llama_init_tgt = common_init_from_params(params); - model_tgt = llama_init_tgt.model.get(); - ctx_tgt = llama_init_tgt.context.get(); + model_tgt = llama_init_tgt->model(); + ctx_tgt = llama_init_tgt->context(); const llama_vocab * vocab = llama_model_get_vocab(model_tgt); @@ -61,10 +61,10 @@ int main(int argc, char ** argv) { params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; - common_init_result llama_init_dft = common_init_from_params(params); + auto llama_init_dft = common_init_from_params(params); - //model_dft = llama_init_dft.model.get(); - ctx_dft = llama_init_dft.context.get(); + //model_dft = llama_init_dft->model(); + ctx_dft = llama_init_dft->context(); if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str()); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 5f5ac5eb64d..2fb7f6374ef 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -71,10 +71,10 @@ int main(int argc, char ** argv) { llama_context * ctx_dft = NULL; // load the target model - common_init_result llama_init_tgt = common_init_from_params(params); + auto llama_init_tgt = common_init_from_params(params); - model_tgt = llama_init_tgt.model.get(); - ctx_tgt = llama_init_tgt.context.get(); + model_tgt = llama_init_tgt->model(); + ctx_tgt = llama_init_tgt->context(); // load the draft model params.devices = params.speculative.devices; @@ -87,10 +87,10 @@ int main(int argc, char ** argv) { params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; - common_init_result llama_init_dft = common_init_from_params(params); + auto llama_init_dft = common_init_from_params(params); - model_dft = llama_init_dft.model.get(); - ctx_dft = llama_init_dft.context.get(); + model_dft = llama_init_dft->model(); + ctx_dft = llama_init_dft->context(); const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); @@ -242,7 +242,7 @@ int main(int argc, char ** argv) { bool accept = false; if (params.sampling.temp > 0) { // stochastic verification - common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true); + common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); auto & dist_tgt = *common_sampler_get_candidates(smpl, true); @@ -491,7 +491,7 @@ int main(int argc, char ** argv) { continue; } - common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true); + common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true); diff --git a/examples/training/finetune.cpp b/examples/training/finetune.cpp index 416d8d8f6c8..c82de8d35d0 100644 --- a/examples/training/finetune.cpp +++ b/examples/training/finetune.cpp @@ -39,9 +39,10 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); // load the model and apply lora adapter, if any - common_init_result llama_init = common_init_from_params(params); - llama_model_ptr & model = llama_init.model; - llama_context_ptr & ctx = llama_init.context; + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); @@ -54,8 +55,8 @@ int main(int argc, char ** argv) { LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } - std::vector tokens = common_tokenize(ctx.get(), params.prompt, true); - ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2); + std::vector tokens = common_tokenize(ctx, params.prompt, true); + ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx) / 2); struct lr_opt & lr = params.lr; LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n", @@ -70,7 +71,7 @@ int main(int argc, char ** argv) { /*get_opt_pars_ud =*/¶ms.lr, /*optimizer_type =*/params.optimizer, }; - llama_opt_init(ctx.get(), model.get(), lopt_params); + llama_opt_init(ctx, model, lopt_params); const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split); @@ -78,7 +79,7 @@ int main(int argc, char ** argv) { ggml_opt_result_t result_eval = ggml_opt_result_init(); for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) { - llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split, + llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split, ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); fprintf(stderr, "\n"); @@ -88,7 +89,7 @@ int main(int argc, char ** argv) { ggml_opt_result_free(result_train); ggml_opt_result_free(result_eval); - llama_model_save_to_file(model.get(), params.out_file.c_str()); + llama_model_save_to_file(model, params.out_file.c_str()); llama_backend_free(); diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index cb2641ae0ae..85480f33691 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -141,13 +141,15 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); - common_init_result llama_init = common_init_from_params(params); - model = llama_init.model.get(); - ctx = llama_init.context.get(); + auto llama_init = common_init_from_params(params); - if (model == NULL) { - LOG_ERR("%s: error: unable to load model\n", __func__); + ctx = llama_init->context(); + model = llama_init->model(); + smpl = llama_init->sampler(0); + + if (ctx == NULL) { + LOG_ERR("%s: error: unable to create context\n", __func__); return 1; } @@ -474,12 +476,6 @@ int main(int argc, char ** argv) { } } - smpl = common_sampler_init(model, sparams); - if (!smpl) { - LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); - return 1; - } - LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl)); LOG_INF("sampler params: \n%s\n", sparams.print().c_str()); LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str()); @@ -993,8 +989,6 @@ int main(int argc, char ** argv) { LOG("\n\n"); common_perf_print(ctx, smpl); - common_sampler_free(smpl); - llama_backend_free(); ggml_threadpool_free_fn(threadpool); diff --git a/tools/cvector-generator/cvector-generator.cpp b/tools/cvector-generator/cvector-generator.cpp index d2d97e05ceb..3ba7c529506 100644 --- a/tools/cvector-generator/cvector-generator.cpp +++ b/tools/cvector-generator/cvector-generator.cpp @@ -419,10 +419,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the model to get hparams - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); // int n_ctx = llama_n_ctx(ctx); int n_layers = llama_model_n_layer(model); diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index f28a036deeb..669de55ddb9 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -1265,10 +1265,10 @@ int main(int argc, char ** argv) { params.warmup = false; // init - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == nullptr || ctx == nullptr) { LOG_ERR("%s : failed to init\n", __func__); diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 25d24603db6..332d2049e54 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -65,7 +65,7 @@ static void sigint_handler(int signo) { struct mtmd_cli_context { mtmd::context_ptr ctx_vision; - common_init_result llama_init; + common_init_result_ptr llama_init; llama_model * model; llama_context * lctx; @@ -89,8 +89,8 @@ struct mtmd_cli_context { llama_pos n_past = 0; mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) { - model = llama_init.model.get(); - lctx = llama_init.context.get(); + model = llama_init->model(); + lctx = llama_init->context(); vocab = llama_model_get_vocab(model); smpl = common_sampler_init(model, params.sampling); n_threads = params.cpuparams.n_threads; diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index caf080e8d18..1ead9c871e9 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -2024,10 +2024,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the model and apply lora adapter, if any - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 5a67f508dfb..90898b5ec43 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -153,7 +153,7 @@ struct server_slot { // sampling json json_schema; - struct common_sampler * smpl = nullptr; + common_sampler_ptr smpl; llama_token sampled; // in speculative mode, this is the last accepted token llama_tokens drafted; @@ -510,8 +510,8 @@ struct server_context_impl { common_params params_base; // note: keep these alive - they determine the lifetime of the model, context, etc. - common_init_result llama_init; - common_init_result llama_init_dft; + common_init_result_ptr llama_init; + common_init_result_ptr llama_init_dft; llama_model * model = nullptr; llama_context * ctx = nullptr; @@ -557,9 +557,6 @@ struct server_context_impl { // Clear any sampling context for (server_slot & slot : slots) { - common_sampler_free(slot.smpl); - slot.smpl = nullptr; - llama_free(slot.ctx_dft); slot.ctx_dft = nullptr; @@ -580,8 +577,8 @@ struct server_context_impl { llama_init = common_init_from_params(params_base); - model = llama_init.model.get(); - ctx = llama_init.context.get(); + model = llama_init->model(); + ctx = llama_init->context(); if (model == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); @@ -613,25 +610,25 @@ struct server_context_impl { llama_init_dft = common_init_from_params(params_dft); - model_dft = llama_init_dft.model.get(); + model_dft = llama_init_dft->model(); if (model_dft == nullptr) { SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); return false; } - vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); + vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context()); if (!vocab_dft_compatible) { SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); } - const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + const int n_ctx_dft = llama_n_ctx(llama_init_dft->context()); cparams_dft = common_context_params_to_llama(params_dft); cparams_dft.n_batch = n_ctx_dft; // the context is not needed - we will create one for each slot - llama_init_dft.context.reset(); + llama_init_dft->free_context(); } chat_templates = common_chat_templates_init(model, params_base.chat_template); @@ -1051,18 +1048,15 @@ struct server_context_impl { // initialize samplers { - if (slot.smpl != nullptr) { - common_sampler_free(slot.smpl); - } + slot.smpl.reset(common_sampler_init(model, task.params.sampling)); - slot.smpl = common_sampler_init(model, task.params.sampling); if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; } - SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); } // initialize draft batch @@ -1216,11 +1210,10 @@ struct server_context_impl { } void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.task->params.sampling.n_probs; - size_t n_vocab = llama_vocab_n_tokens(vocab); + const size_t n_probs = slot.task->params.sampling.n_probs; if (post_sampling) { - const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); + const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true); const size_t max_probs = cur_p->size; // set probability for sampled token @@ -1245,7 +1238,7 @@ struct server_context_impl { std::vector cur = get_token_probabilities(ctx, idx); // set probability for sampled token - for (size_t i = 0; i < n_vocab; i++) { + for (size_t i = 0; i < cur.size(); i++) { // set probability for sampled token if (cur[i].id == result.tok) { result.prob = cur[i].p; @@ -1255,7 +1248,7 @@ struct server_context_impl { // set probability for top n_probs tokens result.probs.reserve(n_probs); - for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) { result.probs.push_back({ cur[i].id, common_token_to_piece(ctx, cur[i].id, special), @@ -2301,13 +2294,13 @@ struct server_context_impl { GGML_ASSERT(batch.n_tokens > 0); - common_sampler_reset(slot.smpl); + common_sampler_reset(slot.smpl.get()); // Process all prompt tokens through sampler system for (int i = 0; i < slot.task->n_tokens(); ++i) { llama_token id = input_tokens[i]; if (id != LLAMA_TOKEN_NULL) { - common_sampler_accept(slot.smpl, id, false); + common_sampler_accept(slot.smpl.get(), id, false); } } @@ -2525,11 +2518,11 @@ struct server_context_impl { const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); slot.i_batch = -1; - common_sampler_accept(slot.smpl, id, true); + common_sampler_accept(slot.smpl.get(), id, true); slot.n_decoded += 1; @@ -2570,7 +2563,7 @@ struct server_context_impl { size_t n_draft = slot.drafted.size(); // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted); + const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); slot.i_batch_dft.clear(); slot.drafted.clear(); diff --git a/tools/tts/tts.cpp b/tools/tts/tts.cpp index eaf56591d9d..8c39fce8baa 100644 --- a/tools/tts/tts.cpp +++ b/tools/tts/tts.cpp @@ -568,10 +568,10 @@ int main(int argc, char ** argv) { llama_context * ctx_ttc = NULL; llama_context * ctx_cts = NULL; - common_init_result llama_init_ttc = common_init_from_params(params); + auto llama_init_ttc = common_init_from_params(params); - model_ttc = llama_init_ttc.model.get(); - ctx_ttc = llama_init_ttc.context.get(); + model_ttc = llama_init_ttc->model(); + ctx_ttc = llama_init_ttc->context(); if (model_ttc == nullptr || ctx_ttc == nullptr) { return ENOENT; @@ -583,10 +583,10 @@ int main(int argc, char ** argv) { params.embedding = true; params.n_ubatch = params.n_batch; - common_init_result llama_init_cts = common_init_from_params(params); + auto llama_init_cts = common_init_from_params(params); - model_cts = llama_init_cts.model.get(); - ctx_cts = llama_init_cts.context.get(); + model_cts = llama_init_cts->model(); + ctx_cts = llama_init_cts->context(); if (model_cts == nullptr || ctx_cts == nullptr) { return ENOENT; From 68c56542be18d984c21544335b9061166b29318a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 11 Dec 2025 22:44:25 +0200 Subject: [PATCH 2/3] tests : increase max_tokens to get needed response --- tools/server/tests/unit/test_compat_anthropic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py index d55dd1d9454..e0a003557e7 100644 --- a/tools/server/tests/unit/test_compat_anthropic.py +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -684,7 +684,7 @@ def test_anthropic_streaming_content_block_indices(): # Request that might produce both text and tool use res = server.make_stream_request("POST", "/v1/messages", data={ "model": "test", - "max_tokens": 200, + "max_tokens": 400, "stream": True, "tools": [{ "name": "test_tool", From 2fa9874da504e8020775d27ae1d6076240ebd129 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Dec 2025 15:13:15 +0200 Subject: [PATCH 3/3] batched : fix uninitialized samplers --- examples/batched/batched.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index a2bf95c3fd8..36a12d299ff 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -68,7 +68,7 @@ int main(int argc, char ** argv) { auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = false; - std::vector sampler_configs; + std::vector samplers; for (int32_t i = 0; i < n_parallel; ++i) { llama_sampler * smpl = llama_sampler_chain_init(sparams); @@ -77,6 +77,8 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); + + samplers.push_back(smpl); } llama_context * ctx = llama_init_from_model(model, ctx_params); @@ -178,7 +180,7 @@ int main(int argc, char ** argv) { continue; } - const llama_token new_token_id = llama_sampler_sample(sampler_configs[i], ctx, i_batch[i]); + const llama_token new_token_id = llama_sampler_sample(samplers[i], ctx, i_batch[i]); // is it an end of generation? -> mark the stream as finished if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { @@ -234,14 +236,14 @@ int main(int argc, char ** argv) { __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); LOG("\n"); - llama_perf_sampler_print(sampler_configs[0]); + llama_perf_sampler_print(samplers[0]); llama_perf_context_print(ctx); fprintf(stderr, "\n"); llama_batch_free(batch); - for (auto & sampler_config : sampler_configs) { + for (auto & sampler_config : samplers) { llama_sampler_free(sampler_config); }