Skip to content

Commit fb40266

Browse files
committed
server: handle limiting maximum reasoning budget
1 parent fb615a2 commit fb40266

File tree

9 files changed

+228
-11
lines changed

9 files changed

+228
-11
lines changed

common/arg.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,12 +2574,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
25742574
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
25752575
add_opt(common_arg(
25762576
{"--reasoning-budget"}, "N",
2577-
"controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)",
2577+
"controls the maximum number of thinking tokens allowed; -1 for unlimited, 0 to disable thinking, or a positive value to limit thinking tokens (default: -1)",
25782578
[](common_params & params, int value) {
2579-
if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); }
2579+
if (value < -1) { throw std::invalid_argument("invalid value: must be -1 (unlimited), 0 (disabled), or a positive number"); }
25802580
params.reasoning_budget = value;
25812581
}
25822582
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET"));
2583+
add_opt(common_arg(
2584+
{"--reasoning-force-close-message"}, "STRING",
2585+
string_format(
2586+
"if specified, forces the model to close its reasoning/thoughts when generating this message (default: %s)\n",
2587+
params.reasoning_force_close_message.c_str()
2588+
),
2589+
[](common_params & params, const std::string & value) {
2590+
params.reasoning_force_close_message = value;
2591+
}
2592+
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_FORCE_CLOSE_MESSAGE"));
25832593
add_opt(common_arg(
25842594
{"--chat-template"}, "JINJA_TEMPLATE",
25852595
string_format(

common/common.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,14 @@ struct common_init_result common_init_from_params(common_params & params) {
10781078

10791079
common_init_sampler_from_model(model, params.sampling);
10801080

1081+
// Allow models to override the forced reasoning close message via GGUF metadata
1082+
if (params.reasoning_force_close_message == COMMON_DEFAULT_REASONING_FORCE_CLOSE_MESSAGE) {
1083+
char buf[512] = {0};
1084+
if (llama_model_meta_val_str(model, "tokenizer.ggml.reasoning_force_close_message", buf, sizeof(buf)) > 0) {
1085+
params.reasoning_force_close_message = buf;
1086+
}
1087+
}
1088+
10811089
const llama_vocab * vocab = llama_model_get_vocab(model);
10821090

10831091
auto cparams = common_context_params_to_llama(params);

common/common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ enum llama_example {
102102
LLAMA_EXAMPLE_COUNT,
103103
};
104104

105+
inline constexpr const char * COMMON_DEFAULT_REASONING_FORCE_CLOSE_MESSAGE = "... I now conclude my reasoning and will provide the final answer.";
106+
105107
enum common_sampler_type {
106108
COMMON_SAMPLER_TYPE_NONE = 0,
107109
COMMON_SAMPLER_TYPE_DRY = 1,
@@ -466,6 +468,7 @@ struct common_params {
466468
bool enable_chat_template = true;
467469
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
468470
int reasoning_budget = -1;
471+
std::string reasoning_force_close_message = COMMON_DEFAULT_REASONING_FORCE_CLOSE_MESSAGE;
469472
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
470473

471474
std::vector<std::string> api_keys;

tools/server/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ For the ful list of features, please refer to [server's changelog](https://githu
203203
| `--jinja` | use jinja template for chat (default: enabled)<br/><br/>(env: LLAMA_ARG_JINJA) |
204204
| `--no-jinja` | disable jinja template for chat (default: enabled)<br/><br/>(env: LLAMA_ARG_NO_JINJA) |
205205
| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:<br/>- none: leaves thoughts unparsed in `message.content`<br/>- deepseek: puts thoughts in `message.reasoning_content`<br/>- deepseek-legacy: keeps `<think>` tags in `message.content` while also populating `message.reasoning_content`<br/>(default: auto)<br/>(env: LLAMA_ARG_THINK) |
206-
| `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)<br/>(env: LLAMA_ARG_THINK_BUDGET) |
206+
| `--reasoning-budget N` | controls the maximum number of thinking tokens allowed; -1 for unlimited, 0 to disable thinking, or a positive value to limit thinking tokens. When the budget is exceeded, the server automatically injects a closing `</think>` and continues with the final answer. Individual OpenAI-compatible requests can override this value with `thinking_budget_tokens`. (default: -1)<br/>(env: LLAMA_ARG_THINK_BUDGET) |
207+
| `--reasoning-force-close-message STRING` | when the reasoning budget is exceeded, this message is appended to the current user message to signal the model to close any open thought tags. (default: '... I now conclude my reasoning and will provide the final answer.')<br/>(env: LLAMA_ARG_THINK_FORCE_CLOSE_MESSAGE) |
207208
| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
208209
| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) |
209210
| `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)<br/>when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled<br/><br/>(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) |

tools/server/server-context.cpp

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <memory>
1919
#include <unordered_set>
2020
#include <filesystem>
21+
#include <deque>
22+
#include <exception>
2123

2224
// fix problem with std::min and std::max
2325
#if defined(_WIN32)
@@ -47,6 +49,13 @@ enum server_state {
4749
SERVER_STATE_READY, // Server is ready and model is loaded
4850
};
4951

52+
enum reasoning_state {
53+
REASONING_STATE_NONE,
54+
REASONING_STATE_REASONING,
55+
REASONING_STATE_PENDING_FORCE_CLOSE,
56+
REASONING_STATE_FINISHED,
57+
};
58+
5059
static bool server_task_type_need_embd(server_task_type task_type) {
5160
switch (task_type) {
5261
case SERVER_TASK_TYPE_EMBEDDING:
@@ -113,6 +122,12 @@ struct server_slot {
113122
bool has_new_line = false;
114123
bool truncated = false;
115124

125+
// reasoning budget tracking
126+
int32_t n_reasoning_tokens = 0; // number of tokens generated while in reasoning/thinking mode
127+
reasoning_state reasoning = REASONING_STATE_NONE; // are we currently in reasoning mode
128+
std::string reasoning_end_tag; // the closing tag to inject when budget is exceeded (e.g., "</think>")
129+
std::deque<llama_token> forced_tokens; // tokens we must feed back to the model (e.g., forced </think>)
130+
116131
stop_type stop;
117132

118133
std::string stopping_word;
@@ -162,9 +177,11 @@ struct server_slot {
162177
size_t n_sent_text = 0; // number of sent text character
163178

164179
int64_t t_start_process_prompt;
180+
int64_t t_start_reasoning;
165181
int64_t t_start_generation;
166182

167183
double t_prompt_processing; // ms
184+
double t_reasoning_token_generation; // ms
168185
double t_token_generation; // ms
169186

170187
std::function<void(int)> callback_on_release;
@@ -188,6 +205,13 @@ struct server_slot {
188205

189206
drafted.clear();
190207
i_batch_dft.clear();
208+
209+
// reset reasoning budget tracking
210+
n_reasoning_tokens = 0;
211+
reasoning = REASONING_STATE_NONE;
212+
reasoning_end_tag = "";
213+
forced_tokens.clear();
214+
191215
generated_tokens.clear();
192216
generated_token_probs.clear();
193217
json_schema = json();
@@ -372,15 +396,20 @@ struct server_slot {
372396
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
373397
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
374398

399+
const double t_reasoning = t_reasoning_token_generation / n_reasoning_tokens;
400+
const double n_reasoning_second = 1e3 / t_reasoning_token_generation * n_reasoning_tokens;
401+
375402
const double t_gen = t_token_generation / n_decoded;
376403
const double n_gen_second = 1e3 / t_token_generation * n_decoded;
377404

378405
SLT_INF(*this,
379406
"\n"
380407
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
408+
" reasoning time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
381409
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
382410
" total time = %10.2f ms / %5d tokens\n",
383411
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
412+
t_reasoning_token_generation, n_reasoning_tokens, t_reasoning, n_reasoning_second,
384413
t_token_generation, n_decoded, t_gen, n_gen_second,
385414
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
386415

@@ -1079,6 +1108,13 @@ struct server_context_impl {
10791108
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
10801109
: SLOT_STATE_STARTED;
10811110

1111+
// Initialize reasoning tracking
1112+
slot.forced_tokens.clear();
1113+
slot.n_reasoning_tokens = 0;
1114+
slot.reasoning = REASONING_STATE_NONE;
1115+
slot.reasoning_end_tag.clear();
1116+
1117+
10821118
SLT_INF(slot, "%s", "processing task\n");
10831119

10841120
return true;
@@ -1154,6 +1190,85 @@ struct server_context_impl {
11541190
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict);
11551191
}
11561192

1193+
const int32_t reasoning_budget = (slot.task ? slot.task->params.reasoning_budget : params_base.reasoning_budget);
1194+
1195+
// check reasoning budget limit
1196+
// Track reasoning tokens using the chat parser to detect reasoning segments consistently across formats
1197+
// When the budget is exceeded we enqueue the closing tag tokens so they get sent to the client
1198+
// and fed back into the model before continuing normal generation
1199+
if (slot.has_next_token && reasoning_budget > 0 && slot.reasoning != REASONING_STATE_FINISHED) {
1200+
const auto parsed_msg = common_chat_parse(
1201+
slot.generated_text,
1202+
/* is_partial = */ true,
1203+
slot.task->params.oaicompat_chat_syntax);
1204+
const auto & rstatus = parsed_msg.reasoning_status;
1205+
1206+
if (rstatus.active && slot.reasoning != REASONING_STATE_PENDING_FORCE_CLOSE) {
1207+
if (slot.reasoning != REASONING_STATE_REASONING) {
1208+
SLT_DBG(slot, "detected reasoning start via parser%s\n", "");
1209+
slot.reasoning = REASONING_STATE_REASONING;
1210+
slot.reasoning_end_tag = rstatus.end_tag;
1211+
slot.n_reasoning_tokens = 0;
1212+
slot.t_start_reasoning = ggml_time_us();
1213+
}
1214+
} else if (!rstatus.active && slot.reasoning == REASONING_STATE_REASONING) {
1215+
SLT_DBG(slot, "detected reasoning end '%s' via parser\n", rstatus.end_tag.c_str());
1216+
slot.reasoning = REASONING_STATE_FINISHED;
1217+
slot.t_reasoning_token_generation = (ggml_time_us() - slot.t_start_reasoning) / 1e3;
1218+
}
1219+
1220+
if (slot.reasoning == REASONING_STATE_REASONING) {
1221+
slot.n_reasoning_tokens++;
1222+
1223+
// Detect if we are in the middle of emitting a tool call this step.
1224+
// The parser sets tool_call_in_progress when it catches a partial exception
1225+
// while parsing tool calls, indicating incomplete tool call parsing.
1226+
// We also check for tool call diffs in this token as a fallback.
1227+
if (!parsed_msg.tool_call_in_progress && slot.n_reasoning_tokens >= reasoning_budget) {
1228+
SLT_INF(slot, "reasoning budget exceeded, forcing close with '%s', n_reasoning_tokens = %d, reasoning_budget = %d\n",
1229+
slot.reasoning_end_tag.c_str(), slot.n_reasoning_tokens, reasoning_budget);
1230+
1231+
auto fail_close = [&](const char * reason) {
1232+
SLT_WRN(slot, "failed to inject reasoning close tag (%s) -> stopping generation\n", reason);
1233+
slot.stop = STOP_TYPE_LIMIT;
1234+
slot.has_next_token = false;
1235+
};
1236+
1237+
if (slot.reasoning_end_tag.empty()) {
1238+
fail_close("no closing tag detected");
1239+
} else {
1240+
const std::string forced_message = slot.task->params.reasoning_force_close_message.empty()
1241+
? std::string(COMMON_DEFAULT_REASONING_FORCE_CLOSE_MESSAGE)
1242+
: slot.task->params.reasoning_force_close_message;
1243+
const std::string forced_injection = forced_message + slot.reasoning_end_tag;
1244+
1245+
llama_tokens closing_tokens;
1246+
try {
1247+
closing_tokens = common_tokenize(ctx, forced_injection, /*add_special=*/false, /*parse_special=*/true);
1248+
} catch (const std::exception & err) {
1249+
SLT_WRN(slot, "tokenization error while forcing reasoning close: %s\n", err.what());
1250+
fail_close("tokenization error");
1251+
closing_tokens.clear();
1252+
}
1253+
1254+
if (!closing_tokens.empty()) {
1255+
slot.forced_tokens.insert(slot.forced_tokens.end(), closing_tokens.begin(), closing_tokens.end());
1256+
slot.reasoning = REASONING_STATE_PENDING_FORCE_CLOSE;
1257+
} else if (slot.has_next_token) {
1258+
fail_close("closing tag produced no tokens");
1259+
}
1260+
}
1261+
}
1262+
} else if (slot.reasoning == REASONING_STATE_PENDING_FORCE_CLOSE) {
1263+
// We've already scheduled the forced close, wait until it's done
1264+
if (slot.forced_tokens.empty()) {
1265+
SLT_DBG(slot, "completed forced reasoning close with '%s'\n", slot.reasoning_end_tag.c_str());
1266+
slot.reasoning = REASONING_STATE_FINISHED;
1267+
slot.t_reasoning_token_generation = (ggml_time_us() - slot.t_start_reasoning) / 1e3;
1268+
}
1269+
}
1270+
}
1271+
11571272
if (slot.has_new_line) {
11581273
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
11591274
if (slot.task->params.n_indent > 0) {
@@ -2484,7 +2599,15 @@ struct server_context_impl {
24842599

24852600
const int tok_idx = slot.i_batch - i;
24862601

2487-
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
2602+
const bool has_forced_token = !slot.forced_tokens.empty();
2603+
llama_token id = 0;
2604+
2605+
if (has_forced_token) {
2606+
id = slot.forced_tokens.front();
2607+
slot.forced_tokens.pop_front();
2608+
} else {
2609+
id = common_sampler_sample(slot.smpl, ctx, tok_idx);
2610+
}
24882611

24892612
slot.i_batch = -1;
24902613

@@ -2522,7 +2645,7 @@ struct server_context_impl {
25222645

25232646
// speculative decoding - main model sample and accept
25242647
for (auto & slot : slots) {
2525-
if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
2648+
if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty() || !slot.forced_tokens.empty()) {
25262649
continue;
25272650
}
25282651

tools/server/server-task.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,15 @@ json task_params::to_json(bool only_metrics) const {
130130
{"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
131131
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
132132
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
133+
{"reasoning_force_close_message", reasoning_force_close_message},
133134
{"samplers", samplers},
134135
{"speculative.n_max", speculative.n_max},
135136
{"speculative.n_min", speculative.n_min},
136137
{"speculative.p_min", speculative.p_min},
137138
{"timings_per_token", timings_per_token},
138139
{"post_sampling_probs", post_sampling_probs},
139140
{"lora", lora},
141+
{"thinking_budget_tokens", reasoning_budget},
140142
};
141143
}
142144

@@ -159,8 +161,8 @@ task_params server_task::params_from_json_cmpl(
159161
defaults.speculative = params_base.speculative;
160162
defaults.n_keep = params_base.n_keep;
161163
defaults.n_predict = params_base.n_predict;
162-
defaults.n_cache_reuse = params_base.n_cache_reuse;
163164
defaults.antiprompt = params_base.antiprompt;
165+
defaults.reasoning_force_close_message = params_base.reasoning_force_close_message;
164166

165167
// enabling this will output extra debug information in the HTTP responses from the server
166168
params.verbose = params_base.verbosity > 9;
@@ -182,6 +184,9 @@ task_params server_task::params_from_json_cmpl(
182184
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
183185
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
184186

187+
params.reasoning_budget = json_value(data, "thinking_budget_tokens", params_base.reasoning_budget);
188+
params.reasoning_force_close_message = json_value(data, "reasoning_force_close_message", defaults.reasoning_force_close_message);
189+
185190
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
186191
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
187192
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);

tools/server/server-task.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ struct task_params {
7272
struct common_params_speculative speculative;
7373

7474
// response formatting
75-
bool verbose = false;
76-
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
77-
std::string oaicompat_model;
78-
std::string oaicompat_cmpl_id;
79-
common_chat_syntax oaicompat_chat_syntax;
75+
bool verbose = false;
76+
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
77+
std::string oaicompat_model;
78+
std::string oaicompat_cmpl_id;
79+
common_chat_syntax oaicompat_chat_syntax;
80+
int32_t reasoning_budget;
81+
std::string reasoning_force_close_message;
8082

8183
// Embeddings
8284
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)

0 commit comments

Comments
 (0)