Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions backend/api/config.py

This file was deleted.

76 changes: 76 additions & 0 deletions backend/api/core/agent/orchestration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import functools

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.base import RunnableSequence
from langchain_core.tools import StructuredTool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.graph import MessagesState, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode, tools_condition

from api.core.agent.prompts import SYSTEM_PROMPT


class State(MessagesState):
next: str


def agent_factory(
llm: ChatOpenAI, tools: list[StructuredTool], system_prompt: str
) -> RunnableSequence:
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
]
)
if tools:
agent = prompt | llm.bind_tools(tools)
else:
agent = prompt | llm
return agent


def agent_node_factory(
state: State,
agent: RunnableSequence,
) -> State:
result = agent.invoke(state)
return dict(messages=[result])


def graph_factory(
agent_node: functools.partial,
tools: list[StructuredTool],
checkpointer: AsyncPostgresSaver | None = None,
name: str = "agent_node",
) -> CompiledStateGraph:
graph_builder = StateGraph(State)
graph_builder.add_node(name, agent_node)
graph_builder.add_node("tools", ToolNode(tools))

graph_builder.add_conditional_edges(name, tools_condition)
graph_builder.add_edge("tools", name)

graph_builder.set_entry_point(name)
graph = graph_builder.compile(checkpointer=checkpointer)
return graph


def get_graph(
llm: ChatOpenAI,
tools: list[StructuredTool] = [],
system_prompt: str = SYSTEM_PROMPT,
name: str = "agent_node",
checkpointer: AsyncPostgresSaver | None = None,
) -> CompiledStateGraph:
agent = agent_factory(llm, tools, system_prompt)
worker_node = functools.partial(agent_node_factory, agent=agent)
return graph_factory(worker_node, tools, checkpointer, name)


def get_config():
return dict(
configurable=dict(thread_id="1"),
)
43 changes: 43 additions & 0 deletions backend/api/core/agent/persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from contextlib import asynccontextmanager
from typing import AsyncGenerator

import psycopg
import psycopg.errors
import uvicorn
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from psycopg_pool import AsyncConnectionPool

from api.core.logs import uvicorn


@asynccontextmanager
async def checkpointer_context(
conn_str: str,
) -> AsyncGenerator[AsyncPostgresSaver]:
"""
Async context manager that sets up and yields a LangGraph checkpointer.

Uses a psycopg async connection pool to initialize AsyncPostgresSaver.
Skips setup if checkpointer is already configured.

Args:
conn_str (str): PostgreSQL connection string.

Yields:
AsyncPostgresSaver: The initialized checkpointer.
"""
# NOTE: LangGraph AsyncPostgresSaver does not support SQLAlchemy ORM Connections.
# A compatible psycopg connection is created via the connection pool to connect to the checkpointer.
async with AsyncConnectionPool(
conninfo=conn_str,
kwargs=dict(prepare_threshold=None),
) as pool:
checkpointer = AsyncPostgresSaver(pool)
try:
await checkpointer.setup()
except (
psycopg.errors.DuplicateColumn,
psycopg.errors.ActiveSqlTransaction,
):
uvicorn.warning("Skipping checkpointer setup — already configured.")
yield checkpointer
9 changes: 9 additions & 0 deletions backend/api/core/agent/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os


def read_system_prompt():
with open(os.path.join(os.path.dirname(__file__), "system.md"), "r") as f:
return f.read()


SYSTEM_PROMPT = read_system_prompt()
1 change: 1 addition & 0 deletions backend/api/core/agent/prompts/system.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
You are a helpful assistant.
33 changes: 33 additions & 0 deletions backend/api/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pydantic import PostgresDsn, computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file="/opt/.env",
env_ignore_empty=True,
extra="ignore",
)

model: str = "gpt-4o-mini-2024-07-18"
openai_api_key: str = ""
mcp_server_port: int = 8050

postgres_dsn: PostgresDsn = (
"postgresql+psycopg://postgres:[email protected]:6543/postgres"
)

@computed_field
@property
def orm_conn_str(self) -> str:
return self.postgres_dsn.encoded_string()

@computed_field
@property
def checkpoint_conn_str(self) -> str:
# NOTE: LangGraph AsyncPostgresSaver has some issues
# with specifying psycopg driver explicitly
return self.postgres_dsn.encoded_string().replace("+psycopg", "")


settings = Settings()
49 changes: 49 additions & 0 deletions backend/api/core/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from contextlib import asynccontextmanager
from typing import Annotated, AsyncGenerator

