Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command
from mycli.packages.parseutils import is_destructive, is_dropping_database
from mycli.packages.prompt_utils import confirm, confirm_destructive_query
from mycli.packages.special.favoritequeries import FavoriteQueries
from mycli.packages.special.main import ArgType
from mycli.packages.tabular_output import sql_format
from mycli.packages.toolkit.history import FileHistoryWithTimestamp
Expand Down Expand Up @@ -128,8 +127,6 @@ def __init__(
special.set_timing_enabled(c["main"].as_bool("timing"))
self.beep_after_seconds = float(c["main"]["beep_after_seconds"] or 0)

FavoriteQueries.instance = FavoriteQueries.from_config(self.config)

self.dsn_alias = None
self.main_formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
self.redirect_formatter = TabularOutputFormatter(format_name=c["main"].get("redirect_format", "csv"))
Expand Down Expand Up @@ -681,6 +678,47 @@ def get_continuation(width, *_):
def show_suggestion_tip():
return iterations < 2

def output_res(res, start):
result_count = 0
mutating = False
for title, cur, headers, status in res:
logger.debug("headers: %r", headers)
logger.debug("rows: %r", cur)
logger.debug("status: %r", status)
threshold = 1000
if is_select(status) and cur and cur.rowcount > threshold:
self.echo(
"The result set has more than {} rows.".format(threshold),
fg="red",
)
if not confirm("Do you want to continue?"):
self.echo("Aborted!", err=True, fg="red")
break

if self.auto_vertical_output:
max_width = self.prompt_app.output.get_size().columns
else:
max_width = None

formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width)

t = time() - start
try:
if result_count > 0:
self.echo("")
try:
self.output(formatted, status)
except KeyboardInterrupt:
pass
self.echo("Time: %0.03fs" % t)
except KeyboardInterrupt:
pass

start = time()
result_count += 1
mutating = mutating or is_mutating(status)
return mutating

def one_iteration(text=None):
if text is None:
try:
Expand All @@ -707,6 +745,27 @@ def one_iteration(text=None):
logger.error("traceback: %r", traceback.format_exc())
self.echo(str(e), err=True, fg="red")
return
# LLM command support
while special.is_llm_command(text):
try:
start = time()
cur = sqlexecute.conn.cursor()
context, sql, duration = special.handle_llm(text, cur)
if context:
click.echo("LLM Response:")
click.echo(context)
click.echo("---")
click.echo(f"Time: {duration:.2f} seconds")
text = self.prompt_app.prompt(default=sql)
except KeyboardInterrupt:
return
except special.FinishIteration as e:
return output_res(e.results, start) if e.results else None
except RuntimeError as e:
logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc())
self.echo(str(e), err=True, fg="red")
return

if not text.strip():
return
Expand Down
2 changes: 2 additions & 0 deletions mycli/packages/completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def suggest_special(text: str) -> list[dict[str, Any]]:
]
elif cmd in ["\\.", "source"]:
return [{"type": "file_name"}]
if cmd in ["\\llm", "\\ai"]:
return [{"type": "llm"}]

return [{"type": "keyword"}, {"type": "special"}]

Expand Down
1 change: 1 addition & 0 deletions mycli/packages/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ def export(defn: Callable):
from mycli.packages.special import (
dbcommands, # noqa: E402 F401
iocommands, # noqa: E402 F401
llm, # noqa: E402 F401
)
22 changes: 15 additions & 7 deletions mycli/packages/special/iocommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Generator

import click
from configobj import ConfigObj
from pymysql.cursors import Cursor
import pyperclip
import sqlparse
Expand All @@ -36,6 +37,13 @@
'stdout_mode': None,
}
delimiter_command = DelimiterCommand()
favoritequeries = FavoriteQueries(ConfigObj())


@export
def set_favorite_queries(config):
global favoritequeries
favoritequeries = FavoriteQueries(config)


@export
Expand Down Expand Up @@ -261,7 +269,7 @@ def execute_favorite_query(cur: Cursor, arg: str, **_) -> Generator[tuple, None,
name, _separator, arg_str = arg.partition(" ")
args = shlex.split(arg_str)

query = FavoriteQueries.instance.get(name)
query = favoritequeries.get(name)
if query is None:
message = "No favorite query: %s" % (name)
yield (None, None, None, message)
Expand All @@ -286,10 +294,10 @@ def list_favorite_queries() -> list[tuple]:
Returns (title, rows, headers, status)"""

headers = ["Name", "Query"]
rows = [(r, FavoriteQueries.instance.get(r)) for r in FavoriteQueries.instance.list()]
rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()]

if not rows:
status = "\nNo favorite queries found." + FavoriteQueries.instance.usage
status = "\nNo favorite queries found." + favoritequeries.usage
else:
status = ""
return [("", rows, headers, status)]
Expand All @@ -316,7 +324,7 @@ def save_favorite_query(arg: str, **_) -> list[tuple]:
"""Save a new favorite query.
Returns (title, rows, headers, status)"""

usage = "Syntax: \\fs name query.\n\n" + FavoriteQueries.instance.usage
usage = "Syntax: \\fs name query.\n\n" + favoritequeries.usage
if not arg:
return [(None, None, None, usage)]

Expand All @@ -326,18 +334,18 @@ def save_favorite_query(arg: str, **_) -> list[tuple]:
if (not name) or (not query):
return [(None, None, None, usage + "Err: Both name and query are required.")]

FavoriteQueries.instance.save(name, query)
favoritequeries.save(name, query)
return [(None, None, None, "Saved.")]


@special_command("\\fd", "\\fd [name]", "Delete a favorite query.")
def delete_favorite_query(arg: str, **_) -> list[tuple]:
"""Delete an existing favorite query."""
usage = "Syntax: \\fd name.\n\n" + FavoriteQueries.instance.usage
usage = "Syntax: \\fd name.\n\n" + favoritequeries.usage
if not arg:
return [(None, None, None, usage)]

status = FavoriteQueries.instance.delete(arg)
status = favoritequeries.delete(arg)

return [(None, None, None, status)]

Expand Down
Loading