Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 94 additions & 225 deletions aider/helpers/model_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _first_env_value(names):
return None


class _JSONOpenAIProvider(CustomLLM if CustomLLM is not None else object): # type: ignore[misc]
class _JSONOpenAIProvider(OpenAILikeChatHandler): # type: ignore[misc]
"""CustomLLM wrapper that routes OpenAI-compatible providers through LiteLLM."""

def __init__(self, slug: str, config: Dict):
Expand All @@ -69,7 +69,6 @@ def __init__(self, slug: str, config: Dict):
super().__init__() # type: ignore[misc]
self.slug = slug
self.config = config
self._chat_handler = OpenAILikeChatHandler()

def _resolve_api_base(self, api_base: Optional[str]) -> str:
base = (
Expand Down Expand Up @@ -124,212 +123,109 @@ def _build_request_params(self, optional_params, stream: bool):
params["stream"] = True
return params

def _invoke_handler(
self,
*,
model,
messages,
api_base,
custom_prompt_dict,
model_response,
print_verbose,
encoding,
api_key,
logging_obj,
optional_params,
litellm_params,
logger_fn,
headers,
timeout,
client,
stream: bool,
):
api_base = self._resolve_api_base(api_base)
api_key = self._resolve_api_key(api_key)
headers = self._inject_headers(headers)
params = self._build_request_params(optional_params, stream)
cleaned_messages = self._apply_special_handling(messages)
api_model = self._normalize_model_name(model)
http_client = None
if HTTPHandler is not None and isinstance(client, HTTPHandler):
http_client = client
return self._chat_handler.completion(
model=api_model,
messages=cleaned_messages,
api_base=api_base,
custom_llm_provider="openai",
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=params,
litellm_params=litellm_params or {},
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=http_client,
def completion(self, *args, **kwargs):
kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None))
kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None))
kwargs["headers"] = self._inject_headers(kwargs.get("headers", None))
kwargs["optional_params"] = self._build_request_params(
kwargs.get("optional_params", None), False
)

def completion(
self,
model,
messages,
api_base,
custom_prompt_dict,
model_response,
print_verbose,
encoding,
api_key,
logging_obj,
optional_params,
litellm_params=None,
acompletion=None,
logger_fn=None,
headers=None,
timeout=None,
client=None,
):
return self._invoke_handler(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
stream=False,
kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", []))
kwargs["model"] = self._normalize_model_name(kwargs.get("model", None))
kwargs["custom_llm_provider"] = "openai"

return super().completion(*args, **kwargs)

async def acompletion(self, *args, **kwargs):
kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None))
kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None))
kwargs["headers"] = self._inject_headers(kwargs.get("headers", None))
kwargs["optional_params"] = self._build_request_params(
kwargs.get("optional_params", None), False
)

def streaming(
self,
model,
messages,
api_base,
custom_prompt_dict,
model_response,
print_verbose,
encoding,
api_key,
logging_obj,
optional_params,
litellm_params=None,
acompletion=None,
logger_fn=None,
headers=None,
timeout=None,
client=None,
):
return self._invoke_handler(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
stream=True,
kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", []))
kwargs["model"] = self._normalize_model_name(kwargs.get("model", None))
kwargs["custom_llm_provider"] = "openai"
kwargs["acompletion"] = True

return await super().completion(*args, **kwargs)

def streaming(self, *args, **kwargs):
kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None))
kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None))
kwargs["headers"] = self._inject_headers(kwargs.get("headers", None))
kwargs["optional_params"] = self._build_request_params(
kwargs.get("optional_params", None), True
)
kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", []))
kwargs["model"] = self._normalize_model_name(kwargs.get("model", None))
kwargs["custom_llm_provider"] = "openai"

def acompletion(
self,
model,
messages,
api_base,
custom_prompt_dict,
model_response,
print_verbose,
encoding,
api_key,
logging_obj,
optional_params,
litellm_params=None,
acompletion=None,
logger_fn=None,
headers=None,
timeout=None,
client=None,
):
return self.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
)
response = super().completion(*args, **kwargs)

for chunk in response:
yield self.get_generic_chunk(chunk)

def astreaming(
self,
model,
messages,
api_base,
custom_prompt_dict,
model_response,
print_verbose,
encoding,
api_key,
logging_obj,
optional_params,
litellm_params=None,
acompletion=None,
logger_fn=None,
headers=None,
timeout=None,
client=None,
):
return self.streaming(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
async def astreaming(self, *args, **kwargs):
kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None))
kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None))
kwargs["headers"] = self._inject_headers(kwargs.get("headers", None))
kwargs["optional_params"] = self._build_request_params(
kwargs.get("optional_params", None), True
)
kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", []))
kwargs["model"] = self._normalize_model_name(kwargs.get("model", None))
kwargs["custom_llm_provider"] = "openai"
kwargs["acompletion"] = True

