From 600ca71f436c5687f915bda852773dab84e6c8dd Mon Sep 17 00:00:00 2001 From: manfredcalvo Date: Wed, 19 Feb 2025 09:58:05 -0600 Subject: [PATCH 1/5] Adding the option of returning metadata in genie agent langchain integration --- .../src/databricks_langchain/genie.py | 84 ++++++++------ .../langchain/tests/unit_tests/test_genie.py | 107 ++++++++++++++---- 2 files changed, 129 insertions(+), 62 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 026016eb..c351933e 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -1,53 +1,63 @@ import mlflow +from langchain_core.messages import AIMessage, BaseMessage + from databricks_ai_bridge.genie import Genie +from langchain_core.runnables import RunnableLambda -@mlflow.trace() -def _concat_messages_array(messages): - concatenated_message = "\n".join( - [ - f"{message.get('role', message.get('name', 'unknown'))}: {message.get('content', '')}" - if isinstance(message, dict) - else f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}" - for message in messages - ] - ) - return concatenated_message +from typing import Dict, Any -@mlflow.trace() -def _query_genie_as_agent(input, genie_space_id, genie_agent_name): - from langchain_core.messages import AIMessage +@mlflow.trace(span_type="AGENT") +class GenieAgent(RunnableLambda): + def __init__(self, genie_space_id, + genie_agent_name: str = "Genie", + description: str = "", + return_metadata: bool = False): + self.genie_space_id = genie_space_id + self.genie_agent_name = genie_agent_name + self.description = description + self.return_metadata = return_metadata + self.genie = Genie(genie_space_id) + super().__init__(self._query_genie_as_agent) - genie = Genie(genie_space_id) + @mlflow.trace() + def _concat_messages_array(self, messages): - message = f"I will provide you a chat history, where your name is {genie_agent_name}. Please help with the described information in the chat history.\n" + data = [] - # Concatenate messages to form the chat history - message += _concat_messages_array(input.get("messages")) + for message in messages: + if isinstance(message, dict): + data.append(f"{message.get('role', 'unknown')}: {message.get('content', '')}") + elif isinstance(message, BaseMessage): + data.append(f"{message.type}: {message.content}") + else: + data.append(f"{getattr(message, 'role', getattr(message, 'name', 'unknown'))}: {getattr(message, 'content', '')}") - # Send the message and wait for a response - genie_response = genie.ask_question(message) + concatenated_message = "\n".join([e for e in data if e]) - if query_result := genie_response.result: - return {"messages": [AIMessage(content=query_result)]} - else: - return {"messages": [AIMessage(content="")]} + return concatenated_message + @mlflow.trace() + def _query_genie_as_agent(self, state: Dict[str, Any]): + message = (f"I will provide you a chat history, where your name is {self.genie_agent_name}. " + f"Please help with the described information in the chat history.\n") -@mlflow.trace(span_type="AGENT") -def GenieAgent(genie_space_id, genie_agent_name: str = "Genie", description: str = ""): - """Create a genie agent that can be used to query the API""" - from functools import partial + # Concatenate messages to form the chat history + message += self._concat_messages_array(state.get("messages")) + + # Send the message and wait for a response + genie_response = self.genie.ask_question(message) + + content = "" + metadata = None + + if genie_response.result: + content = genie_response.result + metadata = genie_response - from langchain_core.runnables import RunnableLambda + if self.return_metadata: + return {"messages": [AIMessage(content=content)], "metadata": metadata} - # Create a partial function with the genie_space_id pre-filled - partial_genie_agent = partial( - _query_genie_as_agent, - genie_space_id=genie_space_id, - genie_agent_name=genie_agent_name, - ) + return {"messages": [AIMessage(content=content)]} - # Use the partial function in the RunnableLambda - return RunnableLambda(partial_genie_agent) diff --git a/integrations/langchain/tests/unit_tests/test_genie.py b/integrations/langchain/tests/unit_tests/test_genie.py index 024ca3d7..b7004a87 100644 --- a/integrations/langchain/tests/unit_tests/test_genie.py +++ b/integrations/langchain/tests/unit_tests/test_genie.py @@ -1,28 +1,46 @@ from unittest.mock import patch from databricks_ai_bridge.genie import GenieResponse -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, HumanMessage -from databricks_langchain.genie import ( - GenieAgent, - _concat_messages_array, - _query_genie_as_agent, -) +from databricks_langchain.genie import GenieAgent +import pytest -def test_concat_messages_array(): + +@pytest.fixture +def agent(): + return GenieAgent("id-1", "Genie") + + +@pytest.fixture +def agent_with_metadata(): + return GenieAgent("id-1", "Genie", return_metadata=True) + + +def test_concat_messages_array_base_messages(agent): + messages = [HumanMessage("What is the weather?"), AIMessage("It is sunny.")] + + result = agent._concat_messages_array(messages) + + expected_result = "human: What is the weather?\nai: It is sunny." + + assert result == expected_result + + +def test_concat_messages_array(agent): # Test a simple case with multiple messages messages = [ {"role": "user", "content": "What is the weather?"}, {"role": "assistant", "content": "It is sunny."}, ] - result = _concat_messages_array(messages) + result = agent._concat_messages_array(messages) expected = "user: What is the weather?\nassistant: It is sunny." assert result == expected # Test case with missing content messages = [{"role": "user"}, {"role": "assistant", "content": "I don't know."}] - result = _concat_messages_array(messages) + result = agent._concat_messages_array(messages) expected = "user: \nassistant: I don't know." assert result == expected @@ -36,37 +54,76 @@ def __init__(self, role, content): Message("user", "Tell me a joke."), Message("assistant", "Why did the chicken cross the road?"), ] - result = _concat_messages_array(messages) + result = agent._concat_messages_array(messages) expected = "user: Tell me a joke.\nassistant: Why did the chicken cross the road?" assert result == expected -@patch("databricks_langchain.genie.Genie") -def test_query_genie_as_agent(MockGenie): - # Mock the Genie class and its response - mock_genie = MockGenie.return_value - mock_genie.ask_question.return_value = GenieResponse(result="It is sunny.") +@patch("databricks_ai_bridge.genie.Genie.ask_question") +def test_query_genie_as_agent(mock_ask_question, agent): + + genie_response = GenieResponse(result="It is sunny.") + + mock_ask_question.return_value = genie_response input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} - result = _query_genie_as_agent(input_data, "space-id", "Genie") + + result = agent._query_genie_as_agent(input_data) expected_message = {"messages": [AIMessage(content="It is sunny.")]} + assert result == expected_message # Test the case when genie_response is empty - mock_genie.ask_question.return_value = GenieResponse(result=None) - result = _query_genie_as_agent(input_data, "space-id", "Genie") + genie_empty_response = GenieResponse(result=None) + + mock_ask_question.return_value = genie_empty_response + + result = agent._query_genie_as_agent(input_data) expected_message = {"messages": [AIMessage(content="")]} + assert result == expected_message -@patch("langchain_core.runnables.RunnableLambda") -def test_create_genie_agent(MockRunnableLambda): - mock_runnable = MockRunnableLambda.return_value +@patch("databricks_ai_bridge.genie.Genie.ask_question") +def test_query_genie_as_agent_with_metadata(mock_ask_question, agent_with_metadata): + + genie_response = GenieResponse(result="It is sunny.", query="select a from data_table", description="description") + + mock_ask_question.return_value = genie_response + + input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} - agent = GenieAgent("space-id", "Genie") - assert agent == mock_runnable + result = agent_with_metadata._query_genie_as_agent(input_data) - # Check that the partial function is created with the correct arguments - MockRunnableLambda.assert_called() + expected_message = {"messages": [AIMessage(content="It is sunny.")], "metadata": genie_response} + + assert result == expected_message + + # Test the case when genie_response is empty + genie_empty_response = GenieResponse(result=None) + + mock_ask_question.return_value = genie_empty_response + + result = agent_with_metadata._query_genie_as_agent(input_data) + + expected_message = {"messages": [AIMessage(content="")], "metadata": None} + + assert result == expected_message + + +@patch("databricks_ai_bridge.genie.Genie.ask_question") +def test_query_genie_as_agent_invoke(mock_ask_question, agent): + + genie_response = GenieResponse(result="It is sunny.", query="select a from data_table", description="description") + + mock_ask_question.return_value = genie_response + + input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} + + result = agent.invoke(input_data) + + expected_message = {"messages": [AIMessage(content="It is sunny.")]} + + assert result == expected_message From f1740cead3d208bcbcf20f56bd5580a20e0b4125 Mon Sep 17 00:00:00 2001 From: manfredcalvo Date: Wed, 19 Feb 2025 11:57:28 -0600 Subject: [PATCH 2/5] Change span tracing --- integrations/langchain/src/databricks_langchain/genie.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index c351933e..43c01529 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -8,7 +8,6 @@ from typing import Dict, Any -@mlflow.trace(span_type="AGENT") class GenieAgent(RunnableLambda): def __init__(self, genie_space_id, genie_agent_name: str = "Genie", @@ -19,7 +18,7 @@ def __init__(self, genie_space_id, self.description = description self.return_metadata = return_metadata self.genie = Genie(genie_space_id) - super().__init__(self._query_genie_as_agent) + super().__init__(self._query_genie_as_agent, name="Genie_Agent") @mlflow.trace() def _concat_messages_array(self, messages): From 4ba0b0c9bc50de978cae0a62e440b2b86266ed83 Mon Sep 17 00:00:00 2001 From: manfredcalvo Date: Wed, 19 Feb 2025 18:18:04 -0600 Subject: [PATCH 3/5] Using agent_name as the name it will be used in the agent trace --- integrations/langchain/src/databricks_langchain/genie.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 43c01529..e95dfcee 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -18,7 +18,7 @@ def __init__(self, genie_space_id, self.description = description self.return_metadata = return_metadata self.genie = Genie(genie_space_id) - super().__init__(self._query_genie_as_agent, name="Genie_Agent") + super().__init__(self._query_genie_as_agent, name=genie_agent_name) @mlflow.trace() def _concat_messages_array(self, messages): From 537bd23425c81b4d7e0d77ceb8f67f1637144cd6 Mon Sep 17 00:00:00 2001 From: manfredcalvo Date: Thu, 20 Feb 2025 15:05:04 -0600 Subject: [PATCH 4/5] Adding doc string to GenieAgent class. --- .../src/databricks_langchain/genie.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index e95dfcee..2324508d 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -9,7 +9,42 @@ class GenieAgent(RunnableLambda): - def __init__(self, genie_space_id, + """ + A class that implements an agent to send user questions to Genie Space in Databricks through the Genie API. + + This class implements an agent that uses the GenieAPI to send user questions to Genie Space in Databricks. + If return_metadata is False, the agent's response will be a dictionary containing a single key, 'messages', + which holds the result of the SQL query executed by the Genie Space. + If `return_metadata` is set to True, the agent's response will be a dictionary containing two keys: `messages` + and `metadata`. The `messages` key will contain only one element, similar to the previous case. + The `metadata` key will include the `GenieResponse` from the API, which will consist of the result of the SQL query, + the SQL query itself, and a brief description of what the query is doing. + + Attributes: + genie_space_id (str): The ID of the Genie space created in Databricks will be called by the Genie API. + description (str): Description of the Genie space created in Databricks that will be accessed by the GenieAPI. + genie_agent_name (str): The name of the genie agent that will be displayed in the trace. + return_metadata (bool): Whether to return the GenieResponse generated by the GenieAPI when the agent is called. + genie (Genie): The Genie API class. + + Methods: + invoke(state): Returns a dictionary with two possible keys: "messages" and "metadata," which contain the results + of the query executed by Genie Space and the associated metadata. + + Examples: + >>> genie_agent = GenieAgent("01ef92421857143785bb9e765454520f") + >>> genie_agent.invoke({"messages": [{"role": "user", "content": "What is the average total invoice across the different customers?"}]}) + {'messages': [AIMessage(content='| | average_total_invoice |\n|---:|------------------------:|\n| 0 | 195.648 |', + additional_kwargs={}, response_metadata={})]} + >>> genie_agent = GenieAgent("01ef92421857143785bb9e765454520f", return_metadata=True) + >>> genie_agent.invoke({"messages": [{"role": "user", "content": "What is the average total invoice across the different customers?"}]}) + {'messages': [AIMessage(content='| | avg_total_invoice |\n|---:|--------------------:|\n| 0 | 195.648 |', + additional_kwargs={}, response_metadata={})], + 'metadata': GenieResponse(result='| | avg_total_invoice |\n|---:|--------------------:|\n| 0 | 195.648 |', + query='SELECT AVG(`total_invoice`) AS avg_total_invoice FROM `finance`.`external_customers`.`invoices`', + description='This query calculates the average total invoice amount from all customer invoices, providing insight into overall billing trends.')} + """ + def __init__(self, genie_space_id: str, genie_agent_name: str = "Genie", description: str = "", return_metadata: bool = False): From 4a7113e18aebb94f42a33625353607082d359b2b Mon Sep 17 00:00:00 2001 From: manfredcalvo Date: Thu, 20 Feb 2025 15:35:58 -0600 Subject: [PATCH 5/5] Rename function that calls the api from query_genie_as_agent to call_genie_api. --- integrations/langchain/src/databricks_langchain/genie.py | 4 ++-- integrations/langchain/tests/unit_tests/test_genie.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/genie.py b/integrations/langchain/src/databricks_langchain/genie.py index 2324508d..efddb6bf 100644 --- a/integrations/langchain/src/databricks_langchain/genie.py +++ b/integrations/langchain/src/databricks_langchain/genie.py @@ -53,7 +53,7 @@ def __init__(self, genie_space_id: str, self.description = description self.return_metadata = return_metadata self.genie = Genie(genie_space_id) - super().__init__(self._query_genie_as_agent, name=genie_agent_name) + super().__init__(self._call_genie_api, name=genie_agent_name) @mlflow.trace() def _concat_messages_array(self, messages): @@ -73,7 +73,7 @@ def _concat_messages_array(self, messages): return concatenated_message @mlflow.trace() - def _query_genie_as_agent(self, state: Dict[str, Any]): + def _call_genie_api(self, state: Dict[str, Any]): message = (f"I will provide you a chat history, where your name is {self.genie_agent_name}. " f"Please help with the described information in the chat history.\n") diff --git a/integrations/langchain/tests/unit_tests/test_genie.py b/integrations/langchain/tests/unit_tests/test_genie.py index b7004a87..3d24528d 100644 --- a/integrations/langchain/tests/unit_tests/test_genie.py +++ b/integrations/langchain/tests/unit_tests/test_genie.py @@ -68,7 +68,7 @@ def test_query_genie_as_agent(mock_ask_question, agent): input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} - result = agent._query_genie_as_agent(input_data) + result = agent._call_genie_api(input_data) expected_message = {"messages": [AIMessage(content="It is sunny.")]} @@ -79,7 +79,7 @@ def test_query_genie_as_agent(mock_ask_question, agent): mock_ask_question.return_value = genie_empty_response - result = agent._query_genie_as_agent(input_data) + result = agent._call_genie_api(input_data) expected_message = {"messages": [AIMessage(content="")]} @@ -95,7 +95,7 @@ def test_query_genie_as_agent_with_metadata(mock_ask_question, agent_with_metada input_data = {"messages": [{"role": "user", "content": "What is the weather?"}]} - result = agent_with_metadata._query_genie_as_agent(input_data) + result = agent_with_metadata._call_genie_api(input_data) expected_message = {"messages": [AIMessage(content="It is sunny.")], "metadata": genie_response} @@ -106,7 +106,7 @@ def test_query_genie_as_agent_with_metadata(mock_ask_question, agent_with_metada mock_ask_question.return_value = genie_empty_response - result = agent_with_metadata._query_genie_as_agent(input_data) + result = agent_with_metadata._call_genie_api(input_data) expected_message = {"messages": [AIMessage(content="")], "metadata": None}