Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 140 additions & 38 deletions llm_bedrock_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import llm
from pydantic import Field, field_validator
from PIL import Image
import json

@dataclass
class AttachmentData:
Expand Down Expand Up @@ -73,20 +74,20 @@ def register_models(register):

# Claude 3 models (with attachment support)
register(
BedrockClaude("anthropic.claude-3-sonnet-20240229-v1:0", supports_attachments=True),
BedrockClaude("anthropic.claude-3-sonnet-20240229-v1:0", supports_attachments=True, supports_tools=True),
aliases=(
"bedrock-claude-v3-sonnet",
),
)
register(
BedrockClaude("us.anthropic.claude-3-5-sonnet-20241022-v2:0", supports_attachments=True),
BedrockClaude("us.anthropic.claude-3-5-sonnet-20241022-v2:0", supports_attachments=True, supports_tools=True),
aliases=(
"bedrock-claude-v3.5-sonnet-v2",
"bedrock-claude-sonnet-v2",
),
)
register(
BedrockClaude("anthropic.claude-3-5-sonnet-20240620-v1:0", supports_attachments=True),
BedrockClaude("anthropic.claude-3-5-sonnet-20240620-v1:0", supports_attachments=True, supports_tools=True),
aliases=(
"bedrock-claude-v3.5-sonnet",
"bedrock-claude-sonnet",
Expand All @@ -96,7 +97,7 @@ def register_models(register):
),
)
register(
BedrockClaude("anthropic.claude-3-opus-20240229-v1:0", supports_attachments=True),
BedrockClaude("anthropic.claude-3-opus-20240229-v1:0", supports_attachments=True, supports_tools=True),
aliases=(
"bedrock-claude-v3-opus",
"bedrock-claude-opus",
Expand All @@ -105,15 +106,15 @@ def register_models(register):
),
)
register(
BedrockClaude("us.anthropic.claude-3-5-haiku-20241022-v1:0", supports_attachments=False),
BedrockClaude("us.anthropic.claude-3-5-haiku-20241022-v1:0", supports_attachments=False, supports_tools=False),
aliases=(
"bedrock-claude-v3.5-haiku",
"bedrock-haiku-v3.5",
"bh-v3.5",
),
)
register(
BedrockClaude("anthropic.claude-3-haiku-20240307-v1:0", supports_attachments=True),
BedrockClaude("anthropic.claude-3-haiku-20240307-v1:0", supports_attachments=True, supports_tools=False),
aliases=(
"bedrock-claude-v3-haiku",
"bedrock-claude-haiku",
Expand All @@ -122,23 +123,23 @@ def register_models(register):
),
)
register(
BedrockClaude("us.anthropic.claude-3-7-sonnet-20250219-v1:0", supports_attachments=True),
BedrockClaude("anthropic.claude-3-7-sonnet-20250219-v1:0", supports_attachments=True, supports_tools=True),
aliases=(
"bedrock-claude-v3.7-sonnet",
"bedrock-claude-sonnet-v3.7",
"bc-v3.7",
),
)
register(
BedrockClaude("anthropic.claude-sonnet-4-20250514-v1:0", supports_attachments=True),
BedrockClaude("anthropic.claude-sonnet-4-20250514-v1:0", supports_attachments=True, supports_tools=True),
aliases=(
"bedrock-claude-v4-sonnet",
"bedrock-claude-sonnet-v4",
"bc-v4",
),
)
register(
BedrockClaude("anthropic.claude-opus-4-20250514-v1:0", supports_attachments=True),
BedrockClaude("anthropic.claude-opus-4-20250514-v1:0", supports_attachments=True, supports_tools=True),
aliases=(
"bedrock-claude-v4-opus",
"bedrock-claude-opus-v4",
Expand Down Expand Up @@ -172,9 +173,10 @@ def validate_length(cls, max_tokens_to_sample):
raise ValueError("max_tokens_to_sample must be in range 1-1,000,000")
return max_tokens_to_sample

def __init__(self, model_id, supports_attachments=False):
def __init__(self, model_id, supports_attachments=False, supports_tools=False):
self.model_id = model_id
self.supports_attachments = supports_attachments
self.supports_tools = supports_tools
if supports_attachments:
image_mime_types = {f"image/{fmt}" for fmt in BEDROCK_CONVERSE_IMAGE_FORMATS}
document_mime_types = set(MIME_TYPE_TO_BEDROCK_CONVERSE_DOCUMENT_FORMAT.keys())
Expand Down Expand Up @@ -382,7 +384,32 @@ def document_bytes_to_content_block(self, doc_bytes: bytes, mime_type: str, name
}
}

def prompt_to_content(self, prompt):
def tools_to_bedrock_format(self, tools):
"""Convert llm Tool objects to Bedrock's toolConfig format."""
if not tools:
return None

bedrock_tools = []
for tool in tools:
# Ensure description is not empty (Bedrock requires min length 1)
description = tool.description or f"Tool: {tool.name}"

bedrock_tool = {
"toolSpec": {
"name": tool.name,
"description": description,
"inputSchema": {
"json": tool.input_schema
}
}
}
bedrock_tools.append(bedrock_tool)

return {
"tools": bedrock_tools
}

def prompt_to_content(self, prompt: llm.Prompt):
"""
Convert a llm.Prompt object to the content format expected by the Bedrock Converse API.
If we encounter the bedrock_attach_files option, detect the file type(s) and use the
Expand Down Expand Up @@ -418,14 +445,30 @@ def prompt_to_content(self, prompt):
data = [self.create_attachment_data(a) for a in prompt.attachments]
content_blocks = [self.process_attachment(d) for d in data]
content.extend(content_blocks)


# Append the prompt text as a text content block.
content.append(
{
'text': prompt.prompt
}
)
if prompt.prompt is not None and prompt.prompt.strip() != "":
# Append the prompt text as a text content block.
content.append(
{
'text': prompt.prompt
}
)
# If tool_results exist, add toolResult blocks
if prompt.tool_results:
for tool_result in prompt.tool_results:
# Check if output is valid JSON
try:
parsed = json.loads(tool_result.output)
content_block = {'json': parsed}
except json.JSONDecodeError:
content_block = {'text': tool_result.output}
content.append({
'toolResult': {
'toolUseId': tool_result.tool_call_id,
'content': [content_block],
'status': 'success'
}
})

return content

Expand Down Expand Up @@ -489,25 +532,50 @@ def build_messages(self, prompt_content, conversation) -> List[dict]:
'text': response.prompt.prompt
}
]
assistant_content = [
{
'text': response.text()
}
]
messages.extend(
[
{
"role": "user",
"content": user_content
},
{
"role": "assistant",
"content": assistant_content
},
]
)

# Build assistant content - handle both text and tool calls
assistant_content = []
if response.text():
text_val = response.text()
if text_val and text_val.strip():
assistant_content.append({
'text': text_val
})

# Add tool calls if any
if hasattr(response, '_tool_calls') and response._tool_calls:
for tool_call in response._tool_calls:
assistant_content.append({
'toolUse': {
'name': tool_call.name,
'input': tool_call.arguments,
'toolUseId': tool_call.tool_call_id
}
})

# Only add assistant message if there is at least one content block
if assistant_content:
messages.extend(
[
{
"role": "user",
"content": user_content
},
{
"role": "assistant",
"content": assistant_content
},
]
)
else:
# Only add the user message if there is no assistant content
messages.append({
"role": "user",
"content": user_content
})

messages.append({"role": "user", "content": prompt_content})

return messages

def execute(self, prompt, stream, response, conversation):
Expand All @@ -534,6 +602,7 @@ def execute(self, prompt, stream, response, conversation):
# Preserve the Bedrock-specific user content dict, so it can be re-used in
# future conversations.
response.response_json = {
'id': response.id,
'bedrock_user_content': self.encode_bytes(prompt_content)
}

Expand All @@ -548,6 +617,12 @@ def execute(self, prompt, stream, response, conversation):
'inferenceConfig': inference_config,
}

# Add tools if supported and provided
if self.supports_tools and prompt.tools:
tool_config = self.tools_to_bedrock_format(prompt.tools)
if tool_config:
params['toolConfig'] = tool_config

if prompt.system:
params['system'] = [
{
Expand All @@ -557,20 +632,47 @@ def execute(self, prompt, stream, response, conversation):

client = boto3.client('bedrock-runtime')
if stream:
raise NotImplementedError("Streaming is broken with tool use. To be fixed.")

bedrock_response = client.converse_stream(**params)
response.response_json |= bedrock_response
events = []
for event in bedrock_response['stream']:
(event_type, event_content), = event.items()
if event_type == "contentBlockDelta":
completion = event_content["delta"]["text"]
yield completion
if "text" in event_content["delta"]:
completion = event_content["delta"]["text"]
yield completion
elif event_type == "toolUse":
# Handle tool calls in streaming mode
tool_use = event_content["toolUse"]
tool_call = llm.ToolCall(
name=tool_use["name"],
arguments=tool_use["input"],
tool_call_id=tool_use.get("toolUseId")
)
response.add_tool_call(tool_call)
events.append(event)
response.response_json["stream"] = events
else:
bedrock_response = client.converse(**params)
completion = bedrock_response['output']['message']['content'][-1]['text']
response.response_json |= bedrock_response

# Handle tool calls in non-streaming mode
message_content = bedrock_response['output']['message']['content']
completion = ""
for content_block in message_content:
if 'text' in content_block:
completion = content_block['text']
elif 'toolUse' in content_block:
tool_use = content_block['toolUse']
tool_call = llm.ToolCall(
name=tool_use["name"],
arguments=tool_use["input"],
tool_call_id=tool_use.get("toolUseId")
)
response.add_tool_call(tool_call)

yield completion
self.set_usage(response)

Expand Down