from fastapi import Depends
from langchain_mcp_adapters.tools import load_mcp_tools
from langchain_openai import ChatOpenAI
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from api.core.agent.persistence import checkpointer_context
from api.core.config import settings
from api.core.mcps import mcp_sse_client
from api.core.models import Resource


def get_llm() -> ChatOpenAI:
return ChatOpenAI(
streaming=True,
model=settings.model,
temperature=0,
api_key=settings.openai_api_key,
stream_usage=True,
)


LLMDep = Annotated[ChatOpenAI, Depends(get_llm)]


engine: AsyncEngine = create_async_engine(settings.orm_conn_str)


def get_engine() -> AsyncEngine:
return engine


EngineDep = Annotated[AsyncEngine, Depends(get_engine)]


@asynccontextmanager
async def setup_graph() -> AsyncGenerator[Resource]:
async with checkpointer_context(
settings.checkpoint_conn_str
) as checkpointer:
async with mcp_sse_client() as session:
tools = await load_mcp_tools(session)
yield Resource(
checkpointer=checkpointer,
tools=tools,
session=session,
)
7 changes: 7 additions & 0 deletions backend/api/core/logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from logging import getLogger

from rich.pretty import pprint as print

print # facade

uvicorn = getLogger("uvicorn")
27 changes: 27 additions & 0 deletions backend/api/core/mcps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from contextlib import asynccontextmanager
from typing import AsyncGenerator

from mcp import ClientSession
from mcp.client.sse import sse_client

from api.core.config import settings


@asynccontextmanager
async def mcp_sse_client() -> AsyncGenerator[ClientSession]:
"""
Creates and initializes an MCP client session over SSE.

Establishes an SSE connection to the MCP server and yields an initialized
`ClientSession` for communication.

Yields:
ClientSession: An initialized MCP client session.
"""
async with sse_client(f"http://mcp:{settings.mcp_server_port}/sse") as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
13 changes: 13 additions & 0 deletions backend/api/core/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from langchain_core.tools import StructuredTool
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from mcp import ClientSession
from pydantic import BaseModel


class Resource(BaseModel):
checkpointer: AsyncPostgresSaver
tools: list[StructuredTool]
session: ClientSession

class Config:
arbitrary_types_allowed = True
23 changes: 0 additions & 23 deletions backend/api/dependencies.py

This file was deleted.

3 changes: 2 additions & 1 deletion backend/api/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import FastAPI

from api.routers import llms, mcps
from api.routers import checkpoints, llms, mcps

app = FastAPI(swagger_ui_parameters={"tryItOutEnabled": True})
app.include_router(llms.router, prefix="/v1")
app.include_router(mcps.router, prefix="/v1")
app.include_router(checkpoints.router, prefix="/v1")
4 changes: 3 additions & 1 deletion backend/api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ dependencies = [
"fastapi[standard]==0.115.11",
"langchain==0.3.6",
"langchain-community==0.3.4",
"langchain-mcp-adapters>=0.0.9",
"langchain-openai==0.2.3",
"langchain-postgres==0.0.12",
"langfuse==2.60.2",
"langgraph==0.2.39",
"langgraph-checkpoint-postgres>=2.0.21",
"mcp[cli]>=1.6.0",
"prometheus-client==0.21.1",
"psycopg[binary]==3.2.3",
"pydantic-settings==2.6.0",
"pypdf==5.1.0",
"rich==13.9.4",
"sqlmodel>=0.0.24",
"sse-starlette==2.1.3",
]
38 changes: 38 additions & 0 deletions backend/api/routers/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from fastapi import APIRouter
from sqlalchemy import text

from api.core.dependencies import EngineDep
from api.core.logs import uvicorn

TABLES = [
"checkpoints",
"checkpoint_migrations",
"checkpoint_blobs",
"checkpoint_writes",
]
router = APIRouter(tags=["checkpoints"])


@router.delete("/truncate")
async def truncate_checkpoints(engine: EngineDep):
"""
Truncates all checkpoint-related tables from LangGraph AsyncPostgresSaver.

This operation removes all records from the following tables:
- checkpoints
- checkpoint_migrations
- checkpoint_blobs
- checkpoint_writes

**Warning**: This action is irreversible and should be used with caution. Ensure proper backups are in place
before performing this operation.
"""

async with engine.begin() as conn:
for table in TABLES:
await conn.execute(text(f"TRUNCATE TABLE {table};"))
uvicorn.info(f"Truncated table {table}")
return {
"status": "success",
"message": "All checkpoint-related tables truncated successfully.",
}
Loading