From 490def66158325d50715ea9d42727857e2b44a48 Mon Sep 17 00:00:00 2001 From: SZU-ZJW Date: Fri, 23 Jan 2026 20:52:04 +0800 Subject: [PATCH 1/2] Update the RecLM-cgen, Named RecLM-uni --- README.md | 15 +- RecLM-cgen/README.md | 288 ----------- {RecLM-cgen => RecLM-uni}/.gitignore | 0 RecLM-uni/GRPO/rl_dataset.py | 126 +++++ RecLM-uni/GRPO/run_GRPO.sh | 49 ++ RecLM-uni/PLUGIN_CONFIG.md | 168 +++++++ RecLM-uni/README.md | 471 ++++++++++++++++++ {RecLM-cgen => RecLM-uni}/accelerate.yaml | 0 {RecLM-cgen => RecLM-uni}/cli_serve.py | 0 RecLM-uni/grpo_train.py | 218 ++++++++ RecLM-uni/index/datasets.py | 21 + RecLM-uni/index/generate_indices.py | 216 ++++++++ RecLM-uni/index/main.py | 98 ++++ RecLM-uni/index/models/layers.py | 108 ++++ RecLM-uni/index/models/rq.py | 56 +++ RecLM-uni/index/models/rqvae.py | 104 ++++ RecLM-uni/index/models/vq.py | 101 ++++ RecLM-uni/index/trainer.py | 255 ++++++++++ RecLM-uni/index/utils.py | 37 ++ {RecLM-cgen => RecLM-uni}/main.py | 0 RecLM-uni/plugin.env.example | 30 ++ RecLM-uni/plugin.py | 409 +++++++++++++++ .../preprocess/data_preprocess_amazon.py | 0 .../preprocess/transform2unirec.py | 0 {RecLM-cgen => RecLM-uni}/requirements.txt | 0 .../scripts/data_preprocess_amazon.sh | 0 .../scripts/run_SFT_merge.sh | 0 .../scripts/train_RecLM_cgen.sh | 0 .../scripts/train_RecLM_ret.sh | 0 .../scripts/unirec_serve.sh | 0 .../scripts/unirec_train.sh | 0 {RecLM-cgen => RecLM-uni}/task_MR_test.py | 0 {RecLM-cgen => RecLM-uni}/task_test.py | 0 RecLM-uni/task_test_tokenizer.py | 193 +++++++ .../train_utils/__init__.py | 0 .../train_utils/dataset.py | 5 +- {RecLM-cgen => RecLM-uni}/train_utils/loss.py | 0 .../train_utils/metrics.py | 0 .../train_utils/model.py | 10 + .../train_utils/param.py | 0 .../train_utils/processor.py | 0 .../train_utils/template.py | 0 .../train_utils/utils.py | 41 ++ {RecLM-cgen => RecLM-uni}/trainer.py | 14 + .../unirec/asyc_server.py | 0 .../unirec/config/base.yaml | 0 .../unirec/config/model/SASRec.yaml | 0 47 files changed, 2738 insertions(+), 295 deletions(-) delete mode 100644 RecLM-cgen/README.md rename {RecLM-cgen => RecLM-uni}/.gitignore (100%) create mode 100644 RecLM-uni/GRPO/rl_dataset.py create mode 100755 RecLM-uni/GRPO/run_GRPO.sh create mode 100644 RecLM-uni/PLUGIN_CONFIG.md create mode 100644 RecLM-uni/README.md rename {RecLM-cgen => RecLM-uni}/accelerate.yaml (100%) rename {RecLM-cgen => RecLM-uni}/cli_serve.py (100%) create mode 100644 RecLM-uni/grpo_train.py create mode 100644 RecLM-uni/index/datasets.py create mode 100644 RecLM-uni/index/generate_indices.py create mode 100644 RecLM-uni/index/main.py create mode 100644 RecLM-uni/index/models/layers.py create mode 100644 RecLM-uni/index/models/rq.py create mode 100644 RecLM-uni/index/models/rqvae.py create mode 100644 RecLM-uni/index/models/vq.py create mode 100644 RecLM-uni/index/trainer.py create mode 100644 RecLM-uni/index/utils.py rename {RecLM-cgen => RecLM-uni}/main.py (100%) create mode 100644 RecLM-uni/plugin.env.example create mode 100644 RecLM-uni/plugin.py rename {RecLM-cgen => RecLM-uni}/preprocess/data_preprocess_amazon.py (100%) rename {RecLM-cgen => RecLM-uni}/preprocess/transform2unirec.py (100%) rename {RecLM-cgen => RecLM-uni}/requirements.txt (100%) rename {RecLM-cgen => RecLM-uni}/scripts/data_preprocess_amazon.sh (100%) rename {RecLM-cgen => RecLM-uni}/scripts/run_SFT_merge.sh (100%) rename {RecLM-cgen => RecLM-uni}/scripts/train_RecLM_cgen.sh (100%) rename {RecLM-cgen => RecLM-uni}/scripts/train_RecLM_ret.sh (100%) rename {RecLM-cgen => RecLM-uni}/scripts/unirec_serve.sh (100%) rename {RecLM-cgen => RecLM-uni}/scripts/unirec_train.sh (100%) rename {RecLM-cgen => RecLM-uni}/task_MR_test.py (100%) rename {RecLM-cgen => RecLM-uni}/task_test.py (100%) create mode 100644 RecLM-uni/task_test_tokenizer.py rename {RecLM-cgen => RecLM-uni}/train_utils/__init__.py (100%) rename {RecLM-cgen => RecLM-uni}/train_utils/dataset.py (99%) rename {RecLM-cgen => RecLM-uni}/train_utils/loss.py (100%) rename {RecLM-cgen => RecLM-uni}/train_utils/metrics.py (100%) rename {RecLM-cgen => RecLM-uni}/train_utils/model.py (97%) rename {RecLM-cgen => RecLM-uni}/train_utils/param.py (100%) rename {RecLM-cgen => RecLM-uni}/train_utils/processor.py (100%) rename {RecLM-cgen => RecLM-uni}/train_utils/template.py (100%) rename {RecLM-cgen => RecLM-uni}/train_utils/utils.py (87%) rename {RecLM-cgen => RecLM-uni}/trainer.py (98%) rename {RecLM-cgen => RecLM-uni}/unirec/asyc_server.py (100%) rename {RecLM-cgen => RecLM-uni}/unirec/config/base.yaml (100%) rename {RecLM-cgen => RecLM-uni}/unirec/config/model/SASRec.yaml (100%) diff --git a/README.md b/README.md index 47c4a17..02a9b60 100644 --- a/README.md +++ b/README.md @@ -186,13 +186,16 @@ And corresponding paper in the subfolder: } ``` -#### RecLM-cgen: +#### RecLM-uni: ``` -@article{liao2025avoid, - title={Avoid Recommending Out-of-Domain Items: Constrained Generative Recommendation with LLMs}, - author={Liao, Hao and Lu, Wensheng and Lian, Jianxun and Wu, Mingqi and Wang, Shuo and Zhang, Yong and Huang, Yitian and Zhou, Mingyang and Xie, Xing}, - journal={arXiv preprint arXiv:2505.03336} - year={2025}, +@misc{liao2026eliminatingoutofdomainrecommendationsllmbased, + title={Eliminating Out-of-Domain Recommendations in LLM-based Recommender Systems: A Unified View}, + author={Hao Liao and Jiwei Zhang and Jianxun Lian and Wensheng Lu and Mingqi Wu and Shuo Wang and Yong Zhang and Yitian Huang and Mingyang Zhou and Rui Mao}, + year={2026}, + eprint={2505.03336}, + archivePrefix={arXiv}, + primaryClass={cs.IR}, + url={https://arxiv.org/abs/2505.03336}, } ``` diff --git a/RecLM-cgen/README.md b/RecLM-cgen/README.md deleted file mode 100644 index 74460d0..0000000 --- a/RecLM-cgen/README.md +++ /dev/null @@ -1,288 +0,0 @@ - -# RecLM-cgen -## Introduction -This project introduces methods for avoid recommending out-of-domain items in LLM-based recsys. It contains the code for implementing two methods in (arXiv preprint arXiv:2505.03336), i.e., RecLM-cgen and RecLM-ret. - -**RecLM-cgen** is a generative recommendation framework in the native structure of LLMs. This framework divides the output space of LLMs into item generation and general text generation parts by introducing item control tokens, and simultaneously employs a decoding strategy with prefix tree constraints to prevent the generation of out-of-domain items. RecLM-cgen enables LLMs to acquire the ability to recommend products without sacrificing their original general capabilities. - -The RecLM-cgen framework seamlessly integrates LLMs with recommendation scenarios. Interacting with RecLM-cgen is just like interacting with general LLMs, enabling users to complete recommendation tasks and other general tasks in multi-round conversations. - -The pipeline of RecLM-cgen has 4 steps: -1. Preprocessing raw dataset (Section 1) -2. Training teacher model (Section 2.3) -3. Deploying teacher model service (Section 2.4) -4. Training RecLM-cgen (Section 3.1) - -This project is mainly contributed by College of Computer Science and Software Engineering, Shenzhen University. - -Our implementation leverages the [`transformers`](https://github.com/huggingface/transformers) library by Hugging Face. - -## 1. Raw dataset preprocess -We provide the code in `preprocess/data_preprocess_amazon.py` to automatically generate the intermediate dataset with above format from the downloaded raw dataset. - -Firstly, download `Movies_and_TV_5.json.gz` and `meta_Movies_and_TV.json.gz` from [Amazon](https://cseweb.ucsd.edu/~jmcauley/datasets/amazon_v2/), then place them in `data/dataset/movies/` and run the next command. - -Then, change the data path and dataset full name in [./scripts/data_preprocess_amazon.sh](scripts/data_preprocess_amazon.sh). -```shell -TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct" -DATASET_FULL_NAME="Movies_and_TV" -DATASET_NAME="movies" # used for selecting dataset in subsequent experiments. -DATA_PATH="./data/dataset/${DATASET_NAME}/" -UNIREC_DATA_PATH="./unirec/data/${DATASET_NAME}/" -UNIREC_CONFIG_PATH="./unirec/config/dataset/${DATASET_NAME}.yaml" -``` -After that, run the command `./scripts/data_preprocess_amazon.sh` to generate the intermediate dataset. - - -### Intermediate dataset format - -To use this repo, you'll need an intermediate dataset comprising at least three files located in data_path: `category.jsonl`, `metas.jsonl`, and `sequential.jsonl`. -You can prepare your own dataset in this format to train the model. - -**A volunteer has prepared a copy of data for reproducing the experiments. You can download it from [Google Drive link](https://drive.google.com/file/d/1jZMa0Sx-zVccCpkep5KiY6VXoOdl6PCl/view?usp=drive_link), and place each file of it in the respective path. Thanks [Luuuk12321](https://github.com/Luuuk12321)!** - -#### category.jsonl -This file contains a dictionary where the keys are category names, and the values are lists of item IDs belonging to those categories. -```json -{ - "category_1": ["item_id_1", "..."], - "category_2": ["item_id_i", "..."], - "...": "...", - "category_k": ["item_id_j", "..."] -} -``` -#### metas.jsonl -This file contains a dictionary where the keys are item IDs, and the values are dictionaries with at least one field of item index. This field is used for prefix tree construction (such as `title` or `title_t`). -```json -{ - "item_id_1": {"title": "...", "title_t": "...", "description": "..."}, - "item_id_2": {"title": "...", "title_t": "...", "description": "..."}, - "...": "...", - "item_id_n": {"title": "...", "title_t": "...", "description": "..."} -} -``` - -#### sequential.jsonl -This file contains a dictionary where the keys are user IDs, and the values are lists of item IDs that represent the user's historical interactions in a time-dependent order. - -```json -{ - "user_id_1": ["item_id_1", "...", "item_id_x"], - "...": "...", - "user_id_m": ["item_id_1", "...", "item_id_y"] -} -``` - - -## 2. SASRec Server -We utilize the [UniRec](https://github.com/microsoft/UniRec) library to implement the SASRec teacher model and deploy as a server. - -### 2.1. Install UniRec - -Clone the UniRec repository and install the necessary packages: - -```shell -git clone https://github.com/microsoft/UniRec.git -pip install --user --upgrade setuptools wheel twine -``` - -Modify the `unirec/setup.py` file to update the `torch` dependency: - -```python -install_requires = [ - "torch>=1.10.0,<=1.13.1" # Change this line to the one below - # "torch>=1.10.0,<=2.1.2", - "..." -] -``` - -Continue with the installation: - -```shell -cd UniRec -python setup.py sdist bdist_wheel -pip install dist/unirec-*.whl -``` - -### 2.2. Unirec dataset for SASRec model training -You need the dataset files `train.pkl`, `valid.pkl`, `test.pkl`, `user_history.pkl`, `map.pkl`, and `category.jsonl` to train SASRec model with UniRec library. - -1. After running of `./scripts/data_preprocess_amazon.sh`, these files will be placed in `./unirec/data/movies/`. - -2. If you had prepared the intermediate dataset, these files will be automatically generated according to the intermediate dataset in `./data/dataset/movies/`. - - -### 2.3. SASRec model training - -Train the model by specifying the dataset name (e.g., `movies`): - -```shell -./scripts/unirec_train.sh movies -``` -Model parameters and weights are saved in `./unirec/output/`. - -### 2.4. SASRec service deploying - -Update the `MODEL_PATH` and `DATASET_NAME` in [./scripts/unirec_serve.sh](./scripts/unirec_serve.sh) to point to the model files: - -```python -DATASET_NAME="movies" -MODEL_PATH="./unirec/output/movies/SASRec/train/checkpoint_.../SASRec-SASRec-movies.pth" -``` - -Start the server by specifying the serve port(`2068`): - -```shell -./scripts/unirec_serve.sh 2068 -``` - - -## 3. SFT stage - -### 3.1. SFT train - -The training dataset is dynamically generated during the `__getitem__` function call of the dataset class. An example script for training can be found at [./scripts/train_RecLM_cgen.sh](scripts/train_RecLM_cgen.sh) for **RecLM-cgen** and [./scripts/train_RecLM_ret.sh](scripts/train_RecLM_ret.sh) for **RecLM-ret**. -```shell -./scripts/train_RecLM_cgen.sh movies # RecLM-cgen -./scripts/train_RecLM_ret.sh movies # RecLM-ret -``` - -### 3.2. SFT model merge - -Merge the trained models using the script found at [./scripts/run_SFT_merge.sh](scripts/run_SFT_merge.sh). The merged model will be saved to `snap/.../SFT_Epoch20/`. -```shell -./scripts/run_SFT_merge.sh -``` - -## 4. RecLM-cgen testing - -### 4.1. Recommendation testing -```shell -python task_test.py \ ---data_path data/dataset/movies/ \ ---SFT_test_task SFTTestSeqRec-MR \ ---model_name snap/.../SFT_Epoch20/ \ ---gpu cuda:0 \ ---use_control_symbol \ ---batch_size 16 \ ---use_CBS \ ---CBS_type 2 \ ---topk 10 \ ---idx - -# setting --data_path to `data/dataset/toys/` for cross-domain evaluation. -``` - -### 4.2. Multi-round conversation testing -```shell -python task_MR_test.py \ ---data_path data/dataset/movies/ \ ---SFT_test_task SFTTestSeqRec-CS-MR \ ---model_name snap/.../SFT_Epoch20/ \ ---gpu cuda:0 \ ---use_control_symbol \ ---batch_size 8 \ ---use_CBS \ ---CBS_type 2 \ ---topk 10 \ ---idx -``` - -### 4.3. SFT model deploying -```shell -python cli_serve.py \ ---model_name snap/.../SFT_Epoch20/ \ ---gpu cuda:0 -``` - -## 5. RecLM-ret testing - -### 5.1. Recommendation testing -```shell -python main.py \ ---seed 0 \ ---data_path data/dataset/movies/ \ ---SFT_test_task SFTTestSeqRec-MR \ ---gpu cuda:0 \ ---use_control_symbol \ ---test_batch_size 8 \ ---topk 10 \ ---item_index title_t \ ---idx \ ---gen_max_length 512 \ ---max_token_length 1024 \ ---train_stage SFT_Embedding_Test \ ---SFT_actor_lora_r 16 \ ---SFT_actor_lora_a 8 \ ---chat_template llama-3 \ ---FA2 \ ---backbone meta-llama/Meta-Llama-3-8B-Instruct \ ---embedding_model BAAI/bge-m3 \ ---SFT_load snap/.../Epoch20_SFT_Embedding -``` - -### 5.2. Multi-round conversation testing -```shell -python main.py \ ---seed 0 \ ---data_path data/dataset/movies/ \ ---SFT_test_task SFTTestSeqRec-CS-MR \ ---gpu cuda:0 \ ---use_control_symbol \ ---test_batch_size 8 \ ---topk 10 \ ---item_index title_t \ ---idx \ ---gen_max_length 512 \ ---max_token_length 1024 \ ---train_stage SFT_Embedding_Test \ ---SFT_actor_lora_r 16 \ ---SFT_actor_lora_a 8 \ ---chat_template llama-3 \ ---FA2 \ ---backbone meta-llama/Meta-Llama-3-8B-Instruct \ ---embedding_model BAAI/bge-m3 \ ---SFT_load snap/.../Epoch20_SFT_Embedding -``` - -## 6. Build domain item prefix tree for enabling constrained generation -You can customize the recommendation domain and build the domain item prefix tree for enabling constrained generation following the next code. -```python -from train_utils.processor import FastPrefixConstrainedLogitsProcessor, Trie_link -from transformers import AutoTokenizer, AutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained(...) -tokenizer.soi_token_id = xxx # specific a token -tokenizer.eoi_token_id = xxx # specific a token -model = AutoModelForCausalLM.from_pretrained(...) - -in_domain_titles: list[str] = [...] # customized domain titles -item_ids = tokenizer.batch_encode_plus(in_domain_titles).data['input_ids'] - -num_beams = 1 -# create prefix tree -item_prefix_tree = Trie_link(item_ids, tokenizer) -# create logit processor base on prefix tree -processor = FastPrefixConstrainedLogitsProcessor( - item_prefix_tree.constrain_search_list, - num_beams -) - -output = model.generate( - ..., - logits_processor=[processor], - num_beams=num_beams -) -``` - -## Citation -If you find this project useful in your research, please cite our research paper: - -``` -@article{liao2025avoid, - title={Avoid Recommending Out-of-Domain Items: Constrained Generative Recommendation with LLMs}, - author={Liao, Hao and Lu, Wensheng and Lian, Jianxun and Wu, Mingqi and Wang, Shuo and Zhang, Yong and Huang, Yitian and Zhou, Mingyang and Xie, Xing}, - journal={arXiv preprint arXiv:2505.03336} - year={2025}, -} -``` diff --git a/RecLM-cgen/.gitignore b/RecLM-uni/.gitignore similarity index 100% rename from RecLM-cgen/.gitignore rename to RecLM-uni/.gitignore diff --git a/RecLM-uni/GRPO/rl_dataset.py b/RecLM-uni/GRPO/rl_dataset.py new file mode 100644 index 0000000..279a7ac --- /dev/null +++ b/RecLM-uni/GRPO/rl_dataset.py @@ -0,0 +1,126 @@ +import random +from dataclasses import dataclass +from typing import List, Dict, Tuple + +from SFT.SFT_templates import SeqRec_MR_group +from utils import load_json, get_history_text, get_output_text + + +SYSTEM_PROMPT = "You are an expert recommender engine as well as a helpful, respectful and honest assistant." +DEFAULT_TEMPLATE_ID = next(iter(SeqRec_MR_group.keys())) + + +@dataclass +class RLSample: + prompt: str + reference_response: str + target_title: str + target_item: str + history_text: str + + +def build_llama3_prompt(user_turn: str) -> str: + """Create a Meta-Llama-3 style prompt with a single user turn.""" + parts = [ + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n", + SYSTEM_PROMPT, + "<|eot_id|>", + "<|start_header_id|>user<|end_header_id|>\n\n", + user_turn, + "<|eot_id|>", + "<|start_header_id|>assistant<|end_header_id|>\n\n", + ] + return "".join(parts) + + +def _sample_history(sequential: List[str], max_item_length: int, rng: random.Random) -> Tuple[List[str], str]: + """Replicate SFT sub sequence sampling strategy.""" + if len(sequential) <= 2: + return [], None + prefix = sequential[:-1] + target_idx = rng.randrange(1, len(prefix)) + min_start = max(0, target_idx - max_item_length) + start_idx = rng.randrange(min_start, target_idx + 1) + history = prefix[start_idx:target_idx] + target = prefix[target_idx] + return history, target + + +def build_rl_samples( + data_path: str, + item_index_field: str, + max_samples: int, + seed: int, + topk: int, + max_item_length: int, + use_control_symbol: bool, + use_idx: bool, +) -> List[RLSample]: + """Create RL samples directly from sequential logs.""" + metas: Dict[str, Dict] = load_json(f"{data_path}metas.jsonl") + sequential: Dict[str, List[str]] = load_json(f"{data_path}sequential.jsonl") + template = SeqRec_MR_group[DEFAULT_TEMPLATE_ID] + rng = random.Random(seed) + valid_items = [iid for iid, meta in metas.items() if meta.get(item_index_field)] + samples: List[RLSample] = [] + + for user, seq in sequential.items(): + if len(samples) >= max_samples: + break + history, target_item = _sample_history(seq, max_item_length, rng) + if not history or target_item is None: + continue + if target_item not in metas: + continue + target_title = metas[target_item].get(item_index_field) + if not target_title: + continue + history_titles = [ + metas[item_id].get(item_index_field) + for item_id in history + if item_id in metas and metas[item_id].get(item_index_field) + ] + if len(history_titles) < 1: + continue + history_text = get_history_text([f"'{title}'" for title in history_titles]) + + # Build synthetic recommendation list: target + random negatives (order shuffled later) + negatives = [] + pool = list(valid_items) + rng.shuffle(pool) + for candidate in pool: + if candidate == target_item: + continue + negatives.append(candidate) + if len(negatives) >= max(topk - 1, 0): + break + output_items = [target_item] + negatives[: max(topk - 1, 0)] + output_titles = [ + metas[item_id].get(item_index_field, "") + for item_id in output_items + if metas[item_id].get(item_index_field) + ] + if not output_titles: + continue + recommendation_text = get_output_text( + output_titles, idx=use_idx, user_control_symbol=use_control_symbol + ) + "\n" + + instruction_fields = {"history": history_text, "item_count": topk} + user_turn = " ".join(template.get_input_text(instruction_fields)) + assistant_turn = " ".join( + template.get_output_text({"item_title_list": recommendation_text}) + ) + + prompt = build_llama3_prompt(user_turn) + samples.append( + RLSample( + prompt=prompt, + reference_response=assistant_turn, + target_title=target_title, + target_item=target_item, + history_text=history_text, + ) + ) + + return samples diff --git a/RecLM-uni/GRPO/run_GRPO.sh b/RecLM-uni/GRPO/run_GRPO.sh new file mode 100755 index 0000000..7f0b9f8 --- /dev/null +++ b/RecLM-uni/GRPO/run_GRPO.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +set -euo pipefail + +DATASET="movies" +DATA_PATH="data/${DATASET}/" + + +INITIAL_MODEL="path/to/your/initial/model" +OUTPUT_BASE=./sensitivate2/$(date "+%m%d")-${DATASET}-RL + +mkdir -p "${OUTPUT_BASE}" +CURRENT_MODEL="${INITIAL_MODEL}" + +if [ "$DATASET" = "toys" ]; then + GPU_ID=0,1,2,5 + MAIN_PORT=13355 +elif [ "$DATASET" = "movies" ]; then + GPU_ID=1,2,5,6 + MAIN_PORT=13356 +elif [ "$DATASET" = "steam" ]; then + GPU_ID=0,1,2,5 + MAIN_PORT=13357 +fi + +nohup accelerate launch --gpu_ids $GPU_ID --config_file ./accelerate.yaml grpo_train.py \ + --data_path "${DATA_PATH}" \ + --model_path "${CURRENT_MODEL}" \ + --output_dir "${OUTPUT_BASE}" \ + --topk 10 \ + --rl_max_samples 10000 \ + --num_generations 16 \ + --num_train_epochs 2 \ + --train_batch_size 32 \ + --eval_batch_size 32 \ + --eval_steps 500 \ + --eval_ratio 0.05 \ + --gradient_accumulation_steps 4 \ + --learning_rate 1e-5 \ + --beta 1e-3 \ + --reward_type ranking \ + --use_control_symbol \ + --idx \ + --logging_steps 5 \ + --use_lora \ + --lora_r 16 \ + --lora_alpha 32 \ + --lora_dropout 0.05 > ${OUTPUT_BASE}/output.log 2>&1 & + diff --git a/RecLM-uni/PLUGIN_CONFIG.md b/RecLM-uni/PLUGIN_CONFIG.md new file mode 100644 index 0000000..d75e8f0 --- /dev/null +++ b/RecLM-uni/PLUGIN_CONFIG.md @@ -0,0 +1,168 @@ +# Plugin Configuration Guide + +This document explains how to configure the reward functions in `plugin.py` for item title rewrite training. + +## Environment Variables + +The plugin uses environment variables for configuration. Copy `plugin.env.example` to `plugin.env` and fill in your actual values: + +```bash +cp plugin.env.example plugin.env +# Edit plugin.env with your configuration +source plugin.env # Load variables before running training +``` + +## Configuration Options + +### 1. Model Configuration + +```bash +export MODEL_PATH="meta-llama/Meta-Llama-3-8B-Instruct" +``` +- Path to the LLaMA model used for tokenization + +### 2. Embedding API Service + +```bash +export EMBEDDING_API_URL="http://localhost:8010/v1" +export EMBEDDING_API_KEY="not-needed" +``` +- URL and API key for the embedding service (e.g., BGE-M3) +- Used by `Item2ItemReward` for semantic similarity computation + +### 3. vLLM Endpoints + +```bash +# Option 1: Single set of endpoints for all data sources +export VLLM_ENDPOINTS="http://localhost:8020/v1,http://localhost:8021/v1" + +# Option 2: Separate endpoints per data source +export VLLM_ENDPOINTS_STEAM="http://localhost:8020/v1" +export VLLM_ENDPOINTS_MOVIES="http://localhost:8021/v1" +export VLLM_ENDPOINTS_TOYS="http://localhost:8022/v1" +``` +- vLLM server endpoints for LLM inference +- Multiple endpoints provide automatic failover +- Used by `ConditionalPPL` and `DiscriminativeReward` + +### 4. Embedding Service Ports + +```bash +export EMBEDDING_SERVICE_BASE_URL="http://localhost:5000" +export EMBEDDING_PORT_STEAM="5003" +export EMBEDDING_PORT_MOVIES="5004" +export EMBEDDING_PORT_TOYS="5005" +``` +- Different ports for different dataset embedding services +- Used by `User2ItemReward` for user preference embedding + +### 5. Data Directory + +```bash +export EMBEDDING_DATA_DIR="./data/embeddings" +``` +- Directory containing pre-computed item embeddings in JSONL format +- Expected files: + - `{source}_all_item_embedding.jsonl` - Item description embeddings + - `{source}_all_item_embedding_t-desc.jsonl` - User history embeddings + +### 6. Performance Settings + +```bash +export MAX_WORKERS="8" +``` +- Number of parallel threads for API requests +- Adjust based on your system capabilities + +## Setting Up Services + +### vLLM Server Setup + +Start vLLM servers for inference: + +```bash +# Example for movies dataset +vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ + --port 8020 \ + --host 0.0.0.0 \ + --tensor-parallel-size 1 +``` + +### Embedding Service Setup + +Start embedding service (e.g., using FlagAI or similar): + +```bash +# Example for BGE-M3 +python -m flagai.embedding_server \ + --model_name BAAI/bge-m3 \ + --port 8010 +``` + +## Preparing Embedding Files + +Generate embedding files in JSONL format: + +```jsonl +{"title": "Item 1", "info": {"embedding": [0.1, 0.2, ...]}} +{"title": "Item 2", "info": {"embedding": [0.3, 0.4, ...]}} +``` + +Example script to generate embeddings: + +```python +import json +import numpy as np +from sentence_transformers import SentenceTransformer + +model = SentenceTransformer('BAAI/bge-m3') +items = load_items_from_metas_jsonl() + +with open('./data/embeddings/movies_all_item_embedding.jsonl', 'w') as f: + for item in items: + text = f"{item['title']} {item.get('description', '')}" + embedding = model.encode(text).tolist() + f.write(json.dumps({"title": item['title'], "info": {"embedding": embedding}})) + f.write('\n') +``` + +## Quick Start + +1. Configure environment variables: +```bash +source plugin.env +``` + +2. Start required services (vLLM, embedding service) + +3. Run training: +```bash +swift rl \ + --model_type llama3-8b \ + --dataset path/to/rewrite_data.jsonl \ + --plugin_module plugin \ + --reward_functions average_ppl length discriminative item2item user2item +``` + +## Troubleshooting + +### Connection Errors + +If you see connection errors: +- Verify all services are running: `curl http://localhost:8020/v1/models` +- Check firewall rules allow connections between services +- Ensure environment variables are correctly set + +### Missing Embedding Files + +If you get "embedding file not found" errors: +- Verify `EMBEDDING_DATA_DIR` points to the correct location +- Check that embedding files exist: `{source}_all_item_embedding.jsonl` +- Ensure file format matches the expected JSONL structure + +### Performance Issues + +If training is slow: +- Increase `MAX_WORKERS` (but be mindful of resource limits) +- Use multiple vLLM endpoints with `VLLM_ENDPOINTS` +- Pre-generate and cache embeddings to reduce API calls diff --git a/RecLM-uni/README.md b/RecLM-uni/README.md new file mode 100644 index 0000000..d578438 --- /dev/null +++ b/RecLM-uni/README.md @@ -0,0 +1,471 @@ + +# RecLM-uni +## Introduction +This project introduces methods for avoid recommending out-of-domain items in LLM-based recsys. It contains the code for implementing three methods, i.e., RecLM-cgen, RecLM-ret and RecLM-token. + +**RecLM-uni** is a generative recommendation framework in the native structure of LLMs. This framework divides the output space of LLMs into item generation and general text generation parts by introducing item control tokens, and simultaneously employs a decoding strategy with prefix tree constraints to prevent the generation of out-of-domain items. RecLM-uni enables LLMs to acquire the ability to recommend products without sacrificing their original general capabilities. + +The RecLM-uni framework seamlessly integrates LLMs with recommendation scenarios. Interacting with RecLM-uni is just like interacting with general LLMs, enabling users to complete recommendation tasks and other general tasks in multi-round conversations. + +The pipeline of RecLM-uni has 4 steps: +1. Preprocessing raw dataset +2. Training teacher model +3. Deploying teacher model service +4. Training RecLM-uni + +This project is mainly contributed by College of Computer Science and Software Engineering, Shenzhen University. + +Our implementation leverages the [`transformers`](https://github.com/huggingface/transformers) library by Hugging Face. + +## 1. Raw dataset preprocess +We provide the code in `preprocess/data_preprocess_amazon.py` to automatically generate the intermediate dataset with above format from the downloaded raw dataset. + +Firstly, download `Movies_and_TV_5.json.gz` and `meta_Movies_and_TV.json.gz` from [Amazon](https://cseweb.ucsd.edu/~jmcauley/datasets/amazon_v2/), then place them in `data/dataset/movies/` and run the next command. + +Then, change the data path and dataset full name in [./scripts/data_preprocess_amazon.sh](scripts/data_preprocess_amazon.sh). +```shell +TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct" +DATASET_FULL_NAME="Movies_and_TV" +DATASET_NAME="movies" # used for selecting dataset in subsequent experiments. +DATA_PATH="./data/dataset/${DATASET_NAME}/" +UNIREC_DATA_PATH="./unirec/data/${DATASET_NAME}/" +UNIREC_CONFIG_PATH="./unirec/config/dataset/${DATASET_NAME}.yaml" +``` +After that, run the command `./scripts/data_preprocess_amazon.sh` to generate the intermediate dataset. + + +### Intermediate dataset format + +To use this repo, you'll need an intermediate dataset comprising at least three files located in data_path: `category.jsonl`, `metas.jsonl`, and `sequential.jsonl`. +You can prepare your own dataset in this format to train the model. + +**A volunteer has prepared a copy of data for reproducing the experiments. You can download it from [Google Drive link](https://drive.google.com/file/d/1jZMa0Sx-zVccCpkep5KiY6VXoOdl6PCl/view?usp=drive_link), and place each file of it in the respective path. Thanks [Luuuk12321](https://github.com/Luuuk12321)!** + +#### category.jsonl +This file contains a dictionary where the keys are category names, and the values are lists of item IDs belonging to those categories. +```json +{ + "category_1": ["item_id_1", "..."], + "category_2": ["item_id_i", "..."], + "...": "...", + "category_k": ["item_id_j", "..."] +} +``` +#### metas.jsonl +This file contains a dictionary where the keys are item IDs, and the values are dictionaries with at least one field of item index. This field is used for prefix tree construction (such as `title` or `title_t`). +```json +{ + "item_id_1": {"title": "...", "title_t": "...", "description": "..."}, + "item_id_2": {"title": "...", "title_t": "...", "description": "..."}, + "...": "...", + "item_id_n": {"title": "...", "title_t": "...", "description": "..."} +} +``` + +#### sequential.jsonl +This file contains a dictionary where the keys are user IDs, and the values are lists of item IDs that represent the user's historical interactions in a time-dependent order. + +```json +{ + "user_id_1": ["item_id_1", "...", "item_id_x"], + "...": "...", + "user_id_m": ["item_id_1", "...", "item_id_y"] +} +``` + + +## 2. SASRec Server +We utilize the [UniRec](https://github.com/microsoft/UniRec) library to implement the SASRec teacher model and deploy as a server. + +### 2.1. Install UniRec + +Clone the UniRec repository and install the necessary packages: + +```shell +git clone https://github.com/microsoft/UniRec.git +pip install --user --upgrade setuptools wheel twine +``` + +Modify the `unirec/setup.py` file to update the `torch` dependency: + +```python +install_requires = [ + "torch>=1.10.0,<=1.13.1" # Change this line to the one below + # "torch>=1.10.0,<=2.1.2", + "..." +] +``` + +Continue with the installation: + +```shell +cd UniRec +python setup.py sdist bdist_wheel +pip install dist/unirec-*.whl +``` + +### 2.2. Unirec dataset for SASRec model training +You need the dataset files `train.pkl`, `valid.pkl`, `test.pkl`, `user_history.pkl`, `map.pkl`, and `category.jsonl` to train SASRec model with UniRec library. + +1. After running of `./scripts/data_preprocess_amazon.sh`, these files will be placed in `./unirec/data/movies/`. + +2. If you had prepared the intermediate dataset, these files will be automatically generated according to the intermediate dataset in `./data/dataset/movies/`. + + +### 2.3. SASRec model training + +Train the model by specifying the dataset name (e.g., `movies`): + +```shell +./scripts/unirec_train.sh movies +``` +Model parameters and weights are saved in `./unirec/output/`. + +### 2.4. SASRec service deploying + +Update the `MODEL_PATH` and `DATASET_NAME` in [./scripts/unirec_serve.sh](./scripts/unirec_serve.sh) to point to the model files: + +```python +DATASET_NAME="movies" +MODEL_PATH="./unirec/output/movies/SASRec/train/checkpoint_.../SASRec-SASRec-movies.pth" +``` + +Start the server by specifying the serve port(`2068`): + +```shell +./scripts/unirec_serve.sh 2068 +``` + + +## 3. SFT stage + +### 3.1. SFT train + +The training dataset is dynamically generated during the `__getitem__` function call of the dataset class. An example script for training can be found at [./scripts/train_RecLM_cgen.sh](scripts/train_RecLM_cgen.sh) for **RecLM-cgen** and [./scripts/train_RecLM_ret.sh](scripts/train_RecLM_ret.sh) for **RecLM-ret**. The training script for **RecLM-token** is shown in Section 4.4. +```shell +./scripts/train_RecLM_cgen.sh movies # RecLM-cgen +./scripts/train_RecLM_ret.sh movies # RecLM-ret +``` + +### 3.2. SFT model merge + +Merge the trained models using the script found at [./scripts/run_SFT_merge.sh](scripts/run_SFT_merge.sh). The merged model will be saved to `snap/.../SFT_Epoch20/`. +```shell +./scripts/run_SFT_merge.sh +``` + +## 4. RecLM-token + +**RecLM-token** uses RQ-VAE to encode items as sequences of special tokens (codebook), enabling semantic item representation while maintaining generation constraints. + +### 4.1. Step 1: Prepare Item Embeddings + +Generate item description embeddings by extracting the last hidden layer from LLaMA-3-8B and applying attention-masked weighted average pooling. Each item's text (title + description from `metas.jsonl`) should be encoded into a fixed-dimensional embedding vector. + +The embeddings should be saved as a numpy array with shape `(num_items, hidden_dim)`, where `hidden_dim` is the LLaMA model's hidden size. The output file should be named `{dataset}.emb-llama-td.npy` and placed in the data directory (e.g., `data/dataset/movies/movies.emb-llama-td.npy`). + +### 4.2. Step 2: Train RQ-VAE Model + +Train the Residual Quantized Variational AutoEncoder(RQ-VAE) to learn item codebook mappings using the `index/` module from RecLM-LC1 or RecLM-cgen: + +```bash +cd index + +python main.py \ + --lr 1e-3 \ + --epochs 10000 \ + --batch_size 1024 \ + --weight_decay 1e-4 \ + --lr_scheduler_type linear \ + --dropout_prob 0.0 \ + --bn False \ + --e_dim 32 \ + --quant_loss_weight 1.0 \ + --contrastive_loss_weight 0.1 \ + --beta 0.25 \ + --num_emb_list 256 256 256 256 \ + --sk_epsilons 0.0 0.0 0.0 0.003 \ + --layers 2048 1024 512 256 128 64 \ + --device cuda:0 \ + --data_path ../../RecAI/RecLM-uni/data/dataset/movies/movies.emb-llama-td.npy \ + --ckpt_dir ./ckpt/movies/ +``` + +**Key Parameters:** +- `--num_emb_list`: Codebook sizes per layer (e.g., `256 256 256 256` = 4 layers, 256 codes each) +- `--sk_epsilons`: Sinkhorn algorithm epsilon for collision avoidance (enable on last layer) +- `--e_dim`: Codebook embedding dimension + +**Output:** Model checkpoint saved to `./ckpt/movies/{timestamp}/best_collision_model.pth` + +### 4.3. Step 3: Generate Item Indices + +Generate `.index.json` and `.item2id` files from the trained RQ-VAE model: + +```bash +python generate_indices.py \ + --dataset movies \ + --ckpt_path ./ckpt/movies/best_collision_model.pth \ + --output_dir ../../RecAI/RecLM-uni/data/dataset/movies/ \ + --metas_file ../../RecAI/RecLM-uni/data/dataset/movies/metas.jsonl \ + --device cuda:0 +``` + +**Output Files:** +- `movies.index.json`: Maps indices to token sequences (e.g., `"0": ["", "", "", ""]`) +- `movies.item2id`: Maps item IDs to indices (e.g., `B00001234 0`) + +Return to RecLM-uni directory after generating the files: + +### 4.4. Step 4: Train SFT Model with Codebook + +> **Prerequisite**: Before training the SFT model, ensure that the SASRec teacher model has been trained and deployed as a service (follow the steps in Sections 2.2-2.4). The training script will connect to the teacher service at the port specified by `--teacher_port`. + +Train the recommendation model using codebook representation. The codebase will **automatically detect** and use the codebook files if they exist in the data directory. + +```bash +python main.py \ + --seed 0 \ + --data_path data/dataset/movies/ \ + --backbone meta-llama/Meta-Llama-3-8B-Instruct \ + --item_index rq_token_seq \ + --train_stage SFT \ + --SFT_train_tasks SFTSeqRec-CS-MR \ + --SFT_val_tasks SFTTestSeqRec-MR \ + --multi_round_ratio 0.1 \ + --use_control_symbol \ + --use_CBS \ + --CBS_type 2 \ + --batch_size 2 \ + --topk 10 \ + --epoch 20 \ + --lr 0.0001 \ + --gradient_accumulation_steps 32 \ + --SFT_actor_lora_r 16 \ + --SFT_actor_lora_a 8 \ + --chat_template llama-3 \ + --teacher_port 2609 \ + --FA2 \ + --backup_ip 0.0.0.0 \ + --loss_type 3 \ + --scope_mask_type 3 \ + --fl_gamma 2 \ + --token_emb \ + --output snap/movies-token/ +``` + +**Key Configuration:** +- `--item_index rq_token_seq`: Use codebook representation (optional, will auto-detect) +- `--use_control_symbol`: Enable `` and `` markers +- `--use_CBS`: Enable Constrained Beam Search for valid token sequences + +**Note:** The codebook files (`{dataset}.index.json` and `{dataset}.item2id`) must be present in the `data_path` directory. If they are missing, the model will automatically fall back to using text-based item representation (`title_t`). + +### 4.5. Step 5: GRPO Training + +After SFT training, you can apply GRPO to further improve recommendation quality using reinforcement learning. This step uses the SFT model as the initial policy and optimizes ranking-based rewards. + +**Prerequisite**: Complete SFT model merge (Section 3.2) to obtain the merged model checkpoint. + +Update the configuration in [./GRPO/run_GRPO.sh](./GRPO/run_GRPO.sh): + +```bash +INITIAL_MODEL="snap/movies-token/SFT_Epoch20" # Path to merged SFT model +OUTPUT_BASE=./output/$(date "+%m%d")-${DATASET}-RL +DATASET="movies" +DATA_PATH="data/${DATASET}/" +``` + +Then launch GRPO training: + +```bash +./GRPO/run_GRPO.sh +``` + +**Key GRPO Parameters:** +- `--num_generations 16`: Number of generations per prompt for advantage estimation +- `--reward_type ranking`: Use ranking-based reward (hit rate, ndcg) +- `--beta 1e-3`: KL divergence coefficient for policy constraint +- `--use_lora`: Apply LoRA for efficient parameter updates +- `--rl_max_samples 10000`: Maximum number of RL training samples + +The GRPO-trained model will be saved to `${OUTPUT_BASE}` and can be used for evaluation in the next step. + +### 4.6. Step 6: RecLM-token Evaluation + +#### Recommendation Testing + +```bash +python task_test_tokenizer.py \ + --data_path data/dataset/movies/ \ + --SFT_test_task SFTTestSeqRec-MR \ + --model_name snap/movies-token/SFT_Epoch20/ \ + --gpu cuda:0 \ + --item_index rq_token_seq \ + --use_control_symbol \ + --batch_size 72 \ + --use_CBS \ + --CBS_type 2 \ + --idx \ + --topk 10 +``` + +**Key Difference from RecLM-cgen:** +- **RecLM-cgen**: Items represented as text (e.g., `"The Dark Knight"`) +- **RecLM-token**: Items represented as token sequences (e.g., `""`) + +## 5. RecLM-cgen (Text-based Item Generation) + +**RecLM-cgen** uses item title text directly as the generation target, enabling flexible and interpretable item recommendations. + +> **Prerequisite**: Before training RecLM-cgen, you need to train an item title rewrite model using MS-Swift (Section 8). The rewrite model helps improve recommendation quality by generating paraphrased item titles that maintain semantic meaning while varying surface form. + +### 5.1. Recommendation testing +```shell +python task_test.py \ +--data_path data/dataset/movies/ \ +--SFT_test_task SFTTestSeqRec-MR \ +--model_name snap/.../SFT_Epoch20/ \ +--gpu cuda:0 \ +--use_control_symbol \ +--batch_size 16 \ +--use_CBS \ +--CBS_type 2 \ +--topk 10 \ +--idx + +# setting --data_path to `data/dataset/toys/` for cross-domain evaluation. +``` + +### 5.2. Multi-round conversation testing +```shell +python task_MR_test.py \ +--data_path data/dataset/movies/ \ +--SFT_test_task SFTTestSeqRec-CS-MR \ +--model_name snap/.../SFT_Epoch20/ \ +--gpu cuda:0 \ +--use_control_symbol \ +--batch_size 8 \ +--use_CBS \ +--CBS_type 2 \ +--topk 10 \ +--idx +``` + +### 5.3. SFT model deploying +```shell +python cli_serve.py \ +--model_name snap/.../SFT_Epoch20/ \ +--gpu cuda:0 +``` + +## 6. RecLM-ret testing + +### 6.1. Recommendation testing +```shell +python main.py \ +--seed 0 \ +--data_path data/dataset/movies/ \ +--SFT_test_task SFTTestSeqRec-MR \ +--gpu cuda:0 \ +--use_control_symbol \ +--test_batch_size 8 \ +--topk 10 \ +--item_index title_t \ +--idx \ +--gen_max_length 512 \ +--max_token_length 1024 \ +--train_stage SFT_Embedding_Test \ +--SFT_actor_lora_r 16 \ +--SFT_actor_lora_a 8 \ +--chat_template llama-3 \ +--FA2 \ +--backbone meta-llama/Meta-Llama-3-8B-Instruct \ +--embedding_model BAAI/bge-m3 \ +--SFT_load snap/.../Epoch20_SFT_Embedding +``` + +### 6.2. Multi-round conversation testing +```shell +python main.py \ +--seed 0 \ +--data_path data/dataset/movies/ \ +--SFT_test_task SFTTestSeqRec-CS-MR \ +--gpu cuda:0 \ +--use_control_symbol \ +--test_batch_size 8 \ +--topk 10 \ +--item_index title_t \ +--idx \ +--gen_max_length 512 \ +--max_token_length 1024 \ +--train_stage SFT_Embedding_Test \ +--SFT_actor_lora_r 16 \ +--SFT_actor_lora_a 8 \ +--chat_template llama-3 \ +--FA2 \ +--backbone meta-llama/Meta-Llama-3-8B-Instruct \ +--embedding_model BAAI/bge-m3 \ +--SFT_load snap/.../Epoch20_SFT_Embedding +``` + +## 7. Build domain item prefix tree for enabling constrained generation +You can customize the recommendation domain and build the domain item prefix tree for enabling constrained generation following the next code. +```python +from train_utils.processor import FastPrefixConstrainedLogitsProcessor, Trie_link +from transformers import AutoTokenizer, AutoModelForCausalLM + +tokenizer = AutoTokenizer.from_pretrained(...) +tokenizer.soi_token_id = xxx # specific a token +tokenizer.eoi_token_id = xxx # specific a token +model = AutoModelForCausalLM.from_pretrained(...) + +in_domain_titles: list[str] = [...] # customized domain titles +item_ids = tokenizer.batch_encode_plus(in_domain_titles).data['input_ids'] + +num_beams = 1 +# create prefix tree +item_prefix_tree = Trie_link(item_ids, tokenizer) +# create logit processor base on prefix tree +processor = FastPrefixConstrainedLogitsProcessor( + item_prefix_tree.constrain_search_list, + num_beams +) + +output = model.generate( + ..., + logits_processor=[processor], + num_beams=num_beams +) +``` + +## 8. Item Title Rewrite Training with ms-swift + +The item title rewrite model is a crucial component for RecLM-cgen that paraphrases item titles to improve recommendation robustness and diversity. This model is trained using reinforcement learning via [[ms-swift](https://github.com/modelscope/ms-swift)] framework with custom reward functions. + +The rewrite model learns to generate paraphrased versions of item titles while preserving their semantic meaning. This is achieved through GRPO (Group Relative Policy Optimization) training with multiple reward signals: + +- **ConditionalPPL**: Measures language model quality of the rewritten title in context +- **LengthReward**: Encourages appropriate title length relative to the original +- **DiscriminativeReward**: Tests if the rewrite can be correctly matched to the original item +- **Item2ItemReward**: Evaluates semantic similarity using embedding-based retrieval +- **User2ItemReward**: Assesses if the rewrite maintains user-relevance in recommendation context + +All reward functions are implemented in [./plugin.py](./plugin.py). + +The paraphrased titles generated by the rewrite model are then used as additional training samples for RecLM-cgen (Section 3.1), improving the model's robustness to variations in item title phrasing and enhancing recommendation diversity. + +## 9. Citation +If you find this project useful in your research, please cite our research paper: + +``` +@misc{liao2026eliminatingoutofdomainrecommendationsllmbased, + title={Eliminating Out-of-Domain Recommendations in LLM-based Recommender Systems: A Unified View}, + author={Hao Liao and Jiwei Zhang and Jianxun Lian and Wensheng Lu and Mingqi Wu and Shuo Wang and Yong Zhang and Yitian Huang and Mingyang Zhou and Rui Mao}, + year={2026}, + eprint={2505.03336}, + archivePrefix={arXiv}, + primaryClass={cs.IR}, + url={https://arxiv.org/abs/2505.03336}, +} +``` diff --git a/RecLM-cgen/accelerate.yaml b/RecLM-uni/accelerate.yaml similarity index 100% rename from RecLM-cgen/accelerate.yaml rename to RecLM-uni/accelerate.yaml diff --git a/RecLM-cgen/cli_serve.py b/RecLM-uni/cli_serve.py similarity index 100% rename from RecLM-cgen/cli_serve.py rename to RecLM-uni/cli_serve.py diff --git a/RecLM-uni/grpo_train.py b/RecLM-uni/grpo_train.py new file mode 100644 index 0000000..188ff80 --- /dev/null +++ b/RecLM-uni/grpo_train.py @@ -0,0 +1,218 @@ +import argparse +import math +import os +import random +import re +import torch + +from typing import List, Optional + +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import GRPOConfig, GRPOTrainer + +from GRPO.rl_dataset import build_rl_samples + +try: + from peft import LoraConfig, get_peft_model +except ImportError: + LoraConfig = None + get_peft_model = None + + +def parse_args(): + parser = argparse.ArgumentParser(description="GRPO training on RecLM data") + parser.add_argument("--data_path", type=str, required=True, help="Path to dataset folder, e.g. data/steam/") + parser.add_argument("--model_path", type=str, required=True, help="HF model directory produced by SFT merge.") + parser.add_argument("--output_dir", type=str, required=True, help="Directory to store GRPO checkpoints.") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--topk", type=int, default=10, help="Recommendation list length.") + parser.add_argument("--max_item_length", type=int, default=20, help="History window length.") + parser.add_argument("--rl_max_samples", type=int, default=20000, help="Maximum RL samples to build.") + parser.add_argument("--eval_ratio", type=float, default=0.05, help="Portion of samples for evaluation.") + parser.add_argument("--train_batch_size", type=int, default=32) + parser.add_argument("--eval_batch_size", type=int, default=64) + parser.add_argument("--eval_steps", type=int, default=None, help="Run evaluation every N steps. Defaults to logging_steps.") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--num_generations", type=int, default=8) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=1e-5) + parser.add_argument("--beta", type=float, default=0.04, help="KL coefficient.") + parser.add_argument("--max_completion_length", type=int, default=128) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--reward_type", choices=["rule", "ranking"], default="ranking") + parser.add_argument("--logging_steps", type=int, default=10) + parser.add_argument("--bf16", action="store_true") + parser.add_argument("--use_control_symbol", action="store_true") + parser.add_argument("--idx", action="store_true", help="Prefix recommendation list with indices.") + parser.add_argument("--item_index", type=str, default="rq_token_seq", help="Meta field to use as item surface text.") + parser.add_argument("--save_total_limit", type=int, default=5) + parser.add_argument("--use_lora", action="store_true") + parser.add_argument("--lora_r", type=int, default=16) + parser.add_argument("--lora_alpha", type=int, default=32) + parser.add_argument("--lora_dropout", type=float, default=0.05) + parser.add_argument( + "--lora_target_modules", + type=str, + default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj", + help="Comma separated module names to apply LoRA on.", + ) + return parser.parse_args() + + +def normalize_title(title: str) -> str: + return re.sub(r"\s+", " ", title).strip().lower() + + +def extract_generated_items(text: str, use_control_symbol: bool) -> List[str]: + if use_control_symbol: + matches = re.findall(r"\s*(.*?)\s*", text, flags=re.MULTILINE) + if matches: + return [normalize_title(m) for m in matches] + # fallback: split by newline or numbering + lines = [] + for raw_line in text.splitlines(): + clean = raw_line.strip() + if not clean: + continue + clean = re.sub(r"^\d+[\.\)]\s*", "", clean) + lines.append(normalize_title(clean)) + return lines + + +def make_rule_reward(use_control_symbol: bool): + def reward_fn(prompts, completions, target_title, **_): + rewards = [] + for comp, target in zip(completions, target_title): + pred_items = extract_generated_items(comp, use_control_symbol) + target_norm = normalize_title(target) + score = 1.0 if target_norm in pred_items else 0.0 + rewards.append(score) + return rewards + + return reward_fn + + +def make_ndcg_reward(use_control_symbol: bool): + def reward_fn(prompts, completions, target_title, **_): + rewards = [] + for comp, target in zip(completions, target_title): + pred_items = extract_generated_items(comp, use_control_symbol) + target_norm = normalize_title(target) + if target_norm in pred_items: + rank = pred_items.index(target_norm) + rewards.append(1.0 / math.log2(rank + 2)) + else: + rewards.append(0.0) + return rewards + + return reward_fn + + +def main(): + args = parse_args() + random.seed(args.seed) + + samples = build_rl_samples( + data_path=args.data_path, + item_index_field=args.item_index, + max_samples=args.rl_max_samples, + seed=args.seed, + topk=args.topk, + max_item_length=args.max_item_length, + use_control_symbol=args.use_control_symbol, + use_idx=args.idx, + ) + + if len(samples) < 10: + raise RuntimeError("RL sample size too small, please check data or increase rl_max_samples.") + + eval_size = max(1, int(len(samples) * args.eval_ratio)) + train_samples = samples[:-eval_size] + eval_samples = samples[-eval_size:] + + def sample_to_dict(sample): + return { + "prompt": sample.prompt, + "reference_response": sample.reference_response, + "target_title": sample.target_title, + "target_item": sample.target_item, + "history": sample.history_text, + } + + train_dataset = Dataset.from_list([sample_to_dict(s) for s in train_samples]) + eval_dataset = Dataset.from_list([sample_to_dict(s) for s in eval_samples]) + + os.makedirs(args.output_dir, exist_ok=True) + + eval_steps = args.eval_steps if args.eval_steps and args.eval_steps > 0 else max(1, args.logging_steps) + + training_args = GRPOConfig( + output_dir=args.output_dir, + per_device_train_batch_size=args.train_batch_size, + per_device_eval_batch_size=args.eval_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_generations=args.num_generations, + max_completion_length=args.max_completion_length, + learning_rate=args.learning_rate, + beta=args.beta, + num_train_epochs=args.num_train_epochs, + logging_steps=args.logging_steps, + eval_strategy="steps", + eval_steps=eval_steps, + save_strategy="epoch", + temperature=args.temperature, + bf16=True, + save_total_limit=args.save_total_limit, + report_to=[], + ) + + tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False) + + model_init: Optional[object] + if args.use_lora: + if LoraConfig is None or get_peft_model is None: + raise ImportError("peft is required for LoRA training but was not found.") + base_model = AutoModelForCausalLM.from_pretrained( + args.model_path, + torch_dtype=torch.bfloat16 if args.bf16 else None, + ) + target_modules = [ + module.strip() + for module in args.lora_target_modules.split(",") + if module.strip() + ] + lora_config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + target_modules=target_modules or None, + ) + model_init = get_peft_model(base_model, lora_config) + try: + model_init.print_trainable_parameters() + except AttributeError: + pass + else: + model_init = args.model_path + + reward_funcs = [make_rule_reward(args.use_control_symbol)] + if args.reward_type == "ranking": + reward_funcs.append(make_ndcg_reward(args.use_control_symbol)) + + trainer = GRPOTrainer( + model=model_init, + reward_funcs=reward_funcs, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, + args=training_args, + ) + + trainer.train() + trainer.save_model(args.output_dir) + +if __name__ == "__main__": + main() diff --git a/RecLM-uni/index/datasets.py b/RecLM-uni/index/datasets.py new file mode 100644 index 0000000..08d6eb9 --- /dev/null +++ b/RecLM-uni/index/datasets.py @@ -0,0 +1,21 @@ +import numpy as np +import torch +import torch.utils.data as data + + +class EmbDataset(data.Dataset): + + def __init__(self,data_path): + + self.data_path = data_path + # self.embeddings = np.fromfile(data_path, dtype=np.float32).reshape(16859,-1) + self.embeddings = np.load(data_path) + self.dim = self.embeddings.shape[-1] + + def __getitem__(self, index): + emb = self.embeddings[index] + tensor_emb=torch.FloatTensor(emb) + return tensor_emb + + def __len__(self): + return len(self.embeddings) diff --git a/RecLM-uni/index/generate_indices.py b/RecLM-uni/index/generate_indices.py new file mode 100644 index 0000000..68c82f7 --- /dev/null +++ b/RecLM-uni/index/generate_indices.py @@ -0,0 +1,216 @@ +import collections +import json +import logging +import argparse + +import numpy as np +import torch +from time import time +from torch import optim +from tqdm import tqdm + +from torch.utils.data import DataLoader + +from datasets import EmbDataset +from models.rqvae import RQVAE + +import os + +def check_collision(all_indices_str): + tot_item = len(all_indices_str) + tot_indice = len(set(all_indices_str.tolist())) + return tot_item==tot_indice + +def get_indices_count(all_indices_str): + indices_count = collections.defaultdict(int) + for index in all_indices_str: + indices_count[index] += 1 + return indices_count + +def get_collision_item(all_indices_str): + index2id = {} + for i, index in enumerate(all_indices_str): + if index not in index2id: + index2id[index] = [] + index2id[index].append(i) + + collision_item_groups = [] + + for index in index2id: + if len(index2id[index]) > 1: + collision_item_groups.append(index2id[index]) + + return collision_item_groups + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate item indices from trained RQ-VAE model") + parser.add_argument("--dataset", type=str, default="Movies", help="Dataset name") + parser.add_argument("--ckpt_path", type=str, default=None, + help="Path to RQ-VAE checkpoint. If not provided, uses ./ckpt/{dataset}/best_collision_model.pth") + parser.add_argument("--output_dir", type=str, default=None, + help="Output directory. If not provided, uses ../data/{dataset}/") + parser.add_argument("--metas_file", type=str, default=None, + help="Path to metas.jsonl file for extracting item_ids. If not provided, uses sequential indices.") + parser.add_argument("--device", type=str, default="cuda:0", help="Device to use") + return parser.parse_args() + + +cmd_args = parse_args() + +# Set default paths if not provided +dataset = cmd_args.dataset +ckpt_path = cmd_args.ckpt_path or f"./ckpt/{dataset}/best_collision_model.pth" +output_dir = cmd_args.output_dir or f"../data/{dataset}/" +device = torch.device(cmd_args.device) + +# Create output directory if needed +os.makedirs(output_dir, exist_ok=True) + +# Output file paths +index_output_file = os.path.join(output_dir, f"{dataset}.index.json") +item2id_output_file = os.path.join(output_dir, f"{dataset}.item2id") + +print(f"Dataset: {dataset}") +print(f"Checkpoint: {ckpt_path}") +print(f"Output directory: {output_dir}") +print(f"Index output: {index_output_file}") +print(f"Item2id output: {item2id_output_file}") + +ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'), weights_only=False) +train_args = ckpt["args"] +state_dict = ckpt["state_dict"] + + +data = EmbDataset(train_args.data_path) + +model = RQVAE(in_dim=data.dim, + num_emb_list=train_args.num_emb_list, + e_dim=train_args.e_dim, + layers=train_args.layers, + dropout_prob=train_args.dropout_prob, + bn=train_args.bn, + loss_type=train_args.loss_type, + quant_loss_weight=train_args.quant_loss_weight, + kmeans_init=train_args.kmeans_init, + kmeans_iters=train_args.kmeans_iters, + sk_epsilons=train_args.sk_epsilons, + sk_iters=train_args.sk_iters, + ) + +model.load_state_dict(state_dict) +model = model.to(device) +model.eval() +print(model) + +data_loader = DataLoader(data, num_workers=train_args.num_workers, + batch_size=64, shuffle=False, + pin_memory=True) + +all_indices = [] +all_indices_str = [] +prefix = ["","","","",""] + +for d in tqdm(data_loader): + d = d.to(device) + indices = model.get_indices(d,use_sk=False) + indices = indices.view(-1, indices.shape[-1]).cpu().numpy() + for index in indices: + code = [] + for i, ind in enumerate(index): + code.append(prefix[i].format(int(ind))) + + all_indices.append(code) + all_indices_str.append(str(code)) + # break + +all_indices = np.array(all_indices) +all_indices_str = np.array(all_indices_str) + +for vq in model.rq.vq_layers[:-1]: + vq.sk_epsilon=0.0 +# model.rq.vq_layers[-1].sk_epsilon = 0.005 +if model.rq.vq_layers[-1].sk_epsilon == 0.0: + model.rq.vq_layers[-1].sk_epsilon = 0.003 + +tt = 0 +#There are often duplicate items in the dataset, and we no longer differentiate them +while True: + if tt >= 20 or check_collision(all_indices_str): + break + + collision_item_groups = get_collision_item(all_indices_str) + print(collision_item_groups) + print(len(collision_item_groups)) + for collision_items in collision_item_groups: + d = data[collision_items].to(device) + + indices = model.get_indices(d, use_sk=True) + indices = indices.view(-1, indices.shape[-1]).cpu().numpy() + for item, index in zip(collision_items, indices): + code = [] + for i, ind in enumerate(index): + code.append(prefix[i].format(int(ind))) + + all_indices[item] = code + all_indices_str[item] = str(code) + tt += 1 + + +print("All indices number: ",len(all_indices)) +print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) + +tot_item = len(all_indices_str) +tot_indice = len(set(all_indices_str.tolist())) +print("Collision Rate",(tot_item-tot_indice)/tot_item) + +# Prepare item_id list +print("\n" + "="*50) +print("Generating item_id mappings...") + +item_ids = [] +if cmd_args.metas_file and os.path.exists(cmd_args.metas_file): + # Load item_ids from metas.jsonl + print(f"Loading item_ids from: {cmd_args.metas_file}") + with open(cmd_args.metas_file, 'r', encoding='utf-8') as f: + metas = json.load(f) + + # Sort by emb_idx if available, otherwise use dict order + if metas and 'emb_idx' in next(iter(metas.values())): + print("Sorting by 'emb_idx' field") + sorted_items = sorted(metas.items(), key=lambda x: x[1].get('emb_idx', 0)) + item_ids = [item_id for item_id, _ in sorted_items] + else: + print("Using dictionary order (no emb_idx found)") + item_ids = list(metas.keys()) + + if len(item_ids) != len(all_indices): + print(f"WARNING: Metas has {len(item_ids)} items but embeddings have {len(all_indices)} items") + print(f"Using first {len(all_indices)} items from metas") + item_ids = item_ids[:len(all_indices)] +else: + # Use sequential indices as item_ids + print(f"No metas file provided, using sequential indices (0, 1, 2, ...)") + item_ids = [str(i) for i in range(len(all_indices))] + +# Generate .index.json (key: index, value: token list) +all_indices_dict = {} +for idx, indices in enumerate(all_indices.tolist()): + all_indices_dict[str(idx)] = list(indices) + +print(f"\nSaving index file: {index_output_file}") +with open(index_output_file, 'w', encoding='utf-8') as fp: + json.dump(all_indices_dict, fp, indent=2, ensure_ascii=False) + +# Generate .item2id file (format: item_id index) +print(f"Saving item2id file: {item2id_output_file}") +with open(item2id_output_file, 'w', encoding='utf-8') as f: + for idx, item_id in enumerate(item_ids): + f.write(f"{item_id} {idx}\n") + +print("\n" + "="*50) +print("Generation completed successfully!") +print(f"Total items: {len(all_indices)}") +print(f"Index file: {index_output_file}") +print(f"Item2id file: {item2id_output_file}") +print("="*50) \ No newline at end of file diff --git a/RecLM-uni/index/main.py b/RecLM-uni/index/main.py new file mode 100644 index 0000000..98da7cc --- /dev/null +++ b/RecLM-uni/index/main.py @@ -0,0 +1,98 @@ +import argparse +import random +import torch +import numpy as np +from time import time +import logging + +from torch.utils.data import DataLoader + +from datasets import EmbDataset +from models.rqvae import RQVAE +from trainer import Trainer + +def parse_args(): + parser = argparse.ArgumentParser(description="Index") + + parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') + parser.add_argument('--epochs', type=int, default=5000, help='number of epochs') + parser.add_argument('--batch_size', type=int, default=2048, help='batch size') + parser.add_argument('--num_workers', type=int, default=4, ) + parser.add_argument('--eval_step', type=int, default=50, help='eval step') + parser.add_argument('--learner', type=str, default="AdamW", help='optimizer') + parser.add_argument('--lr_scheduler_type', type=str, default="constant", help='scheduler') + parser.add_argument('--warmup_epochs', type=int, default=50, help='warmup epochs') + parser.add_argument("--data_path", type=str, + default="../data/Games/Games.emb-llama-td.npy", + help="Input data path.") + + parser.add_argument("--weight_decay", type=float, default=0.0, help='l2 regularization weight') + parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio") + parser.add_argument("--bn", type=bool, default=False, help="use bn or not") + parser.add_argument("--loss_type", type=str, default="mse", help="loss_type") + parser.add_argument("--kmeans_init", type=bool, default=True, help="use kmeans_init or not") + parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters") + parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.0, 0.0, 0.0], help="sinkhorn epsilons") + parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters") + + parser.add_argument("--device", type=str, default="cuda:0", help="gpu or cpu") + + parser.add_argument('--num_emb_list', type=int, nargs='+', default=[256,256,256], help='emb num of every vq') + parser.add_argument('--e_dim', type=int, default=32, help='vq codebook embedding size') + parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight') + parser.add_argument('--contrastive_loss_weight', type=float, default=0.0, help='contrastive loss weight') + parser.add_argument('--temperature', type=float, default=0.1, help='contrastive loss temperature') + parser.add_argument("--beta", type=float, default=0.25, help="Beta for commitment loss") + parser.add_argument('--layers', type=int, nargs='+', default=[2048,1024,512,256,128,64], help='hidden sizes of every layer') + + parser.add_argument('--save_limit', type=int, default=5) + parser.add_argument("--ckpt_dir", type=str, default="", help="output directory for model") + + return parser.parse_args() + + +if __name__ == '__main__': + """fix the random seed""" + seed = 2024 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + args = parse_args() + print("=================================================") + print(args) + print("=================================================") + + logging.basicConfig(level=logging.DEBUG) + + """build dataset""" + data = EmbDataset(args.data_path) + model = RQVAE(in_dim=data.dim, + num_emb_list=args.num_emb_list, + e_dim=args.e_dim, + layers=args.layers, + dropout_prob=args.dropout_prob, + bn=args.bn, + loss_type=args.loss_type, + quant_loss_weight=args.quant_loss_weight, + contrastive_loss_weight=args.contrastive_loss_weight, + temperature=args.temperature, + beta=args.beta, + kmeans_init=args.kmeans_init, + kmeans_iters=args.kmeans_iters, + sk_epsilons=args.sk_epsilons, + sk_iters=args.sk_iters, + ) + print(model) + data_loader = DataLoader(data,num_workers=args.num_workers, + batch_size=args.batch_size, shuffle=True, + pin_memory=True) + trainer = Trainer(args,model, len(data_loader)) + best_loss, best_collision_rate = trainer.fit(data_loader) + + print("Best Loss",best_loss) + print("Best Collision Rate", best_collision_rate) + diff --git a/RecLM-uni/index/models/layers.py b/RecLM-uni/index/models/layers.py new file mode 100644 index 0000000..6148d13 --- /dev/null +++ b/RecLM-uni/index/models/layers.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +from torch.nn.init import xavier_normal_ +from sklearn.cluster import KMeans + + +class MLPLayers(nn.Module): + + def __init__( + self, layers, dropout=0.0, activation="relu", bn=False + ): + super(MLPLayers, self).__init__() + self.layers = layers + self.dropout = dropout + self.activation = activation + self.use_bn = bn + + mlp_modules = [] + for idx, (input_size, output_size) in enumerate( + zip(self.layers[:-1], self.layers[1:]) + ): + mlp_modules.append(nn.Dropout(p=self.dropout)) + mlp_modules.append(nn.Linear(input_size, output_size)) + + if self.use_bn and idx != (len(self.layers)-2): + mlp_modules.append(nn.BatchNorm1d(num_features=output_size)) + + activation_func = activation_layer(self.activation, output_size) + if activation_func is not None and idx != (len(self.layers)-2): + mlp_modules.append(activation_func) + + self.mlp_layers = nn.Sequential(*mlp_modules) + self.apply(self.init_weights) + + def init_weights(self, module): + # We just initialize the module with normal distribution as the paper said + if isinstance(module, nn.Linear): + xavier_normal_(module.weight.data) + if module.bias is not None: + module.bias.data.fill_(0.0) + + def forward(self, input_feature): + return self.mlp_layers(input_feature) + +def activation_layer(activation_name="relu", emb_dim=None): + + if activation_name is None: + activation = None + elif isinstance(activation_name, str): + if activation_name.lower() == "sigmoid": + activation = nn.Sigmoid() + elif activation_name.lower() == "tanh": + activation = nn.Tanh() + elif activation_name.lower() == "relu": + activation = nn.ReLU() + elif activation_name.lower() == "leakyrelu": + activation = nn.LeakyReLU() + elif activation_name.lower() == "none": + activation = None + elif issubclass(activation_name, nn.Module): + activation = activation_name() + else: + raise NotImplementedError( + "activation function {} is not implemented".format(activation_name) + ) + + return activation + +def kmeans( + samples, + num_clusters, + num_iters = 10, +): + B, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device + x = samples.cpu().detach().numpy() + + cluster = KMeans(n_clusters = num_clusters, max_iter = num_iters).fit(x) + + centers = cluster.cluster_centers_ + tensor_centers = torch.from_numpy(centers).to(device) + + return tensor_centers + + +@torch.no_grad() +def sinkhorn_algorithm(distances, epsilon, sinkhorn_iterations): + Q = torch.exp(- distances / epsilon) + + B = Q.shape[0] # number of samples to assign + K = Q.shape[1] # how many centroids per block (usually set to 256) + + # make the matrix sums to 1 + sum_Q = Q.sum(-1, keepdim=True).sum(-2, keepdim=True) + Q /= sum_Q + # print(Q.sum()) + for it in range(sinkhorn_iterations): + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=1, keepdim=True) + Q /= B + + # normalize each row: total weight per prototype must be 1/K + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= K + + + Q *= B # the colomns must sum to 1 so that Q is an assignment + return Q \ No newline at end of file diff --git a/RecLM-uni/index/models/rq.py b/RecLM-uni/index/models/rq.py new file mode 100644 index 0000000..85f678b --- /dev/null +++ b/RecLM-uni/index/models/rq.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn + +from .vq import VectorQuantizer + + +class ResidualVectorQuantizer(nn.Module): + """ References: + SoundStream: An End-to-End Neural Audio Codec + https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, n_e_list, e_dim, sk_epsilons, beta = 0.25, + kmeans_init = False, kmeans_iters = 100, sk_iters=100,): + super().__init__() + self.n_e_list = n_e_list + self.e_dim = e_dim + self.num_quantizers = len(n_e_list) + self.beta = beta + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.sk_epsilons = sk_epsilons + self.sk_iters = sk_iters + self.vq_layers = nn.ModuleList([VectorQuantizer(n_e, e_dim, + beta=self.beta, + kmeans_init = self.kmeans_init, + kmeans_iters = self.kmeans_iters, + sk_epsilon=sk_epsilon, + sk_iters=sk_iters) + for n_e, sk_epsilon in zip(n_e_list,sk_epsilons) ]) + + def get_codebook(self): + all_codebook = [] + for quantizer in self.vq_layers: + codebook = quantizer.get_codebook() + all_codebook.append(codebook) + return torch.stack(all_codebook) + + def forward(self, x, use_sk=True): + all_losses = [] + all_indices = [] + + x_q = 0 + residual = x + for quantizer in self.vq_layers: + x_res, loss, indices = quantizer(residual, use_sk=use_sk) + residual = residual - x_res + x_q = x_q + x_res + + all_losses.append(loss) + all_indices.append(indices) + + mean_losses = torch.stack(all_losses).mean() + all_indices = torch.stack(all_indices, dim=-1) + + return x_q, mean_losses, all_indices \ No newline at end of file diff --git a/RecLM-uni/index/models/rqvae.py b/RecLM-uni/index/models/rqvae.py new file mode 100644 index 0000000..0cc52c2 --- /dev/null +++ b/RecLM-uni/index/models/rqvae.py @@ -0,0 +1,104 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from .layers import MLPLayers +from .rq import ResidualVectorQuantizer + + +class RQVAE(nn.Module): + def __init__(self, + in_dim=768, + # num_emb_list=[256,256,256,256], + num_emb_list=None, + e_dim=64, + # layers=[512,256,128], + layers=None, + dropout_prob=0.0, + bn=False, + loss_type="mse", + quant_loss_weight=1.0, + contrastive_loss_weight=0.0, + temperature=0.1, + beta=0.25, + kmeans_init=False, + kmeans_iters=100, + # sk_epsilons=[0,0,0.003,0.01]], + sk_epsilons=None, + sk_iters=100, + ): + super(RQVAE, self).__init__() + + self.in_dim = in_dim + self.num_emb_list = num_emb_list + self.e_dim = e_dim + + self.layers = layers + self.dropout_prob = dropout_prob + self.bn = bn + self.loss_type = loss_type + self.quant_loss_weight=quant_loss_weight + self.contrastive_loss_weight = contrastive_loss_weight + self.temperature = temperature + self.beta = beta + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.sk_epsilons = sk_epsilons + self.sk_iters = sk_iters + + self.encode_layer_dims = [self.in_dim] + self.layers + [self.e_dim] + self.encoder = MLPLayers(layers=self.encode_layer_dims, + dropout=self.dropout_prob,bn=self.bn) + + self.rq = ResidualVectorQuantizer(num_emb_list, e_dim, + beta=self.beta, + kmeans_init = self.kmeans_init, + kmeans_iters = self.kmeans_iters, + sk_epsilons=self.sk_epsilons, + sk_iters=self.sk_iters,) + + self.decode_layer_dims = self.encode_layer_dims[::-1] + self.decoder = MLPLayers(layers=self.decode_layer_dims, + dropout=self.dropout_prob,bn=self.bn) + + def forward(self, x, use_sk=True): + x = self.encoder(x) + x_q, rq_loss, indices = self.rq(x,use_sk=use_sk) + out = self.decoder(x_q) + + return out, rq_loss, indices + + @torch.no_grad() + def get_indices(self, xs, use_sk=False): + x_e = self.encoder(xs) + _, _, indices = self.rq(x_e, use_sk=use_sk) + return indices + + def compute_loss(self, out, quant_loss, xs=None): + + if self.loss_type == 'mse': + loss_recon = F.mse_loss(out, xs, reduction='mean') + elif self.loss_type == 'l1': + loss_recon = F.l1_loss(out, xs, reduction='mean') + else: + raise ValueError('incompatible loss type') + + loss_total = loss_recon + self.quant_loss_weight * quant_loss + + if self.contrastive_loss_weight > 0: + # InfoNCE Loss + # Normalize vectors + out_norm = F.normalize(out, p=2, dim=1) + xs_norm = F.normalize(xs, p=2, dim=1) + + # Compute similarity matrix (Batch_Size, Batch_Size) + logits = torch.matmul(out_norm, xs_norm.t()) / self.temperature + + # Targets are the diagonal elements (0, 1, 2, ...) + labels = torch.arange(logits.size(0), device=logits.device) + + loss_contrastive = F.cross_entropy(logits, labels) + loss_total += self.contrastive_loss_weight * loss_contrastive + + return loss_total, loss_recon \ No newline at end of file diff --git a/RecLM-uni/index/models/vq.py b/RecLM-uni/index/models/vq.py new file mode 100644 index 0000000..30474fb --- /dev/null +++ b/RecLM-uni/index/models/vq.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .layers import kmeans, sinkhorn_algorithm + + +class VectorQuantizer(nn.Module): + + def __init__(self, n_e, e_dim, + beta = 0.25, kmeans_init = False, kmeans_iters = 10, + sk_epsilon=0.003, sk_iters=100,): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.sk_epsilon = sk_epsilon + self.sk_iters = sk_iters + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + if not kmeans_init: + self.initted = True + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + else: + self.initted = False + self.embedding.weight.data.zero_() + + def get_codebook(self): + return self.embedding.weight + + def get_codebook_entry(self, indices, shape=None): + # get quantized latent vectors + z_q = self.embedding(indices) + if shape is not None: + z_q = z_q.view(shape) + + return z_q + + def init_emb(self, data): + + centers = kmeans( + data, + self.n_e, + self.kmeans_iters, + ) + + self.embedding.weight.data.copy_(centers) + self.initted = True + + @staticmethod + def center_distance_for_constraint(distances): + # distances: B, K + max_distance = distances.max() + min_distance = distances.min() + + middle = (max_distance + min_distance) / 2 + amplitude = max_distance - middle + 1e-5 + assert amplitude > 0 + centered_distances = (distances - middle) / amplitude + return centered_distances + + def forward(self, x, use_sk=True): + # Flatten input + latent = x.view(-1, self.e_dim) + + if not self.initted and self.training: + self.init_emb(latent) + + # Calculate the L2 Norm between latent and Embedded weights + d = torch.sum(latent**2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t()- \ + 2 * torch.matmul(latent, self.embedding.weight.t()) + if not use_sk or self.sk_epsilon <= 0: + indices = torch.argmin(d, dim=-1) + else: + d = self.center_distance_for_constraint(d) + d = d.double() + Q = sinkhorn_algorithm(d, self.sk_epsilon, self.sk_iters) + + if torch.isnan(Q).any() or torch.isinf(Q).any(): + print(f"Sinkhorn Algorithm returns nan/inf values.") + indices = torch.argmax(Q, dim=-1) + + # indices = torch.argmin(d, dim=-1) + + x_q = self.embedding(indices).view(x.shape) + + # compute loss for embedding + commitment_loss = F.mse_loss(x_q.detach(), x) + codebook_loss = F.mse_loss(x_q, x.detach()) + loss = codebook_loss + self.beta * commitment_loss + + # preserve gradients + x_q = x + (x_q - x).detach() + + indices = indices.view(x.shape[:-1]) + + return x_q, loss, indices + + diff --git a/RecLM-uni/index/trainer.py b/RecLM-uni/index/trainer.py new file mode 100644 index 0000000..9c620ab --- /dev/null +++ b/RecLM-uni/index/trainer.py @@ -0,0 +1,255 @@ +import logging + +import numpy as np +import torch +from time import time +from torch import optim +from tqdm import tqdm +from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup + +from utils import ensure_dir,set_color,get_local_time,delete_file +import os + +import heapq +class Trainer(object): + + def __init__(self, args, model, data_num): + self.args = args + self.model = model + self.logger = logging.getLogger() + + self.lr = args.lr + self.learner = args.learner + self.lr_scheduler_type = args.lr_scheduler_type + + self.weight_decay = args.weight_decay + self.epochs = args.epochs + self.warmup_steps = args.warmup_epochs * data_num + self.max_steps = args.epochs * data_num + + self.save_limit = args.save_limit + self.best_save_heap = [] + self.newest_save_queue = [] + self.eval_step = min(args.eval_step, self.epochs) + self.device = args.device + self.device = torch.device(self.device) + self.ckpt_dir = args.ckpt_dir + saved_model_dir = "{}".format(get_local_time()) + self.ckpt_dir = os.path.join(self.ckpt_dir,saved_model_dir) + ensure_dir(self.ckpt_dir) + + self.best_loss = np.inf + self.best_collision_rate = np.inf + self.best_loss_ckpt = "best_loss_model.pth" + self.best_collision_ckpt = "best_collision_model.pth" + self.optimizer = self._build_optimizer() + self.scheduler = self._get_scheduler() + self.model = self.model.to(self.device) + + def _build_optimizer(self): + + params = self.model.parameters() + learner = self.learner + learning_rate = self.lr + weight_decay = self.weight_decay + + if learner.lower() == "adam": + optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay) + elif learner.lower() == "sgd": + optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay) + elif learner.lower() == "adagrad": + optimizer = optim.Adagrad( + params, lr=learning_rate, weight_decay=weight_decay + ) + for state in optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = v.to(self.device) + elif learner.lower() == "rmsprop": + optimizer = optim.RMSprop( + params, lr=learning_rate, weight_decay=weight_decay + ) + elif learner.lower() == 'adamw': + optimizer = optim.AdamW( + params, lr=learning_rate, weight_decay=weight_decay + ) + else: + self.logger.warning( + "Received unrecognized optimizer, set default Adam optimizer" + ) + optimizer = optim.Adam(params, lr=learning_rate) + return optimizer + + def _get_scheduler(self): + if self.lr_scheduler_type.lower() == "linear": + lr_scheduler = get_linear_schedule_with_warmup(optimizer=self.optimizer, + num_warmup_steps=self.warmup_steps, + num_training_steps=self.max_steps) + else: + lr_scheduler = get_constant_schedule_with_warmup(optimizer=self.optimizer, + num_warmup_steps=self.warmup_steps) + + return lr_scheduler + def _check_nan(self, loss): + if torch.isnan(loss): + raise ValueError("Training loss is nan") + + + def _train_epoch(self, train_data, epoch_idx): + + self.model.train() + + total_loss = 0 + total_recon_loss = 0 + iter_data = tqdm( + train_data, + total=len(train_data), + ncols=100, + desc=set_color(f"Train {epoch_idx}","pink"), + ) + + for batch_idx, data in enumerate(iter_data): + data = data.to(self.device) + self.optimizer.zero_grad() + out, rq_loss, indices = self.model(data) + loss, loss_recon = self.model.compute_loss(out, rq_loss, xs=data) + self._check_nan(loss) + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + # print(self.scheduler.get_last_lr()) + total_loss += loss.item() + total_recon_loss += loss_recon.item() + + return total_loss, total_recon_loss + + @torch.no_grad() + def _valid_epoch(self, valid_data): + + self.model.eval() + + iter_data =tqdm( + valid_data, + total=len(valid_data), + ncols=100, + desc=set_color(f"Evaluate ", "pink"), + ) + + indices_set = set() + num_sample = 0 + for batch_idx, data in enumerate(iter_data): + num_sample += len(data) + data = data.to(self.device) + indices = self.model.get_indices(data) + indices = indices.view(-1,indices.shape[-1]).cpu().numpy() + for index in indices: + code = "-".join([str(int(_)) for _ in index]) + indices_set.add(code) + + collision_rate = (num_sample - len(list(indices_set)))/num_sample + + return collision_rate + + def _save_checkpoint(self, epoch, collision_rate=1, ckpt_file=None): + + ckpt_path = os.path.join(self.ckpt_dir,ckpt_file) if ckpt_file \ + else os.path.join(self.ckpt_dir, 'epoch_%d_collision_%.4f_model.pth' % (epoch, collision_rate)) + state = { + "args": self.args, + "epoch": epoch, + "best_loss": self.best_loss, + "best_collision_rate": self.best_collision_rate, + "state_dict": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + } + torch.save(state, ckpt_path, pickle_protocol=4) + + self.logger.info( + set_color("Saving current", "blue") + f": {ckpt_path}" + ) + + return ckpt_path + + def _generate_train_loss_output(self, epoch_idx, s_time, e_time, loss, recon_loss): + train_loss_output = ( + set_color("epoch %d training", "green") + + " [" + + set_color("time", "blue") + + ": %.2fs, " + ) % (epoch_idx, e_time - s_time) + train_loss_output += set_color("train loss", "blue") + ": %.4f" % loss + train_loss_output +=", " + train_loss_output += set_color("reconstruction loss", "blue") + ": %.4f" % recon_loss + return train_loss_output + "]" + + + def fit(self, data): + + cur_eval_step = 0 + + for epoch_idx in range(self.epochs): + # train + training_start_time = time() + train_loss, train_recon_loss = self._train_epoch(data, epoch_idx) + training_end_time = time() + train_loss_output = self._generate_train_loss_output( + epoch_idx, training_start_time, training_end_time, train_loss, train_recon_loss + ) + self.logger.info(train_loss_output) + + + # eval + if (epoch_idx + 1) % self.eval_step == 0: + valid_start_time = time() + collision_rate = self._valid_epoch(data) + + if train_loss < self.best_loss: + self.best_loss = train_loss + self._save_checkpoint(epoch=epoch_idx, ckpt_file=self.best_loss_ckpt) + + if collision_rate < self.best_collision_rate: + self.best_collision_rate = collision_rate + cur_eval_step = 0 + self._save_checkpoint(epoch_idx, collision_rate=collision_rate, + ckpt_file=self.best_collision_ckpt) + else: + cur_eval_step += 1 + + + valid_end_time = time() + valid_score_output = ( + set_color("epoch %d evaluating", "green") + + " [" + + set_color("time", "blue") + + ": %.2fs, " + + set_color("collision_rate", "blue") + + ": %f]" + ) % (epoch_idx, valid_end_time - valid_start_time, collision_rate) + + self.logger.info(valid_score_output) + ckpt_path = self._save_checkpoint(epoch_idx, collision_rate=collision_rate) + now_save = (-collision_rate, ckpt_path) + if len(self.newest_save_queue) < self.save_limit: + self.newest_save_queue.append(now_save) + heapq.heappush(self.best_save_heap, now_save) + else: + old_save = self.newest_save_queue.pop(0) + self.newest_save_queue.append(now_save) + if collision_rate < -self.best_save_heap[0][0]: + bad_save = heapq.heappop(self.best_save_heap) + heapq.heappush(self.best_save_heap, now_save) + + if bad_save not in self.newest_save_queue: + delete_file(bad_save[1]) + + if old_save not in self.best_save_heap: + delete_file(old_save[1]) + + + + return self.best_loss, self.best_collision_rate + + + + diff --git a/RecLM-uni/index/utils.py b/RecLM-uni/index/utils.py new file mode 100644 index 0000000..abd33c4 --- /dev/null +++ b/RecLM-uni/index/utils.py @@ -0,0 +1,37 @@ + +import datetime +import os + + +def ensure_dir(dir_path): + + os.makedirs(dir_path, exist_ok=True) + +def set_color(log, color, highlight=True): + color_set = ["black", "red", "green", "yellow", "blue", "pink", "cyan", "white"] + try: + index = color_set.index(color) + except: + index = len(color_set) - 1 + prev_log = "\033[" + if highlight: + prev_log += "1;3" + else: + prev_log += "0;3" + prev_log += str(index) + "m" + return prev_log + log + "\033[0m" + +def get_local_time(): + r"""Get current time + + Returns: + str: current time + """ + cur = datetime.datetime.now() + cur = cur.strftime("%b-%d-%Y_%H-%M-%S") + + return cur + +def delete_file(filename): + if os.path.exists(filename): + os.remove(filename) diff --git a/RecLM-cgen/main.py b/RecLM-uni/main.py similarity index 100% rename from RecLM-cgen/main.py rename to RecLM-uni/main.py diff --git a/RecLM-uni/plugin.env.example b/RecLM-uni/plugin.env.example new file mode 100644 index 0000000..a337136 --- /dev/null +++ b/RecLM-uni/plugin.env.example @@ -0,0 +1,30 @@ +# Environment Variables for plugin.py +# Copy this file to plugin.env and fill in your actual values + +# Model Configuration +MODEL_PATH=meta-llama/Meta-Llama-3-8B-Instruct + +# Embedding API Service (BGE-M3 or similar) +EMBEDDING_API_URL=http://localhost:8010/v1 +EMBEDDING_API_KEY=not-needed + +# vLLM Endpoints for LLM Inference +# For multiple endpoints, separate with commas +VLLM_ENDPOINTS=http://localhost:8020/v1,http://localhost:8021/v1,http://localhost:8022/v1 + +# Or configure per data source: +VLLM_ENDPOINTS_STEAM=http://localhost:8020/v1 +VLLM_ENDPOINTS_MOVIES=http://localhost:8020/v1 +VLLM_ENDPOINTS_TOYS=http://localhost:8020/v1 + +# Embedding Service for User/Item Similarity Search +EMBEDDING_SERVICE_BASE_URL=http://localhost:5000 +EMBEDDING_PORT_STEAM=5003 +EMBEDDING_PORT_MOVIES=5004 +EMBEDDING_PORT_TOYS=5005 + +# Data Directory for Embedding Files +EMBEDDING_DATA_DIR=./data/embeddings + +# Performance Settings +MAX_WORKERS=8 diff --git a/RecLM-uni/plugin.py b/RecLM-uni/plugin.py new file mode 100644 index 0000000..54de013 --- /dev/null +++ b/RecLM-uni/plugin.py @@ -0,0 +1,409 @@ +from math import exp +import numpy as np +import random +from copy import deepcopy +from typing import List, Dict, Any +from transformers import AutoTokenizer +from swift.plugin import ORM, orms +from swift.utils import get_logger +import openai +import time +import json +import math +import requests +from scipy.stats import spearmanr +from sklearn.metrics.pairwise import cosine_similarity +import faiss +from functools import lru_cache +from concurrent.futures import ThreadPoolExecutor +from collections import defaultdict + +logger = get_logger() + +# --- 全局配置与客户端 --- +# 建议将这些配置项在顶层统一定义 +import os + +MODEL_PATH = os.getenv("MODEL_PATH", "meta-llama/Meta-Llama-3-8B-Instruct") +TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True) + +# Embedding service configuration +EMBEDDING_API_URL = os.getenv("EMBEDDING_API_URL", "http://localhost:8010/v1") +EMBEDDING_API_KEY = os.getenv("EMBEDDING_API_KEY", "not-needed") +EMBEDDING_CLIENT = openai.OpenAI(base_url=EMBEDDING_API_URL, api_key=EMBEDDING_API_KEY) + +MAX_WORKERS = int(os.getenv("MAX_WORKERS", "8")) # 可根据您的系统和网络状况调整并发数 + +_embedding_file_cache = {} + +def read_embedding_file_cached(embedding_path): + if embedding_path in _embedding_file_cache: + return _embedding_file_cache[embedding_path] + + item2embedding = {} + with open(embedding_path, 'r', encoding='utf-8') as f: + for line in f: + data = json.loads(line) + item2embedding[data['title']] = data['info']['embedding'] + _embedding_file_cache[embedding_path] = item2embedding + return item2embedding + +# --- 优化的核心:用于Faiss索引和Embeddings的集中式资源管理器 --- +class ResourceManager: + """ + 处理加载、构建和缓存重量级资源(如Faiss索引和item embeddings), + 以避免在批处理中重复进行I/O和计算。 + """ + def __init__(self): + self._faiss_indexes: Dict[str, Any] = {} + self._item_names: Dict[str, List[str]] = {} + self._item_embeddings: Dict[str, Dict[str, List[float]]] = {} + + def _load_and_build(self, source: str, embedding_file: str): + # 如果该数据源的索引已存在,则直接返回 + if source in self._faiss_indexes: + return + + logger.info(f"为数据源构建Faiss索引: {source}...") + item2embedding = read_embedding_file_cached(embedding_file) + self._item_embeddings[source] = item2embedding + + item_names = list(item2embedding.keys()) + # 指定dtype为float32以兼容Faiss + embeddings_matrix = np.vstack([np.array(v, dtype=np.float32) for v in item2embedding.values()]) + faiss.normalize_L2(embeddings_matrix) + + index = faiss.IndexFlatIP(embeddings_matrix.shape[1]) + index.add(embeddings_matrix) + + self._item_names[source] = item_names + self._faiss_indexes[source] = index + logger.info(f"为数据源 {source} 构建Faiss索引完成。") + + def get_faiss_resources(self, source: str, embedding_file: str): + self._load_and_build(source, embedding_file) + return self._faiss_indexes[source], self._item_names[source] + + def get_item_embeddings(self, source: str, embedding_file: str): + self._load_and_build(source, embedding_file) + return self._item_embeddings[source] + +# 创建资源管理器的全局实例 +resource_manager = ResourceManager() + + +# --- 重构后的网络请求工具函数 --- +def _get_vllm_endpoints(source: str) -> List[str]: + # 将端点选择逻辑集中管理,便于维护 + # Configure vLLM endpoints via environment variables or use defaults + default_endpoints = os.getenv("VLLM_ENDPOINTS", "").split(",") if os.getenv("VLLM_ENDPOINTS") else [] + + if default_endpoints: + return default_endpoints + + # Default endpoints for different data sources (replace with your own) + endpoints = { + "steam": os.getenv("VLLM_ENDPOINTS_STEAM", "http://localhost:8020/v1").split(","), + "movies": os.getenv("VLLM_ENDPOINTS_MOVIES", "http://localhost:8020/v1").split(","), + "toys": os.getenv("VLLM_ENDPOINTS_TOYS", "http://localhost:8020/v1").split(","), + } + return endpoints.get(source, endpoints["steam"]) # 如果来源未知,默认使用steam的端点 + +def try_vllm_chat(prompt: str, source: str, max_tokens: int = 1, temperature: float = 0.0) -> str: + for url in _get_vllm_endpoints(source): + try: + client = openai.OpenAI(base_url=url, api_key="not-needed") + response = client.chat.completions.create( + model="llama3-8b", + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, + temperature=temperature, + top_p=1.0, + ) + return response.choices[0].message.content.strip() + except Exception as e: + logger.warning(f"[vLLM Chat] 请求失败于 {url}: {e}") + logger.error("所有vLLM聊天端点均请求失败。") + time.sleep(50) + return None + +def try_vllm_completion(prompt: str, source: str) -> object: + for url in _get_vllm_endpoints(source): + try: + client = openai.OpenAI(base_url=url, api_key="not-needed") + response = client.completions.create( + model="llama3-8b", + prompt=prompt, + max_tokens=0, + logprobs=True, + echo=True, + ) + return response + except Exception as e: + logger.warning(f"[vLLM Completion] 请求失败于 {url}: {e}") + logger.error("所有vLLM补全端点均请求失败。") + return None + +@lru_cache(maxsize=10000) +def cached_generate_embedding(text: str) -> tuple: + response = EMBEDDING_CLIENT.embeddings.create(input=text, model='bge-m3') + return tuple(response.data[0].embedding) + +def _get_request_embedding_url(source: str) -> str: + # Configure embedding service URLs via environment variables + base_url = os.getenv("EMBEDDING_SERVICE_BASE_URL", "http://localhost:5000") + port_map = { + "steam": os.getenv("EMBEDDING_PORT_STEAM", "5003"), + "movies": os.getenv("EMBEDDING_PORT_MOVIES", "5004"), + "toys": os.getenv("EMBEDDING_PORT_TOYS", "5005") + } + port = port_map.get(source, "5000") + return f"{base_url.replace(':5000', '')}:{port}/embedding" + +@lru_cache(maxsize=10000) +def cached_request_embedding(text: str, source: str) -> tuple: + url = _get_request_embedding_url(source) + if not url: + logger.error(f"未知的embedding请求来源: {source}") + # 在出错时返回一个零向量,维度与预期一致 + return tuple([0.0] * 1024) + response = requests.post(url, json={"text": text}) + response.raise_for_status() # 如果请求失败 (如 4xx or 5xx), 抛出异常 + return tuple(response.json()["embedding"]) + + +# --- 辅助函数 --- +def spearman_rank_correlation(top10_positions, top10_items): + original_ranks = list(range(1, len(top10_items) + 1)) + correlation, _ = spearmanr(original_ranks, top10_positions) + # 将Spearman相关系数从[-1, 1]范围映射到[0, 1]范围 + return (correlation + 1) / 2 + + +# ========= 1. ConditionalPPL (并行化处理) ========= +class ConditionalPPL(ORM): + def _process_item(self, args): + content, sol = args + try: + prompt = sol["recommend_prompt"] + full_text = prompt + content + + input_tokens = TOKENIZER(prompt, add_special_tokens=False) + output_tokens = TOKENIZER(content, add_special_tokens=False) + split_index = len(input_tokens["input_ids"]) + output_len = len(output_tokens["input_ids"]) + + response = try_vllm_completion(full_text, sol["source"]) + if response is None: return 0.0 + + logprobs = response.choices[0].logprobs.token_logprobs[split_index : split_index + output_len] + + if not logprobs or len(logprobs) != output_len: + logger.warning("[ConditionalPPL] Logprob长度不匹配。") + return 0.0 + + cross_entropy = -np.mean(logprobs) + ppl = np.exp(cross_entropy) + score = np.exp(-0.02 * ppl) + return score + except Exception as e: + logger.warning(f"[ConditionalPPL] 处理失败: {e}") + return 0.0 + + def __call__(self, completions, task, solution, **kwargs) -> List[float]: + # 筛选出需要处理的任务 + single_tasks_args = [(c, s) for c, t, s in zip(completions, task, solution) if t == "single"] + + # 使用线程池并行执行网络请求 + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + results_iter = executor.map(self._process_item, single_tasks_args) + + results = list(results_iter) + + # 按原始顺序重组结果 + final_rewards = [] + result_idx = 0 + for t in task: + if t == "single": + final_rewards.append(results[result_idx]) + result_idx += 1 + else: + final_rewards.append(None) + return final_rewards + +# ========= 2. LengthReward (该函数本身很快,做少量代码清理) ========= +class LengthReward(ORM): + def __call__(self, completions, task, solution, **kwargs) -> List[float]: + rewards = [] + for content, sol, t in zip(completions, solution, task): + if t == "single": + try: + input_len = len(TOKENIZER(sol["recommend_item"], add_special_tokens=False)["input_ids"]) + output_len = len(TOKENIZER(content, add_special_tokens=False)["input_ids"]) + if input_len == 0: + rewards.append(0.0) + continue + ratio = output_len / input_len + rewards.append(1.0 / (1.0 + ratio ** 2)) + except Exception as e: + logger.warning(f"[LengthReward] 计算失败: {e}") + rewards.append(0.0) + else: + rewards.append(None) + return rewards + +# ========= 3. DiscriminativeReward (并行化处理) ========= +class DiscriminativeReward(ORM): + def _process_item(self, args): + content, sol = args + try: + target_item = sol["recommend_item"] + all_items = list(set(sol["top3"] + [target_item])) + random.shuffle(all_items) + options = all_items[:4] + + option_labels = ["A", "B", "C", "D"] + labeled_options = {label: item for label, item in zip(option_labels, options)} + + prompt_parts = [ + "I have a rewritten result that is derived from one of the following four options. " + "Please tell me which option it corresponds to by answering with A, B, C, or D only." + ] + for label, item in labeled_options.items(): + prompt_parts.append(f"{label}. {item}") + prompt_parts.append(f"\nThe rewritten result is:\n{content}\nPlease respond with the correct option letter.") + prompt = "\n".join(prompt_parts) + + answer = try_vllm_chat(prompt, sol["source"], max_tokens=1) + if answer and labeled_options.get(answer.strip().upper()) == target_item: + return 1.0 + return 0.0 + except Exception as e: + logger.warning(f"[DiscriminativeReward] 处理失败: {e}") + return 0.0 + + def __call__(self, completions, task, solution, **kwargs) -> List[float]: + single_tasks_args = [(c, s) for c, t, s in zip(completions, task, solution) if t == "single"] + + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + results_iter = executor.map(self._process_item, single_tasks_args) + + results = list(results_iter) + + final_rewards = [] + result_idx = 0 + for t in task: + if t == "single": + final_rewards.append(results[result_idx]) + result_idx += 1 + else: + final_rewards.append(None) + return final_rewards + +# ========= 4. item2item (使用批量Faiss搜索进行优化) ========= +class Item2ItemReward(ORM): + def __call__(self, completions, task, solution, **kwargs) -> List[float]: + # 按数据源对任务进行分组,以便复用Faiss索引 + grouped_tasks = defaultdict(list) + for i, (c, t, s) in enumerate(zip(completions, task, solution)): + if t == "single": + grouped_tasks[s['source']].append({'original_idx': i, 'completion': c, 'solution': s}) + + results = {} + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + for source, tasks in grouped_tasks.items(): + # Path to item embedding files (configure via environment variable) + embedding_base_dir = os.getenv("EMBEDDING_DATA_DIR", "./data/embeddings") + embedding_file = f"{embedding_base_dir}/{source}_all_item_embedding.jsonl" + faiss_index, all_item_names = resource_manager.get_faiss_resources(source, embedding_file) + + # 步骤 1: 并行生成所有查询的embeddings + def get_embedding(task_item): + new_item_info = (f"The item's title : {task_item['completion']}\n" + f"The item's description: {task_item['solution']['title_desc']}\n" + f"The item's category: {task_item['solution']['title_category']}") + return cached_generate_embedding(new_item_info) + + embedding_futures = [executor.submit(get_embedding, task_item) for task_item in tasks] + content_embeddings = np.array([future.result() for future in embedding_futures], dtype=np.float32) + faiss.normalize_L2(content_embeddings) + + # 步骤 2: 执行一次性的批量搜索 + _, all_indices = faiss_index.search(content_embeddings, len(all_item_names)) + + # 步骤 3: 为批次中的每个项目计算分数 + for i, task_item in enumerate(tasks): + # 获取当前查询的排序结果 + sorted_items = [all_item_names[idx] for idx in all_indices[i]] + + # 从排序结果中移除查询项目本身,以获得更公平的排名 + original_title = task_item['solution']['recommend_item'] + if original_title in sorted_items: + sorted_items.remove(original_title) + + top10_items = task_item['solution']['similarity_top10'] + try: + top10_positions = [sorted_items.index(item) + 1 for item in top10_items] + score = spearman_rank_correlation(top10_positions, top10_items) + except ValueError: # 如果某个top10 item在新排名中未找到 + score = 0.0 + results[task_item['original_idx']] = score + + # 按原始顺序重组最终的奖励列表 + final_rewards = [results.get(i) for i in range(len(completions))] + return final_rewards + +# ========= 5. User2Item (使用批量Faiss搜索进行优化) ========= +class User2ItemReward(ORM): + def __call__(self, completions, task, solution, **kwargs) -> List[float]: + # 按数据源分组 + grouped_tasks = defaultdict(list) + for i, (c, t, s) in enumerate(zip(completions, task, solution)): + if t == "group": + grouped_tasks[s['source']].append({'original_idx': i, 'completion': c, 'solution': s}) + + results = {} + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + for source, tasks in grouped_tasks.items(): + # Path to user embedding files (configure via environment variable) + embedding_base_dir = os.getenv("EMBEDDING_DATA_DIR", "./data/embeddings") + embedding_file = f"{embedding_base_dir}/{source}_all_item_embedding_t-desc.jsonl" + faiss_index, item_names = resource_manager.get_faiss_resources(source, embedding_file) + + # 步骤 1: 并行生成所有embeddings + def get_embedding(task_item): + prompt = (f"You need to generate a recommendation list considering user's preference from " + f"historical interactions. The historical interactions are provided as follows: " + f"{task_item['completion']}. Please generate a recommendation list with 1 different " + f"items. Each item should be enclosed by and . should be generated " + f"before item title, and should be generated after item title.") + return cached_request_embedding(prompt, source) + + embedding_futures = [executor.submit(get_embedding, task_item) for task_item in tasks] + prompt_embeddings = np.array([future.result() for future in embedding_futures], dtype=np.float32) + faiss.normalize_L2(prompt_embeddings) + + # 步骤 2: 执行一次性的批量搜索 + _, all_indices = faiss_index.search(prompt_embeddings, len(item_names)) + + # 步骤 3: 计算分数 + for i, task_item in enumerate(tasks): + sorted_items = [item_names[idx] for idx in all_indices[i]] + item_title = task_item['solution']['target_item'] + try: + item_rank = sorted_items.index(item_title) + 1 + score = math.exp(-(item_rank - 1) / 2000) + except ValueError: # 目标项目未在排名中找到 + score = 0.0 + results[task_item['original_idx']] = score + + # 按原始顺序重组结果 + final_rewards = [results.get(i) for i in range(len(completions))] + return final_rewards + +# ========= 注册所有奖励函数 ========= +orms['average_ppl'] = ConditionalPPL +orms['length'] = LengthReward +orms['discriminative'] = DiscriminativeReward +orms['item2item'] = Item2ItemReward +orms['user2item'] = User2ItemReward \ No newline at end of file diff --git a/RecLM-cgen/preprocess/data_preprocess_amazon.py b/RecLM-uni/preprocess/data_preprocess_amazon.py similarity index 100% rename from RecLM-cgen/preprocess/data_preprocess_amazon.py rename to RecLM-uni/preprocess/data_preprocess_amazon.py diff --git a/RecLM-cgen/preprocess/transform2unirec.py b/RecLM-uni/preprocess/transform2unirec.py similarity index 100% rename from RecLM-cgen/preprocess/transform2unirec.py rename to RecLM-uni/preprocess/transform2unirec.py diff --git a/RecLM-cgen/requirements.txt b/RecLM-uni/requirements.txt similarity index 100% rename from RecLM-cgen/requirements.txt rename to RecLM-uni/requirements.txt diff --git a/RecLM-cgen/scripts/data_preprocess_amazon.sh b/RecLM-uni/scripts/data_preprocess_amazon.sh similarity index 100% rename from RecLM-cgen/scripts/data_preprocess_amazon.sh rename to RecLM-uni/scripts/data_preprocess_amazon.sh diff --git a/RecLM-cgen/scripts/run_SFT_merge.sh b/RecLM-uni/scripts/run_SFT_merge.sh similarity index 100% rename from RecLM-cgen/scripts/run_SFT_merge.sh rename to RecLM-uni/scripts/run_SFT_merge.sh diff --git a/RecLM-cgen/scripts/train_RecLM_cgen.sh b/RecLM-uni/scripts/train_RecLM_cgen.sh similarity index 100% rename from RecLM-cgen/scripts/train_RecLM_cgen.sh rename to RecLM-uni/scripts/train_RecLM_cgen.sh diff --git a/RecLM-cgen/scripts/train_RecLM_ret.sh b/RecLM-uni/scripts/train_RecLM_ret.sh similarity index 100% rename from RecLM-cgen/scripts/train_RecLM_ret.sh rename to RecLM-uni/scripts/train_RecLM_ret.sh diff --git a/RecLM-cgen/scripts/unirec_serve.sh b/RecLM-uni/scripts/unirec_serve.sh similarity index 100% rename from RecLM-cgen/scripts/unirec_serve.sh rename to RecLM-uni/scripts/unirec_serve.sh diff --git a/RecLM-cgen/scripts/unirec_train.sh b/RecLM-uni/scripts/unirec_train.sh similarity index 100% rename from RecLM-cgen/scripts/unirec_train.sh rename to RecLM-uni/scripts/unirec_train.sh diff --git a/RecLM-cgen/task_MR_test.py b/RecLM-uni/task_MR_test.py similarity index 100% rename from RecLM-cgen/task_MR_test.py rename to RecLM-uni/task_MR_test.py diff --git a/RecLM-cgen/task_test.py b/RecLM-uni/task_test.py similarity index 100% rename from RecLM-cgen/task_test.py rename to RecLM-uni/task_test.py diff --git a/RecLM-uni/task_test_tokenizer.py b/RecLM-uni/task_test_tokenizer.py new file mode 100644 index 0000000..503c9d0 --- /dev/null +++ b/RecLM-uni/task_test_tokenizer.py @@ -0,0 +1,193 @@ +import argparse +import copy +import json +import os +from concurrent.futures import ProcessPoolExecutor + +import torch +from Levenshtein import distance +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM + +from train_utils.dataset import Test_task_group_mapping, SFTDataset +from train_utils.processor import FastPrefixConstrainedLogitsProcessor +from train_utils.metrics import Metrics +from train_utils.utils import ( + save_json, + get_ctrl_item, + rm_idx, + load_json, + load_pickle, + side_tokenizer, + process_train_sample, + load_item_code_mapping, + ITEM_CODE_FIELD, +) + + +@torch.no_grad() +def process_dataset_hf(data_list): + if len(data_list) == 0: + return + + eot_token = "<|eot_id|>" + eot_token_id = tokenizer.convert_tokens_to_ids(eot_token) + num_beams = 1 + logits_processors = [ + FastPrefixConstrainedLogitsProcessor(test_data.item_prefix_tree.constrain_search_list, num_beams) + ] if args.use_CBS else None + model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.bfloat16, device_map=args.gpu).eval() + bs = args.batch_size + for i in tqdm(range(0, len(data_list), bs)): + input_texts = [ + d['input_text'] if 'input_text' in d else process_train_sample(d['input_texts'], d['output_texts'], tokenizer)[3] + for d in data_list[i: i + bs] + ] + input_data = side_tokenizer(input_texts, 'left', tokenizer, padding=True, truncation=True, + max_length=args.max_token_length, return_tensors='pt').to(device=args.gpu).data + input_ids_length = input_data['input_ids'].shape[1] + output_ids = model.generate( + **input_data, + logits_processor=logits_processors if args.use_CBS else None, + max_length=args.max_token_length + args.gen_max_length, + num_beams=num_beams, + num_return_sequences=1, + eos_token_id=eot_token_id, + pad_token_id=tokenizer.pad_token_id + ) + output_texts = tokenizer.batch_decode(output_ids[:, input_ids_length:], + skip_special_tokens=False if args.use_control_symbol else True) + for d, o in zip(data_list[i: i + bs], output_texts): + d[f'{args.model_name}_output'] = o + + if i == 0: + print(output_texts[0]) + + +if __name__ == "__main__": + def vague_mapping(ts): + for idx, __ in enumerate(ts): + if __ in test_data.title2item: + continue + for ___ in test_data.title2item: + if distance(__, ___) <= 3: + ts[idx] = ___ + break + + def process_api_output(d): + if f'{args.model_name}_output' not in d: + return d + if d[f'{args.model_name}_output'] == "": + d[f'{args.SFT_test_task}_output_title_list'] = [] + return d + if f'{args.SFT_test_task}_output_title_list' in d: + return d + + raw_output = d[f'{args.model_name}_output'] + if args.use_control_symbol: + ts = get_ctrl_item(raw_output) + else: + ts = [_.strip() for _ in raw_output.strip().split('\n')] + ts = [rm_idx(_) if args.idx else _ for _ in ts] + + vague_mapping(ts) + d[f'{args.SFT_test_task}_output_title_list'] = ts + + return d + + parser = argparse.ArgumentParser() + parser.add_argument("--data_path", type=str, default='data/dataset/sub_movie/', help="processed_data path") + parser.add_argument('--SFT_test_task', type=str, default='', help='in {SFTTestSeqRec, SFTTestRanking, SFT+TestPersonalControlRec, SFT-TestPersonalControlRec, SFTTestPersonalCategoryRate_xx%, SFTTestItemCount}') + parser.add_argument("--num_process", type=int, default=128) + parser.add_argument("--model_name", type=str, default='Llama-2-7b-hf-chat', help="openai model") + parser.add_argument("--try_num", type=int, default=2, help="The number of attempts to call the API") + parser.add_argument("--max_item_length", type=int, default=10) + parser.add_argument("--max_token_length", type=int, default=512, help="The max length of input text to gpt") + parser.add_argument("--gen_max_length", type=int, default=1024) + parser.add_argument("--candidate_num", type=int, default=10) + parser.add_argument("--topk", type=int, default=10) + parser.add_argument("--item_index", type=str, default='title_t') + parser.add_argument("--backup_ip", type=str, default='0.0.0.0') + parser.add_argument("--idx", action='store_true') + parser.add_argument("--use_control_symbol", action='store_true') + parser.add_argument("--use_CBS", action='store_true') + parser.add_argument("--CBS_type", type=int, default=2) + parser.add_argument("--port", type=int, default=13579) + parser.add_argument("--gpu", type=str, default='cuda:0') + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--ids", type=int, default=0) + args = parser.parse_args() + args.is_main_process = True + print(json.dumps(args.__dict__, ensure_ascii=False, indent=2)) + data = { + 'category': load_json(args.data_path + 'category.jsonl'), + 'metas': load_json(args.data_path + 'metas.jsonl'), + 'sequential': load_json(args.data_path + 'sequential.jsonl'), + 'share_chat_gpt': None, + } + code_info = load_item_code_mapping(args.data_path) + args.item_code_tokens = [] + if code_info: + args.item_code_tokens = code_info.get('token_vocab', []) + for item_id, token_seq in code_info.get('item_seq', {}).items(): + if item_id in data['metas']: + data['metas'][item_id][ITEM_CODE_FIELD] = token_seq + if args.item_code_tokens: + print(f"Loaded item code mapping for {len(code_info.get('item_seq', {}))} items " + f"({len(args.item_code_tokens)} unique tokens, {code_info.get('missing', 0)} missing).") + TestTaskTemplate = {args.SFT_test_task: Test_task_group_mapping[args.SFT_test_task]} + TestTaskNum = {args.SFT_test_task: 1} + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + tokenizer.pad_token = '<|reserved_special_token_250|>' + tokenizer.pad_token_id = 128255 + tokenizer.soi_token = "" + tokenizer.eoi_token = "" + tokenizer.soi_token_id = tokenizer.convert_tokens_to_ids(tokenizer.soi_token) + tokenizer.eoi_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eoi_token) + tokenizer.eos_token = "<|eot_id|>" + tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") + if getattr(args, 'item_code_tokens', None): + added = tokenizer.add_special_tokens({'additional_special_tokens': args.item_code_tokens}) + if added > 0: + print(f"Added {added} item code tokens to tokenizer for testing.") + test_data = SFTDataset(args, TestTaskTemplate, TestTaskNum, data, tokenizer, 'test') + dataset = args.data_path.strip('/').split('/')[-1] + + result_file = os.path.join(args.model_name, f'{dataset}_{args.SFT_test_task}_Top10{f"_CBS{args.CBS_type}" if args.use_CBS else ""}_test_Result_{args.ids}.jsonl') + print(f"load file from {result_file}") + test_data_list = load_json(result_file) + _test_data_list = [_ for _ in test_data] + if test_data_list and len(test_data_list) == len(_test_data_list): + for _, __ in zip(test_data_list, _test_data_list): + _.update(__) + else: + test_data_list = _test_data_list + + remain_test_data_list = [_ for _ in test_data_list if f'{args.model_name}_output' not in _][:] + print(f"Loading Test Task: '{args.SFT_test_task}'. Remain Example Count: {len(remain_test_data_list)}") + print(test_data_list[1]['input_texts'] if 'input_texts' in test_data_list[1] else test_data_list[1]['input_text']) + + process_dataset_hf(remain_test_data_list) + + if len(remain_test_data_list) > 0: + save_json(test_data_list, result_file) + + with ProcessPoolExecutor(max_workers=args.num_process) as executor: + result = list(tqdm(executor.map(process_api_output, test_data_list), total=len(test_data_list))) + test_data_list = result + + metrics_dict = Metrics([args.SFT_test_task], args.topk, test_data.category2item, test_data.title2item) + for step_i, example in tqdm(enumerate(test_data_list[:])): + if f'{args.SFT_test_task}_output_title_list' not in example or len(example[f'{args.SFT_test_task}_output_title_list']) == 0: + continue + if args.use_control_symbol: + output_label = [example['output_texts'][-1]] + else: + output_label = [_.strip() for _ in example['output_texts'][-1].strip().split('\n')] + output_label = [rm_idx(_) if args.idx else _ for _ in output_label] + metrics_dict.add_sample(example['task'], example['input_field_data'], example[f'{args.SFT_test_task}_output_title_list'], output_label, vague_mapping=False) + + metrics_dict.print() + + if len(remain_test_data_list) > 0: + save_json(test_data_list, result_file) diff --git a/RecLM-cgen/train_utils/__init__.py b/RecLM-uni/train_utils/__init__.py similarity index 100% rename from RecLM-cgen/train_utils/__init__.py rename to RecLM-uni/train_utils/__init__.py diff --git a/RecLM-cgen/train_utils/dataset.py b/RecLM-uni/train_utils/dataset.py similarity index 99% rename from RecLM-cgen/train_utils/dataset.py rename to RecLM-uni/train_utils/dataset.py index f04dffc..e66dce6 100644 --- a/RecLM-cgen/train_utils/dataset.py +++ b/RecLM-uni/train_utils/dataset.py @@ -8,7 +8,7 @@ from .processor import Trie_link from .template import * from .utils import get_item_list, get_output_text, get_history_text, \ - process_train_sample, process_train_sample_llama2, save_json, load_json, pad_sequence + process_train_sample, process_train_sample_llama2, save_json, load_json, pad_sequence,ITEM_CODE_FIELD class SFTDataset(Dataset): @@ -149,6 +149,9 @@ def __len__(self): return len(self.datum_info) def get_item_index(self, item): + token_seq = self.metas[item].get(ITEM_CODE_FIELD) + if token_seq: + return token_seq return self.metas[item][self.args.item_index] def get_sub_sequential(self, user): diff --git a/RecLM-cgen/train_utils/loss.py b/RecLM-uni/train_utils/loss.py similarity index 100% rename from RecLM-cgen/train_utils/loss.py rename to RecLM-uni/train_utils/loss.py diff --git a/RecLM-cgen/train_utils/metrics.py b/RecLM-uni/train_utils/metrics.py similarity index 100% rename from RecLM-cgen/train_utils/metrics.py rename to RecLM-uni/train_utils/metrics.py diff --git a/RecLM-cgen/train_utils/model.py b/RecLM-uni/train_utils/model.py similarity index 97% rename from RecLM-cgen/train_utils/model.py rename to RecLM-uni/train_utils/model.py index 3fc2111..c66f79a 100644 --- a/RecLM-cgen/train_utils/model.py +++ b/RecLM-uni/train_utils/model.py @@ -34,6 +34,7 @@ def __init__(self, args, device, actor_lora_scope='actor', item_emb=None): self.model_config = self.create_model_config() self.tokenizer = self.create_tokenizer() self.model = self.create_model(device) + self.register_item_tokens() if args.use_control_symbol: self.resize_init_embedding(self.model, self.tokenizer) @@ -152,6 +153,15 @@ def resize_init_embedding(self, model, tokenizer): new_embedding = model.get_input_embeddings().weight[tokenized_ids].mean(axis=0) model.get_input_embeddings().weight[token_id, :] = new_embedding.clone().detach() + def register_item_tokens(self): + new_tokens = getattr(self.args, 'item_code_tokens', None) + if not new_tokens: + return + added = self.tokenizer.add_special_tokens({'additional_special_tokens': new_tokens}) + if added > 0: + self.model.resize_token_embeddings(len(self.tokenizer)) + self.model_config.vocab_size = len(self.tokenizer) + def create_model_config(self): config_class = AutoConfig config = config_class.from_pretrained(self.args.backbone) diff --git a/RecLM-cgen/train_utils/param.py b/RecLM-uni/train_utils/param.py similarity index 100% rename from RecLM-cgen/train_utils/param.py rename to RecLM-uni/train_utils/param.py diff --git a/RecLM-cgen/train_utils/processor.py b/RecLM-uni/train_utils/processor.py similarity index 100% rename from RecLM-cgen/train_utils/processor.py rename to RecLM-uni/train_utils/processor.py diff --git a/RecLM-cgen/train_utils/template.py b/RecLM-uni/train_utils/template.py similarity index 100% rename from RecLM-cgen/train_utils/template.py rename to RecLM-uni/train_utils/template.py diff --git a/RecLM-cgen/train_utils/utils.py b/RecLM-uni/train_utils/utils.py similarity index 87% rename from RecLM-cgen/train_utils/utils.py rename to RecLM-uni/train_utils/utils.py index a175f43..532312e 100644 --- a/RecLM-cgen/train_utils/utils.py +++ b/RecLM-uni/train_utils/utils.py @@ -7,6 +7,7 @@ import requests import torch +ITEM_CODE_FIELD = 'rq_token_seq' def pad_sequence(seq: list[list], pad_token_id, device, pad_side='right'): max_len = max([len(s) for s in seq]) @@ -228,3 +229,43 @@ def gsm8K_is_correct(completion, answer): def gsm8K_clean_answer(text): text = text.split("Question:")[0] return text + + +def load_item_code_mapping(data_path): + if data_path is None: + return None + dataset_root = os.path.normpath(data_path) + dataset_name = os.path.basename(dataset_root) + index_path = os.path.join(dataset_root, f"{dataset_name}.index.json") + mapping_path = os.path.join(dataset_root, f"{dataset_name}.item2id") + if not (os.path.exists(index_path) and os.path.exists(mapping_path)): + return None + + index_data = load_json(index_path) + if not index_data: + return None + + item2idx = {} + with open(mapping_path, 'r') as f: + for line in f: + parts = line.strip().split() + if len(parts) < 2: + continue + item2idx[parts[0]] = parts[1] + + token_vocab = set() + item_seq = {} + missing_indices = 0 + for item_id, idx in item2idx.items(): + tokens = index_data.get(str(idx)) + if not tokens: + missing_indices += 1 + continue + item_seq[item_id] = ''.join(tokens) + token_vocab.update(tokens) + + return { + 'item_seq': item_seq, + 'token_vocab': sorted(token_vocab), + 'missing': missing_indices, + } \ No newline at end of file diff --git a/RecLM-cgen/trainer.py b/RecLM-uni/trainer.py similarity index 98% rename from RecLM-cgen/trainer.py rename to RecLM-uni/trainer.py index 3d0eb83..e102844 100644 --- a/RecLM-cgen/trainer.py +++ b/RecLM-uni/trainer.py @@ -45,6 +45,20 @@ def __init__(self, args): 'sequential': load_json(args.data_path + 'sequential.jsonl'), 'share_chat_gpt': load_pickle('data/share_chat_gpt2.pickle'), } + + code_info = load_item_code_mapping(self.args.data_path) + self.args.item_code_tokens = [] + if code_info: + self.args.item_code_tokens = code_info.get('token_vocab', []) + for item_id, token_seq in code_info.get('item_seq', {}).items(): + if item_id in self.data['metas']: + self.data['metas'][item_id][ITEM_CODE_FIELD] = token_seq + if self.args.is_main_process: + covered_items = sum(1 for _ in self.data['metas'].values() if ITEM_CODE_FIELD in _) + print(f"Loaded item code mapping for {covered_items} items " + f"({len(self.args.item_code_tokens)} unique tokens, " + f"{code_info.get('missing', 0)} missing indices).") + self.item_emb = self.create_embeddings() if self.args.train_stage in ['SFT_Embedding', 'SFT_Embedding_Test'] else None self.actor = BaseModel(args=self.args, device=self.args.gpu, item_emb=self.item_emb) diff --git a/RecLM-cgen/unirec/asyc_server.py b/RecLM-uni/unirec/asyc_server.py similarity index 100% rename from RecLM-cgen/unirec/asyc_server.py rename to RecLM-uni/unirec/asyc_server.py diff --git a/RecLM-cgen/unirec/config/base.yaml b/RecLM-uni/unirec/config/base.yaml similarity index 100% rename from RecLM-cgen/unirec/config/base.yaml rename to RecLM-uni/unirec/config/base.yaml diff --git a/RecLM-cgen/unirec/config/model/SASRec.yaml b/RecLM-uni/unirec/config/model/SASRec.yaml similarity index 100% rename from RecLM-cgen/unirec/config/model/SASRec.yaml rename to RecLM-uni/unirec/config/model/SASRec.yaml From d078e57bdfc8665f82c17765942f193648bc4162 Mon Sep 17 00:00:00 2001 From: SZU-ZJW Date: Mon, 26 Jan 2026 21:40:05 +0800 Subject: [PATCH 2/2] Fix some bugs --- RecLM-uni/GRPO/rl_dataset.py | 6 +- RecLM-uni/README.md | 4 +- RecLM-uni/{ => TitleRewrite}/PLUGIN_CONFIG.md | 0 RecLM-uni/{ => TitleRewrite}/plugin.py | 150 +++++++++--------- RecLM-uni/cli_serve.py | 5 +- RecLM-uni/grpo_train.py | 2 +- RecLM-uni/index/datasets.py | 1 - RecLM-uni/index/generate_indices.py | 15 +- RecLM-uni/index/main.py | 9 +- RecLM-uni/index/models/rq.py | 5 - RecLM-uni/index/models/rqvae.py | 6 - RecLM-uni/index/models/vq.py | 6 - RecLM-uni/index/trainer.py | 8 +- RecLM-uni/index/utils.py | 7 +- .../preprocess/data_preprocess_amazon.py | 9 -- RecLM-uni/scripts/data_preprocess_amazon.sh | 2 - RecLM-uni/scripts/run_SFT_merge.sh | 1 - RecLM-uni/scripts/train_RecLM_cgen.sh | 1 - RecLM-uni/scripts/train_RecLM_ret.sh | 1 - RecLM-uni/task_MR_test.py | 3 +- RecLM-uni/task_test.py | 7 +- RecLM-uni/task_test_tokenizer.py | 6 +- RecLM-uni/train_utils/metrics.py | 1 - RecLM-uni/train_utils/param.py | 1 - RecLM-uni/train_utils/processor.py | 4 +- RecLM-uni/train_utils/utils.py | 2 +- RecLM-uni/trainer.py | 6 +- 27 files changed, 108 insertions(+), 160 deletions(-) rename RecLM-uni/{ => TitleRewrite}/PLUGIN_CONFIG.md (100%) rename RecLM-uni/{ => TitleRewrite}/plugin.py (81%) diff --git a/RecLM-uni/GRPO/rl_dataset.py b/RecLM-uni/GRPO/rl_dataset.py index 279a7ac..66a6dbb 100644 --- a/RecLM-uni/GRPO/rl_dataset.py +++ b/RecLM-uni/GRPO/rl_dataset.py @@ -2,8 +2,8 @@ from dataclasses import dataclass from typing import List, Dict, Tuple -from SFT.SFT_templates import SeqRec_MR_group -from utils import load_json, get_history_text, get_output_text +from train_utils.template import SeqRec_MR_group +from train_utils.utils import load_json, get_history_text, get_output_text SYSTEM_PROMPT = "You are an expert recommender engine as well as a helpful, respectful and honest assistant." @@ -64,7 +64,7 @@ def build_rl_samples( valid_items = [iid for iid, meta in metas.items() if meta.get(item_index_field)] samples: List[RLSample] = [] - for user, seq in sequential.items(): + for _, seq in sequential.items(): if len(samples) >= max_samples: break history, target_item = _sample_history(seq, max_item_length, rng) diff --git a/RecLM-uni/README.md b/RecLM-uni/README.md index d578438..95e1923 100644 --- a/RecLM-uni/README.md +++ b/RecLM-uni/README.md @@ -1,7 +1,7 @@ # RecLM-uni ## Introduction -This project introduces methods for avoid recommending out-of-domain items in LLM-based recsys. It contains the code for implementing three methods, i.e., RecLM-cgen, RecLM-ret and RecLM-token. +This project introduces methods for avoiding recommending out-of-domain items in LLM-based recsys. It contains the code for implementing three methods, i.e., RecLM-cgen, RecLM-ret and RecLM-token. **RecLM-uni** is a generative recommendation framework in the native structure of LLMs. This framework divides the output space of LLMs into item generation and general text generation parts by introducing item control tokens, and simultaneously employs a decoding strategy with prefix tree constraints to prevent the generation of out-of-domain items. RecLM-uni enables LLMs to acquire the ability to recommend products without sacrificing their original general capabilities. @@ -166,7 +166,7 @@ The embeddings should be saved as a numpy array with shape `(num_items, hidden_d ### 4.2. Step 2: Train RQ-VAE Model -Train the Residual Quantized Variational AutoEncoder(RQ-VAE) to learn item codebook mappings using the `index/` module from RecLM-LC1 or RecLM-cgen: +Train the Residual Quantized Variational AutoEncoder(RQ-VAE) to learn item codebook mappings using the `index/` module from RecLM-uni or RecLM-cgen: ```bash cd index diff --git a/RecLM-uni/PLUGIN_CONFIG.md b/RecLM-uni/TitleRewrite/PLUGIN_CONFIG.md similarity index 100% rename from RecLM-uni/PLUGIN_CONFIG.md rename to RecLM-uni/TitleRewrite/PLUGIN_CONFIG.md diff --git a/RecLM-uni/plugin.py b/RecLM-uni/TitleRewrite/plugin.py similarity index 81% rename from RecLM-uni/plugin.py rename to RecLM-uni/TitleRewrite/plugin.py index 54de013..dbaf1e5 100644 --- a/RecLM-uni/plugin.py +++ b/RecLM-uni/TitleRewrite/plugin.py @@ -1,29 +1,24 @@ -from math import exp -import numpy as np +import os import random -from copy import deepcopy -from typing import List, Dict, Any -from transformers import AutoTokenizer -from swift.plugin import ORM, orms -from swift.utils import get_logger import openai -import time import json import math import requests -from scipy.stats import spearmanr -from sklearn.metrics.pairwise import cosine_similarity import faiss + +import numpy as np +from typing import List, Dict, Any +from transformers import AutoTokenizer +from swift.plugin import ORM, orms +from swift.utils import get_logger +from scipy.stats import spearmanr + from functools import lru_cache from concurrent.futures import ThreadPoolExecutor from collections import defaultdict logger = get_logger() -# --- 全局配置与客户端 --- -# 建议将这些配置项在顶层统一定义 -import os - MODEL_PATH = os.getenv("MODEL_PATH", "meta-llama/Meta-Llama-3-8B-Instruct") TOKENIZER = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=True) @@ -32,7 +27,7 @@ EMBEDDING_API_KEY = os.getenv("EMBEDDING_API_KEY", "not-needed") EMBEDDING_CLIENT = openai.OpenAI(base_url=EMBEDDING_API_URL, api_key=EMBEDDING_API_KEY) -MAX_WORKERS = int(os.getenv("MAX_WORKERS", "8")) # 可根据您的系统和网络状况调整并发数 +MAX_WORKERS = int(os.getenv("MAX_WORKERS", "8")) # Adjust concurrency based on your system and network conditions _embedding_file_cache = {} @@ -48,11 +43,11 @@ def read_embedding_file_cached(embedding_path): _embedding_file_cache[embedding_path] = item2embedding return item2embedding -# --- 优化的核心:用于Faiss索引和Embeddings的集中式资源管理器 --- +# --- Optimization Core: Centralized resource manager for Faiss indexes and Embeddings --- class ResourceManager: """ - 处理加载、构建和缓存重量级资源(如Faiss索引和item embeddings), - 以避免在批处理中重复进行I/O和计算。 + Handles loading, building, and caching heavy resources (such as Faiss indexes and item embeddings) + to avoid repetitive I/O and computation during batch processing. """ def __init__(self): self._faiss_indexes: Dict[str, Any] = {} @@ -60,25 +55,25 @@ def __init__(self): self._item_embeddings: Dict[str, Dict[str, List[float]]] = {} def _load_and_build(self, source: str, embedding_file: str): - # 如果该数据源的索引已存在,则直接返回 + # If the index for this data source already exists, return directly if source in self._faiss_indexes: return - logger.info(f"为数据源构建Faiss索引: {source}...") + logger.info(f"Building Faiss index for data source: {source}...") item2embedding = read_embedding_file_cached(embedding_file) self._item_embeddings[source] = item2embedding - + item_names = list(item2embedding.keys()) - # 指定dtype为float32以兼容Faiss + # Specify dtype as float32 for Faiss compatibility embeddings_matrix = np.vstack([np.array(v, dtype=np.float32) for v in item2embedding.values()]) faiss.normalize_L2(embeddings_matrix) - + index = faiss.IndexFlatIP(embeddings_matrix.shape[1]) index.add(embeddings_matrix) - + self._item_names[source] = item_names self._faiss_indexes[source] = index - logger.info(f"为数据源 {source} 构建Faiss索引完成。") + logger.info(f"Faiss index construction completed for data source {source}.") def get_faiss_resources(self, source: str, embedding_file: str): self._load_and_build(source, embedding_file) @@ -88,13 +83,13 @@ def get_item_embeddings(self, source: str, embedding_file: str): self._load_and_build(source, embedding_file) return self._item_embeddings[source] -# 创建资源管理器的全局实例 +# Create global instance of resource manager resource_manager = ResourceManager() -# --- 重构后的网络请求工具函数 --- +# --- Refactored network request utility functions --- def _get_vllm_endpoints(source: str) -> List[str]: - # 将端点选择逻辑集中管理,便于维护 + # Centralize endpoint selection logic for easier maintenance # Configure vLLM endpoints via environment variables or use defaults default_endpoints = os.getenv("VLLM_ENDPOINTS", "").split(",") if os.getenv("VLLM_ENDPOINTS") else [] @@ -107,7 +102,7 @@ def _get_vllm_endpoints(source: str) -> List[str]: "movies": os.getenv("VLLM_ENDPOINTS_MOVIES", "http://localhost:8020/v1").split(","), "toys": os.getenv("VLLM_ENDPOINTS_TOYS", "http://localhost:8020/v1").split(","), } - return endpoints.get(source, endpoints["steam"]) # 如果来源未知,默认使用steam的端点 + return endpoints.get(source, endpoints["steam"]) # Default to steam endpoints if source is unknown def try_vllm_chat(prompt: str, source: str, max_tokens: int = 1, temperature: float = 0.0) -> str: for url in _get_vllm_endpoints(source): @@ -122,10 +117,10 @@ def try_vllm_chat(prompt: str, source: str, max_tokens: int = 1, temperature: fl ) return response.choices[0].message.content.strip() except Exception as e: - logger.warning(f"[vLLM Chat] 请求失败于 {url}: {e}") - logger.error("所有vLLM聊天端点均请求失败。") - time.sleep(50) - return None + logger.warning(f"[vLLM Chat] Request failed at {url}: {e}") + logger.error("All vLLM chat endpoints failed.") + raise RuntimeError(f"All vLLM chat endpoints failed for source '{source}'.") + def try_vllm_completion(prompt: str, source: str) -> object: for url in _get_vllm_endpoints(source): @@ -140,8 +135,8 @@ def try_vllm_completion(prompt: str, source: str) -> object: ) return response except Exception as e: - logger.warning(f"[vLLM Completion] 请求失败于 {url}: {e}") - logger.error("所有vLLM补全端点均请求失败。") + logger.warning(f"[vLLM Completion] Request failed at {url}: {e}") + logger.error("All vLLM completion endpoints failed.") return None @lru_cache(maxsize=10000) @@ -163,20 +158,19 @@ def _get_request_embedding_url(source: str) -> str: @lru_cache(maxsize=10000) def cached_request_embedding(text: str, source: str) -> tuple: url = _get_request_embedding_url(source) - if not url: - logger.error(f"未知的embedding请求来源: {source}") - # 在出错时返回一个零向量,维度与预期一致 - return tuple([0.0] * 1024) - response = requests.post(url, json={"text": text}) - response.raise_for_status() # 如果请求失败 (如 4xx or 5xx), 抛出异常 - return tuple(response.json()["embedding"]) - - -# --- 辅助函数 --- + try: + response = requests.post(url, json={"text": text}, timeout=30) + response.raise_for_status() # Raise exception if request fails (e.g., 4xx or 5xx) + return tuple(response.json()["embedding"]) + except requests.exceptions.RequestException as e: + logger.error(f"[Embedding] Request failed for source '{source}': {e}") + raise RuntimeError(f"Failed to get embedding for source '{source}': {e}") from e + +# --- Helper functions --- def spearman_rank_correlation(top10_positions, top10_items): original_ranks = list(range(1, len(top10_items) + 1)) correlation, _ = spearmanr(original_ranks, top10_positions) - # 将Spearman相关系数从[-1, 1]范围映射到[0, 1]范围 + # Map Spearman correlation coefficient from [-1, 1] range to [0, 1] range return (correlation + 1) / 2 @@ -199,7 +193,7 @@ def _process_item(self, args): logprobs = response.choices[0].logprobs.token_logprobs[split_index : split_index + output_len] if not logprobs or len(logprobs) != output_len: - logger.warning("[ConditionalPPL] Logprob长度不匹配。") + logger.warning("[ConditionalPPL] Logprob length mismatch.") return 0.0 cross_entropy = -np.mean(logprobs) @@ -207,20 +201,20 @@ def _process_item(self, args): score = np.exp(-0.02 * ppl) return score except Exception as e: - logger.warning(f"[ConditionalPPL] 处理失败: {e}") + logger.warning(f"[ConditionalPPL] Processing failed: {e}") return 0.0 def __call__(self, completions, task, solution, **kwargs) -> List[float]: - # 筛选出需要处理的任务 + # Filter tasks that need processing single_tasks_args = [(c, s) for c, t, s in zip(completions, task, solution) if t == "single"] - - # 使用线程池并行执行网络请求 + + # Use thread pool to execute network requests in parallel with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: results_iter = executor.map(self._process_item, single_tasks_args) - + results = list(results_iter) - - # 按原始顺序重组结果 + + # Reassemble results in original order final_rewards = [] result_idx = 0 for t in task: @@ -231,7 +225,7 @@ def __call__(self, completions, task, solution, **kwargs) -> List[float]: final_rewards.append(None) return final_rewards -# ========= 2. LengthReward (该函数本身很快,做少量代码清理) ========= +# ========= 2. LengthReward (This function is fast, minimal code cleanup) ========= class LengthReward(ORM): def __call__(self, completions, task, solution, **kwargs) -> List[float]: rewards = [] @@ -246,7 +240,7 @@ def __call__(self, completions, task, solution, **kwargs) -> List[float]: ratio = output_len / input_len rewards.append(1.0 / (1.0 + ratio ** 2)) except Exception as e: - logger.warning(f"[LengthReward] 计算失败: {e}") + logger.warning(f"[LengthReward] Calculation failed: {e}") rewards.append(0.0) else: rewards.append(None) @@ -279,7 +273,7 @@ def _process_item(self, args): return 1.0 return 0.0 except Exception as e: - logger.warning(f"[DiscriminativeReward] 处理失败: {e}") + logger.warning(f"[DiscriminativeReward] Processing failed: {e}") return 0.0 def __call__(self, completions, task, solution, **kwargs) -> List[float]: @@ -303,7 +297,7 @@ def __call__(self, completions, task, solution, **kwargs) -> List[float]: # ========= 4. item2item (使用批量Faiss搜索进行优化) ========= class Item2ItemReward(ORM): def __call__(self, completions, task, solution, **kwargs) -> List[float]: - # 按数据源对任务进行分组,以便复用Faiss索引 + # Group tasks by data source to reuse Faiss indexes grouped_tasks = defaultdict(list) for i, (c, t, s) in enumerate(zip(completions, task, solution)): if t == "single": @@ -316,8 +310,8 @@ def __call__(self, completions, task, solution, **kwargs) -> List[float]: embedding_base_dir = os.getenv("EMBEDDING_DATA_DIR", "./data/embeddings") embedding_file = f"{embedding_base_dir}/{source}_all_item_embedding.jsonl" faiss_index, all_item_names = resource_manager.get_faiss_resources(source, embedding_file) - - # 步骤 1: 并行生成所有查询的embeddings + + # Step 1: Parallel generation of all query embeddings def get_embedding(task_item): new_item_info = (f"The item's title : {task_item['completion']}\n" f"The item's description: {task_item['solution']['title_desc']}\n" @@ -328,35 +322,35 @@ def get_embedding(task_item): content_embeddings = np.array([future.result() for future in embedding_futures], dtype=np.float32) faiss.normalize_L2(content_embeddings) - # 步骤 2: 执行一次性的批量搜索 + # Step 2: Execute one-time batch search _, all_indices = faiss_index.search(content_embeddings, len(all_item_names)) - # 步骤 3: 为批次中的每个项目计算分数 + # Step 3: Calculate scores for each item in the batch for i, task_item in enumerate(tasks): - # 获取当前查询的排序结果 + # Get current query ranking results sorted_items = [all_item_names[idx] for idx in all_indices[i]] - - # 从排序结果中移除查询项目本身,以获得更公平的排名 + + # Remove the query item itself from the ranking results for a fairer ranking original_title = task_item['solution']['recommend_item'] if original_title in sorted_items: sorted_items.remove(original_title) - + top10_items = task_item['solution']['similarity_top10'] try: top10_positions = [sorted_items.index(item) + 1 for item in top10_items] score = spearman_rank_correlation(top10_positions, top10_items) - except ValueError: # 如果某个top10 item在新排名中未找到 - score = 0.0 + except ValueError: # If a top10 item is not found in the new ranking + score = 0.0 results[task_item['original_idx']] = score - # 按原始顺序重组最终的奖励列表 + # Reassemble the final reward list in original order final_rewards = [results.get(i) for i in range(len(completions))] return final_rewards # ========= 5. User2Item (使用批量Faiss搜索进行优化) ========= class User2ItemReward(ORM): def __call__(self, completions, task, solution, **kwargs) -> List[float]: - # 按数据源分组 + # Group by data source grouped_tasks = defaultdict(list) for i, (c, t, s) in enumerate(zip(completions, task, solution)): if t == "group": @@ -369,8 +363,8 @@ def __call__(self, completions, task, solution, **kwargs) -> List[float]: embedding_base_dir = os.getenv("EMBEDDING_DATA_DIR", "./data/embeddings") embedding_file = f"{embedding_base_dir}/{source}_all_item_embedding_t-desc.jsonl" faiss_index, item_names = resource_manager.get_faiss_resources(source, embedding_file) - - # 步骤 1: 并行生成所有embeddings + + # Step 1: Parallel generation of all embeddings def get_embedding(task_item): prompt = (f"You need to generate a recommendation list considering user's preference from " f"historical interactions. The historical interactions are provided as follows: " @@ -382,26 +376,26 @@ def get_embedding(task_item): embedding_futures = [executor.submit(get_embedding, task_item) for task_item in tasks] prompt_embeddings = np.array([future.result() for future in embedding_futures], dtype=np.float32) faiss.normalize_L2(prompt_embeddings) - - # 步骤 2: 执行一次性的批量搜索 + + # Step 2: Execute one-time batch search _, all_indices = faiss_index.search(prompt_embeddings, len(item_names)) - # 步骤 3: 计算分数 + # Step 3: Calculate scores for i, task_item in enumerate(tasks): sorted_items = [item_names[idx] for idx in all_indices[i]] item_title = task_item['solution']['target_item'] try: item_rank = sorted_items.index(item_title) + 1 score = math.exp(-(item_rank - 1) / 2000) - except ValueError: # 目标项目未在排名中找到 + except ValueError: # Target item not found in ranking score = 0.0 results[task_item['original_idx']] = score - # 按原始顺序重组结果 + # Reassemble results in original order final_rewards = [results.get(i) for i in range(len(completions))] return final_rewards -# ========= 注册所有奖励函数 ========= +# ========= Register all reward functions ========= orms['average_ppl'] = ConditionalPPL orms['length'] = LengthReward orms['discriminative'] = DiscriminativeReward diff --git a/RecLM-uni/cli_serve.py b/RecLM-uni/cli_serve.py index d9b9be9..8cf60cd 100644 --- a/RecLM-uni/cli_serve.py +++ b/RecLM-uni/cli_serve.py @@ -1,14 +1,15 @@ import argparse import html -import os.path +import os +import torch import gradio as gr -import torch from transformers import AutoTokenizer, AutoModelForCausalLM from train_utils.processor import FastPrefixConstrainedLogitsProcessor, Trie_link from train_utils.utils import load_json + domains = ["steam", "movies", "toys"] system_message = { 'role': "system", diff --git a/RecLM-uni/grpo_train.py b/RecLM-uni/grpo_train.py index 188ff80..6bc7116 100644 --- a/RecLM-uni/grpo_train.py +++ b/RecLM-uni/grpo_train.py @@ -194,7 +194,7 @@ def sample_to_dict(sample): try: model_init.print_trainable_parameters() except AttributeError: - pass + print("Note: print_trainable_parameters() method not available on this PEFT model type") else: model_init = args.model_path diff --git a/RecLM-uni/index/datasets.py b/RecLM-uni/index/datasets.py index 08d6eb9..d83b31f 100644 --- a/RecLM-uni/index/datasets.py +++ b/RecLM-uni/index/datasets.py @@ -8,7 +8,6 @@ class EmbDataset(data.Dataset): def __init__(self,data_path): self.data_path = data_path - # self.embeddings = np.fromfile(data_path, dtype=np.float32).reshape(16859,-1) self.embeddings = np.load(data_path) self.dim = self.embeddings.shape[-1] diff --git a/RecLM-uni/index/generate_indices.py b/RecLM-uni/index/generate_indices.py index 68c82f7..0fa9251 100644 --- a/RecLM-uni/index/generate_indices.py +++ b/RecLM-uni/index/generate_indices.py @@ -1,20 +1,15 @@ +import argparse import collections +import os import json -import logging -import argparse +import torch import numpy as np -import torch -from time import time -from torch import optim from tqdm import tqdm - from torch.utils.data import DataLoader -from datasets import EmbDataset -from models.rqvae import RQVAE - -import os +from index.datasets import EmbDataset +from index.models.rqvae import RQVAE def check_collision(all_indices_str): tot_item = len(all_indices_str) diff --git a/RecLM-uni/index/main.py b/RecLM-uni/index/main.py index 98da7cc..39fad5c 100644 --- a/RecLM-uni/index/main.py +++ b/RecLM-uni/index/main.py @@ -1,15 +1,14 @@ import argparse import random import torch -import numpy as np -from time import time import logging +import numpy as np from torch.utils.data import DataLoader -from datasets import EmbDataset -from models.rqvae import RQVAE -from trainer import Trainer +from index.datasets import EmbDataset +from index.models.rqvae import RQVAE +from index.trainer import Trainer def parse_args(): parser = argparse.ArgumentParser(description="Index") diff --git a/RecLM-uni/index/models/rq.py b/RecLM-uni/index/models/rq.py index 85f678b..ef68468 100644 --- a/RecLM-uni/index/models/rq.py +++ b/RecLM-uni/index/models/rq.py @@ -5,11 +5,6 @@ class ResidualVectorQuantizer(nn.Module): - """ References: - SoundStream: An End-to-End Neural Audio Codec - https://arxiv.org/pdf/2107.03312.pdf - """ - def __init__(self, n_e_list, e_dim, sk_epsilons, beta = 0.25, kmeans_init = False, kmeans_iters = 100, sk_iters=100,): super().__init__() diff --git a/RecLM-uni/index/models/rqvae.py b/RecLM-uni/index/models/rqvae.py index 0cc52c2..c8c58b9 100644 --- a/RecLM-uni/index/models/rqvae.py +++ b/RecLM-uni/index/models/rqvae.py @@ -1,4 +1,3 @@ -import numpy as np import torch from torch import nn from torch.nn import functional as F @@ -87,15 +86,10 @@ def compute_loss(self, out, quant_loss, xs=None): loss_total = loss_recon + self.quant_loss_weight * quant_loss if self.contrastive_loss_weight > 0: - # InfoNCE Loss - # Normalize vectors out_norm = F.normalize(out, p=2, dim=1) xs_norm = F.normalize(xs, p=2, dim=1) - - # Compute similarity matrix (Batch_Size, Batch_Size) logits = torch.matmul(out_norm, xs_norm.t()) / self.temperature - # Targets are the diagonal elements (0, 1, 2, ...) labels = torch.arange(logits.size(0), device=logits.device) loss_contrastive = F.cross_entropy(logits, labels) diff --git a/RecLM-uni/index/models/vq.py b/RecLM-uni/index/models/vq.py index 30474fb..675257e 100644 --- a/RecLM-uni/index/models/vq.py +++ b/RecLM-uni/index/models/vq.py @@ -67,7 +67,6 @@ def forward(self, x, use_sk=True): if not self.initted and self.training: self.init_emb(latent) - # Calculate the L2 Norm between latent and Embedded weights d = torch.sum(latent**2, dim=1, keepdim=True) + \ torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t()- \ 2 * torch.matmul(latent, self.embedding.weight.t()) @@ -82,16 +81,11 @@ def forward(self, x, use_sk=True): print(f"Sinkhorn Algorithm returns nan/inf values.") indices = torch.argmax(Q, dim=-1) - # indices = torch.argmin(d, dim=-1) - x_q = self.embedding(indices).view(x.shape) - - # compute loss for embedding commitment_loss = F.mse_loss(x_q.detach(), x) codebook_loss = F.mse_loss(x_q, x.detach()) loss = codebook_loss + self.beta * commitment_loss - # preserve gradients x_q = x + (x_q - x).detach() indices = indices.view(x.shape[:-1]) diff --git a/RecLM-uni/index/trainer.py b/RecLM-uni/index/trainer.py index 9c620ab..cf84122 100644 --- a/RecLM-uni/index/trainer.py +++ b/RecLM-uni/index/trainer.py @@ -1,16 +1,16 @@ +import os import logging +import torch +import heapq import numpy as np -import torch from time import time from torch import optim from tqdm import tqdm from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup -from utils import ensure_dir,set_color,get_local_time,delete_file -import os +from .utils import ensure_dir, set_color, get_local_time, delete_file -import heapq class Trainer(object): def __init__(self, args, model, data_num): diff --git a/RecLM-uni/index/utils.py b/RecLM-uni/index/utils.py index abd33c4..06b2a0c 100644 --- a/RecLM-uni/index/utils.py +++ b/RecLM-uni/index/utils.py @@ -1,17 +1,14 @@ - -import datetime import os - +import datetime def ensure_dir(dir_path): - os.makedirs(dir_path, exist_ok=True) def set_color(log, color, highlight=True): color_set = ["black", "red", "green", "yellow", "blue", "pink", "cyan", "white"] try: index = color_set.index(color) - except: + except ValueError: index = len(color_set) - 1 prev_log = "\033[" if highlight: diff --git a/RecLM-uni/preprocess/data_preprocess_amazon.py b/RecLM-uni/preprocess/data_preprocess_amazon.py index 10a2b3a..65f247c 100644 --- a/RecLM-uni/preprocess/data_preprocess_amazon.py +++ b/RecLM-uni/preprocess/data_preprocess_amazon.py @@ -313,17 +313,8 @@ def main_process(data_name, args, data_type='Amazon'): user_items = get_interaction(datas) # dict of {user: interaction list sorted by time} print(f'{data_name} Raw data has been processed! Lower than {rating_score} are deleted!') print(f'User Num: {len(user_items)}') - # raw_id user: [item1, item2, item3...] - # user 25-core item 10-core - # user_core, item_core = 25, 10 - # user_items = filter_Kcore(user_items, user_core=user_core, item_core=item_core) - # print(f'User {user_core}-core complete! Item {item_core}-core complete!') - # user_num, item_num, datamaps = id_map(user_items) # get mapping dicts - # user_count, item_count, isKcore = check_Kcore(user_items, user_core=user_core, item_core=item_core) - # assert isKcore is True - # sample 10000 users, item max len is 17 user_num, item_len = 10000, 17 user_items = sample_inter(user_items, user_num=user_num, item_len=item_len) user_num, item_num, datamaps = id_map(user_items) # get mapping dicts diff --git a/RecLM-uni/scripts/data_preprocess_amazon.sh b/RecLM-uni/scripts/data_preprocess_amazon.sh index 5ec7c08..76c3df8 100644 --- a/RecLM-uni/scripts/data_preprocess_amazon.sh +++ b/RecLM-uni/scripts/data_preprocess_amazon.sh @@ -1,5 +1,3 @@ - - TOKENIZER_PATH="meta-llama/Meta-Llama-3-8B-Instruct" DATASET_FULL_NAME="Movies_and_TV" DATASET_NAME="movies" diff --git a/RecLM-uni/scripts/run_SFT_merge.sh b/RecLM-uni/scripts/run_SFT_merge.sh index ebb44bf..ca4afda 100644 --- a/RecLM-uni/scripts/run_SFT_merge.sh +++ b/RecLM-uni/scripts/run_SFT_merge.sh @@ -1,4 +1,3 @@ - MODEL_PATH="meta-llama/Meta-Llama-3-8B-Instruct" OUTPUT_PATH="./snap/.../" LORA_PATH="${OUTPUT_PATH}Epoch20_SFT" diff --git a/RecLM-uni/scripts/train_RecLM_cgen.sh b/RecLM-uni/scripts/train_RecLM_cgen.sh index 24fcaa2..554ae12 100644 --- a/RecLM-uni/scripts/train_RecLM_cgen.sh +++ b/RecLM-uni/scripts/train_RecLM_cgen.sh @@ -1,4 +1,3 @@ - MODEL_PATH="meta-llama/Meta-Llama-3-8B-Instruct" SFT_TRAIN_TASKS="SFTSeqRec-MR" DATASET=$1 diff --git a/RecLM-uni/scripts/train_RecLM_ret.sh b/RecLM-uni/scripts/train_RecLM_ret.sh index dcf2a8d..cdc2c39 100644 --- a/RecLM-uni/scripts/train_RecLM_ret.sh +++ b/RecLM-uni/scripts/train_RecLM_ret.sh @@ -1,4 +1,3 @@ - MODEL_PATH="meta-llama/Meta-Llama-3-8B-Instruct" DATASET=$1 TRAIN_ARGS="" diff --git a/RecLM-uni/task_MR_test.py b/RecLM-uni/task_MR_test.py index e7458e8..d24b5a6 100644 --- a/RecLM-uni/task_MR_test.py +++ b/RecLM-uni/task_MR_test.py @@ -1,7 +1,6 @@ import argparse import json import os -import re from concurrent.futures import ProcessPoolExecutor import numpy as np @@ -14,7 +13,7 @@ from train_utils.dataset import Test_task_group_mapping, SFTDataset from train_utils.processor import FastPrefixConstrainedLogitsProcessor from train_utils.metrics import Metrics -from train_utils.utils import save_json, get_ctrl_item, rm_idx, load_json, load_pickle, side_tokenizer, process_train_sample, gsm8K_clean_answer, gsm8K_is_correct +from train_utils.utils import save_json, get_ctrl_item, rm_idx, load_json, side_tokenizer, process_train_sample, gsm8K_clean_answer, gsm8K_is_correct headers = {"User-Agent": "Test Client"} GSM8K_Q1 = '''Question: In 2004, there were 60 kids at a cookout. In 2005, half the number of kids came to the cookout as compared to 2004. In 2006, 2/3 as many kids came to the cookout as in 2005. How many kids came to the cookout in 2006?''' diff --git a/RecLM-uni/task_test.py b/RecLM-uni/task_test.py index 9b9b0c8..6edc5c3 100644 --- a/RecLM-uni/task_test.py +++ b/RecLM-uni/task_test.py @@ -1,10 +1,9 @@ import argparse -import copy import json import os -from concurrent.futures import ProcessPoolExecutor - import torch + +from concurrent.futures import ProcessPoolExecutor from Levenshtein import distance from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM @@ -12,7 +11,7 @@ from train_utils.dataset import Test_task_group_mapping, SFTDataset from train_utils.processor import FastPrefixConstrainedLogitsProcessor from train_utils.metrics import Metrics -from train_utils.utils import save_json, get_ctrl_item, rm_idx, load_json, load_pickle, side_tokenizer, process_train_sample +from train_utils.utils import save_json, get_ctrl_item, rm_idx, load_json, side_tokenizer, process_train_sample @torch.no_grad() diff --git a/RecLM-uni/task_test_tokenizer.py b/RecLM-uni/task_test_tokenizer.py index 503c9d0..7dd6657 100644 --- a/RecLM-uni/task_test_tokenizer.py +++ b/RecLM-uni/task_test_tokenizer.py @@ -1,10 +1,9 @@ import argparse -import copy import json import os -from concurrent.futures import ProcessPoolExecutor - import torch + +from concurrent.futures import ProcessPoolExecutor from Levenshtein import distance from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM @@ -17,7 +16,6 @@ get_ctrl_item, rm_idx, load_json, - load_pickle, side_tokenizer, process_train_sample, load_item_code_mapping, diff --git a/RecLM-uni/train_utils/metrics.py b/RecLM-uni/train_utils/metrics.py index d2e8982..208d29f 100644 --- a/RecLM-uni/train_utils/metrics.py +++ b/RecLM-uni/train_utils/metrics.py @@ -1,5 +1,4 @@ import copy - import math from Levenshtein import distance diff --git a/RecLM-uni/train_utils/param.py b/RecLM-uni/train_utils/param.py index d8b2aa9..82d3527 100644 --- a/RecLM-uni/train_utils/param.py +++ b/RecLM-uni/train_utils/param.py @@ -1,6 +1,5 @@ import argparse import pprint - import yaml diff --git a/RecLM-uni/train_utils/processor.py b/RecLM-uni/train_utils/processor.py index 011a5f0..15e78d0 100644 --- a/RecLM-uni/train_utils/processor.py +++ b/RecLM-uni/train_utils/processor.py @@ -1,7 +1,7 @@ -from typing import Callable, List - import math import torch + +from typing import Callable, List from transformers import LogitsProcessor, add_start_docstrings from transformers.generation.logits_process import LOGITS_PROCESSOR_INPUTS_DOCSTRING diff --git a/RecLM-uni/train_utils/utils.py b/RecLM-uni/train_utils/utils.py index 532312e..533cb9c 100644 --- a/RecLM-uni/train_utils/utils.py +++ b/RecLM-uni/train_utils/utils.py @@ -3,12 +3,12 @@ import os.path import pickle import re - import requests import torch ITEM_CODE_FIELD = 'rq_token_seq' + def pad_sequence(seq: list[list], pad_token_id, device, pad_side='right'): max_len = max([len(s) for s in seq]) for i, s in enumerate(seq): diff --git a/RecLM-uni/trainer.py b/RecLM-uni/trainer.py index e102844..daa11f8 100644 --- a/RecLM-uni/trainer.py +++ b/RecLM-uni/trainer.py @@ -1,8 +1,8 @@ -import os.path - +import os import math -import numpy as np import torch + +import numpy as np from FlagEmbedding import BGEM3FlagModel from Levenshtein import distance from accelerate import Accelerator