response = await super().completion(*args, **kwargs)

async for chunk in response:
yield self.get_generic_chunk(chunk)

def get_generic_chunk(self, chunk):
# Extract the first choice (standard for single-candidate streams)
choice = chunk.choices[0] if chunk.choices else None
delta = choice.delta if choice else None

# Safe extraction of text (content can be None in tool-call chunks)
text_content = delta.content if delta and delta.content else ""

# Safe extraction of tool calls
# LiteLLM provides these as a list of objects, we pass them through
# or set to None if empty
tool_calls = delta.tool_calls if delta and delta.tool_calls else None

if tool_calls and len(tool_calls):
tool_calls = tool_calls[0]

# Handle Usage (often only present in the final chunk)
usage_data = getattr(chunk, "usage", None)
# If usage is a Pydantic object, dump it to dict; otherwise default to 0s
if hasattr(usage_data, "model_dump"):
usage_dict = usage_data.model_dump()
elif isinstance(usage_data, dict):
usage_dict = usage_data
else:
usage_dict = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}

# 3. Construct the GenericStreamingChunk dictionary
generic_chunk = {
"finish_reason": choice.finish_reason if choice else None,
"index": choice.index if choice else 0,
"is_finished": bool(choice.finish_reason) if choice else False,
"text": text_content,
"tool_use": tool_calls,
"usage": usage_dict,
}

return generic_chunk


def _register_provider_with_litellm(slug: str, config: Dict) -> None:
"""Register provider metadata and custom handlers with LiteLLM."""
try:
from litellm.llms.openai_like.json_loader import (
JSONProviderRegistry,
SimpleProviderConfig,
)
from litellm.llms.openai_like.json_loader import JSONProviderRegistry
except Exception:
return

Expand All @@ -340,40 +236,13 @@ def _register_provider_with_litellm(slug: str, config: Dict) -> None:
if not base_url or not api_key_env:
return

if not JSONProviderRegistry.exists(slug):
payload = {
"base_url": base_url,
"api_key_env": api_key_env,
}

api_base_env = _coerce_str(config.get("base_url_env"))
if api_base_env:
payload["api_base_env"] = api_base_env

if config.get("param_mappings"):
payload["param_mappings"] = config["param_mappings"]
if config.get("special_handling"):
payload["special_handling"] = config["special_handling"]
if config.get("base_class"):
payload["base_class"] = config["base_class"]

JSONProviderRegistry._providers[slug] = SimpleProviderConfig(slug, payload)

try:
import litellm # noqa: WPS433
except Exception:
return

provider_list = getattr(litellm, "provider_list", None)
if isinstance(provider_list, list) and slug not in provider_list:
provider_list.append(slug)

openai_like = getattr(litellm, "_openai_like_providers", None)
if isinstance(openai_like, list) and slug not in openai_like:
openai_like.append(slug)

handler = _CUSTOM_HANDLERS.get(slug)
if handler is None and CustomLLM is not None and OpenAILikeChatHandler is not None:
if handler is None:
handler = _JSONOpenAIProvider(slug, config)
_CUSTOM_HANDLERS[slug] = handler

Expand Down
4 changes: 4 additions & 0 deletions aider/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,10 @@ async def sanity_check_repo(repo, io):
"auto_save_session": True,
"input_task": True,
"output_task": True,
"check_output_queue": True,
"_animate_spinner": True,
"handle_output_message": True,
"update_spinner": True,
}


Expand Down
1 change: 1 addition & 0 deletions aider/prompts/agent.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Agent prompts - inherits from base.yaml
# Overrides specific prompts
_inherits: [base]

files_content_assistant_reply: |
I understand. I'll use these files to help with your request.
Expand Down
1 change: 1 addition & 0 deletions aider/prompts/architect.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Architect prompts - inherits from base.yaml
# Overrides specific prompts
_inherits: [base]

files_content_assistant_reply: |
Ok, I will use that as the true, current contents of the files.
Expand Down
1 change: 1 addition & 0 deletions aider/prompts/ask.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Ask prompts - inherits from base.yaml
# Overrides specific prompts
_inherits: [base]

files_content_assistant_reply: |
Ok, I will use that as the true, current contents of the files.
Expand Down
Loading
Loading