Skip to content

Commit 792af89

Browse files
committed
added support for gpt-5
1 parent 11d6517 commit 792af89

File tree

24 files changed

+33
-32
lines changed

24 files changed

+33
-32
lines changed

skllm/llm/anthropic/completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict, List, Optional
22
from skllm.llm.anthropic.credentials import set_credentials
33
from skllm.utils import retry
4-
from model_constants import ANTHROPIC_CLAUDE_MODEL
4+
from skllm.model_constants import ANTHROPIC_CLAUDE_MODEL
55

66
@retry(max_retries=3)
77
def get_chat_completion(

skllm/llm/gpt/clients/openai/completion.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
set_credentials,
66
)
77
from skllm.utils import retry
8-
from model_constants import OPENAI_GPT_MODEL
8+
from skllm.model_constants import OPENAI_GPT_MODEL
99

1010

1111
@retry(max_retries=3)
@@ -44,10 +44,11 @@ def get_chat_completion(
4444
client = set_azure_credentials(key, org)
4545
else:
4646
raise ValueError("Invalid API")
47-
model_dict = {"model": model}
47+
model_dict: dict = {"model": model}
4848
if json_response and api == "openai":
4949
model_dict["response_format"] = {"type": "json_object"}
50-
completion = client.chat.completions.create(
51-
temperature=0.0, messages=messages, **model_dict
52-
)
50+
if not model.startswith(("gpt-o", "gpt-5")):
51+
model_dict["temperature"] = 0.0
52+
print("Setting the temperature ", model_dict.get("temperature"))
53+
completion = client.chat.completions.create(messages=messages, **model_dict) # type: ignore
5354
return completion

skllm/llm/gpt/clients/openai/credentials.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from skllm.config import SKLLMConfig as _Config
66

77

8-
def set_credentials(key: str, org: str) -> None:
8+
def set_credentials(key: str, org: str) -> OpenAI:
99
"""Set the OpenAI key and organization.
1010
1111
Parameters
@@ -20,7 +20,7 @@ def set_credentials(key: str, org: str) -> None:
2020
return client
2121

2222

23-
def set_azure_credentials(key: str, org: str) -> None:
23+
def set_azure_credentials(key: str, org: str) -> AzureOpenAI:
2424
"""Sets OpenAI credentials for Azure.
2525
2626
Parameters

skllm/llm/gpt/clients/openai/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from skllm.utils import retry
33
import openai
44
from openai import OpenAI
5-
from model_constants import OPENAI_EMBEDDING_MODEL
5+
from skllm.model_constants import OPENAI_EMBEDDING_MODEL
66

77

88
@retry(max_retries=3)

skllm/llm/gpt/completion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
)
88
from skllm.llm.gpt.utils import split_to_api_and_model
99
from skllm.config import SKLLMConfig as _Config
10-
from model_constants import OPENAI_GPT_MODEL
10+
from skllm.model_constants import OPENAI_GPT_MODEL
1111

1212

1313
def get_chat_completion(
1414
messages: dict,
15-
openai_key: str = None,
16-
openai_org: str = None,
15+
openai_key: str,
16+
openai_org: str,
1717
model: str = OPENAI_GPT_MODEL,
1818
json_response: bool = False,
1919
):

skllm/llm/gpt/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from skllm.llm.gpt.clients.openai.embedding import get_embedding as _oai_get_embedding
22
from skllm.llm.gpt.utils import split_to_api_and_model
3-
from model_constants import OPENAI_EMBEDDING_MODEL
3+
from skllm.model_constants import OPENAI_EMBEDDING_MODEL
44

55

66
def get_embedding(

skllm/llm/gpt/mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
from tqdm import tqdm
2323
import json
24-
from model_constants import OPENAI_GPT_TUNABLE_MODEL
24+
from skllm.model_constants import OPENAI_GPT_TUNABLE_MODEL
2525

2626

2727
def construct_message(role: str, content: str) -> dict:

skllm/model_constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# OpenAI models
2-
OPENAI_GPT_MODEL = "gpt-3.5-turbo"
2+
OPENAI_GPT_MODEL = "gpt-4.1"
33
OPENAI_GPT_TUNABLE_MODEL = "gpt-3.5-turbo-0613"
44
OPENAI_EMBEDDING_MODEL = "text-embedding-ada-002"
55

66
# Anthropic (Claude) models
77
ANTHROPIC_CLAUDE_MODEL = "claude-3-haiku-20240307"
88

99
# Vertex AI models
10-
VERTEX_DEFAULT_MODEL = "text-bison@002"
10+
VERTEX_DEFAULT_MODEL = "text-bison@002"

skllm/models/anthropic/classification/few_shot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from skllm.models._base.vectorizer import BaseVectorizer
1010
from skllm.memory.base import IndexConstructor
1111
from typing import Optional
12-
from model_constants import ANTHROPIC_CLAUDE_MODEL, OPENAI_EMBEDDING_MODEL
12+
from skllm.model_constants import ANTHROPIC_CLAUDE_MODEL, OPENAI_EMBEDDING_MODEL
1313

1414

1515
class FewShotClaudeClassifier(BaseFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin):

skllm/models/anthropic/classification/zero_shot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from skllm.llm.anthropic.mixin import ClaudeClassifierMixin as _ClaudeClassifierMixin
88
from typing import Optional
9-
from model_constants import ANTHROPIC_CLAUDE_MODEL
9+
from skllm.model_constants import ANTHROPIC_CLAUDE_MODEL
1010

1111

1212
class ZeroShotClaudeClassifier(

0 commit comments

Comments
 (0)