Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,17 @@ async def run_client():

def ensure_version_suffix(base_url):
"""
确保 base_url 以 '/v<n>/' 结尾,其中 n 是数字。
如果不是这样,则添加 '/v1/' 到末尾。
Ensures base_url is properly formatted for the API provider.
"""
# 正则表达式匹配 '/v<n>/' 其中 <n> 是任意字符
match = re.search(r'/v(.+)/$', base_url)
# Special case for Perplexity API
if "api.perplexity.ai" in base_url:
# Return a fixed URL with no manipulation
return "https://api.perplexity.ai"

# Original logic for OpenAI and other APIs
match = re.search(r'/v(.+)/$', base_url)
if not match:
# 如果没有匹配到,则添加 '/v1/'
# If no version suffix, add '/v1/'
if not base_url.endswith('/'):
base_url += '/'
base_url += 'v1/'
Expand Down Expand Up @@ -456,11 +459,22 @@ def send(
history[i]["role"] = "user"
history.append({"role": "assistant", "content": "好的,我会按照你的指示来操作"})
break
openai_client = OpenAI(
api_key= self.apikey,
base_url=self.baseurl,
)
if "openai.azure.com" in self.baseurl:
# Create the OpenAI client
if "api.perplexity.ai" in self.baseurl:
try:
# Use the simplest possible client initialization for Perplexity
openai_client = OpenAI(
api_key=self.apikey,
base_url="https://api.perplexity.ai"
)
except Exception as e:
print(f"Error initializing Perplexity client: {e}")
# Fallback to standard initialization
openai_client = OpenAI(
api_key=self.apikey,
base_url=self.baseurl
)
elif "openai.azure.com" in self.baseurl:
# 获取API版本
api_version = self.baseurl.split("=")[-1].split("/")[0]
# 获取azure_endpoint
Expand All @@ -471,6 +485,11 @@ def send(
azure_endpoint=azure_endpoint,
)
openai_client = azure
else:
openai_client = OpenAI(
api_key=self.apikey,
base_url=self.baseurl,
)
new_message = {"role": "user", "content": user_prompt}
history.append(new_message)
print(history)
Expand Down Expand Up @@ -1107,7 +1126,12 @@ def INPUT_TYPES(s):
CATEGORY = "大模型派对(llm_party)/模型加载器(model loader)"

def chatbot(self, model_name, base_url=None, api_key=None, is_ollama=False):
if is_ollama:
if "api.perplexity.ai" in (base_url or ""):
# For Perplexity API, simplify the base URL
openai.base_url = ensure_version_suffix(base_url)
openai.api_key = api_key
print(f"Using Perplexity API with base_url: {openai.base_url}")
elif is_ollama:
openai.api_key = "ollama"
openai.base_url = "http://127.0.0.1:11434/v1/"
else:
Expand All @@ -1130,7 +1154,7 @@ def chatbot(self, model_name, base_url=None, api_key=None, is_ollama=False):
api_keys = load_api_keys(config_path)
openai.api_key = api_keys.get("openai_api_key")
openai.base_url = api_keys.get("base_url")
if openai.base_url != "":
if openai.base_url != "" and "api.perplexity.ai" not in openai.base_url:
if openai.base_url[-1] != "/":
openai.base_url = openai.base_url + "/"

Expand Down