From 864bd183da96f7a2d711bc6e1d24f24ed5712f54 Mon Sep 17 00:00:00 2001 From: maan2003 Date: Tue, 25 Feb 2025 03:25:07 +0530 Subject: [PATCH] feat: add thinking budget support for Anthropic requests --- llm_client/src/clients/anthropic.rs | 23 +++++++++++++++++++++++ llm_client/src/clients/types.rs | 22 ++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/llm_client/src/clients/anthropic.rs b/llm_client/src/clients/anthropic.rs index ac28681ff..a2bf4e036 100644 --- a/llm_client/src/clients/anthropic.rs +++ b/llm_client/src/clients/anthropic.rs @@ -227,6 +227,13 @@ enum ContentBlockDeltaType { }, } +#[derive(serde::Serialize, Debug, Clone)] +struct AnthropicThinking { + #[serde(rename = "type")] + thinking_type: String, + budget_tokens: usize, +} + #[derive(serde::Serialize, Debug, Clone)] struct AnthropicRequest { system: Vec, @@ -238,6 +245,8 @@ struct AnthropicRequest { stream: bool, #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + thinking: Option, model: String, } @@ -262,6 +271,12 @@ impl AnthropicRequest { } } }; + let thinking = completion_request + .thinking_budget() + .map(|budget| AnthropicThinking { + thinking_type: "enabled".to_owned(), + budget_tokens: budget, + }); let messages = completion_request.messages(); // grab the tools over here ONLY from the system message let tools = messages @@ -346,6 +361,7 @@ impl AnthropicRequest { tools, stream: true, max_tokens, + thinking, model: model_str, } } @@ -356,6 +372,12 @@ impl AnthropicRequest { ) -> Self { let temperature = completion_request.temperature(); let max_tokens = completion_request.get_max_tokens(); + let thinking = completion_request + .thinking_budget() + .map(|budget| AnthropicThinking { + thinking_type: "enabled".to_owned(), + budget_tokens: budget, + }); let messages = vec![AnthropicMessage::new( "user".to_owned(), completion_request.prompt().to_owned(), @@ -367,6 +389,7 @@ impl AnthropicRequest { tools: vec![], stream: true, max_tokens, + thinking, model: model_str, } } diff --git a/llm_client/src/clients/types.rs b/llm_client/src/clients/types.rs index a9168ee8a..3f0c79853 100644 --- a/llm_client/src/clients/types.rs +++ b/llm_client/src/clients/types.rs @@ -693,6 +693,7 @@ pub struct LLMClientCompletionRequest { frequency_penalty: Option, stop_words: Option>, max_tokens: Option, + thinking_budget: Option, } #[derive(Clone)] @@ -703,6 +704,7 @@ pub struct LLMClientCompletionStringRequest { frequency_penalty: Option, stop_words: Option>, max_tokens: Option, + thinking_budget: Option, } impl LLMClientCompletionStringRequest { @@ -719,6 +721,7 @@ impl LLMClientCompletionStringRequest { frequency_penalty, stop_words: None, max_tokens: None, + thinking_budget: None, } } @@ -755,6 +758,15 @@ impl LLMClientCompletionStringRequest { pub fn get_max_tokens(&self) -> Option { self.max_tokens } + + pub fn set_thinking_budget(mut self, thinking_budget: usize) -> Self { + self.thinking_budget = Some(thinking_budget); + self + } + + pub fn thinking_budget(&self) -> Option { + self.thinking_budget + } } impl LLMClientCompletionRequest { @@ -771,6 +783,7 @@ impl LLMClientCompletionRequest { frequency_penalty, stop_words: None, max_tokens: None, + thinking_budget: None, } } @@ -859,6 +872,15 @@ impl LLMClientCompletionRequest { pub fn get_max_tokens(&self) -> Option { self.max_tokens } + + pub fn set_thinking_budget(mut self, thinking_budget: usize) -> Self { + self.thinking_budget = Some(thinking_budget); + self + } + + pub fn thinking_budget(&self) -> Option { + self.thinking_budget + } } #[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]