diff --git a/README.md b/README.md index 53cba03..62d2a61 100644 --- a/README.md +++ b/README.md @@ -1,77 +1,66 @@ -# Text2SQL +# Text-to-SQL Interface -## Description -📔 Text ➡️ SQL 🧑‍💻 +This project provides a natural language interface to query financial data stored in a SQLite database. Users can ask questions in plain English, and the system will convert them into SQL queries and return the results. -## Tools -- [LangChain SQL Q&A Tutorial](https://python.langchain.com/docs/tutorials/sql_qa/) 🦜🔗 -- [Vanna.AI](https://vanna.ai/) 🔮 -- [LlamaIndex](https://www.llamaindex.ai/) 🦙 +## Features -## Tutorials -- [How to Use LangChain to Build a Text-to-SQL Solution](https://medium.com/@marvin_thompson/how-to-use-langchain-to-build-a-text-to-sql-solution-54a173f312a5) -- [Text2SQL GitHub Repository](https://github.com/WeitaoLu/Text2SQL) -- [Text2SQL Workshop GitHub Repository](https://github.com/weet-ai/text2sql-workshop) +- Interactive UI built with Streamlit +- Database schema exploration with hoverable table information +- Sample data preview for each table +- Support for both OpenAI and Google Gemini models +- Natural language processing to convert questions to SQL -## Papers -- [A Survey on Employing Large Language Models for Text-to-SQL Tasks](https://arxiv.org/html/2407.15186v2) -- [PET-SQL: A Prompt-enhanced Two-stage Text-to-SQL Framework with Cross-consistency](https://arxiv.org/html/2403.09732v1) -- [SeaD: End-to-end Text-to-SQL Generation with Schema-aware Denoising](https://arxiv.org/pdf/2105.07911) -- [Next-Generation Database Interfaces:A Survey of LLM-based Text-to-SQL](https://arxiv.org/pdf/2406.08426) +## Database Tables -## Running +The application works with the following tables: -1 - install requirements -```bash -pip install requirements.txt -``` +1. **ohlc**: Stock price data with open, high, low, and close prices for each date +2. **fxrates**: Foreign exchange rates for USD to EUR, GBP, and JPY +3. **treasury_yields**: Treasury yields for 5-year, 7-year, and 10-year bonds +4. **yahoo_ohlc**: Stock price data from Yahoo Finance with ticker symbols -2- navigate to `/data` and run `sqlite-synthetic.py` to create a toy dataset -```bash -cd data -python sqlite-synthetic.py -``` -_after this step you should see a `synthetic_data.db` in `/src`_ +## Installation -3- navigate to `src` and run `main.py` -```bash -cd ../src -# Use OpenAI (default) -python main.py +1. Clone this repository +2. Install the required dependencies: + ``` + pip install -r requirements.txt + ``` +3. Set up your environment variables in a `.env` file: + ``` + OPENAI_API_KEY=your_openai_api_key + GEMINI_API_KEY=your_gemini_api_key + ``` -# Or use Google's Gemini model -python main.py -gemini -``` +## Usage -_Optional_ run `visualize_workflows.py` to show workflow graphs +1. Generate the synthetic database (if not already done): + ``` + python data/sqlite-synthetic.py + ``` -## API Keys +2. Run the Streamlit UI: + ``` + streamlit run src/ui.py + ``` -Create a `keys.env` file in the `src/agents` directory with your API keys: -``` -OPENAI_API_KEY=your_openai_key_here -GOOGLE_API_KEY=your_gemini_key_here # Optional, only if using Gemini -``` +3. Open your browser and navigate to the URL shown in the terminal (typically http://localhost:8501) -## Development +4. Use the sidebar to explore the database schema and sample data -See [TODO.md](TODO.md) for planned features and improvements. +5. Enter your question in the text area and click "Submit Question" -## Implementaions +## Example Questions -Langgraph Workflow 🦜 +- "What are the closing prices for last 30 dates in the ohlc table?" +- "What is the average open price over the last 10 days in the ohlc table?" +- "What is the stock volatility over the last 21 days from the ohlc table?" +- "Calculate the correlation between 7 year treasury yields and close stock prices over the last 30 days" +- "Find the days where the stock price movement was more than 2 standard deviations from the mean" -- Master Workflow -

- Master Workflow -

+## Project Structure -- Python Workflow -

- Python Workflow -

- -- SQL Workflow -

- SQL Workflow -

