Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ Upcoming (TBD)

Features
--------
* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL
* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL.
* Rework reconnect logic to actually create a new connection instead of simply changing the database (#746).

Bug Fixes
--------
Expand Down
63 changes: 40 additions & 23 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def close(self) -> None:
def register_special_commands(self) -> None:
special.register_special_command(self.change_db, "use", "\\u", "Change to a new database.", aliases=["\\u"])
special.register_special_command(
self.change_db,
self.manual_reconnect,
"connect",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the full command from "connect" to "reconnect" is a needless breaking change for the user. Maybe it is more rational since both the long and short versions start with "r"? We should be able to add "reconnect" as an alias here:

 aliases=["\\r", "reconnect"],

if that is the motivation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I switched it back to connect. Goal was consistency, but your point makes sense and it matches the official client anyway

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And RE: adding the alias, I pondered it but decided to leave it off for now. Theoretically in some future change we could split it out to actually be "connect" and "reconnect" with different logic (may never happen, but possible) so didn't want to further muddy the waters there.

"\\r",
"Reconnect to the database. Optional database argument.",
Expand Down Expand Up @@ -260,6 +260,14 @@ def register_special_commands(self) -> None:
self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=["\\R"], case_sensitive=True
)

def manual_reconnect(self, arg: str = "", **_) -> Generator[tuple, None, None]:
"""
wrapper function to use for the \r command so that the real function
may be cleanly used elsewhere
"""
self.reconnect(arg)
yield (None, None, None, None)

def enable_show_warnings(self, **_) -> Generator[tuple, None, None]:
self.show_warnings = True
msg = "Show warnings enabled."
Expand Down Expand Up @@ -912,18 +920,11 @@ def one_iteration(text: str | None = None) -> None:
special.unset_once_if_written(self.post_redirect_command)
special.flush_pipe_once_if_written(self.post_redirect_command)
except err.InterfaceError:
logger.debug("Attempting to reconnect.")
self.echo("Reconnecting...", fg="yellow")
try:
sqlexecute.connect()
logger.debug("Reconnected successfully.")
one_iteration(text)
return # OK to just return, cuz the recursion call runs to the end.
except OperationalError as e2:
logger.debug("Reconnect failed. e: %r", e2)
self.echo(str(e2), err=True, fg="red")
# If reconnection failed, don't proceed further.
# attempt to reconnect
if not self.reconnect():
return
one_iteration(text)
return # OK to just return, cuz the recursion call runs to the end.
except EOFError as e:
raise e
except KeyboardInterrupt:
Expand Down Expand Up @@ -957,18 +958,11 @@ def one_iteration(text: str | None = None) -> None:
except OperationalError as e1:
logger.debug("Exception: %r", e1)
if e1.args[0] in (2003, 2006, 2013):
logger.debug("Attempting to reconnect.")
self.echo("Reconnecting...", fg="yellow")
try:
sqlexecute.connect()
logger.debug("Reconnected successfully.")
one_iteration(text)
return # OK to just return, cuz the recursion call runs to the end.
except OperationalError as e2:
logger.debug("Reconnect failed. e: %r", e2)
self.echo(str(e2), err=True, fg="red")
# If reconnection failed, don't proceed further.
# attempt to reconnect
if not self.reconnect():
return
one_iteration(text)
return # OK to just return, cuz the recursion call runs to the end.
else:
logger.error("sql: %r, error: %r", text, e1)
logger.error("traceback: %r", traceback.format_exc())
Expand Down Expand Up @@ -1040,6 +1034,29 @@ def one_iteration(text: str | None = None) -> None:
if not self.less_chatty:
self.echo("Goodbye!")

def reconnect(self, database: str = "") -> bool:
"""
Attempt to reconnect to the database. Return True if successful,
False if unsuccessful.
"""
assert self.sqlexecute is not None
self.logger.debug("Attempting to reconnect.")
self.echo("Reconnecting...", fg="yellow")
try:
self.sqlexecute.connect()
except OperationalError as e:
self.logger.debug("Reconnect failed. e: %r", e)
self.echo(str(e), err=True, fg="red")
return False
self.logger.debug("Reconnected successfully.")
self.echo("Reconnected successfully.\n", fg="yellow")
if database and self.sqlexecute.dbname != database:
for result in self.change_db(database):
self.echo(result[3])
elif database:
self.echo(f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"')
return True

def log_output(self, output: str) -> None:
"""Log the output in the audit log, if it's enabled."""
if isinstance(self.logfile, TextIOWrapper):
Expand Down
2 changes: 2 additions & 0 deletions test/features/fixture_data/help_commands.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
| help | \? | Show this help. |
| nopager | \n | Disable pager, print to stdout. |
| notee | notee | Stop writing results to an output file. |
| nowarnings | \w | Disable automatic warnings display. |
| pager | \P [command] | Set PAGER. Print the query results via PAGER. |
| prompt | \R | Change prompt format. |
| quit | \q | Quit. |
Expand All @@ -30,5 +31,6 @@
| tableformat | \T | Change the table format used to output results. |
| tee | tee [-o] filename | Append all results to an output file (overwrite using -o). |
| use | \u | Change to a new database. |
| warnings | \W | Enable automatic warnings display. |
| watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). |
+----------------+----------------------------+------------------------------------------------------------+
29 changes: 29 additions & 0 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,35 @@
]


@dbtest
def test_reconnect_no_database(executor):
runner = CliRunner()
sql = "\\r"
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
expected = "Reconnecting...\nReconnected successfully.\n\n"
assert expected in result.output


@dbtest
def test_reconnect_with_different_database(executor):
runner = CliRunner()
database = "mysql"
sql = f"\\r {database}"
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
expected = f'Reconnecting...\nReconnected successfully.\n\nYou are now connected to database "{database}" as user "{USER}"\n'
assert expected in result.output


@dbtest
def test_reconnect_with_same_database(executor):
runner = CliRunner()
database = "mysql"
sql = f"\\u {database}; \\r {database}"
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
expected = f'Reconnecting...\nReconnected successfully.\n\nYou are already connected to database "{database}" as user "{USER}"\n'
assert expected in result.output


@dbtest
def test_prompt_no_host_only_socket(executor):
mycli = MyCli()
Expand Down