Skip to content

Commit 8eff93c

Browse files
authored
Merge pull request #2 from redis-applied-ai/feat/bedrock
Bedrock functionality
2 parents 79e2617 + a41063f commit 8eff93c

File tree

22 files changed

+743
-51
lines changed

22 files changed

+743
-51
lines changed

CLAUDE.md

Lines changed: 0 additions & 31 deletions
This file was deleted.

README.md

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ uv sync
4242
cp .env.example .env
4343
# Edit .env with your API keys
4444

45-
# Start Redis (with RedisJSON module for content management)
46-
docker run -d -p 6379:6379 redis/redis-stack:latest
45+
# Start Redis 8 (no Redis Stack required)
46+
docker run -d -p 6379:6379 redis:8-alpine
4747

4848
# Seed database
4949
uv run python scripts/seed.py
@@ -93,6 +93,45 @@ TAVILY_API_KEY=your-tavily-key # Web search tool
9393
REDIS_URL=redis://localhost:6379/0
9494
```
9595

96+
97+
## Amazon Bedrock (LLM provider option)
98+
99+
This repo includes scripts to automate IAM permissions for Bedrock and a local tool-calling test script.
100+
101+
Prerequisites
102+
- AWS CLI v2 configured with credentials
103+
- Region: us-east-1 (default)
104+
105+
1) Grant Bedrock invoke permissions to an IAM user
106+
```bash
107+
chmod +x scripts/bedrock_provision_access.sh
108+
scripts/bedrock_provision_access.sh user <YOUR_IAM_USER_NAME> us-east-1
109+
```
110+
This attaches a minimal policy that allows invoking Bedrock models and listing model info. If you see AccessDenied during inference, enable model access in the console.
111+
112+
2) Enable model access (one-time, per account/region)
113+
Open the Bedrock Model access page and enable the providers/models you plan to use (default used here is Claude 3.5 Sonnet):
114+
- https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/model-access
115+
116+
3) Switch provider to Bedrock and run locally
117+
```bash
118+
export AWS_DEFAULT_REGION=us-east-1
119+
export LLM_PROVIDER=bedrock
120+
export BEDROCK_MODEL_ID=anthropic.claude-3-5-sonnet-20240620-v1:0
121+
export LOG_LEVEL=INFO
122+
uv run python -m app.worker &
123+
uv run fastapi dev app/api/main.py
124+
```
125+
You should see logs like: "LLM configured: provider=bedrock model=anthropic.claude-3-5-sonnet-20240620-v1:0"
126+
127+
Optional: add to your .env for convenience
128+
```bash
129+
AWS_DEFAULT_REGION=us-east-1
130+
BEDROCK_MODEL_ID=anthropic.claude-3-5-sonnet-20240620-v1:0
131+
# Future toggle; default provider may be Bedrock in this repo
132+
LLM_PROVIDER=bedrock
133+
```
134+
96135
## Deployment (AWS, single environment)
97136

98137
This reference deploys a working agent stack on AWS with a single `terraform apply`:

app/agent/core.py

Lines changed: 220 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from redisvl.utils.vectorize import OpenAITextVectorizer
2121

2222
from app.agent.tools import get_search_knowledge_base_tool, get_web_search_tool
23+
from app.utilities.bedrock_client import (
24+
bedrock_text_blocks_to_text,
25+
get_bedrock_runtime_client,
26+
map_openai_tools_to_bedrock_tool_config,
27+
)
2328
from app.utilities.openai_client import get_instrumented_client
2429

2530
logger = logging.getLogger(__name__)
@@ -308,6 +313,19 @@ async def answer_question(
308313
thread_context: Optional conversation context
309314
progress_callback: Optional callback function to send progress updates
310315
"""
316+
# Provider toggle: route to Bedrock implementation when requested
317+
provider = os.getenv("LLM_PROVIDER", "bedrock").lower()
318+
if provider == "bedrock":
319+
return await answer_question_bedrock(
320+
index=index,
321+
vectorizer=vectorizer,
322+
query=query,
323+
session_id=session_id,
324+
user_id=user_id,
325+
thread_context=thread_context,
326+
progress_callback=progress_callback,
327+
)
328+
311329
# Get the underlying OpenAI client for direct access
312330
client = get_instrumented_client()._client
313331

@@ -333,6 +351,8 @@ async def answer_question(
333351
*MemoryAPIClient.get_all_memory_tool_schemas(),
334352
]
335353

354+
logger.info(f"Using LLM provider=openai model={CHAT_MODEL}")
355+
336356
logger.info(f"Available tools: {[tool['function']['name'] for tool in tools]}")
337357

338358
# Track total tokens and tool calls across all iterations
@@ -561,7 +581,7 @@ async def answer_question(
561581

562582
# Record metrics for this answer completion
563583
try:
564-
from app.metrics import get_token_metrics
584+
from app.utilities.metrics import get_token_metrics
565585

566586
token_metrics = get_token_metrics()
567587
if token_metrics:
@@ -652,3 +672,202 @@ def _parse_llm_response(content: str) -> tuple[str, bool]:
652672
except Exception as e:
653673
logger.error(f"Error parsing LLM response: {e}")
654674
return content, False
675+
676+
677+
async def answer_question_bedrock(
678+
index: AsyncSearchIndex,
679+
vectorizer: OpenAITextVectorizer,
680+
query: str,
681+
session_id: str,
682+
user_id: str,
683+
thread_context: list[dict] | None = None,
684+
progress_callback=None,
685+
) -> str:
686+
"""Bedrock-based implementation of the agent loop using Converse API with tools."""
687+
client = get_bedrock_runtime_client()
688+
model_id = os.getenv(
689+
"BEDROCK_MODEL_ID", "anthropic.claude-3-5-sonnet-20240620-v1:0"
690+
)
691+
logger.info(f"Using LLM provider=bedrock model={model_id}")
692+
693+
initial_message = create_initial_message_without_search(query, thread_context)
694+
bedrock_messages: list[dict] = [
695+
{"role": "user", "content": [{"text": initial_message}]}
696+
]
697+
698+
tools_openai = [
699+
get_search_knowledge_base_tool(),
700+
get_web_search_tool(),
701+
*MemoryAPIClient.get_all_memory_tool_schemas(),
702+
]
703+
tool_config = map_openai_tools_to_bedrock_tool_config(tools_openai)
704+
705+
max_iterations = 25
706+
iteration = 0
707+
total_tokens = 0
708+
total_tool_calls = 0
709+
710+
while iteration < max_iterations:
711+
iteration += 1
712+
response = client.converse(
713+
modelId=model_id,
714+
system=[{"text": SYSTEM_PROMPT}],
715+
messages=bedrock_messages,
716+
toolConfig=tool_config,
717+
)
718+
719+
usage = response.get("usage") or {}
720+
total_tokens += int(usage.get("inputTokens", 0)) + int(
721+
usage.get("outputTokens", 0)
722+
)
723+
724+
output_message = response.get("output", {}).get("message", {})
725+
stop_reason = response.get("stopReason")
726+
727+
if stop_reason == "tool_use":
728+
# Collect toolUse requests and produce toolResult blocks
729+
tool_result_blocks: list[dict] = []
730+
if progress_callback:
731+
await progress_callback("Using tools...")
732+
733+
for block in output_message.get("content", []) or []:
734+
tool_use = block.get("toolUse") if isinstance(block, dict) else None
735+
if not tool_use:
736+
continue
737+
name = tool_use.get("name")
738+
tool_use_id = tool_use.get("toolUseId")
739+
input_payload = tool_use.get("input") or {}
740+
total_tool_calls += 1
741+
742+
try:
743+
if name == "search_knowledge_base":
744+
if progress_callback:
745+
await progress_callback("Searching knowledge base...")
746+
from app.agent.tools.search_knowledge_base import (
747+
search_knowledge_base,
748+
)
749+
750+
q = (input_payload or {}).get("query", "")
751+
result_text = await search_knowledge_base(index, vectorizer, q)
752+
tool_result_blocks.append(
753+
{
754+
"toolResult": {
755+
"toolUseId": tool_use_id,
756+
"content": [{"text": str(result_text)}],
757+
"status": "success",
758+
}
759+
}
760+
)
761+
elif name == "web_search":
762+
if progress_callback:
763+
await progress_callback("Searching the web...")
764+
from app.agent.tools.web_search import perform_web_search
765+
766+
q = (input_payload or {}).get("query", "")
767+
web_res = await perform_web_search(
768+
query=q,
769+
search_depth="basic",
770+
max_results=5,
771+
redis_focused=True,
772+
)
773+
tool_result_blocks.append(
774+
{
775+
"toolResult": {
776+
"toolUseId": tool_use_id,
777+
"content": [{"text": str(web_res)}],
778+
"status": "success",
779+
}
780+
}
781+
)
782+
else:
783+
# Memory tools or others resolved via memory client
784+
if progress_callback:
785+
await progress_callback("Using memory tools...")
786+
memory_client = await get_memory_client()
787+
# Enforce user_id for memory tools
788+
args = dict(input_payload or {})
789+
memory_tool_names = {
790+
"search_memory",
791+
"add_memory_to_working_memory",
792+
"update_working_memory_data",
793+
"get_working_memory",
794+
"search_long_term_memory",
795+
"memory_prompt",
796+
"set_working_memory",
797+
}
798+
if name in memory_tool_names:
799+
args["user_id"] = user_id
800+
function_call = {"name": name, "arguments": json.dumps(args)}
801+
mem_res = await memory_client.resolve_tool_call(
802+
tool_call=function_call,
803+
session_id=session_id,
804+
user_id=user_id,
805+
)
806+
tool_content = (
807+
str(mem_res)
808+
if isinstance(mem_res, (dict, list))
809+
else str(mem_res)
810+
)
811+
tool_content += "\n\nReflect on this memory tool result and your instructions about how to use memory tools. Make subsequent memory tool calls if necessary."
812+
tool_result_blocks.append(
813+
{
814+
"toolResult": {
815+
"toolUseId": tool_use_id,
816+
"content": [{"text": tool_content}],
817+
"status": "success",
818+
}
819+
}
820+
)
821+
except Exception as e:
822+
logger.error(f"Tool execution error for {name}: {e}")
823+
tool_result_blocks.append(
824+
{
825+
"toolResult": {
826+
"toolUseId": tool_use_id,
827+
"content": [
828+
{"text": f"Error executing tool {name}: {str(e)}"}
829+
],
830+
"status": "error",
831+
}
832+
}
833+
)
834+
835+
# Append assistant request and our tool results back to the conversation
836+
bedrock_messages.append(output_message)
837+
if tool_result_blocks:
838+
bedrock_messages.append({"role": "user", "content": tool_result_blocks})
839+
# Continue loop for model to produce next step
840+
continue
841+
842+
# No tool use requested; treat as final answer
843+
final_text = bedrock_text_blocks_to_text(output_message.get("content", []))
844+
response_text, use_org_search = _parse_llm_response(final_text)
845+
if use_org_search:
846+
logger.info("LLM wanted to use org search, but org search is disabled")
847+
848+
# Metrics
849+
try:
850+
from app.utilities.metrics import get_token_metrics
851+
852+
token_metrics = get_token_metrics()
853+
if token_metrics:
854+
token_metrics.record_answer_completion(
855+
model=model_id,
856+
total_tokens=total_tokens,
857+
tool_calls=total_tool_calls,
858+
)
859+
logger.info(
860+
f"Recorded metrics for answer completion: model={model_id}, tokens={total_tokens}, tool_calls={total_tool_calls}"
861+
)
862+
except Exception as e:
863+
logger.warning(f"Failed to record metrics for answer completion: {e}")
864+
865+
return response_text
866+
867+
# Max iterations reached; return last assistant text if any
868+
last_text = (
869+
bedrock_text_blocks_to_text(output_message.get("content", []))
870+
if "output_message" in locals()
871+
else ""
872+
)
873+
return last_text or "I'm sorry, I couldn't complete the request."

app/api/main.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import logging
8+
import os
89
from contextlib import asynccontextmanager
910

1011
from docket.docket import Docket
@@ -26,7 +27,10 @@
2627
from app.api.slack_app import get_slack_app
2728
from app.utilities import keys
2829
from app.utilities.environment import get_env_var
29-
from app.utilities.logging_config import configure_uvicorn_logging
30+
from app.utilities.logging_config import (
31+
configure_uvicorn_logging,
32+
ensure_stdout_logging,
33+
)
3034
from app.utilities.telemetry import setup_telemetry
3135
from app.worker.task_registration import register_all_tasks
3236

@@ -344,6 +348,22 @@ async def lifespan(app: FastAPI):
344348
"""FastAPI lifespan context manager."""
345349
print("Starting up FastAPI application with Docket...")
346350

351+
# Ensure logs go to stdout with a sane default
352+
ensure_stdout_logging()
353+
354+
# Log LLM provider/model at API startup for visibility
355+
try:
356+
provider = os.getenv("LLM_PROVIDER", "bedrock").lower()
357+
if provider == "bedrock":
358+
model = os.getenv(
359+
"BEDROCK_MODEL_ID", "anthropic.claude-3-5-sonnet-20240620-v1:0"
360+
)
361+
else:
362+
model = os.getenv("OPENAI_CHAT_MODEL", "gpt-4.1")
363+
logger.info(f"LLM configured: provider={provider} model={model} (api)")
364+
except Exception as e:
365+
logger.warning(f"Could not determine LLM provider/model on API startup: {e}")
366+
347367
try:
348368
await setup_slack_app()
349369

0 commit comments

Comments
 (0)