diff --git a/aider/helpers/model_providers.py b/aider/helpers/model_providers.py index fe9c58b4054..a1052d92ad4 100644 --- a/aider/helpers/model_providers.py +++ b/aider/helpers/model_providers.py @@ -60,7 +60,7 @@ def _first_env_value(names): return None -class _JSONOpenAIProvider(CustomLLM if CustomLLM is not None else object): # type: ignore[misc] +class _JSONOpenAIProvider(OpenAILikeChatHandler): # type: ignore[misc] """CustomLLM wrapper that routes OpenAI-compatible providers through LiteLLM.""" def __init__(self, slug: str, config: Dict): @@ -69,7 +69,6 @@ def __init__(self, slug: str, config: Dict): super().__init__() # type: ignore[misc] self.slug = slug self.config = config - self._chat_handler = OpenAILikeChatHandler() def _resolve_api_base(self, api_base: Optional[str]) -> str: base = ( @@ -124,212 +123,109 @@ def _build_request_params(self, optional_params, stream: bool): params["stream"] = True return params - def _invoke_handler( - self, - *, - model, - messages, - api_base, - custom_prompt_dict, - model_response, - print_verbose, - encoding, - api_key, - logging_obj, - optional_params, - litellm_params, - logger_fn, - headers, - timeout, - client, - stream: bool, - ): - api_base = self._resolve_api_base(api_base) - api_key = self._resolve_api_key(api_key) - headers = self._inject_headers(headers) - params = self._build_request_params(optional_params, stream) - cleaned_messages = self._apply_special_handling(messages) - api_model = self._normalize_model_name(model) - http_client = None - if HTTPHandler is not None and isinstance(client, HTTPHandler): - http_client = client - return self._chat_handler.completion( - model=api_model, - messages=cleaned_messages, - api_base=api_base, - custom_llm_provider="openai", - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=params, - litellm_params=litellm_params or {}, - logger_fn=logger_fn, - headers=headers, - timeout=timeout, - client=http_client, + def completion(self, *args, **kwargs): + kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) + kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) + kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) + kwargs["optional_params"] = self._build_request_params( + kwargs.get("optional_params", None), False ) - - def completion( - self, - model, - messages, - api_base, - custom_prompt_dict, - model_response, - print_verbose, - encoding, - api_key, - logging_obj, - optional_params, - litellm_params=None, - acompletion=None, - logger_fn=None, - headers=None, - timeout=None, - client=None, - ): - return self._invoke_handler( - model=model, - messages=messages, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - timeout=timeout, - client=client, - stream=False, + kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) + kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) + kwargs["custom_llm_provider"] = "openai" + + return super().completion(*args, **kwargs) + + async def acompletion(self, *args, **kwargs): + kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) + kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) + kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) + kwargs["optional_params"] = self._build_request_params( + kwargs.get("optional_params", None), False ) - - def streaming( - self, - model, - messages, - api_base, - custom_prompt_dict, - model_response, - print_verbose, - encoding, - api_key, - logging_obj, - optional_params, - litellm_params=None, - acompletion=None, - logger_fn=None, - headers=None, - timeout=None, - client=None, - ): - return self._invoke_handler( - model=model, - messages=messages, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - timeout=timeout, - client=client, - stream=True, + kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) + kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) + kwargs["custom_llm_provider"] = "openai" + kwargs["acompletion"] = True + + return await super().completion(*args, **kwargs) + + def streaming(self, *args, **kwargs): + kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) + kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) + kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) + kwargs["optional_params"] = self._build_request_params( + kwargs.get("optional_params", None), True ) + kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) + kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) + kwargs["custom_llm_provider"] = "openai" - def acompletion( - self, - model, - messages, - api_base, - custom_prompt_dict, - model_response, - print_verbose, - encoding, - api_key, - logging_obj, - optional_params, - litellm_params=None, - acompletion=None, - logger_fn=None, - headers=None, - timeout=None, - client=None, - ): - return self.completion( - model=model, - messages=messages, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - timeout=timeout, - client=client, - ) + response = super().completion(*args, **kwargs) + + for chunk in response: + yield self.get_generic_chunk(chunk) - def astreaming( - self, - model, - messages, - api_base, - custom_prompt_dict, - model_response, - print_verbose, - encoding, - api_key, - logging_obj, - optional_params, - litellm_params=None, - acompletion=None, - logger_fn=None, - headers=None, - timeout=None, - client=None, - ): - return self.streaming( - model=model, - messages=messages, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - timeout=timeout, - client=client, + async def astreaming(self, *args, **kwargs): + kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) + kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) + kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) + kwargs["optional_params"] = self._build_request_params( + kwargs.get("optional_params", None), True ) + kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) + kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) + kwargs["custom_llm_provider"] = "openai" + kwargs["acompletion"] = True + + response = await super().completion(*args, **kwargs) + + async for chunk in response: + yield self.get_generic_chunk(chunk) + + def get_generic_chunk(self, chunk): + # Extract the first choice (standard for single-candidate streams) + choice = chunk.choices[0] if chunk.choices else None + delta = choice.delta if choice else None + + # Safe extraction of text (content can be None in tool-call chunks) + text_content = delta.content if delta and delta.content else "" + + # Safe extraction of tool calls + # LiteLLM provides these as a list of objects, we pass them through + # or set to None if empty + tool_calls = delta.tool_calls if delta and delta.tool_calls else None + + if tool_calls and len(tool_calls): + tool_calls = tool_calls[0] + + # Handle Usage (often only present in the final chunk) + usage_data = getattr(chunk, "usage", None) + # If usage is a Pydantic object, dump it to dict; otherwise default to 0s + if hasattr(usage_data, "model_dump"): + usage_dict = usage_data.model_dump() + elif isinstance(usage_data, dict): + usage_dict = usage_data + else: + usage_dict = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0} + + # 3. Construct the GenericStreamingChunk dictionary + generic_chunk = { + "finish_reason": choice.finish_reason if choice else None, + "index": choice.index if choice else 0, + "is_finished": bool(choice.finish_reason) if choice else False, + "text": text_content, + "tool_use": tool_calls, + "usage": usage_dict, + } + + return generic_chunk def _register_provider_with_litellm(slug: str, config: Dict) -> None: """Register provider metadata and custom handlers with LiteLLM.""" try: - from litellm.llms.openai_like.json_loader import ( - JSONProviderRegistry, - SimpleProviderConfig, - ) + from litellm.llms.openai_like.json_loader import JSONProviderRegistry except Exception: return @@ -340,40 +236,13 @@ def _register_provider_with_litellm(slug: str, config: Dict) -> None: if not base_url or not api_key_env: return - if not JSONProviderRegistry.exists(slug): - payload = { - "base_url": base_url, - "api_key_env": api_key_env, - } - - api_base_env = _coerce_str(config.get("base_url_env")) - if api_base_env: - payload["api_base_env"] = api_base_env - - if config.get("param_mappings"): - payload["param_mappings"] = config["param_mappings"] - if config.get("special_handling"): - payload["special_handling"] = config["special_handling"] - if config.get("base_class"): - payload["base_class"] = config["base_class"] - - JSONProviderRegistry._providers[slug] = SimpleProviderConfig(slug, payload) - try: import litellm # noqa: WPS433 except Exception: return - provider_list = getattr(litellm, "provider_list", None) - if isinstance(provider_list, list) and slug not in provider_list: - provider_list.append(slug) - - openai_like = getattr(litellm, "_openai_like_providers", None) - if isinstance(openai_like, list) and slug not in openai_like: - openai_like.append(slug) - handler = _CUSTOM_HANDLERS.get(slug) - if handler is None and CustomLLM is not None and OpenAILikeChatHandler is not None: + if handler is None: handler = _JSONOpenAIProvider(slug, config) _CUSTOM_HANDLERS[slug] = handler diff --git a/aider/main.py b/aider/main.py index 2e8891db99a..1f4b8386a92 100644 --- a/aider/main.py +++ b/aider/main.py @@ -523,6 +523,10 @@ async def sanity_check_repo(repo, io): "auto_save_session": True, "input_task": True, "output_task": True, + "check_output_queue": True, + "_animate_spinner": True, + "handle_output_message": True, + "update_spinner": True, } diff --git a/aider/prompts/agent.yml b/aider/prompts/agent.yml index 89b3cb9027b..7429525e4be 100644 --- a/aider/prompts/agent.yml +++ b/aider/prompts/agent.yml @@ -1,5 +1,6 @@ # Agent prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [base] files_content_assistant_reply: | I understand. I'll use these files to help with your request. diff --git a/aider/prompts/architect.yml b/aider/prompts/architect.yml index be8ec65c692..0dcce2db588 100644 --- a/aider/prompts/architect.yml +++ b/aider/prompts/architect.yml @@ -1,5 +1,6 @@ # Architect prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [base] files_content_assistant_reply: | Ok, I will use that as the true, current contents of the files. diff --git a/aider/prompts/ask.yml b/aider/prompts/ask.yml index 947b3d7a960..69b0daa7407 100644 --- a/aider/prompts/ask.yml +++ b/aider/prompts/ask.yml @@ -1,5 +1,6 @@ # Ask prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [base] files_content_assistant_reply: | Ok, I will use that as the true, current contents of the files. diff --git a/aider/prompts/base.yml b/aider/prompts/base.yml index 06fbe5fc470..f8c8353f063 100644 --- a/aider/prompts/base.yml +++ b/aider/prompts/base.yml @@ -1,5 +1,6 @@ # Base prompts for all coder types # This file contains default prompts that can be overridden by specific YAML files +_inherits: [] system_reminder: "" @@ -48,6 +49,8 @@ files_no_full_files_with_repo_map_reply: | Ok, based on your requests I will suggest which files need to be edited and then stop and wait for your approval. +main_system: "" + repo_content_prefix: | Here are summaries of some files present in my git repository. Do not propose changes to these files, treat them as *read-only*. diff --git a/aider/prompts/context.yml b/aider/prompts/context.yml index 6470b468535..d5067c9fcc1 100644 --- a/aider/prompts/context.yml +++ b/aider/prompts/context.yml @@ -1,5 +1,6 @@ # Context prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [base] files_content_assistant_reply: | Ok, I will use that as the true, current contents of the files. diff --git a/aider/prompts/copypaste.yml b/aider/prompts/copypaste.yml index 98132569e2a..1048b5dee76 100644 --- a/aider/prompts/copypaste.yml +++ b/aider/prompts/copypaste.yml @@ -2,3 +2,4 @@ # Overrides specific prompts for copypaste format # Copypaste mode doesn't have its own prompts - it mirrors prompts from other coders # This file exists for completeness in the prompt registry +_inherits: [base] diff --git a/aider/prompts/editblock.yml b/aider/prompts/editblock.yml index de4bb5f5d37..e50147d326d 100644 --- a/aider/prompts/editblock.yml +++ b/aider/prompts/editblock.yml @@ -1,5 +1,6 @@ # EditBlock prompts - inherits from base.yaml # Overrides specific prompts for editblock format +_inherits: [base] main_system: | Act as an expert software developer. diff --git a/aider/prompts/editblock_fenced.yml b/aider/prompts/editblock_fenced.yml index e3f265e4b78..4e895d16b6e 100644 --- a/aider/prompts/editblock_fenced.yml +++ b/aider/prompts/editblock_fenced.yml @@ -1,5 +1,7 @@ # Editblock_Fenced prompts - inherits from base.yaml # Overrides specific prompts for editblock_fenced format +_inherits: [editblock, base] + example_messages: - role: user content: Change get_factorial() to use math.factorial diff --git a/aider/prompts/editblock_func.yml b/aider/prompts/editblock_func.yml index 475dffe372a..8b38308a26d 100644 --- a/aider/prompts/editblock_func.yml +++ b/aider/prompts/editblock_func.yml @@ -1,5 +1,6 @@ # Editblock_Func prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [base] files_content_prefix: | Here is the current content of the files: diff --git a/aider/prompts/editor_diff_fenced.yml b/aider/prompts/editor_diff_fenced.yml index b77083f6f3a..e51b2d4b279 100644 --- a/aider/prompts/editor_diff_fenced.yml +++ b/aider/prompts/editor_diff_fenced.yml @@ -1,5 +1,6 @@ # Editor_Diff_Fenced prompts - inherits from base.yaml # Overrides specific prompts for editor_diff_fenced format +_inherits: [editblock_fenced, editblock, base] go_ahead_tip: '' diff --git a/aider/prompts/editor_editblock.yml b/aider/prompts/editor_editblock.yml index b3904d20ac6..f1b1c582033 100644 --- a/aider/prompts/editor_editblock.yml +++ b/aider/prompts/editor_editblock.yml @@ -1,5 +1,6 @@ # Editor_Editblock prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [editblock, base] go_ahead_tip: '' diff --git a/aider/prompts/editor_whole.yml b/aider/prompts/editor_whole.yml index c285afd0d1c..90d3eb5ef9f 100644 --- a/aider/prompts/editor_whole.yml +++ b/aider/prompts/editor_whole.yml @@ -1,5 +1,6 @@ # Editor_Whole prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [wholefile, base] main_system: | Act as an expert software developer tasked with editing source code based on instructions from an architect. diff --git a/aider/prompts/help.yml b/aider/prompts/help.yml index 968f5cfbfac..0597bf76aa1 100644 --- a/aider/prompts/help.yml +++ b/aider/prompts/help.yml @@ -1,5 +1,6 @@ # Help prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [base] files_content_prefix: | These are some files we have been discussing that we may want to edit after you answer my questions: diff --git a/aider/prompts/patch.yml b/aider/prompts/patch.yml index c6ebcd09e17..30adb4ce509 100644 --- a/aider/prompts/patch.yml +++ b/aider/prompts/patch.yml @@ -1,5 +1,6 @@ # Patch prompts - inherits from base.yaml # Overrides specific prompts for patch format +_inherits: [editblock, base] main_system: | Act as an expert software developer. diff --git a/aider/prompts/single_wholefile_func.yml b/aider/prompts/single_wholefile_func.yml index 0c3a6e7aa4f..ed7cc7fe225 100644 --- a/aider/prompts/single_wholefile_func.yml +++ b/aider/prompts/single_wholefile_func.yml @@ -1,5 +1,6 @@ # Single_Wholefile_Func prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [base] files_content_prefix: | Here is the current content of the file: diff --git a/aider/prompts/udiff.yml b/aider/prompts/udiff.yml index 207eb598dd3..614fb5206f0 100644 --- a/aider/prompts/udiff.yml +++ b/aider/prompts/udiff.yml @@ -1,5 +1,6 @@ # UnifiedDiff prompts - inherits from base.yaml # Overrides specific prompts for unified diff format +_inherits: [base] main_system: | Act as an expert software developer. diff --git a/aider/prompts/udiff_simple.yml b/aider/prompts/udiff_simple.yml index f7eef9be1c8..d2f1183cb77 100644 --- a/aider/prompts/udiff_simple.yml +++ b/aider/prompts/udiff_simple.yml @@ -1,5 +1,6 @@ # Udiff_Simple prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [udiff, base] system_reminder: | # File editing rules: diff --git a/aider/prompts/utils/prompt_registry.py b/aider/prompts/utils/prompt_registry.py index 75173b1b9d9..a8b906e5fce 100644 --- a/aider/prompts/utils/prompt_registry.py +++ b/aider/prompts/utils/prompt_registry.py @@ -2,13 +2,16 @@ Central registry for managing all prompts in YAML format. This module implements a YAML-based prompt inheritance system where: -1. base.yml contains default prompts -2. Specific YAML files can override/extend base.yml -3. No Python prompt classes needed +1. base.yml contains default prompts with `_inherits: []` +2. Specific YAML files can override/extend using `_inherits` key +3. Inheritance chains are resolved recursively (e.g., editor_diff_fenced → editblock_fenced → editblock → base) +4. Prompts are merged in inheritance order (base → intermediate → specific) +5. The `_inherits` key is removed from final merged results +6. Circular dependencies are detected and prevented """ from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import yaml @@ -59,6 +62,54 @@ def _merge_prompts(self, base: Dict[str, Any], override: Dict[str, Any]) -> Dict return result + def _resolve_inheritance_chain( + self, prompt_name: str, visited: Optional[set] = None + ) -> List[str]: + """ + Resolve the full inheritance chain for a prompt type. + + Args: + prompt_name: Name of the prompt type + visited: Set of already visited prompts to detect circular dependencies + + Returns: + List of prompt names in inheritance order (from base to most specific) + """ + if visited is None: + visited = set() + + if prompt_name in visited: + raise ValueError(f"Circular dependency detected in prompt inheritance: {prompt_name}") + + visited.add(prompt_name) + + # Special case for base.yml + if prompt_name == "base": + return ["base"] + + # Load the prompt file to get its inheritance chain + prompt_path = self._prompts_dir / f"{prompt_name}.yml" + if not prompt_path.exists(): + raise FileNotFoundError(f"Prompt file not found: {prompt_path}") + + prompt_data = self._load_yaml_file(prompt_path) + inherits = prompt_data.get("_inherits", []) + + # Resolve inheritance chain recursively + inheritance_chain = [] + for parent in inherits: + parent_chain = self._resolve_inheritance_chain(parent, visited.copy()) + # Add parent chain, avoiding duplicates while preserving order + for item in parent_chain: + if item not in inheritance_chain: + inheritance_chain.append(item) + + # Add current prompt to the end of the chain + if prompt_name not in inheritance_chain: + inheritance_chain.append(prompt_name) + + return inheritance_chain + def get_prompt(self, prompt_name: str) -> Dict[str, Any]: """ Get prompts for a specific prompt type. @@ -73,15 +124,25 @@ def get_prompt(self, prompt_name: str) -> Dict[str, Any]: if prompt_name in self._prompts_cache: return self._prompts_cache[prompt_name] - # Load base prompts - base_prompts = self._get_base_prompts() + # Resolve inheritance chain + inheritance_chain = self._resolve_inheritance_chain(prompt_name) - # Load specific prompt file if it exists - prompt_path = self._prompts_dir / f"{prompt_name}.yml" - specific_prompts = self._load_yaml_file(prompt_path) + # Start with empty dict and merge in inheritance order + merged_prompts: Dict[str, Any] = {} + + for current_name in inheritance_chain: + # Load prompts for this level + if current_name == "base": + current_prompts = self._get_base_prompts() + else: + prompt_path = self._prompts_dir / f"{current_name}.yml" + current_prompts = self._load_yaml_file(prompt_path) + + # Merge current prompts into accumulated result + merged_prompts = self._merge_prompts(merged_prompts, current_prompts) - # Merge base with specific overrides - merged_prompts = self._merge_prompts(base_prompts, specific_prompts) + # Remove _inherits key from final result (it's metadata, not a prompt) + merged_prompts.pop("_inherits", None) # Cache the result self._prompts_cache[prompt_name] = merged_prompts diff --git a/aider/prompts/wholefile.yml b/aider/prompts/wholefile.yml index 718ade72d13..bdfe408a44c 100644 --- a/aider/prompts/wholefile.yml +++ b/aider/prompts/wholefile.yml @@ -1,5 +1,6 @@ # WholeFile prompts - inherits from base.yaml # Overrides specific prompts for wholefile format +_inherits: [base] main_system: | Act as an expert software developer. diff --git a/aider/prompts/wholefile_func.yml b/aider/prompts/wholefile_func.yml index 743e3f99ed6..eb2f25b4e12 100644 --- a/aider/prompts/wholefile_func.yml +++ b/aider/prompts/wholefile_func.yml @@ -1,5 +1,6 @@ # Wholefile_Func prompts - inherits from base.yaml # Overrides specific prompts +_inherits: [base] files_content_prefix: | Here is the current content of the files: diff --git a/aider/resources/providers.json b/aider/resources/providers.json index 7c022a21095..66171d1a23e 100644 --- a/aider/resources/providers.json +++ b/aider/resources/providers.json @@ -1,27 +1,4 @@ { - "openrouter": { - "api_base": "https://openrouter.ai/api/v1", - "models_url": "https://openrouter.ai/api/v1/models", - "api_key_env": [ - "OPENROUTER_API_KEY" - ], - "requires_api_key": false, - "default_headers": { - "HTTP-Referer": "https://aider.chat", - "X-Title": "aider" - } - }, - "openai": { - "api_base": "https://api.openai.com/v1", - "models_url": "https://api.openai.com/v1/models", - "api_key_env": [ - "OPENAI_API_KEY" - ], - "base_url_env": [ - "OPENAI_API_BASE" - ], - "display_name": "openai" - }, "apertis": { "api_base": "https://api.stima.tech/v1", "api_key_env": [ diff --git a/tests/basic/test_prompts.py b/tests/basic/test_prompts.py new file mode 100644 index 00000000000..6d3950dedb6 --- /dev/null +++ b/tests/basic/test_prompts.py @@ -0,0 +1,424 @@ +""" +Tests for the prompt inheritance system and prompt registry. + +This module tests the YAML-based prompt inheritance system where: +1. base.yml contains default prompts with `_inherits: []` +2. Specific YAML files can override/extend using `_inherits` key +3. Inheritance chains are resolved recursively +4. Prompts are merged in inheritance order (base → intermediate → specific) +5. The `_inherits` key is removed from final merged results +6. Circular dependencies are detected and prevented +""" + +import os +import tempfile +from pathlib import Path + +import pytest +import yaml + +from aider.prompts.utils.prompt_registry import PromptRegistry + + +class TestPromptRegistry: + """Test suite for PromptRegistry class.""" + + def setup_method(self): + """Set up test fixtures.""" + # Create a fresh instance for each test + self.registry = PromptRegistry.__new__(PromptRegistry) + self.registry._prompts_dir = Path(__file__).parent / "../../aider/prompts" + self.registry._initialized = True + self.registry._prompts_cache = {} + self.registry._base_prompts = None + + def test_singleton_pattern(self): + """Test that PromptRegistry follows singleton pattern.""" + registry1 = PromptRegistry() + registry2 = PromptRegistry() + assert registry1 is registry2, "PromptRegistry should be a singleton" + + def test_get_base_prompts(self): + """Test loading base prompts.""" + base_prompts = self.registry._get_base_prompts() + assert isinstance(base_prompts, dict) + assert "_inherits" in base_prompts + assert base_prompts["_inherits"] == [] + assert "system_reminder" in base_prompts + + def test_load_yaml_file_valid(self): + """Test loading a valid YAML file.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f: + yaml.dump({"test_key": "test_value", "nested": {"key": "value"}}, f) + temp_path = f.name + + try: + result = self.registry._load_yaml_file(Path(temp_path)) + assert result == {"test_key": "test_value", "nested": {"key": "value"}} + finally: + os.unlink(temp_path) + + def test_load_yaml_file_not_found(self): + """Test loading a non-existent YAML file returns empty dict.""" + result = self.registry._load_yaml_file(Path("/nonexistent/path/file.yml")) + assert result == {} + + def test_load_yaml_file_invalid_yaml(self): + """Test loading an invalid YAML file raises ValueError.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as f: + f.write("invalid: yaml: : :") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="Error parsing YAML file"): + self.registry._load_yaml_file(Path(temp_path)) + finally: + os.unlink(temp_path) + + def test_merge_prompts_simple(self): + """Test simple dictionary merging.""" + base = {"key1": "value1", "key2": "value2"} + override = {"key2": "new_value2", "key3": "value3"} + result = self.registry._merge_prompts(base, override) + expected = {"key1": "value1", "key2": "new_value2", "key3": "value3"} + assert result == expected + + def test_merge_prompts_nested(self): + """Test nested dictionary merging.""" + base = {"key1": "value1", "nested": {"a": 1, "b": 2}} + override = {"nested": {"b": 20, "c": 30}, "key2": "value2"} + result = self.registry._merge_prompts(base, override) + expected = {"key1": "value1", "nested": {"a": 1, "b": 20, "c": 30}, "key2": "value2"} + assert result == expected + + def test_merge_prompts_deep_nested(self): + """Test deeply nested dictionary merging.""" + base = {"a": {"b": {"c": {"d": 1, "e": 2}}}} + override = {"a": {"b": {"c": {"e": 20, "f": 30}}}} + result = self.registry._merge_prompts(base, override) + expected = {"a": {"b": {"c": {"d": 1, "e": 20, "f": 30}}}} + assert result == expected + + def test_resolve_inheritance_chain_base(self): + """Test inheritance chain resolution for base.yml.""" + chain = self.registry._resolve_inheritance_chain("base") + assert chain == ["base"] + + def test_resolve_inheritance_chain_simple(self): + """Test inheritance chain resolution for a simple prompt.""" + # Create a temporary directory with test YAML files + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create base.yml + base_path = temp_path / "base.yml" + with open(base_path, "w") as f: + yaml.dump({"_inherits": []}, f) + + # Create simple.yml that inherits from base + simple_path = temp_path / "simple.yml" + with open(simple_path, "w") as f: + yaml.dump({"_inherits": ["base"]}, f) + + # Create a test registry with our temp directory + test_registry = PromptRegistry.__new__(PromptRegistry) + test_registry._prompts_dir = temp_path + test_registry._initialized = True + test_registry._prompts_cache = {} + test_registry._base_prompts = None + + chain = test_registry._resolve_inheritance_chain("simple") + assert chain == ["base", "simple"] + + def test_resolve_inheritance_chain_complex(self): + """Test inheritance chain resolution for a complex prompt.""" + # Create a temporary directory with test YAML files + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create base.yml + base_path = temp_path / "base.yml" + with open(base_path, "w") as f: + yaml.dump({"_inherits": []}, f) + + # Create editblock.yml that inherits from base + editblock_path = temp_path / "editblock.yml" + with open(editblock_path, "w") as f: + yaml.dump({"_inherits": ["base"]}, f) + + # Create editblock_fenced.yml that inherits from editblock and base + editblock_fenced_path = temp_path / "editblock_fenced.yml" + with open(editblock_fenced_path, "w") as f: + yaml.dump({"_inherits": ["editblock", "base"]}, f) + + # Create a test registry with our temp directory + test_registry = PromptRegistry.__new__(PromptRegistry) + test_registry._prompts_dir = temp_path + test_registry._initialized = True + test_registry._prompts_cache = {} + test_registry._base_prompts = None + + chain = test_registry._resolve_inheritance_chain("editblock_fenced") + assert chain == ["base", "editblock", "editblock_fenced"] + + def test_resolve_inheritance_chain_circular_dependency(self): + """Test detection of circular dependencies.""" + # Create a temporary directory with circular YAML files + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a.yml that inherits from b.yml + a_path = temp_path / "a.yml" + with open(a_path, "w") as f: + yaml.dump({"_inherits": ["b"]}, f) + + # Create b.yml that inherits from a.yml (circular!) + b_path = temp_path / "b.yml" + with open(b_path, "w") as f: + yaml.dump({"_inherits": ["a"]}, f) + + # Create a test registry with our temp directory + test_registry = PromptRegistry.__new__(PromptRegistry) + test_registry._prompts_dir = temp_path + test_registry._initialized = True + test_registry._prompts_cache = {} + test_registry._base_prompts = None + + # Should detect circular dependency + with pytest.raises(ValueError, match="Circular dependency detected"): + test_registry._resolve_inheritance_chain("a") + + def test_resolve_inheritance_chain_file_not_found(self): + """Test error when prompt file doesn't exist.""" + with pytest.raises(FileNotFoundError, match="Prompt file not found"): + self.registry._resolve_inheritance_chain("nonexistent") + + def test_get_prompt_base(self): + """Test getting base prompts.""" + prompts = self.registry.get_prompt("base") + assert isinstance(prompts, dict) + assert "_inherits" not in prompts # Should be removed + assert "system_reminder" in prompts + # Base has empty system_reminder + assert prompts["system_reminder"] == "" + + def test_get_prompt_editblock(self): + """Test getting editblock prompts.""" + prompts = self.registry.get_prompt("editblock") + assert isinstance(prompts, dict) + assert "_inherits" not in prompts # Should be removed + assert "main_system" in prompts + assert "system_reminder" in prompts + assert "Act as an expert software developer" in prompts["main_system"] + + def test_get_prompt_patch(self): + """Test getting patch prompts (inherits from editblock).""" + prompts = self.registry.get_prompt("patch") + assert isinstance(prompts, dict) + assert "_inherits" not in prompts # Should be removed + assert "main_system" in prompts + assert "example_messages" in prompts + # Patch should have its own system_reminder that overrides editblock's + assert "V4A Diff Format" in prompts["system_reminder"] + + def test_get_prompt_caching(self): + """Test that prompts are cached.""" + # Clear cache + self.registry.reload_prompts() + assert len(self.registry._prompts_cache) == 0 + + # First call should populate cache + prompts1 = self.registry.get_prompt("editblock") + assert len(self.registry._prompts_cache) == 1 + + # Second call should use cache + prompts2 = self.registry.get_prompt("editblock") + assert len(self.registry._prompts_cache) == 1 + assert prompts1 is prompts2 # Same object from cache + + def test_get_prompt_removes_inherits_key(self): + """Test that _inherits key is removed from final prompts.""" + # Test with a few different prompt types + for prompt_name in ["base", "editblock", "patch", "editor_diff_fenced"]: + prompts = self.registry.get_prompt(prompt_name) + assert "_inherits" not in prompts, f"_inherits key found in {prompt_name}" + + def test_reload_prompts(self): + """Test that reload_prompts clears cache.""" + # Populate cache + self.registry.get_prompt("editblock") + self.registry.get_prompt("patch") + assert len(self.registry._prompts_cache) == 2 + + # Reload should clear cache + self.registry.reload_prompts() + assert len(self.registry._prompts_cache) == 0 + assert self.registry._base_prompts is None + + def test_list_available_prompts(self): + """Test listing available prompts.""" + prompts = self.registry.list_available_prompts() + assert isinstance(prompts, list) + assert len(prompts) > 0 + assert "editblock" in prompts + assert "patch" in prompts + assert "base" not in prompts # base.yml should be excluded + assert all(isinstance(p, str) for p in prompts) + + def test_inheritance_chain_real_example(self): + """Test a real inheritance chain from the actual YAML files.""" + # Test editor_diff_fenced which has a deep inheritance chain + chain = self.registry._resolve_inheritance_chain("editor_diff_fenced") + expected_chain = ["base", "editblock", "editblock_fenced", "editor_diff_fenced"] + assert chain == expected_chain, f"Expected {expected_chain}, got {chain}" + + # Get the prompts and verify they have expected content + prompts = self.registry.get_prompt("editor_diff_fenced") + assert "main_system" in prompts + assert "system_reminder" in prompts + assert "go_ahead_tip" in prompts + assert prompts["go_ahead_tip"] == "" # editor_diff_fenced overrides this to empty string + + def test_all_prompts_loadable(self): + """Test that all available prompts can be loaded without errors.""" + prompt_names = self.registry.list_available_prompts() + + for name in prompt_names: + try: + prompts = self.registry.get_prompt(name) + assert isinstance(prompts, dict) + # Some prompts might be minimal (like copypaste) + if name != "copypaste": + assert len(prompts) > 0, f"Prompt '{name}' is empty" + except Exception as e: + pytest.fail(f"Failed to load prompt '{name}': {e}") + + def test_prompt_override_behavior(self): + """Test that prompt overrides work correctly in inheritance chain.""" + # Get editblock prompts + editblock_prompts = self.registry.get_prompt("editblock") + + # Get patch prompts (inherits from editblock) + patch_prompts = self.registry.get_prompt("patch") + + # Patch should have different system_reminder than editblock + assert editblock_prompts["system_reminder"] != patch_prompts["system_reminder"] + + # But they should share some common fields from base + assert "files_content_prefix" in editblock_prompts + assert "files_content_prefix" in patch_prompts + # The files_content_prefix should be the same (inherited from base) + assert editblock_prompts["files_content_prefix"] == patch_prompts["files_content_prefix"] + + +class TestPromptInheritanceIntegration: + """Integration tests for the prompt inheritance system.""" + + def setup_method(self): + """Set up test fixtures.""" + self.registry = PromptRegistry() + self.registry.reload_prompts() + + def test_complete_inheritance_workflow(self): + """Test complete workflow from YAML files to merged prompts.""" + # Test a prompt with deep inheritance + prompts = self.registry.get_prompt("editor_diff_fenced") + + # Verify it has content from all levels of inheritance + assert "main_system" in prompts # From editblock + assert "example_messages" in prompts # From editblock_fenced + assert "go_ahead_tip" in prompts # From editor_diff_fenced (overridden to empty string) + assert "system_reminder" in prompts # From editblock + assert "files_content_prefix" in prompts # From base + + # Verify specific overrides + assert prompts["go_ahead_tip"] == "" # editor_diff_fenced overrides this + + def test_yaml_structure_preserved(self): + """Test that YAML structure (lists, multiline strings) is preserved.""" + # Get editblock prompts which have example_messages list + prompts = self.registry.get_prompt("editblock") + + assert "example_messages" in prompts + example_messages = prompts["example_messages"] + assert isinstance(example_messages, list) + assert len(example_messages) > 0 + + # Check structure of first example message + first_msg = example_messages[0] + assert isinstance(first_msg, dict) + assert "role" in first_msg + assert "content" in first_msg + + # Check multiline strings are preserved + assert "\n" in prompts["main_system"] # Should have newlines + + +if __name__ == "__main__": + # Run tests if executed directly + pytest.main([__file__, "-v"]) + + +class TestPromptInheritanceChains: + """Test that all prompt inheritance chains are valid and match expected structure.""" + + def setup_method(self): + """Set up test fixtures.""" + self.registry = PromptRegistry() + self.registry.reload_prompts() + + def test_all_inheritance_chains_resolvable(self): + """Test that all inheritance chains can be resolved without errors.""" + prompt_names = self.registry.list_available_prompts() + + for name in prompt_names: + try: + chain = self.registry._resolve_inheritance_chain(name) + assert isinstance(chain, list) + assert len(chain) > 0 + assert "base" in chain, f"Prompt '{name}' should inherit from base" + assert chain[-1] == name, f"Last item in chain should be '{name}'" + except Exception as e: + pytest.fail(f"Failed to resolve inheritance chain for '{name}': {e}") + + def test_expected_inheritance_chains(self): + """Test specific inheritance chains that we expect to exist.""" + expected_chains = { + "base": ["base"], + "editblock": ["base", "editblock"], + "editblock_fenced": ["base", "editblock", "editblock_fenced"], + "editor_diff_fenced": ["base", "editblock", "editblock_fenced", "editor_diff_fenced"], + "editor_editblock": ["base", "editblock", "editor_editblock"], + "editor_whole": ["base", "wholefile", "editor_whole"], + "patch": ["base", "editblock", "patch"], + "udiff": ["base", "udiff"], # udiff inherits directly from base + "udiff_simple": ["base", "udiff", "udiff_simple"], # udiff_simple inherits from udiff + "wholefile": ["base", "wholefile"], + "wholefile_func": ["base", "wholefile_func"], # inherits directly from base + "single_wholefile_func": [ + "base", + "single_wholefile_func", + ], # inherits directly from base + "editblock_func": ["base", "editblock_func"], # inherits directly from base + "agent": ["base", "agent"], + "architect": ["base", "architect"], + "ask": ["base", "ask"], + "context": ["base", "context"], + "copypaste": ["base", "copypaste"], + "help": ["base", "help"], + } + + for prompt_name, expected_chain in expected_chains.items(): + if prompt_name == "base": + continue # Already tested separately + + try: + chain = self.registry._resolve_inheritance_chain(prompt_name) + assert ( + chain == expected_chain + ), f"Chain for '{prompt_name}' mismatch. Expected {expected_chain}, got {chain}" + except FileNotFoundError: + # Some prompts might not exist in all configurations + if prompt_name in ["copypaste"]: + continue # copypaste might not exist + raise