diff --git a/bugbug/ml_filter_finetune_tool.py b/bugbug/ml_filter_finetune_tool.py new file mode 100644 index 0000000000..9341a6241b --- /dev/null +++ b/bugbug/ml_filter_finetune_tool.py @@ -0,0 +1,97 @@ +from abc import ABC, abstractmethod +from pathlib import Path + +import torch +from datasets import Dataset +from torch.nn.functional import softmax +from transformers import ( + AutoTokenizer, + ModernBertForSequenceClassification, + Trainer, + TrainingArguments, + set_seed, +) + + +class FineTuneMLClassifer(ABC): + def __init__(self, model_path, seed=42): + self.model = ModernBertForSequenceClassification.from_pretrained( + model_path, device_map=self.device, attn_implementation="sdpa" + ) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, device_map=self.device + ) + self.seed = seed + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def _tokenize(self, batch): + return self.tokenizer( + batch["comment"], + padding=True, + truncation=True, + return_tensors="pt", + ) + + def fit(self, inputs, labels, tmpdir): + set_seed(self.seed) + + train_dataset = Dataset.from_dict( + { + "comment": inputs, + "label": labels, + } + ) + + train_dataset = train_dataset.map( + self._tokenize, batched=True, remove_columns=["comment"] + ) + + training_args = TrainingArguments( + # Required parameter: + output_dir=None, + # Optional training parameters: + num_train_epochs=30, + per_device_train_batch_size=128, + warmup_steps=500, + learning_rate=5e-5, + optim="adamw_torch", + # lr_scheduler_type="constant", + # warmup_ratio=0.1, + bf16=True, + eval_steps=0, + save_strategy="no", + save_steps=100, + save_total_limit=2, + logging_steps=10, + logging_strategy="epoch", + report_to="none", + seed=self.seed, + use_cpu=True if self.device == "cpu" else False, + ) + trainer = Trainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=train_dataset, + eval_dataset=None, + ) + + trainer.train() + self.model.save_pretrained(save_directory=tmpdir) + self.tokenizer.save_pretrained(save_directory=tmpdir) + + def predict(self, inputs): + self.model.to(self.device).eval() + + input = self.tokenizer( + inputs, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + + with torch.no_grad(): + logits = self.model(**input).logits + probs = softmax(logits, dim=1)[:, 0] + probs = probs.detach().cpu().numpy() + return probs + + @abstractmethod + def save(self, tmpdir: Path): ... diff --git a/bugbug/ml_filter_tool.py b/bugbug/ml_filter_tool.py new file mode 100644 index 0000000000..6e02a521bc --- /dev/null +++ b/bugbug/ml_filter_tool.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class MLCommentFilter(ABC): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + @abstractmethod + def query_ml_filter(self, comments, *args, **kwargs) -> Any: ... + + +ml_comment_filters = {} + + +def register_ml_comment_filters(name, cls): + ml_comment_filters[name] = cls diff --git a/bugbug/ml_filter_trainer_tool.py b/bugbug/ml_filter_trainer_tool.py new file mode 100644 index 0000000000..64ae84349a --- /dev/null +++ b/bugbug/ml_filter_trainer_tool.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from pathlib import Path + +import numpy as np +from sklearn.metrics import recall_score + + +class Trainer(ABC): + def __init__( + self, + min_recall: float = 0.9, + thr_metric: str = "acceptance_rate", + tmpdir: Path = Path(""), + ): + self.min_recall = min_recall + self.thr_metric = thr_metric + self.tmpdir = tmpdir + + @abstractmethod + def train_test_split(self, data, test_size=0.5, random_split=True): ... + + def _fit(self, model): + model.fit(self.train_inputs, self.train_labels, self.tmpdir) + return model.predict(self.val_inputs) + + def train(self, model): + probs = self._fit(model) + thresholds_results = {} + for thr in np.arange(0, 1.01, 0.01): + preds = np.where(probs >= thr, 0, 1) + recalls = recall_score(self.val_labels, preds, average=None) + acceptance_rate = sum( + [1 for pred, label in zip(preds, self.val_labels) if pred and label] + ) / sum(preds) + thresholds_results[thr] = { + "recall_accept": recalls[1], + "gmean": np.sqrt(recalls[0] * recalls[1]), + "acceptance_rate": acceptance_rate, + } + # Select threshold based on minimum accept recall and max acceptance_rate/gmean + thresholds_results = { + thr: metrics + for thr, metrics in thresholds_results.items() + if metrics["recall_accept"] >= self.min_recall + } + thresholds_results = sorted( + thresholds_results.items(), + key=lambda x: x[1][f"{self.thr_metric}"], + reverse=True, + ) + return thresholds_results[0][0] diff --git a/bugbug/tools/code_review.py b/bugbug/tools/code_review.py index c099ec3d11..40552675e8 100644 --- a/bugbug/tools/code_review.py +++ b/bugbug/tools/code_review.py @@ -30,6 +30,7 @@ from bugbug import db, phabricator, utils from bugbug.code_search.function_search import FunctionSearch from bugbug.generative_model_tool import GenerativeModelTool, get_tokenizer +from bugbug.ml_filter_tool import MLCommentFilter from bugbug.utils import get_secret from bugbug.vectordb import PayloadScore, QueryFilter, VectorDB, VectorPoint @@ -1138,6 +1139,7 @@ def __init__( verbose: bool = True, suggestions_feedback_db: Optional["SuggestionsFeedbackDB"] = None, target_software: Optional[str] = None, + ml_comment_filter: Optional[MLCommentFilter] = None, ) -> None: super().__init__() @@ -1212,6 +1214,8 @@ def __init__( self.suggestions_feedback_db = suggestions_feedback_db + self.ml_comment_filter = ml_comment_filter + def count_tokens(self, text): return len(self._tokenizer.encode(text)) @@ -1379,7 +1383,16 @@ def run(self, patch: Patch) -> list[InlineComment] | None: if self.verbose: GenerativeModelTool._print_answer(raw_output) - return list(generate_processed_output(raw_output, patch.patch_set)) + generated_inline_comments = list( + generate_processed_output(raw_output, patch.patch_set) + ) + + if self.ml_comment_filter: + generated_inline_comments = self.ml_comment_filter.query_ml_filter( + generated_inline_comments + ) + + return generated_inline_comments def _get_generated_examples(self, patch, created_before: datetime | None = None): """Get examples of comments that were generated by an LLM.