diff --git a/src/agents/sql_workflow.py b/src/agents/sql_workflow.py index d2b71e6..aa96892 100644 --- a/src/agents/sql_workflow.py +++ b/src/agents/sql_workflow.py @@ -197,4 +197,49 @@ def process_question(self, question: str) -> Dict: "results": None } result = self.compiled_graph.invoke(initial_state) - return result \ No newline at end of file + return result + + def identify_data_needs(self, state: State) -> Dict: + print(f"\n{Fore.YELLOW}=== IDENTIFY DATA NEEDS ==={Style.RESET_ALL}") + + schema = self.get_schema() + prompt = f"""You are a SQL expert. Write a SQL query to fetch the raw data needed for Python analysis. + Return ONLY the SQL query wrapped in ```sql``` blocks. + + Database Schema: + {schema} + + Important: + - DO NOT perform calculations in SQL + - Just fetch the necessary columns needed for Python analysis + - For N-day rolling calculations, fetch at least N+2 days of data + - For volatility calculations: + * Need N+2 days minimum (N days + 1 for pct_change + 1 for initial value) + * For 21-day volatility, fetch at least 23 days + - For correlation calculations, fetch at least 62 days + - For date-based queries, use 'ORDER BY date DESC LIMIT X' + - Include the date column in results + - Always join tables using proper date matching + - If user asks for last N days, fetch N+1 days + - Example: if user asks for 21 days, fetch LIMIT 22 + """ + + messages = [ + SystemMessage(content=prompt), + HumanMessage(content=state['user_question']) + ] + + response = self.chat_gpt.invoke(messages) + sql = self.extract_sql(response.content) + + new_state = { + "user_question": state['user_question'], + "messages": state['messages'], + "code": sql, + "data": None, + "results": None + } + + print(f"\n{Fore.YELLOW}SQL Generated:{Style.RESET_ALL}\n{sql}") + + return new_state \ No newline at end of file