Skip to content

Commit d0fe2b2

Browse files
authored
Merge pull request #106 from iryna-kondr/cot_classifier
Added the chain of thought classifier
2 parents fd3a4ac + fab8f78 commit d0fe2b2

File tree

8 files changed

+174
-10
lines changed

8 files changed

+174
-10
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies = [
1111
"google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0"
1212
]
1313
name = "scikit-llm"
14-
version = "1.2.0"
14+
version = "1.3.0"
1515
authors = [
1616
{ name="Oleh Kostromin", email="[email protected]" },
1717
{ name="Iryna Kondrashchenko", email="[email protected]" },

skllm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = '1.2.0'
1+
__version__ = '1.3.0'
22
__author__ = 'Iryna Kondrashchenko, Oleh Kostromin'

skllm/classification.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
## GPT
2+
3+
from skllm.models.gpt.classification.zero_shot import (
4+
ZeroShotGPTClassifier,
5+
MultiLabelZeroShotGPTClassifier,
6+
CoTGPTClassifier,
7+
)
8+
from skllm.models.gpt.classification.few_shot import (
9+
FewShotGPTClassifier,
10+
DynamicFewShotGPTClassifier,
11+
MultiLabelFewShotGPTClassifier,
12+
)
13+
from skllm.models.gpt.classification.tunable import (
14+
GPTClassifier as TunableGPTClassifier,
15+
)
16+
17+
## Vertex
18+
from skllm.models.vertex.classification.zero_shot import (
19+
ZeroShotVertexClassifier,
20+
MultiLabelZeroShotVertexClassifier,
21+
)
22+
from skllm.models.vertex.classification.tunable import (
23+
VertexClassifier as TunableVertexClassifier,
24+
)

skllm/models/_base/classifier.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
FEW_SHOT_MLCLF_PROMPT_TEMPLATE,
2424
ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE,
2525
ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE,
26+
COT_CLF_PROMPT_TEMPLATE,
27+
COT_MLCLF_PROMPT_TEMPLATE,
2628
)
2729
from skllm.prompts.builders import (
2830
build_zero_shot_prompt_slc,
@@ -33,7 +35,8 @@
3335
from skllm.memory.base import IndexConstructor
3436
from skllm.memory._sklearn_nn import SklearnMemoryIndex
3537
from skllm.models._base.vectorizer import BaseVectorizer as _BaseVectorizer
36-
import ast
38+
from skllm.utils import re_naive_json_extractor
39+
import json
3740

3841
_TRAINING_SAMPLE_PROMPT_TEMPLATE = """
3942
Sample input:
@@ -221,7 +224,7 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int =
221224
----------
222225
X : Union[np.ndarray, pd.Series, List[str]]
223226
The input data to predict the class of.
224-
227+
225228
num_workers : int
226229
number of workers to use for multithreaded prediction, default 1
227230
@@ -231,12 +234,16 @@ def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int =
231234
The predicted classes as a numpy array.
232235
"""
233236
X = _to_numpy(X)
234-
237+
235238
if num_workers > 1:
236-
warnings.warn("Passing num_workers to predict is temporary and will be removed in the future.")
239+
warnings.warn(
240+
"Passing num_workers to predict is temporary and will be removed in the future."
241+
)
237242
with ThreadPoolExecutor(max_workers=num_workers) as executor:
238-
predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X)))
239-
243+
predictions = list(
244+
tqdm(executor.map(self._predict_single, X), total=len(X))
245+
)
246+
240247
return np.array(predictions)
241248

242249
def _get_unique_targets(self, y: Any):
@@ -286,6 +293,47 @@ def _get_prompt(self, x: str) -> dict:
286293
return {"messages": prompt, "system_message": self.system_msg}
287294

288295

