diff --git a/src/agents/basic_agents/agent1.py b/src/agents/basic_agents/agent1.py new file mode 100644 index 0000000..29cc32e --- /dev/null +++ b/src/agents/basic_agents/agent1.py @@ -0,0 +1,140 @@ +import os +import time +import json +import re +import numpy as np +import pandas as pd +from sqlalchemy import create_engine +from langchain_openai import ChatOpenAI +from langchain_community.utilities import SQLDatabase +from dotenv import load_dotenv + +class Config: + """SQL agent config""" + def __init__( + self, + env_fpath, + sql_db_fpath, + openai_model = 'gpt-4-1106-preview', + temperature = 0.7 + ): + load_dotenv(env_fpath) + openai_api_key = os.getenv("OPENAI_API_KEY") + + self.llm = ChatOpenAI( + model=openai_model, + temperature=temperature, + api_key=openai_api_key + ) + self.db = SQLDatabase.from_uri(f"sqlite:///{sql_db_fpath}") + self.engine = create_engine(f"sqlite:///{sql_db_fpath}") + +def get_schema_context(config): + db = config.db + tables = db.get_usable_table_names() + + schema_lines = [] + for table in tables: + if table in ['ohlc', 'treasury_yields']: + table_info = db.get_table_info([table]) + schema_lines.append(f"Table: {table}\n{table_info}\n") + + schema_context = ( + "DATABASE SCHEMA:\n" + + "\n".join(schema_lines) + ) + return schema_context + +def extract_query(response, type='sql'): + pattern = rf"```{type}\s+([\s\S]*?)\s+```" + match = re.search(pattern, response) + + if match: + return match.group(1).strip() + else: + print(f"Extracting Query Failed: returning response.strip():\n{response.strip()}") + return response.strip() + +def execute_sql1(user_query, config): + """generate and execute sql""" + schema_context = get_schema_context(config) + sql_prompt = f''' + Given a user query, generate a single syntatically correct SQLite query wrapped in ```sql QUERY``` that pulls all the + data relevant to the query. Do not perform any computation, assume that the user is able to take the raw data and + perform the relevant computations to reach the desired result. Be mindful of how many days of data to pull, as + certain queries may specific n days but require more than n to compute the result. For days do LIMIT rather than now(). + Make sure the column names in the resulting table are clear. Here is the database schema: {schema_context} + + Here is an example user query and then correct output: + User Query: Calculate the correlation between 7 year treasury yields and stocks close over the last 30 days in the table. + Correct Output: ``` sql + SELECT + y.date AS date, + y.yield_7_year AS treasury_yield_7_year, + o.close AS stock_close + FROM + treasury_yields y + INNER JOIN + ohlc o ON y.date = o.date + ORDER BY + y.date DESC + LIMIT 30; + ``` + Notice there is no calculation calculation. This is the correct response, as + the correlation will be calculated later. + + Give the correct output for user query: {user_query} + ''' + sql_query = extract_query(config.llm.invoke(sql_prompt).content) + print(f'SQL Query: {sql_query}') + df = pd.read_sql(sql_query, config.engine) + return df + +def execute_python1(user_query, df, config): + """generate and execute python""" + python_prompt = f''' + Given a user query and a pandas dataframe with the relevant data, generate syntatically correct python code + wrapped in ```python QUERY`` that takes the raw dataframe and performs any computations to fully answer the + user's query. Assume access to NumPy (v1.26.4), Pandas (v2.2.3) and that the dataframe is called df. The output + variable should always be called result. + + Here is an example user query and df and then correct output: + User Query: Calculate the correlation between 7 year treasury yields and stocks close over the last 30 days + in the table. + Dataframe (df): + Date,Treasury Yield (7-Year),Stock Close + 2024-01-01 00:00:00,4.113933,84.676268 + 2023-12-29 00:00:00,4.117221,100.393128 + 2023-12-28 00:00:00,2.391113,112.97598 + 2023-12-27 00:00:00,1.482054,119.224503 + 2023-12-26 00:00:00,4.187207,108.335695 + + *Labeled Answer:* + ``` python + ### calculate corr btwn 7yr tsy and stock closes + result = df['treasury_yield_7_year'].corr(df['stock_close']) + ``` + + User Query: \n{user_query} + df.head: \n{df.head()} + ''' + + code = extract_query(config.llm.invoke(python_prompt).content, type='python') + print(f'Python Code:\n {code}') + + namespace = {'pd': pd, 'np': np, 'df': df} + exec(code, namespace) + result = namespace.get('result', "No result variable found") + print(f'Result:\n{result}') + return result + +def run_agent1(user_query, config): + """main function to run agent 1""" + print(f'User Query:\n {user_query}') + start = time.time() + df = execute_sql1(user_query, config) + result = execute_python1(user_query, df, config) + end = time.time() + + print(f'runtime: {round(end-start,2)} seconds') + return result \ No newline at end of file diff --git a/src/agents/basic_agents/agent2.py b/src/agents/basic_agents/agent2.py new file mode 100644 index 0000000..e0961ed --- /dev/null +++ b/src/agents/basic_agents/agent2.py @@ -0,0 +1,203 @@ +import os +import time +import json +import re +import numpy as np +import pandas as pd +from sqlalchemy import create_engine +from langchain_openai import ChatOpenAI +from langchain_community.utilities import SQLDatabase +from dotenv import load_dotenv + +def get_schema_context(config): + db = config.db + tables = db.get_usable_table_names() + + schema_lines = [] + for table in tables: + if table in ['ohlc', 'treasury_yields']: + table_info = db.get_table_info([table]) + schema_lines.append(f"Table: {table}\n{table_info}\n") + + schema_context = ( + "DATABASE SCHEMA:\n" + + "\n".join(schema_lines) + ) + return schema_context + +def extract_query(response, type='sql'): + pattern = rf"```{type}\s+([\s\S]*?)\s+```" + match = re.search(pattern, response) + + if match: + return match.group(1).strip() + else: + print(f"Extracting Query Failed: returning response.strip():\n{response.strip()}") + return response.strip() + +def get_sql_prompt2(user_query, config): + schema_context = get_schema_context(config) + prompt = f'''Given a user query and a SQlite db schema, only write + QUERY DESCRIPTION: QUERY DESCRIPTION, where QUERY DESCRIPTION is a prompt that + describes the query (table, cols, new col name, joins, etc.) to get the + raw data necessary to answer this user query. Don't write any code or explain any + computation, but write the prompt such that if an independent SQL master with access + to the SQlite db + schema + your instructions could easily query the data. + + Be mindful of how many days of data to pull, as certain queries may specific n days but + require more than n to compute the result. Make sure the column names in the resulting + table are clear. + + Database Schema: {schema_context} + + Example User Query: Calculate the correlation between 7 year treasury yields and stocks + close over the last 30 days in the table. + + Example Labeled Answer: + QUERY DESCRIPTION: + Tables Involved: + - ohlc (Stock data with date and close price) + - treasury_yields (Treasury yields with 7-year yield and date) + Columns Required: + - from ohlc: date, close (rename stock_close) + - from treasury_yields: date, yield_7_year (rename -tsy_yield_7_year) + Filters: + - only consider the last 30 days of data in the table. + Joins: + - perform an inner join between ohlc and treasury_yields on the date column to + align stock data with treasury yields. + + Note that the correlation is not calculated here. The prompt should NOT include + any math. no standard deviation, no avg, nothing more advanced than multiplication. DO NOT MAKE + ANY NEW COLUMNS. SAY THAT CALCULATIONS WILL BE DONE LATER, BY THE MATH MASTER. No filters, + this will be done later by the FILTER MASTER. + + User Query: {user_query} + ''' + sql_prompt = config.llm.invoke(prompt).content + print(f'Generated SQL Prompt: {sql_prompt}') + return sql_prompt + +def execute_sql2(sql_prompt, config): + """generate and execute sql""" + schema_context = get_schema_context(config) + sql_prompt = f''' + Given a SQlite db schema and a query description, generate a syntactically correct SQLite + query wrapped in ```sql QUERY``` that pulls all the data relevant to the query. + + Database Schema: {schema_context} + + Example Input Prompt: + QUERY DESCRIPTION: + Tables Involved: + - ohlc (Stock data with date and close price) + - treasury_yields (Treasury yields with 7-year yield and date) + Columns Required: + - from ohlc: date, close (rename stock_close) + - from treasury_yields: date, yield_7_year (rename tsy_yield_7_year) + Joins: + - perform an inner join between ohlc and treasury_yields on the date column to align + stock data with treasury yields. + + *Labeled Answer:* + ``` sql + SELECT + y.date AS date, + y.yield_7_year AS treasury_yield_7_year, + o.close AS stock_close + FROM + treasury_yields y + INNER JOIN + ohlc o ON y.date = o.date + ORDER BY + y.date DESC + + Input Prompt: {sql_prompt} + ''' + sql_query = extract_query(config.llm.invoke(sql_prompt).content) + print(f'SQL Query: {sql_query}') + df = pd.read_sql(sql_query, config.engine) + return df + +def get_python_prompt2(user_query, df, config): + prompt = f''' + Given a user query and a pandas dataframe with the relevant data, only write + CODE DESCRIPTION: CODE DESCRIPTION, where CODE DESCRIPTION is a prompt that + describes how to take the dataframe (called df) and write python code to perform + relevant computations to answer the user query. Don't write any code, but write + the prompt such that if an independent python master with access to df + your instructions could + easily answer the original user query. Be specific about how to perform the computations, + including any relevant math, what functions to use (assume pandas, numpy access). + + Example User Query: Calculate the correlation between 7 year treasury yields and stocks close over the last 30 days + in the table. + Example Dataframe (df.head()): + Date,Treasury Yield (7-Year),Stock Close + 2024-01-01 00:00:00,4.113933,84.676268 + 2023-12-29 00:00:00,4.117221,100.393128 + 2023-12-28 00:00:00,2.391113,112.97598 + 2023-12-27 00:00:00,1.482054,119.224503 + 2023-12-26 00:00:00,4.187207,108.335695 + + Example Answer: + CODE DESCRIPTION: Given df with cols treasury_yield_7_year, stock_close, date, use pandas corr function + to compute the correlation between treasury_yield_7_year and stock close over the most recent 30 days. + + df.head(): {df.head()} + User Query: {user_query} + ''' + py_prompt = config.llm.invoke(prompt).content + print(f'Generated Python Prompt: {py_prompt}') + return py_prompt + +def execute_python2(py_prompt, df, config): + """generate and execute python""" + py_code = f''' + Given a pandas dataframe df and a description to perform a specific computation, + generate syntactically correct python code wrapped in ```python QUERY`` that takes + the raw dataframe and performs any computations to fully answer the user's query. + Assume access to NumPy (v{np.__version__}), Pandas (v{pd.__version__}) and that + the dataframe is called df. The output of the code should be the variable that + contains the result of the user's query (call this variable result) + + Example Dataframe (df): + Date,Treasury Yield (7-Year),Stock Close + 2024-01-01 00:00:00,4.113933,84.676268 + 2023-12-29 00:00:00,4.117221,100.393128 + 2023-12-28 00:00:00,2.391113,112.97598 + 2023-12-27 00:00:00,1.482054,119.224503 + 2023-12-26 00:00:00,4.187207,108.335695 + + Example Prompt: Given df with cols treasury_yield_7_year, stock_close, date, use pandas corr function + to compute the correlation between treasury_yield_7_year and stock close. + + Example Labeled Answer: + ``` python + ### calculate corr btwn 7yr tsy and stock closes + df = df.sort_values('date')[:30] + result = df['treasury_yield_7_year'].corr(df['stock_close']) + ``` + df.head(): {df.head()} + Prompt: {py_prompt} + ''' + code = extract_query(config.llm.invoke(py_code).content, type='python') + print(f'Python Code:\n {code}') + + namespace = {'pd': pd, 'np': np, 'df': df} + exec(code, namespace) + result = namespace.get('result', "No result variable found") + print(f'Result:\n{result}') + return result + +def run_agent2(user_query, config): + """main function to run agent 2""" + print(f'User Query:\n {user_query}') + start = time.time() + sql_prompt = get_sql_prompt2(user_query, config) + df = execute_sql2(sql_prompt, config) + py_prompt = get_python_prompt2(user_query, df, config) + result = execute_python2(py_prompt, df, config) + end = time.time() + + print(f'runtime: {round(end-start,2)} seconds') + return result \ No newline at end of file diff --git a/src/agents/basic_agents/agent3.py b/src/agents/basic_agents/agent3.py new file mode 100644 index 0000000..1214552 --- /dev/null +++ b/src/agents/basic_agents/agent3.py @@ -0,0 +1,241 @@ +import os +import time +import json +import re +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from sqlalchemy import create_engine +from langchain_openai import ChatOpenAI +from langchain_community.utilities import SQLDatabase +from dotenv import load_dotenv + +def get_schema_context(config): + db = config.db + tables = db.get_usable_table_names() + + schema_lines = [] + for table in tables: + if table in ['ohlc', 'treasury_yields']: + table_info = db.get_table_info([table]) + schema_lines.append(f"Table: {table}\n{table_info}\n") + + schema_context = ( + "DATABASE SCHEMA:\n" + + "\n".join(schema_lines) + ) + return schema_context + +def extract_query(response, type='sql'): + pattern = rf"```{type}\s+([\s\S]*?)\s+```" + match = re.search(pattern, response) + + if match: + return match.group(1).strip() + else: + print(f"Extracting Query Failed: returning response.strip():\n{response.strip()}") + return response.strip() + +def get_sql_json3(user_query, config): + """Get SQL query as JSON structure""" + schema_context = get_schema_context(config) + prompt = f'''Given a user query and a SQLite database schema, return ONLY a valid JSON describing the + data required to answer the user query. The JSON should be parsable and adhere to proper JSON syntax. + + Instructions: + 2. Ensure column names and new column names are strings, and there are no extraneous characters. + 3. Be mindful of how many days of data to pull, as certain queries may specify n days but + require more than n to compute the result. + 4. Avoid creating new columns or performing calculations; these will be handled in later steps. + 5. Ensure new column names are self-explanatory and clear. + + Answer ONLY with a JSON in this format: + {{ + "tables": {{ + "table1": [ + ["original_column_name", "new_column_name"], + ... + ], + "table2": [...], + }}, + "joins": [ + ["tableA", "tableB", "tableA_join_column", "tableB_join_column", "join_type"] + ] + }} + + Constraints: + - Join types must be one of: "inner", "left", "right", "outer". + - Joins will be applied in order of the json. Note the order does matter here. For example + if the first join is inner(A,B) then next is outer(B, C), what is really happening is + outer(C, inner(A,B)), i.e. not inner(A, outer(B,C)) + + Database Schema: {schema_context} + + Example User Query: Calculate the correlation between 7-year treasury yields and stocks' + close prices over the last 30 days. + + Example Labeled Answer: + ```json + {{ + "tables": {{ + "ohlc": [ + ["date", "date"], + ["close", "stock_close"] + ], + "treasury_yields": [ + ["date", "date"], + ["yield_7_year", "tsy_yield_7_year"] + ] + }}, + "joins": [ + ["ohlc", "treasury_yields", "date", "date", "inner"] + ] + }} + + Answer for the following user query: {user_query} + ''' + llm_response = config.llm.invoke(prompt).content + print('=' * 30) + print(f'LLM Response: {llm_response}') + sql_json = extract_query(llm_response, type='json') + try: + sql_json = json.loads(sql_json) + print(f'=' * 30) + print(f'SQL PARSED JSON: {sql_json}') + valid_join_types = {"inner", "left", "right", "outer"} + for join in sql_json.get("joins", []): + if len(join) != 5 or join[-1] not in valid_join_types: + raise ValueError(f"Invalid join type or structure: {join}") + + return sql_json + except json.JSONDecodeError: + raise ValueError("LLM response is not valid JSON.") + except Exception as e: + raise ValueError(f"Error validating SQL JSON: {e}") + +def compile_sql(sql_json: dict, config) -> str: + """ + Compiles a JSON definition of tables and joins into a SQLite SELECT query. + """ + tables = sql_json.get("tables", {}) + joins = sql_json.get("joins", []) + + # 1) Build the projection columns (the SELECT part) + select_columns = [] + for table_name, column_pairs in tables.items(): + for original, alias in column_pairs: + select_columns.append(f'"{table_name}"."{original}" AS "{alias}"') + + # If no tables at all, we can't form a valid query + if not tables: + raise ValueError("No tables were provided. At least one table is required.") + + columns_str = ",\n ".join(select_columns) + + # 2) Build the FROM/JOIN parts of the query + if joins: + base_table = joins[0][0] + from_clause = f'"{base_table}"' + + for (tbl_left, tbl_right, left_on, right_on, join_type) in joins: + join_type_upper = join_type.upper() + " JOIN" + from_clause += ( + f'\n{join_type_upper} "{tbl_right}" ' + f'ON "{tbl_left}"."{left_on}" = "{tbl_right}"."{right_on}"' + ) + else: + base_table = list(tables.keys())[0] + from_clause = f'"{base_table}"' + + # 3) Put it all together + query = f''' +SELECT + {columns_str} +FROM {from_clause}'''.strip() + print(f'=' * 30) + print(f'SQL Query: {query}') + return pd.read_sql(query, config.engine) + +def get_python_prompt3(user_query, df, config): + prompt = f''' + Given a user query and a pandas dataframe with the relevant data, only write + CODE DESCRIPTION: CODE DESCRIPTION, where CODE DESCRIPTION is a prompt that + describes how to take the dataframe (called df) and write python code to perform + relevant computations to answer the user query. Don't write any code, but write + the prompt such that if an independent python master with access to df + your instructions could + easily answer the original user query. Be specific about how to perform the computations, + including any relevant math, what functions to use (assume pandas, numpy access). + + Example User Query: Calculate the correlation between 7 year treasury yields and stocks close over the last 30 days + in the table. + Example Dataframe (df.head()): + Date,Treasury Yield (7-Year),Stock Close + 2024-01-01 00:00:00,4.113933,84.676268 + 2023-12-29 00:00:00,4.117221,100.393128 + 2023-12-28 00:00:00,2.391113,112.97598 + 2023-12-27 00:00:00,1.482054,119.224503 + 2023-12-26 00:00:00,4.187207,108.335695 + + Example Answer: + CODE DESCRIPTION: Given df with cols treasury_yield_7_year, stock_close, date, use pandas corr function + to compute the correlation between treasury_yield_7_year and stock close over the most recent 30 days. + + df.head(): {df.head()} + User Query: {user_query} + ''' + py_prompt = config.llm.invoke(prompt).content + print(f'Generated Python Prompt: {py_prompt}') + return py_prompt + +def execute_python3(py_prompt, df, config): + """generate and execute python""" + py_code = f''' + Given a pandas dataframe df and a description to perform a specific computation, + generate syntactically correct python code wrapped in ```python QUERY`` that takes + the raw dataframe and performs any computations to fully answer the user's query. + Assume access to NumPy (v{np.__version__}), Pandas (v{pd.__version__}) and that + the dataframe is called df. The output of the code should be the variable that + contains the result of the user's query (call this variable result) + + Example Dataframe (df): + Date,Treasury Yield (7-Year),Stock Close + 2024-01-01 00:00:00,4.113933,84.676268 + 2023-12-29 00:00:00,4.117221,100.393128 + 2023-12-28 00:00:00,2.391113,112.97598 + 2023-12-27 00:00:00,1.482054,119.224503 + 2023-12-26 00:00:00,4.187207,108.335695 + + Example Prompt: Given df with cols treasury_yield_7_year, stock_close, date, use pandas corr function + to compute the correlation between treasury_yield_7_year and stock close. + + Example Labeled Answer: + ``` python + ### calculate corr btwn 7yr tsy and stock closes + df = df.sort_values('date')[:30] + result = df['treasury_yield_7_year'].corr(df['stock_close']) + ``` + df.head(): {df.head()} + Prompt: {py_prompt} + ''' + code = extract_query(config.llm.invoke(py_code).content, type='python') + print(f'Python Code:\n {code}') + + namespace = {'pd': pd, 'np': np, 'df': df} + exec(code, namespace) + result = namespace.get('result', "No result variable found") + print(f'Result:\n{result}') + return result + +def run_agent3(user_query, config): + """main function to run agent 3""" + print('=' * 30) + print(f'User Query:\n {user_query}') + start = time.time() + sql_json = get_sql_json3(user_query, config) + df = compile_sql(sql_json, config) + py_prompt = get_python_prompt3(user_query, df, config) + result = execute_python3(py_prompt, df, config) + end = time.time() + print('=' * 30) + print(f'runtime: {round(end-start,2)} seconds') + return result \ No newline at end of file diff --git a/src/agents/basic_agent.ipynb b/src/agents/basic_agents/basic_agent.ipynb similarity index 100% rename from src/agents/basic_agent.ipynb rename to src/agents/basic_agents/basic_agent.ipynb diff --git a/src/agents/basic_agents/config.py b/src/agents/basic_agents/config.py new file mode 100644 index 0000000..e1a4218 --- /dev/null +++ b/src/agents/basic_agents/config.py @@ -0,0 +1,33 @@ +import os +from sqlalchemy import create_engine +from langchain_openai import ChatOpenAI +from langchain_community.utilities import SQLDatabase +from dotenv import load_dotenv + +class Config: + """SQL agent config""" + def __init__( + self, + env_fpath, + sql_db_fpath, + openai_model='gpt-4-1106-preview', + temperature=0.7 + ): + load_dotenv(env_fpath) + openai_api_key = os.getenv("OPENAI_API_KEY") + + self.llm = ChatOpenAI( + model=openai_model, + temperature=temperature, + api_key=openai_api_key + ) + self.db = SQLDatabase.from_uri(f"sqlite:///{sql_db_fpath}") + self.engine = create_engine(f"sqlite:///{sql_db_fpath}") + +def get_paths(): + """Get the paths for environment and database files.""" + SRC_PATH = os.path.dirname(os.getcwd()) + ENV_FPATH = f'{SRC_PATH}/keys.env' + SQL_DB_FPATH = f'{SRC_PATH}/synthetic_data.db' + + return ENV_FPATH, SQL_DB_FPATH \ No newline at end of file diff --git a/src/main.py b/src/main.py index 67adbc7..3f21c9c 100644 --- a/src/main.py +++ b/src/main.py @@ -1,49 +1,100 @@ import argparse +import os +from agents.basic_agents.agent1 import run_agent1 +from agents.basic_agents.agent2 import run_agent2 +from agents.basic_agents.agent3 import run_agent3 from agents.master import MasterWorkflow from colorama import Fore, Style +from dotenv import load_dotenv +from agents.basic_agents.config import Config, get_paths def main(): parser = argparse.ArgumentParser() parser.add_argument("-gemini", action="store_true", help="Use Google's Gemini model instead of OpenAI") + parser.add_argument("-basic", type=int, choices=[1, 2, 3], help="Specify the agent number to run (1, 2, or 3)") args = parser.parse_args() - workflow = MasterWorkflow(db_path="synthetic_data.db", use_gemini=args.gemini) - - questions = [ - # Simple SQL queries - "What are the closing prices for last 30 dates in the ohlc table?", - "What is the minimum price in the ohlc table?", - "What is the maximum price in the ohlc table?", - - # Moderate complexity - "What is the average open price over the last 10 days in the ohlc table ORDER BY date DESC", - - # Complex analysis requiring Python - "What is the stock volatility over the last 21 days from the ohlc table?", - "Calculate the correlation between 7 year treasury yields and close stock prices over the last 30 days", - "Find the days where the stock price movement was more than 2 standard deviations from the mean", - ] - - for question in questions: - print(f"\n{Fore.CYAN}{'='*80}{Style.RESET_ALL}") - print(f"{Fore.YELLOW}Question:{Style.RESET_ALL} {question}") - print(f"{Fore.YELLOW}Using Model:{Style.RESET_ALL} {'Gemini' if args.gemini else 'OpenAI'}") - - result = workflow.process_question(question) + load_dotenv() + + ENV_FPATH, SQL_DB_FPATH = get_paths() + + print(f'env fpath: {ENV_FPATH}') + print(f'sql db fpath: {SQL_DB_FPATH}') + + config = Config( + env_fpath=ENV_FPATH, + sql_db_fpath=SQL_DB_FPATH + ) + + if args.basic: + agent_number = args.basic - print(f"{Fore.MAGENTA}Workflow Type:{Style.RESET_ALL} {result['workflow_type']}") - print(f"{Fore.MAGENTA}Complexity Score:{Style.RESET_ALL} {result['complexity_score']:.2f}") + if agent_number == 1: + user_query = '''Find the days where stock price movement := close-open was more + than 2 standard deviations from the mean''' + print(f"\n{Fore.CYAN}{'='*80}{Style.RESET_ALL}") + print(f"{Fore.YELLOW}Running Agent {agent_number} for Query:{Style.RESET_ALL} {user_query}") + result = run_agent1(user_query, config) + + elif agent_number == 2: + user_query = '''Find the days where stock price movement := close-open was more + than 2 standard deviations from the mean''' + print(f"\n{Fore.CYAN}{'='*80}{Style.RESET_ALL}") + print(f"{Fore.YELLOW}Running Agent {agent_number} for Query:{Style.RESET_ALL} {user_query}") + result = run_agent2(user_query, config) + + elif agent_number == 3: + user_query = '''For s in [1,2], of the days where the stock price + movement := close - open was more than s std deviations from the mean, look at the distribution + of 7yr tsy yield - 5yr tsy yield. To visualize this, assume access to matplotlib.pyplot as plt and + make a 2 plots, the left where s = 1 and a histogram of 7yr - 5yr tsy yields with lines at 25 percentile, + 50th percentile, 75th, and then right same with s=2''' + print(f"\n{Fore.CYAN}{'='*80}{Style.RESET_ALL}") + print(f"{Fore.YELLOW}Running Agent {agent_number} for Query:{Style.RESET_ALL} {user_query}") + result = run_agent3(user_query, config) + + print(f"{Fore.MAGENTA}Results from Agent {agent_number}:{Style.RESET_ALL}") + print(result) + + else: + # Existing code for running the MasterWorkflow + workflow = MasterWorkflow(db_path="synthetic_data.db", use_gemini=args.gemini) - if result.get('results') is not None: - print(f"\n{Fore.GREEN}Results:{Style.RESET_ALL}") - if isinstance(result['results'], (list, tuple)): - for row in result['results']: - print(row) - else: - print(result['results']) + questions = [ + # Simple SQL queries + "What are the closing prices for last 30 dates in the ohlc table?", + "What is the minimum price in the ohlc table?", + "What is the maximum price in the ohlc table?", + + # Moderate complexity + "What is the average open price over the last 10 days in the ohlc table ORDER BY date DESC", + + # Complex analysis requiring Python + "What is the stock volatility over the last 21 days from the ohlc table?", + "Calculate the correlation between 7 year treasury yields and close stock prices over the last 30 days", + "Find the days where the stock price movement was more than 2 standard deviations from the mean", + ] - if result.get('error'): - print(f"\n{Fore.RED}Error:{Style.RESET_ALL} {result['error']}") + for question in questions: + print(f"\n{Fore.CYAN}{'='*80}{Style.RESET_ALL}") + print(f"{Fore.YELLOW}Question:{Style.RESET_ALL} {question}") + print(f"{Fore.YELLOW}Using Model:{Style.RESET_ALL} {'Gemini' if args.gemini else 'OpenAI'}") + + result = workflow.process_question(question) + + print(f"{Fore.MAGENTA}Workflow Type:{Style.RESET_ALL} {result['workflow_type']}") + print(f"{Fore.MAGENTA}Complexity Score:{Style.RESET_ALL} {result['complexity_score']:.2f}") + + if result.get('results') is not None: + print(f"\n{Fore.GREEN}Results:{Style.RESET_ALL}") + if isinstance(result['results'], (list, tuple)): + for row in result['results']: + print(row) + else: + print(result['results']) + + if result.get('error'): + print(f"\n{Fore.RED}Error:{Style.RESET_ALL} {result['error']}") if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/agents/graph.png b/src/workflow_visualizations/graph.png similarity index 100% rename from src/agents/graph.png rename to src/workflow_visualizations/graph.png