+- `src/ui.py`: Streamlit UI interface +- `src/main.py`: Command-line interface for running queries +- `data/sqlite-synthetic.py`: Script to generate synthetic financial data +- `agents/`: Contains the agents that process natural language queries diff --git a/requirements.txt b/requirements.txt index 8e75cda..101c801 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/src/run_ui.py b/src/run_ui.py new file mode 100644 index 0000000..a1a354a --- /dev/null +++ b/src/run_ui.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +""" +Script to run the Text-to-SQL UI +""" + +import os +import subprocess +import sys +from pathlib import Path + +def main(): + # Check if the database exists + db_path = Path(__file__).parent / 'synthetic_data.db' + + if not db_path.exists(): + print("Database not found. Generating synthetic data...") + data_dir = Path(__file__).parent.parent / 'data' + subprocess.run([sys.executable, str(data_dir / 'sqlite-synthetic.py')], check=True) + print("Synthetic data generated successfully.") + + # Run the Streamlit UI + print("Starting Text-to-SQL UI...") + ui_path = Path(__file__).parent / 'ui.py' + subprocess.run(['streamlit', 'run', str(ui_path)], check=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/ui.py b/src/ui.py new file mode 100644 index 0000000..5e20be9 --- /dev/null +++ b/src/ui.py @@ -0,0 +1,150 @@ +import streamlit as st +import sqlite3 +import pandas as pd +import os +from pathlib import Path +from agents.master import MasterWorkflow +from agents.basic_agents.config import get_paths + +# Set page configuration +st.set_page_config( + page_title="Text-to-SQL Interface", + page_icon="📊", + layout="wide" +) + +# Get database path +_, SQL_DB_FPATH = get_paths() + +# Function to get table schemas +def get_table_schemas(): + conn = sqlite3.connect(SQL_DB_FPATH) + cursor = conn.cursor() + + # Get all tables + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = cursor.fetchall() + + schemas = {} + for table in tables: + table_name = table[0] + cursor.execute(f"PRAGMA table_info({table_name});") + columns = cursor.fetchall() + schemas[table_name] = columns + + conn.close() + return schemas + +# Function to get sample data from a table +def get_sample_data(table_name, limit=5): + conn = sqlite3.connect(SQL_DB_FPATH) + query = f"SELECT * FROM {table_name} LIMIT {limit}" + df = pd.read_sql_query(query, conn) + conn.close() + return df + +# Function to get table descriptions +def get_table_descriptions(): + return { + "ohlc": "Stock price data with open, high, low, and close prices for each date.", + "fxrates": "Foreign exchange rates for USD to EUR, GBP, and JPY.", + "treasury_yields": "Treasury yields for 5-year, 7-year, and 10-year bonds.", + "yahoo_ohlc": "Stock price data from Yahoo Finance with ticker symbols." + } + +# Main UI +def main(): + st.title("📊 Text-to-SQL Interface") + st.markdown(""" + This interface allows you to ask natural language questions about the financial data in our database. + The system will convert your questions into SQL queries and return the results. + """) + + # Sidebar with database information + with st.sidebar: + st.header("Database Information") + + # Get table schemas + schemas = get_table_schemas() + table_descriptions = get_table_descriptions() + + # Display tables with expandable details + for table_name, columns in schemas.items(): + with st.expander(f"📋 {table_name}"): + st.markdown(f"**Description:** {table_descriptions.get(table_name, 'No description available.')}") + st.markdown("**Schema:**") + + # Create a DataFrame for the schema + schema_df = pd.DataFrame(columns, columns=['cid', 'name', 'type', 'notnull', 'default', 'pk']) + schema_df = schema_df[['name', 'type', 'pk']] # Only show relevant columns + schema_df.columns = ['Column', 'Type', 'Primary Key'] + st.dataframe(schema_df, use_container_width=True) + + # Show sample data + st.markdown("**Sample Data:**") + sample_data = get_sample_data(table_name) + st.dataframe(sample_data, use_container_width=True) + + # Main content area + st.header("Ask a Question") + + # Text input for the question + user_question = st.text_area( + "Enter your question about the data:", + placeholder="Example: What is the average closing price over the last 30 days?", + height=100 + ) + + # Model selection + model_option = st.radio( + "Select the model to use:", + ["OpenAI", "Gemini"], + horizontal=True + ) + + # Submit button + if st.button("Submit Question"): + if user_question: + with st.spinner("Processing your question..."): + try: + # Initialize the workflow + use_gemini = model_option == "Gemini" + workflow = MasterWorkflow(db_path=SQL_DB_FPATH, use_gemini=use_gemini) + + # Process the question + result = workflow.process_question(user_question) + + # Display results + st.subheader("Results") + + # Display workflow type and complexity score + col1, col2 = st.columns(2) + with col1: + st.metric("Workflow Type", result['workflow_type']) + with col2: + st.metric("Complexity Score", f"{result['complexity_score']:.2f}") + + # Display the results + if result.get('results') is not None: + if isinstance(result['results'], (list, tuple)): + # Convert to DataFrame if it's a list of tuples + if result['results'] and isinstance(result['results'][0], tuple): + df = pd.DataFrame(result['results']) + st.dataframe(df, use_container_width=True) + else: + for row in result['results']: + st.write(row) + else: + st.write(result['results']) + + # Display any errors + if result.get('error'): + st.error(f"Error: {result['error']}") + + except Exception as e: + st.error(f"An error occurred: {str(e)}") + else: + st.warning("Please enter a question.") + +if __name__ == "__main__": + main() \ No newline at end of file