Skip to content
2 changes: 1 addition & 1 deletion aider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from packaging import version

__version__ = "0.88.16.dev"
__version__ = "0.88.17.dev"
safe_version = __version__

try:
Expand Down
165 changes: 96 additions & 69 deletions aider/coders/base_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ def __init__(
self.dry_run = dry_run
self.pretty = self.io.pretty
self.linear_output = linear_output

self.main_model = main_model

# Set the reasoning tag name based on model settings or default
Expand All @@ -493,6 +492,8 @@ def __init__(
self.commands = commands or Commands(self.io, self)
self.commands.coder = self

self.data_cache = {"repo": {"last_key": ""}, "relative_files": None}

self.repo = repo
if use_git and self.repo is None:
try:
Expand Down Expand Up @@ -877,58 +878,81 @@ def get_repo_map(self, force_refresh=False):
self.io.update_spinner("Updating repo map")

cur_msg_text = self.get_cur_message_text()
mentioned_fnames = self.get_file_mentions(cur_msg_text)
mentioned_idents = self.get_ident_mentions(cur_msg_text)
staged_files_hash = hash(str([item.a_path for item in self.repo.repo.index.diff("HEAD")]))
read_only_count = len(set(self.abs_read_only_fnames)) + len(
set(self.abs_read_only_stubs_fnames)
)
self.data_cache["repo"]["mentioned_idents"] = self.get_ident_mentions(cur_msg_text)

mentioned_fnames.update(self.get_ident_filename_matches(mentioned_idents))
if (
staged_files_hash != self.data_cache["repo"]["last_key"]
or read_only_count != self.data_cache["repo"]["read_only_count"]
):
self.data_cache["repo"]["last_key"] = staged_files_hash

all_abs_files = set(self.get_all_abs_files())
mentioned_idents = self.data_cache["repo"]["mentioned_idents"]
mentioned_fnames = self.get_file_mentions(cur_msg_text)
mentioned_fnames.update(self.get_ident_filename_matches(mentioned_idents))

# Exclude metadata/docs from repo map inputs to reduce parsing overhead
def _include_in_map(abs_path):
try:
rel = self.get_rel_fname(abs_path)
except Exception:
rel = str(abs_path)
parts = Path(rel).parts
if ".meta" in parts or ".docs" in parts:
return False
if ".min." in parts[-1]:
return False
if self.repo.git_ignored_file(abs_path):
return False
return True
all_abs_files = set(self.get_all_abs_files())

all_abs_files = {p for p in all_abs_files if _include_in_map(p)}
repo_abs_read_only_fnames = set(self.abs_read_only_fnames) & all_abs_files
repo_abs_read_only_stubs_fnames = set(self.abs_read_only_stubs_fnames) & all_abs_files
chat_files = (
set(self.abs_fnames) | repo_abs_read_only_fnames | repo_abs_read_only_stubs_fnames
)
other_files = all_abs_files - chat_files
# Exclude metadata/docs from repo map inputs to reduce parsing overhead
def _include_in_map(abs_path):
try:
rel = self.get_rel_fname(abs_path)
except Exception:
rel = str(abs_path)
parts = Path(rel).parts
if ".meta" in parts or ".docs" in parts:
return False
if ".min." in parts[-1]:
return False
if self.repo.git_ignored_file(abs_path):
return False
return True

all_abs_files = {p for p in all_abs_files if _include_in_map(p)}
repo_abs_read_only_fnames = set(self.abs_read_only_fnames) & all_abs_files
repo_abs_read_only_stubs_fnames = set(self.abs_read_only_stubs_fnames) & all_abs_files
chat_files = (
set(self.abs_fnames) | repo_abs_read_only_fnames | repo_abs_read_only_stubs_fnames
)
other_files = all_abs_files - chat_files

self.data_cache["repo"].update(
{
"chat_files": chat_files,
"other_files": other_files,
"mentioned_fnames": mentioned_fnames,
"all_abs_files": all_abs_files,
"read_only_count": len(set(self.abs_read_only_fnames)) + len(
set(self.abs_read_only_stubs_fnames)
),
}
)

repo_content = self.repo_map.get_repo_map(
chat_files,
other_files,
mentioned_fnames=mentioned_fnames,
mentioned_idents=mentioned_idents,
self.data_cache["repo"]["chat_files"],
self.data_cache["repo"]["other_files"],
mentioned_fnames=self.data_cache["repo"]["mentioned_fnames"],
mentioned_idents=self.data_cache["repo"]["mentioned_idents"],
force_refresh=force_refresh,
)

# fall back to global repo map if files in chat are disjoint from rest of repo
if not repo_content:
repo_content = self.repo_map.get_repo_map(
set(),
all_abs_files,
mentioned_fnames=mentioned_fnames,
mentioned_idents=mentioned_idents,
self.data_cache["repo"]["all_abs_files"],
mentioned_fnames=self.data_cache["repo"]["mentioned_fnames"],
mentioned_idents=self.data_cache["repo"]["mentioned_idents"],
)

# fall back to completely unhinted repo
if not repo_content:
repo_content = self.repo_map.get_repo_map(
set(),
all_abs_files,
self.data_cache["repo"]["all_abs_files"],
)

self.io.update_spinner(self.io.last_spinner_text)
Expand Down Expand Up @@ -1085,7 +1109,7 @@ async def _run_linear(self, with_message=None, preproc=True):

user_message = None
await self.io.cancel_input_task()
await self.io.cancel_processing_task()
await self.io.cancel_output_task()

while True:
try:
Expand All @@ -1101,11 +1125,9 @@ async def _run_linear(self, with_message=None, preproc=True):
await self.io.input_task
user_message = self.io.input_task.result()

self.io.processing_task = asyncio.create_task(
self._processing_logic(user_message, preproc)
)
self.io.output_task = asyncio.create_task(self._generate(user_message, preproc))

await self.io.processing_task
await self.io.output_task

self.io.ring_bell()
user_message = None
Expand All @@ -1114,8 +1136,8 @@ async def _run_linear(self, with_message=None, preproc=True):
self.io.set_placeholder("")
await self.io.cancel_input_task()

if self.io.processing_task:
await self.io.cancel_processing_task()
if self.io.output_task:
await self.io.cancel_output_task()
self.io.stop_spinner()

self.keyboard_interrupt()
Expand All @@ -1127,7 +1149,7 @@ async def _run_linear(self, with_message=None, preproc=True):
return
finally:
await self.io.cancel_input_task()
await self.io.cancel_processing_task()
await self.io.cancel_output_task()

async def _run_patched(self, with_message=None, preproc=True):
try:
Expand All @@ -1139,7 +1161,7 @@ async def _run_patched(self, with_message=None, preproc=True):
user_message = None
self.user_message = ""
await self.io.cancel_input_task()
await self.io.cancel_processing_task()
await self.io.cancel_output_task()

while True:
try:
Expand All @@ -1151,7 +1173,7 @@ async def _run_patched(self, with_message=None, preproc=True):
or self.io.input_task.done()
or self.io.input_task.cancelled()
)
and (not self.io.processing_task or not self.io.placeholder)
and (not self.io.output_task or not self.io.placeholder)
):
if not self.suppress_announcements_for_next_prompt:
self.show_announcements()
Expand All @@ -1163,8 +1185,8 @@ async def _run_patched(self, with_message=None, preproc=True):
await self.io.recreate_input()