296+
class BaseCoTClassifier(BaseClassifier):
297+
def _get_prompt_template(self) -> str:
298+
"""Returns the prompt template to use for a single input."""
299+
if self.prompt_template is not None:
300+
return self.prompt_template
301+
elif isinstance(self, SingleLabelMixin):
302+
return COT_CLF_PROMPT_TEMPLATE
303+
return COT_MLCLF_PROMPT_TEMPLATE
304+
305+
def _get_prompt(self, x: str) -> dict:
306+
"""Returns the prompt to use for a single input."""
307+
if isinstance(self, SingleLabelMixin):
308+
prompt = build_zero_shot_prompt_slc(
309+
x, repr(self.classes_), template=self._get_prompt_template()
310+
)
311+
else:
312+
prompt = build_zero_shot_prompt_mlc(
313+
x,
314+
repr(self.classes_),
315+
self.max_labels,
316+
template=self._get_prompt_template(),
317+
)
318+
return {"messages": prompt, "system_message": self.system_msg}
319+
320+
def _predict_single(self, x: Any) -> Any:
321+
prompt_dict = self._get_prompt(x)
322+
# this will be inherited from the LLM
323+
completion = self._get_chat_completion(model=self.model, **prompt_dict)
324+
completion = self._convert_completion_to_str(completion)
325+
try:
326+
as_dict = json.loads(re_naive_json_extractor(completion))
327+
label = as_dict["label"]
328+
explanation = str(as_dict["explanation"])
329+
except Exception as e:
330+
label = "None"
331+
explanation = "Explanation is not available."
332+
# this will be inherited from the sl/ml mixin
333+
prediction = self.validate_prediction(label)
334+
return [prediction, explanation]
335+
336+
289337
class BaseFewShotClassifier(BaseClassifier):
290338
def _get_prompt_template(self) -> str:
291339
"""Returns the prompt template to use for a single input."""
@@ -427,6 +475,18 @@ def _get_prompt_template(self) -> str:
427475
return self.prompt_template
428476
return FEW_SHOT_CLF_PROMPT_TEMPLATE
429477

478+
def _reorder_examples(self, examples):
479+
n_classes = len(self.classes_)
480+
n_examples = self.n_examples
481+
482+
shuffled_list = []
483+
484+
for i in range(n_examples):
485+
for cls in range(n_classes):
486+
shuffled_list.append(cls * n_examples + i)
487+
488+
return [examples[i] for i in shuffled_list]
489+
430490
def _get_prompt(self, x: str) -> dict:
431491
"""
432492
Generates the prompt for the given input.
@@ -455,7 +515,7 @@ def _get_prompt(self, x: str) -> dict:
455515
]
456516
)
457517

458-
training_data_str = "\n".join(training_data)
518+
training_data_str = "\n".join(self._reorder_examples(training_data))
459519

460520
msg = build_few_shot_prompt_slc(
461521
x=x,

skllm/models/gpt/classification/zero_shot.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
SingleLabelMixin as _SingleLabelMixin,
33
MultiLabelMixin as _MultiLabelMixin,
44
BaseZeroShotClassifier as _BaseZeroShotClassifier,
5+
BaseCoTClassifier as _BaseCoTClassifier,
56
)
67
from skllm.llm.gpt.mixin import GPTClassifierMixin as _GPTClassifierMixin
78
from typing import Optional
@@ -44,6 +45,41 @@ def __init__(
4445
self._set_keys(key, org)
4546

4647

48+
class CoTGPTClassifier(_BaseCoTClassifier, _GPTClassifierMixin, _SingleLabelMixin):
49+
def __init__(
50+
self,
51+
model: str = "gpt-3.5-turbo",
52+
default_label: str = "Random",
53+
prompt_template: Optional[str] = None,
54+
key: Optional[str] = None,
55+
org: Optional[str] = None,
56+
**kwargs,
57+
):
58+
"""
59+
Chain-of-thought text classifier using OpenAI/GPT API-compatible models.
60+
61+
Parameters
62+
----------
63+
model : str, optional
64+
model to use, by default "gpt-3.5-turbo"
65+
default_label : str, optional
66+
default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
67+
prompt_template : Optional[str], optional
68+
custom prompt template to use, by default None
69+
key : Optional[str], optional
70+
estimator-specific API key; if None, retrieved from the global config, by default None
71+
org : Optional[str], optional
72+
estimator-specific ORG key; if None, retrieved from the global config, by default None
73+
"""
74+
super().__init__(
75+
model=model,
76+
default_label=default_label,
77+
prompt_template=prompt_template,
78+
**kwargs,
79+
)
80+
self._set_keys(key, org)
81+
82+
4783
class MultiLabelZeroShotGPTClassifier(
4884
_BaseZeroShotClassifier, _GPTClassifierMixin, _MultiLabelMixin
4985
):

skllm/prompts/templates.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@
1515
Your JSON response:
1616
"""
1717

