diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index ab6b3aa7cec..d8c7a0d4c52 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -757,11 +757,6 @@ json oaicompat_completion_params_parse(const json & body) { llama_params["stop"] = json_value(body, "stop", json::array()); } - // Handle "echo" field - if (json_value(body, "echo", false)) { - throw std::runtime_error("Only no echo is supported"); - } - // Params supported by OAI but unsupported by llama.cpp static const std::vector unsupported_params { "best_of", "suffix" }; for (const auto & param : unsupported_params) { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 5a67f508dfb..86703140a57 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -70,6 +70,7 @@ static bool server_task_type_need_logits(server_task_type task_type) { struct server_slot { int id; + common_params params_base; llama_batch batch_spec = {}; // TODO: change to unique_ptrs for consistency: @@ -107,6 +108,9 @@ struct server_slot { // ref: https://github.com/ggml-org/llama.cpp/pull/17808 std::vector i_batch_dft; + // idx of prompt tokens to get logits (for echo=true) + std::vector> i_batch_prompt; + std::vector generated_token_probs; bool has_next_token = true; @@ -209,6 +213,12 @@ struct server_slot { return server_task_type_need_embd(task->type); } + bool need_prompt_logits() const { + GGML_ASSERT(task); + + return task->params.echo && task->params.sampling.n_probs > 0; + } + bool need_logits() const { GGML_ASSERT(task); @@ -255,6 +265,12 @@ struct server_slot { return ctx_dft; } + std::string token_to_piece(const llama_token & token) const { + bool is_special = params_base.special + || task->params.sampling.preserved_tokens.find(token) != task->params.sampling.preserved_tokens.end(); + return common_token_to_piece(ctx, token, is_special); + } + void add_token(const completion_token_output & token) { if (!is_processing()) { SLT_WRN(*this, "%s", "slot is not processing\n"); @@ -724,6 +740,7 @@ struct server_context_impl { slot.ctx = ctx; slot.n_ctx = n_ctx_slot; slot.mctx = mctx; + slot.params_base = params_base; slot.prompt.tokens.has_mtmd = mctx != nullptr; if (model_dft) { @@ -1829,11 +1846,6 @@ struct server_context_impl { // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; - // first, add sampled tokens from any ongoing sequences for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { @@ -1919,6 +1931,8 @@ struct server_context_impl { continue; } + slot.i_batch_prompt.clear(); + // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { const auto & input_tokens = slot.task->tokens; @@ -2280,9 +2294,14 @@ struct server_context_impl { cur_tok, slot.prompt.tokens.pos_next(), { slot.id }, - slot.need_embd()); + slot.need_embd() || slot.need_prompt_logits()); slot.prompt.tokens.push_back(cur_tok); + // track prompt tokens that need logits output + if (slot.need_prompt_logits()) { + slot.i_batch_prompt.push_back({batch.n_tokens - 1, cur_tok}); + } + slot.n_prompt_tokens_processed++; // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. @@ -2486,11 +2505,24 @@ struct server_context_impl { } } - // optionally send prompt processing progress if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { + // optionally send prompt processing progress if (slot.task->params.stream && slot.task->params.return_progress) { send_partial_response(slot, {}, true); } + + // optinally get prompt logits (echo=true) + if (!slot.i_batch_prompt.empty()) { + GGML_ASSERT(slot.task->params.stream); // TODO: support non-streaming if needed + for (auto & [tok_idx, id] : slot.i_batch_prompt) { + completion_token_output result; + result.tok = id; + result.text_to_send = slot.token_to_piece(id); + result.prob = 1.0f; + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); + send_partial_response(slot, result, false); + } + } } if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { @@ -2543,7 +2575,7 @@ struct server_context_impl { completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = slot.token_to_piece(result.tok); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.task->params.sampling.n_probs > 0) { @@ -2594,7 +2626,7 @@ struct server_context_impl { completion_token_output result; result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.text_to_send = slot.token_to_piece(result.tok); result.prob = 1.0f; // set later // TODO: set result.probs diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 360826062b1..42b8b76f30b 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -167,6 +167,7 @@ task_params server_task::params_from_json_cmpl( params.timings_per_token = json_value(data, "timings_per_token", false); params.stream = json_value(data, "stream", false); + params.echo = json_value(data, "echo", false); auto stream_opt = json_value(data, "stream_options", json::object()); params.include_usage = json_value(stream_opt, "include_usage", false); params.cache_prompt = json_value(data, "cache_prompt", true); @@ -221,6 +222,14 @@ task_params server_task::params_from_json_cmpl( params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); } + if (params.echo && params.sampling.n_probs == 0) { + throw std::runtime_error("Error: echo without logprobs is not yet supported"); + } + + if (params.echo && params.sampling.n_probs != 0 && !params.stream) { + throw std::runtime_error("Error: echo with logprobs requires streaming to be enabled"); + } + if (data.contains("lora")) { if (data.at("lora").is_array()) { params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 0759094a01d..6c6e89ccd10 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -45,7 +45,8 @@ enum stop_type { struct task_params { bool stream = true; bool include_usage = false; - bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool echo = false; // echo the prompt in the output, useful for eval use cases + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt bool return_tokens = false; bool return_progress = false;