if self.user_message:
self.io.processing_task = asyncio.create_task(
self._processing_logic(self.user_message, preproc)
self.io.output_task = asyncio.create_task(
self._generate(self.user_message, preproc)
)

self.user_message = ""
Expand All @@ -1177,17 +1199,14 @@ async def _run_patched(self, with_message=None, preproc=True):

tasks = set()

if self.io.processing_task:
if self.io.processing_task.done():
exception = self.io.processing_task.exception()
if self.io.output_task:
if self.io.output_task.done():
exception = self.io.output_task.exception()
if exception:
if isinstance(exception, SwitchCoder):
await self.io.processing_task
elif (
not self.io.processing_task.done()
and not self.io.processing_task.cancelled()
):
tasks.add(self.io.processing_task)
await self.io.output_task
elif not self.io.output_task.done() and not self.io.output_task.cancelled():
tasks.add(self.io.output_task)

if (
self.io.input_task
Expand All @@ -1202,9 +1221,9 @@ async def _run_patched(self, with_message=None, preproc=True):
)

if self.io.input_task and self.io.input_task in done:
if self.io.processing_task:
if self.io.output_task:
if not self.io.confirmation_in_progress:
await self.io.cancel_processing_task()
await self.io.cancel_output_task()
self.io.stop_spinner()

try:
Expand All @@ -1222,10 +1241,10 @@ async def _run_patched(self, with_message=None, preproc=True):
await self.io.cancel_input_task()
continue

if self.io.processing_task and self.io.processing_task in pending:
if self.io.output_task and self.io.output_task in pending:
try:
tasks = set()
tasks.add(self.io.processing_task)
tasks.add(self.io.output_task)

# We just did a confirmation so add a new input task
if self.io.get_confirmation_acknowledgement():
Expand All @@ -1241,7 +1260,7 @@ async def _run_patched(self, with_message=None, preproc=True):
and self.io.input_task in done
and not self.io.confirmation_in_progress
):
await self.io.cancel_processing_task()
await self.io.cancel_output_task()
self.io.stop_spinner()
self.io.acknowledge_confirmation()