18+
COT_CLF_PROMPT_TEMPLATE = """
19+
You are tasked with classifying a given text sample based on a list of potential categories. Please adhere to the following guidelines:
20+
21+
1. The text intended for classification is presented between triple backticks.
22+
2. The possible categories are enumerated in square brackets, with each category enclosed in single quotes and separated by commas.
23+
24+
Tasks:
25+
1. Examine the text and provide detailed justifications for the possibility of the text belonging or not belonging to each category listed.
26+
2. Determine and select the most appropriate category for the text based on your comprehensive justifications.
27+
3. Format your decision into a JSON object containing two keys: `explanation` and `label`. The `explanation` should concisely capture the rationale for each category before concluding with the chosen category.
28+
29+
Category List: {labels}
30+
31+
Text Sample: ```{x}```
32+
33+
Provide your JSON response below, ensuring that justifications for all categories are clearly detailed:
34+
"""
35+
1836
ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE = """
1937
Classify the following text into one of the following classes: {labels}. Provide your response in a JSON format containing a single key `label`.
2038
Text: ```{x}```
@@ -84,6 +102,24 @@
84102
Your JSON response:
85103
"""
86104

105+
COT_MLCLF_PROMPT_TEMPLATE = """
106+
You are tasked with classifying a given text sample based on a list of potential categories. Please adhere to the following guidelines:
107+
108+
1. The text intended for classification is presented between triple backticks.
109+
2. The possible categories are enumerated in square brackets, with each category enclosed in quotes and separated by commas.
110+
111+
Tasks:
112+
1. Examine the text and provide detailed justifications for the possibility of the text belonging or not belonging to each category listed.
113+
2. Determine and select at most {max_cats} most appropriate categories for the text based on your comprehensive justifications.
114+
3. Format your decision into a JSON object containing two keys: `explanation` and `label`. The `explanation` should concisely capture the rationale for each category before concluding with the chosen category. The `label` should contain an array of the chosen categories.
115+
116+
Category List: {labels}
117+
118+
Text Sample: ```{x}```
119+
120+
Provide your JSON response below, ensuring that justifications for all categories are clearly detailed:
121+
"""
122+
87123
SUMMARY_PROMPT_TEMPLATE = """
88124
Your task is to generate a summary of the text sample.
89125
Summarize the text sample provided below, delimited by triple backticks, in at most {max_words} words.
@@ -204,4 +240,3 @@
204240
205241
Output json:
206242
"""
207-

skllm/text2text.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
## GPT
2+
from skllm.models.gpt.text2text.summarization import GPTSummarizer
3+
from skllm.models.gpt.text2text.translation import GPTTranslator
4+
from skllm.models.gpt.text2text.tunable import TunableGPTText2Text
5+
6+
## Vertex
7+
8+
from skllm.models.vertex.text2text.tunable import TunableVertexText2Text

skllm/vectorization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from skllm.models.gpt.vectorization import GPTVectorizer

0 commit comments

Comments
 (0)