Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion src/agents/sql_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,49 @@ def process_question(self, question: str) -> Dict:
"results": None
}
result = self.compiled_graph.invoke(initial_state)
return result
return result

def identify_data_needs(self, state: State) -> Dict:
print(f"\n{Fore.YELLOW}=== IDENTIFY DATA NEEDS ==={Style.RESET_ALL}")

schema = self.get_schema()
prompt = f"""You are a SQL expert. Write a SQL query to fetch the raw data needed for Python analysis.
Return ONLY the SQL query wrapped in ```sql``` blocks.

Database Schema:
{schema}

Important:
- DO NOT perform calculations in SQL
- Just fetch the necessary columns needed for Python analysis
- For N-day rolling calculations, fetch at least N+2 days of data
- For volatility calculations:
* Need N+2 days minimum (N days + 1 for pct_change + 1 for initial value)
* For 21-day volatility, fetch at least 23 days
- For correlation calculations, fetch at least 62 days
- For date-based queries, use 'ORDER BY date DESC LIMIT X'
- Include the date column in results
- Always join tables using proper date matching
- If user asks for last N days, fetch N+1 days
- Example: if user asks for 21 days, fetch LIMIT 22
"""

messages = [
SystemMessage(content=prompt),
HumanMessage(content=state['user_question'])
]

response = self.chat_gpt.invoke(messages)
sql = self.extract_sql(response.content)

new_state = {
"user_question": state['user_question'],
"messages": state['messages'],
"code": sql,
"data": None,
"results": None
}

print(f"\n{Fore.YELLOW}SQL Generated:{Style.RESET_ALL}\n{sql}")

return new_state