Expand All @@ -1263,24 +1282,22 @@ async def _run_patched(self, with_message=None, preproc=True):
self.io.ring_bell()
user_message = None
except KeyboardInterrupt:
if self.io.input_task:
self.io.set_placeholder("")
await self.io.cancel_input_task()
self.io.set_placeholder("")

if self.io.processing_task:
await self.io.cancel_processing_task()
self.io.stop_spinner()
await self.io.cancel_input_task()
await self.io.cancel_output_task()

self.io.stop_spinner()
self.keyboard_interrupt()

self.auto_save_session()
except EOFError:
return
finally:
await self.io.cancel_input_task()
await self.io.cancel_processing_task()
await self.io.cancel_output_task()

async def _processing_logic(self, user_message, preproc):
async def _generate(self, user_message, preproc):
await asyncio.sleep(0.1)

try:
Expand Down Expand Up @@ -2729,6 +2746,7 @@ async def check_for_file_mentions(self, content):
if await self.io.confirm_ask(
"Add file to the chat?", subject=rel_fname, group=group, allow_never=True
):
await self.io.recreate_input()
self.add_rel_fname(rel_fname)
added_fnames.append(rel_fname)
else:
Expand Down Expand Up @@ -3215,6 +3233,13 @@ def is_file_safe(self, fname):
return

def get_all_relative_files(self):
staged_files_hash = hash(str([item.a_path for item in self.repo.repo.index.diff("HEAD")]))
if (
staged_files_hash == self.data_cache["repo"]["last_key"]
and self.data_cache["relative_files"]
):
return self.data_cache["relative_files"]

if self.repo:
files = self.repo.get_tracked_files()
else:
Expand All @@ -3223,7 +3248,9 @@ def get_all_relative_files(self):
# This is quite slow in large repos
# files = [fname for fname in files if self.is_file_safe(fname)]

return sorted(set(files))
self.data_cache["relative_files"] = sorted(set(files))

return self.data_cache["relative_files"]

def get_all_abs_files(self):
files = self.get_all_relative_files()
Expand Down
9 changes: 8 additions & 1 deletion aider/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,14 @@ async def cmd_exit(self, args):
pass

await asyncio.sleep(0)
sys.exit()

try:
if self.coder.args.linear_output:
os._exit(0)
else:
sys.exit()
except Exception:
sys.exit()

def cmd_quit(self, args):
"Exit the application"
Expand Down
15 changes: 8 additions & 7 deletions aider/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def __init__(
# Variables used to interface with base_coder
self.coder = None
self.input_task = None
self.processing_task = None
self.output_task = None

# State tracking for confirmation input
self.confirmation_in_progress = False
Expand Down Expand Up @@ -886,6 +886,7 @@ def get_continuation(width, line_number, is_soft_wrap):
except EOFError:
raise
except KeyboardInterrupt:
await self.cancel_output_task()
self.console.print()
return ""
except UnicodeEncodeError as err:
Expand Down Expand Up @@ -961,13 +962,13 @@ async def cancel_input_task(self):
except (asyncio.CancelledError, EOFError, IndexError):
pass

async def cancel_processing_task(self):
if self.processing_task:
processing_task = self.processing_task
self.processing_task = None
async def cancel_output_task(self):
if self.output_task:
output_task = self.output_task
self.output_task = None
try:
processing_task.cancel()
await processing_task
output_task.cancel()
await output_task
except (asyncio.CancelledError, EOFError, IndexError):
pass

Expand Down
Loading
Loading