From ff0b050d598580757884c15ba8fbb78a5e387c55 Mon Sep 17 00:00:00 2001 From: xujiayuan0205 Date: Wed, 15 Apr 2026 17:27:41 +0800 Subject: [PATCH] feat: add badcase recording, LLM judge fallback, dual metrics, and fix ToMi field mapping - Add StructuredResult dataclass wrapping parsed output with raw_response and reasoning_content - Add src/judge.py for LLM semantic judge fallback when structured extraction fails - Extend runner.py with collect_badcases() and build_corrected_predictions() - Add dual metrics (strict + judge-corrected) to all task run.py scripts - Fix ToMi field mapping: Story.full_story/Question/Answer.Correct_Answer - Fix reasoning_content capture: support both 'reasoning' and 'reasoning_content' field names - Fix run_all.py subprocess PYTHONPATH for src module resolution - Update SUMMARY.md with deepseek-chat and deepseek-r1 results --- datasets/.gitattributes | 60 +++++++ datasets/README.md | 10 ++ datasets/SocialIQA/dataset_dict.json | 1 + .../SocialIQA/dev/data-00000-of-00001.arrow | 3 + datasets/SocialIQA/dev/dataset_info.json | 114 ++++++++++++ datasets/SocialIQA/dev/state.json | 13 ++ .../SocialIQA/test/data-00000-of-00001.arrow | 3 + datasets/SocialIQA/test/dataset_info.json | 114 ++++++++++++ datasets/SocialIQA/test/state.json | 13 ++ .../SocialIQA/train/data-00000-of-00001.arrow | 3 + datasets/SocialIQA/train/dataset_info.json | 114 ++++++++++++ datasets/SocialIQA/train/state.json | 13 ++ .../ToMBench/test/data-00000-of-00001.arrow | 3 + datasets/ToMBench/test/dataset_info.json | 51 ++++++ datasets/ToMBench/test/state.json | 13 ++ .../ToMBench/train/data-00000-of-00001.arrow | 3 + datasets/ToMBench/train/dataset_info.json | 51 ++++++ datasets/ToMBench/train/state.json | 13 ++ datasets/ToMQA/dataset_dict.json | 1 + datasets/ToMQA/test/data-00000-of-00001.arrow | 3 + datasets/ToMQA/test/dataset_info.json | 163 ++++++++++++++++++ datasets/ToMQA/test/state.json | 13 ++ .../ToMQA/train/data-00000-of-00001.arrow | 3 + datasets/ToMQA/train/dataset_info.json | 163 ++++++++++++++++++ datasets/ToMQA/train/state.json | 13 ++ .../validation/data-00000-of-00001.arrow | 3 + datasets/ToMQA/validation/dataset_info.json | 163 ++++++++++++++++++ datasets/ToMQA/validation/state.json | 13 ++ datasets/ToMi/dataset_dict.json | 1 + datasets/ToMi/test/data-00000-of-00001.arrow | 3 + datasets/ToMi/test/dataset_info.json | 105 +++++++++++ datasets/ToMi/test/state.json | 13 ++ datasets/ToMi/train/data-00000-of-00001.arrow | 3 + datasets/ToMi/train/dataset_info.json | 105 +++++++++++ datasets/ToMi/train/state.json | 13 ++ datasets/ToMi/val/data-00000-of-00001.arrow | 3 + datasets/ToMi/val/dataset_info.json | 105 +++++++++++ datasets/ToMi/val/state.json | 13 ++ .../Tomato/test/data-00000-of-00001.arrow | 3 + datasets/Tomato/test/dataset_info.json | 163 ++++++++++++++++++ datasets/Tomato/test/state.json | 13 ++ .../Tomato/train/data-00000-of-00001.arrow | 3 + datasets/Tomato/train/dataset_info.json | 163 ++++++++++++++++++ datasets/Tomato/train/state.json | 13 ++ experiment_config.yaml | 16 +- run_all.py | 99 +++++++++-- src/judge.py | 118 +++++++++++++ src/llm/__init__.py | 3 +- src/llm/client.py | 101 ++++++++--- src/runner.py | 128 +++++++++++++- tables/SUMMARY.md | 10 +- ...66\344\273\226\346\214\207\346\240\207.md" | 70 ++++---- ...72\347\241\200\346\214\207\346\240\207.md" | 10 +- ...66\344\273\226\346\214\207\346\240\207.md" | 16 +- ...72\347\241\200\346\214\207\346\240\207.md" | 10 +- ...66\344\273\226\346\214\207\346\240\207.md" | 22 +-- ...72\347\241\200\346\214\207\346\240\207.md" | 10 +- tasks/ToMBench/run.py | 89 +++++++--- tasks/ToMQA/run.py | 88 ++++++++-- tasks/ToMi/metrics.py | 12 +- tasks/ToMi/prompts.py | 5 +- tasks/ToMi/run.py | 91 +++++++--- tasks/Tomato/run.py | 76 +++++++- 63 files changed, 2631 insertions(+), 198 deletions(-) create mode 100644 datasets/.gitattributes create mode 100644 datasets/README.md create mode 100644 datasets/SocialIQA/dataset_dict.json create mode 100644 datasets/SocialIQA/dev/data-00000-of-00001.arrow create mode 100644 datasets/SocialIQA/dev/dataset_info.json create mode 100644 datasets/SocialIQA/dev/state.json create mode 100644 datasets/SocialIQA/test/data-00000-of-00001.arrow create mode 100644 datasets/SocialIQA/test/dataset_info.json create mode 100644 datasets/SocialIQA/test/state.json create mode 100644 datasets/SocialIQA/train/data-00000-of-00001.arrow create mode 100644 datasets/SocialIQA/train/dataset_info.json create mode 100644 datasets/SocialIQA/train/state.json create mode 100644 datasets/ToMBench/test/data-00000-of-00001.arrow create mode 100644 datasets/ToMBench/test/dataset_info.json create mode 100644 datasets/ToMBench/test/state.json create mode 100644 datasets/ToMBench/train/data-00000-of-00001.arrow create mode 100644 datasets/ToMBench/train/dataset_info.json create mode 100644 datasets/ToMBench/train/state.json create mode 100644 datasets/ToMQA/dataset_dict.json create mode 100644 datasets/ToMQA/test/data-00000-of-00001.arrow create mode 100644 datasets/ToMQA/test/dataset_info.json create mode 100644 datasets/ToMQA/test/state.json create mode 100644 datasets/ToMQA/train/data-00000-of-00001.arrow create mode 100644 datasets/ToMQA/train/dataset_info.json create mode 100644 datasets/ToMQA/train/state.json create mode 100644 datasets/ToMQA/validation/data-00000-of-00001.arrow create mode 100644 datasets/ToMQA/validation/dataset_info.json create mode 100644 datasets/ToMQA/validation/state.json create mode 100644 datasets/ToMi/dataset_dict.json create mode 100644 datasets/ToMi/test/data-00000-of-00001.arrow create mode 100644 datasets/ToMi/test/dataset_info.json create mode 100644 datasets/ToMi/test/state.json create mode 100644 datasets/ToMi/train/data-00000-of-00001.arrow create mode 100644 datasets/ToMi/train/dataset_info.json create mode 100644 datasets/ToMi/train/state.json create mode 100644 datasets/ToMi/val/data-00000-of-00001.arrow create mode 100644 datasets/ToMi/val/dataset_info.json create mode 100644 datasets/ToMi/val/state.json create mode 100644 datasets/Tomato/test/data-00000-of-00001.arrow create mode 100644 datasets/Tomato/test/dataset_info.json create mode 100644 datasets/Tomato/test/state.json create mode 100644 datasets/Tomato/train/data-00000-of-00001.arrow create mode 100644 datasets/Tomato/train/dataset_info.json create mode 100644 datasets/Tomato/train/state.json create mode 100644 src/judge.py diff --git a/datasets/.gitattributes b/datasets/.gitattributes new file mode 100644 index 0000000..bed0738 --- /dev/null +++ b/datasets/.gitattributes @@ -0,0 +1,60 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.avro filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.lz4 filter=lfs diff=lfs merge=lfs -text +*.mds filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +# Audio files - uncompressed +*.pcm filter=lfs diff=lfs merge=lfs -text +*.sam filter=lfs diff=lfs merge=lfs -text +*.raw filter=lfs diff=lfs merge=lfs -text +# Audio files - compressed +*.aac filter=lfs diff=lfs merge=lfs -text +*.flac filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.ogg filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +# Image files - uncompressed +*.bmp filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.tiff filter=lfs diff=lfs merge=lfs -text +# Image files - compressed +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.webp filter=lfs diff=lfs merge=lfs -text +# Video files - compressed +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.webm filter=lfs diff=lfs merge=lfs -text diff --git a/datasets/README.md b/datasets/README.md new file mode 100644 index 0000000..d04d59b --- /dev/null +++ b/datasets/README.md @@ -0,0 +1,10 @@ +--- +license: apache-2.0 +task_categories: +- question-answering +language: +- en +--- +**Theory of Mind(心智理论)**指的是: +理解“别人有自己的想法、信念、情绪、意图”,而且这些可能和自己不同的能力。 +此数据库为 ToM 数据集集合 \ No newline at end of file diff --git a/datasets/SocialIQA/dataset_dict.json b/datasets/SocialIQA/dataset_dict.json new file mode 100644 index 0000000..631d407 --- /dev/null +++ b/datasets/SocialIQA/dataset_dict.json @@ -0,0 +1 @@ +{"splits": ["train", "dev", "test"]} \ No newline at end of file diff --git a/datasets/SocialIQA/dev/data-00000-of-00001.arrow b/datasets/SocialIQA/dev/data-00000-of-00001.arrow new file mode 100644 index 0000000..06bb36a --- /dev/null +++ b/datasets/SocialIQA/dev/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0170785f1586d4191664887d480ca14b0d1bbc5f171ce923c242afb6b486fe31 +size 688960 diff --git a/datasets/SocialIQA/dev/dataset_info.json b/datasets/SocialIQA/dev/dataset_info.json new file mode 100644 index 0000000..83f9dee --- /dev/null +++ b/datasets/SocialIQA/dev/dataset_info.json @@ -0,0 +1,114 @@ +{ + "citation": "", + "description": "", + "features": { + "State": { + "Human_State": {}, + "Environment_State": {} + }, + "Action": {}, + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + }, + "charmap": { + "feature": { + "key": { + "dtype": "string", + "_type": "Value" + }, + "value": { + "dtype": "string", + "_type": "Value" + } + }, + "_type": "Sequence" + }, + "answerSourcesOrigins": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "answerSourcesWithCor": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "promptQuestionFocusChar": { + "dtype": "string", + "_type": "Value" + } + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/SocialIQA/dev/state.json b/datasets/SocialIQA/dev/state.json new file mode 100644 index 0000000..2824e40 --- /dev/null +++ b/datasets/SocialIQA/dev/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "522b767197a17e40", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/SocialIQA/test/data-00000-of-00001.arrow b/datasets/SocialIQA/test/data-00000-of-00001.arrow new file mode 100644 index 0000000..24afb43 --- /dev/null +++ b/datasets/SocialIQA/test/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5bece981dd330c890d627ecf5ad31d979213c75888d808865f35420118baba88 +size 787144 diff --git a/datasets/SocialIQA/test/dataset_info.json b/datasets/SocialIQA/test/dataset_info.json new file mode 100644 index 0000000..83f9dee --- /dev/null +++ b/datasets/SocialIQA/test/dataset_info.json @@ -0,0 +1,114 @@ +{ + "citation": "", + "description": "", + "features": { + "State": { + "Human_State": {}, + "Environment_State": {} + }, + "Action": {}, + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + }, + "charmap": { + "feature": { + "key": { + "dtype": "string", + "_type": "Value" + }, + "value": { + "dtype": "string", + "_type": "Value" + } + }, + "_type": "Sequence" + }, + "answerSourcesOrigins": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "answerSourcesWithCor": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "promptQuestionFocusChar": { + "dtype": "string", + "_type": "Value" + } + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/SocialIQA/test/state.json b/datasets/SocialIQA/test/state.json new file mode 100644 index 0000000..ae31c7a --- /dev/null +++ b/datasets/SocialIQA/test/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "9bdbd0b40bc40280", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/SocialIQA/train/data-00000-of-00001.arrow b/datasets/SocialIQA/train/data-00000-of-00001.arrow new file mode 100644 index 0000000..54fe1c1 --- /dev/null +++ b/datasets/SocialIQA/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24fbbb2fd20257cecf7a35ad769a65ddc0b9e934472e12abb4839d9e9a2f5d2c +size 11772760 diff --git a/datasets/SocialIQA/train/dataset_info.json b/datasets/SocialIQA/train/dataset_info.json new file mode 100644 index 0000000..83f9dee --- /dev/null +++ b/datasets/SocialIQA/train/dataset_info.json @@ -0,0 +1,114 @@ +{ + "citation": "", + "description": "", + "features": { + "State": { + "Human_State": {}, + "Environment_State": {} + }, + "Action": {}, + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + }, + "charmap": { + "feature": { + "key": { + "dtype": "string", + "_type": "Value" + }, + "value": { + "dtype": "string", + "_type": "Value" + } + }, + "_type": "Sequence" + }, + "answerSourcesOrigins": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "answerSourcesWithCor": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "promptQuestionFocusChar": { + "dtype": "string", + "_type": "Value" + } + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/SocialIQA/train/state.json b/datasets/SocialIQA/train/state.json new file mode 100644 index 0000000..f7db6b4 --- /dev/null +++ b/datasets/SocialIQA/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "0eca722626a45ed1", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/ToMBench/test/data-00000-of-00001.arrow b/datasets/ToMBench/test/data-00000-of-00001.arrow new file mode 100644 index 0000000..6bc8384 --- /dev/null +++ b/datasets/ToMBench/test/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:139c42b45a48c0b3bdcbb3d3f0b229c86a7c1b03e69635ce8485118fb7c81c7e +size 2004448 diff --git a/datasets/ToMBench/test/dataset_info.json b/datasets/ToMBench/test/dataset_info.json new file mode 100644 index 0000000..4bf9c1a --- /dev/null +++ b/datasets/ToMBench/test/dataset_info.json @@ -0,0 +1,51 @@ +{ + "citation": "", + "description": "", + "features": { + "State": { + "Environment State": {}, + "Human State": {} + }, + "Action": {}, + "Story": { + "dtype": "string", + "_type": "Value" + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Answer": { + "Correct Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "Wrong Answer": { + "feature": { + "dtype": "null", + "_type": "Value" + }, + "_type": "List" + } + }, + "Meta": { + "Index": { + "dtype": "int64", + "_type": "Value" + }, + "ability": { + "dtype": "string", + "_type": "Value" + }, + "filename": { + "dtype": "string", + "_type": "Value" + } + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/ToMBench/test/state.json b/datasets/ToMBench/test/state.json new file mode 100644 index 0000000..f0c4dc3 --- /dev/null +++ b/datasets/ToMBench/test/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "434bf781e5e90cea", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/ToMBench/train/data-00000-of-00001.arrow b/datasets/ToMBench/train/data-00000-of-00001.arrow new file mode 100644 index 0000000..5a7f77f --- /dev/null +++ b/datasets/ToMBench/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4366007deba55d37db7535a6e4ebc9900c78f02e276dd1da74f93bef19eb4e0 +size 865336 diff --git a/datasets/ToMBench/train/dataset_info.json b/datasets/ToMBench/train/dataset_info.json new file mode 100644 index 0000000..fd06259 --- /dev/null +++ b/datasets/ToMBench/train/dataset_info.json @@ -0,0 +1,51 @@ +{ + "citation": "", + "description": "", + "features": { + "State": { + "Environment State": {}, + "Human State": {} + }, + "Action": {}, + "Story": { + "dtype": "string", + "_type": "Value" + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Answer": { + "Correct Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "Wrong Answer": { + "feature": { + "dtype": "null", + "_type": "Value" + }, + "_type": "List" + } + }, + "Meta": { + "ability": { + "dtype": "string", + "_type": "Value" + }, + "id": { + "dtype": "string", + "_type": "Value" + }, + "qa_index": { + "dtype": "int64", + "_type": "Value" + } + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/ToMBench/train/state.json b/datasets/ToMBench/train/state.json new file mode 100644 index 0000000..4d26bfc --- /dev/null +++ b/datasets/ToMBench/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "821a86846bc2d5b4", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/ToMQA/dataset_dict.json b/datasets/ToMQA/dataset_dict.json new file mode 100644 index 0000000..9195703 --- /dev/null +++ b/datasets/ToMQA/dataset_dict.json @@ -0,0 +1 @@ +{"splits": ["train", "validation", "test"]} \ No newline at end of file diff --git a/datasets/ToMQA/test/data-00000-of-00001.arrow b/datasets/ToMQA/test/data-00000-of-00001.arrow new file mode 100644 index 0000000..d4b6836 --- /dev/null +++ b/datasets/ToMQA/test/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfb82fcf2cc2ac6550d52a5f532625c9606cd9225413a6bb862c49a8af58da3e +size 11402032 diff --git a/datasets/ToMQA/test/dataset_info.json b/datasets/ToMQA/test/dataset_info.json new file mode 100644 index 0000000..8d31357 --- /dev/null +++ b/datasets/ToMQA/test/dataset_info.json @@ -0,0 +1,163 @@ +{ + "citation": "", + "description": "", + "features": { + "State": { + "Human_State": { + "beliefs": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "preferences": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "emotions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "other_human_states": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Environment_State": { + "locations": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "objects": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "changes": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "other_env_states": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + } + }, + "Action": { + "actions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "observers": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "timestamps": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + } + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/ToMQA/test/state.json b/datasets/ToMQA/test/state.json new file mode 100644 index 0000000..ee6a312 --- /dev/null +++ b/datasets/ToMQA/test/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "2faf1d300bc9b008", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/ToMQA/train/data-00000-of-00001.arrow b/datasets/ToMQA/train/data-00000-of-00001.arrow new file mode 100644 index 0000000..bb1e16a --- /dev/null +++ b/datasets/ToMQA/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b78bda5af5376ea07f4f2a9a092e74db7635096bc31299c5a928c5195379bb3d +size 9292144 diff --git a/datasets/ToMQA/train/dataset_info.json b/datasets/ToMQA/train/dataset_info.json new file mode 100644 index 0000000..8d31357 --- /dev/null +++ b/datasets/ToMQA/train/dataset_info.json @@ -0,0 +1,163 @@ +{ + "citation": "", + "description": "", + "features": { + "State": { + "Human_State": { + "beliefs": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "preferences": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "emotions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "other_human_states": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Environment_State": { + "locations": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "objects": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "changes": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "other_env_states": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + } + }, + "Action": { + "actions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "observers": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "timestamps": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + } + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/ToMQA/train/state.json b/datasets/ToMQA/train/state.json new file mode 100644 index 0000000..eaa60de --- /dev/null +++ b/datasets/ToMQA/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "1d73a39b0a90424f", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/ToMQA/validation/data-00000-of-00001.arrow b/datasets/ToMQA/validation/data-00000-of-00001.arrow new file mode 100644 index 0000000..88bf5fe --- /dev/null +++ b/datasets/ToMQA/validation/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51929f3d4425857a5bdc4d454b790d0772bf104eef37cc0126a9e3e982a9ce6c +size 11458640 diff --git a/datasets/ToMQA/validation/dataset_info.json b/datasets/ToMQA/validation/dataset_info.json new file mode 100644 index 0000000..8d31357 --- /dev/null +++ b/datasets/ToMQA/validation/dataset_info.json @@ -0,0 +1,163 @@ +{ + "citation": "", + "description": "", + "features": { + "State": { + "Human_State": { + "beliefs": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "preferences": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "emotions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "other_human_states": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Environment_State": { + "locations": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "objects": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "changes": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "other_env_states": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + } + }, + "Action": { + "actions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "observers": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "timestamps": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + } + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/ToMQA/validation/state.json b/datasets/ToMQA/validation/state.json new file mode 100644 index 0000000..8d2cb16 --- /dev/null +++ b/datasets/ToMQA/validation/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "c7fde6c0c2636a5b", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/ToMi/dataset_dict.json b/datasets/ToMi/dataset_dict.json new file mode 100644 index 0000000..d5bb6b7 --- /dev/null +++ b/datasets/ToMi/dataset_dict.json @@ -0,0 +1 @@ +{"splits": ["train", "val", "test"]} \ No newline at end of file diff --git a/datasets/ToMi/test/data-00000-of-00001.arrow b/datasets/ToMi/test/data-00000-of-00001.arrow new file mode 100644 index 0000000..4a498cf --- /dev/null +++ b/datasets/ToMi/test/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb703db399f8675e176ec673c0967782ad05f78bd1d00429b15b4eb41948faaa +size 3230896 diff --git a/datasets/ToMi/test/dataset_info.json b/datasets/ToMi/test/dataset_info.json new file mode 100644 index 0000000..1156695 --- /dev/null +++ b/datasets/ToMi/test/dataset_info.json @@ -0,0 +1,105 @@ +{ + "citation": "", + "description": "", + "features": { + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Action": { + "actions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "observers": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "timestamps": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + } + }, + "State": { + "Human_State": {}, + "Environment_State": {} + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/ToMi/test/state.json b/datasets/ToMi/test/state.json new file mode 100644 index 0000000..f7e6efa --- /dev/null +++ b/datasets/ToMi/test/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "68fe9870db088830", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/ToMi/train/data-00000-of-00001.arrow b/datasets/ToMi/train/data-00000-of-00001.arrow new file mode 100644 index 0000000..364cfc1 --- /dev/null +++ b/datasets/ToMi/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0de13c7ecf64726e223ec663348366228cad1c6245c56238ac9ae86c0116a24f +size 3244640 diff --git a/datasets/ToMi/train/dataset_info.json b/datasets/ToMi/train/dataset_info.json new file mode 100644 index 0000000..1156695 --- /dev/null +++ b/datasets/ToMi/train/dataset_info.json @@ -0,0 +1,105 @@ +{ + "citation": "", + "description": "", + "features": { + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Action": { + "actions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "observers": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "timestamps": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + } + }, + "State": { + "Human_State": {}, + "Environment_State": {} + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/ToMi/train/state.json b/datasets/ToMi/train/state.json new file mode 100644 index 0000000..f276f73 --- /dev/null +++ b/datasets/ToMi/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "b25c0479756b5dce", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/ToMi/val/data-00000-of-00001.arrow b/datasets/ToMi/val/data-00000-of-00001.arrow new file mode 100644 index 0000000..3a30eb6 --- /dev/null +++ b/datasets/ToMi/val/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0752762704ed5d8a7732afce717b212b0bc5c62ee71ed749169fb1410597600d +size 3226712 diff --git a/datasets/ToMi/val/dataset_info.json b/datasets/ToMi/val/dataset_info.json new file mode 100644 index 0000000..1156695 --- /dev/null +++ b/datasets/ToMi/val/dataset_info.json @@ -0,0 +1,105 @@ +{ + "citation": "", + "description": "", + "features": { + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Action": { + "actions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "observers": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "timestamps": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "Sequence" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + } + }, + "State": { + "Human_State": {}, + "Environment_State": {} + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/ToMi/val/state.json b/datasets/ToMi/val/state.json new file mode 100644 index 0000000..6b22cfd --- /dev/null +++ b/datasets/ToMi/val/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "d73828672aabec43", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/Tomato/test/data-00000-of-00001.arrow b/datasets/Tomato/test/data-00000-of-00001.arrow new file mode 100644 index 0000000..75b4e98 --- /dev/null +++ b/datasets/Tomato/test/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d415c995dc836457c1569c01b7d6c9b62bbc14bf02671d5892d0884d0286ec6b +size 21889472 diff --git a/datasets/Tomato/test/dataset_info.json b/datasets/Tomato/test/dataset_info.json new file mode 100644 index 0000000..97ac35f --- /dev/null +++ b/datasets/Tomato/test/dataset_info.json @@ -0,0 +1,163 @@ +{ + "citation": "", + "description": "", + "features": { + "State": { + "Human_State": { + "beliefs": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "preferences": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "emotions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "other_human_states": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + } + }, + "Environment_State": { + "locations": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "objects": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "changes": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "other_env_states": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + } + } + }, + "Action": { + "actions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "observers": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "timestamps": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + } + }, + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + } + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/Tomato/test/state.json b/datasets/Tomato/test/state.json new file mode 100644 index 0000000..aee0c63 --- /dev/null +++ b/datasets/Tomato/test/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "3b840689f52a33c1", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/datasets/Tomato/train/data-00000-of-00001.arrow b/datasets/Tomato/train/data-00000-of-00001.arrow new file mode 100644 index 0000000..7fb30cd --- /dev/null +++ b/datasets/Tomato/train/data-00000-of-00001.arrow @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43f0ba7d03e37d1d60c503c6e02ac358909dd272b7743db01dec29d6f230b5c1 +size 5239016 diff --git a/datasets/Tomato/train/dataset_info.json b/datasets/Tomato/train/dataset_info.json new file mode 100644 index 0000000..97ac35f --- /dev/null +++ b/datasets/Tomato/train/dataset_info.json @@ -0,0 +1,163 @@ +{ + "citation": "", + "description": "", + "features": { + "State": { + "Human_State": { + "beliefs": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "preferences": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "emotions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "other_human_states": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + } + }, + "Environment_State": { + "locations": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "objects": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "changes": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "other_env_states": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + } + } + }, + "Action": { + "actions": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "observers": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "timestamps": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + } + }, + "Story": { + "background": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "full_story": { + "dtype": "string", + "_type": "Value" + }, + "summary": { + "dtype": "string", + "_type": "Value" + } + }, + "Question": { + "dtype": "string", + "_type": "Value" + }, + "Answer": { + "Correct_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "Wrong_Answer": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + } + }, + "Meta": { + "id": { + "dtype": "string", + "_type": "Value" + }, + "dataset_source": { + "dtype": "string", + "_type": "Value" + }, + "dimension": { + "feature": { + "dtype": "string", + "_type": "Value" + }, + "_type": "List" + }, + "order": { + "dtype": "int32", + "_type": "Value" + }, + "task_type": { + "dtype": "string", + "_type": "Value" + }, + "difficulty": { + "dtype": "string", + "_type": "Value" + }, + "ethics_category": { + "dtype": "string", + "_type": "Value" + } + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/datasets/Tomato/train/state.json b/datasets/Tomato/train/state.json new file mode 100644 index 0000000..2e62268 --- /dev/null +++ b/datasets/Tomato/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "b11b7c9884c49cc7", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/experiment_config.yaml b/experiment_config.yaml index 92c1d02..20dfbb3 100644 --- a/experiment_config.yaml +++ b/experiment_config.yaml @@ -1,20 +1,24 @@ llm: model_name: Qwen3-8B api_key: not-needed - api_url: http://0.0.0.0:8000/v1 + api_url: http://127.0.0.1:8006/v1 temperature: 0.6 - max_tokens: 32768 - max_workers: 64 + max_tokens: 8192 + max_workers: 4 enable_thinking: true judge: - model_name: gemma-3-4b-it + model_name: Qwen3-8B api_key: not-needed api_url: http://127.0.0.1:8006/v1 temperature: 0.0 max_tokens: 4096 + enable_llm_judge: true repeats: 1 -max_samples: 2 +max_samples: 10 datasets_path: datasets -results_path: results +results_path: results/Qwen3-8B-test + +badcase: + enabled: true diff --git a/run_all.py b/run_all.py index 66d67ae..6b15e74 100644 --- a/run_all.py +++ b/run_all.py @@ -1,32 +1,77 @@ -"""统一评测入口 +"""统一评测入口 & 可复用评测工具 -运行所有数据集的评测脚本。 -所有配置通过 experiment_config.yaml 管理,无需命令行参数。 +提供三层使用方式: +1. python run_all.py → 运行当前 experiment_config.yaml 下的所有数据集 +2. scripts/run_single_model.py → 指定 model yaml,运行所有数据集 +3. scripts/run_single_dataset.py → 指定 dataset yaml,在所有 model 上运行 """ +import os +import shutil import subprocess import sys from pathlib import Path +from typing import Dict, List, Optional +import yaml -# 数据集列表 DATASETS = [ "ToMBench", "Tomato", "ToMQA", + "ToMi", ] +EXPERIMENT_CONFIG = Path("experiment_config.yaml") +MODEL_CONFIGS_DIR = Path("experiment_configs") -def run_dataset(dataset: str) -> bool: - """运行指定数据集的评测 - Args: - dataset: 数据集名称 +# --------------------------------------------------------------------------- +# 可复用工具函数 +# --------------------------------------------------------------------------- + + +def apply_config(yaml_path: str) -> None: + """将指定 yaml 复制为 experiment_config.yaml 供各 task run.py 读取。""" + src = Path(yaml_path) + if not src.exists(): + raise FileNotFoundError(f"配置文件不存在: {src}") + shutil.copy2(src, EXPERIMENT_CONFIG) + print(f"[config] {src} → {EXPERIMENT_CONFIG}") + + +def discover_model_configs(configs_dir: str = str(MODEL_CONFIGS_DIR)) -> List[Path]: + """发现 experiment_configs/ 下所有 model yaml(按文件名排序)。""" + d = Path(configs_dir) + if not d.is_dir(): + raise FileNotFoundError(f"模型配置目录不存在: {d}") + yamls = sorted(p for p in d.iterdir() if p.suffix in (".yaml", ".yml") and p.is_file()) + if not yamls: + raise RuntimeError(f"{d} 中没有找到任何 yaml 配置文件") + return yamls + + +def get_model_name(yaml_path: Path) -> str: + """从 experiment config yaml 中提取 model_name。""" + with open(yaml_path, encoding="utf-8") as f: + cfg = yaml.safe_load(f) + return cfg.get("llm", {}).get("model_name", yaml_path.stem) + - Returns: - 是否成功 - """ +def get_dataset_name(yaml_path: str) -> str: + """从 dataset config yaml 中提取 dataset 名称。""" + with open(yaml_path, encoding="utf-8") as f: + cfg = yaml.safe_load(f) + name = cfg.get("dataset") + if not name: + raise ValueError(f"yaml 中缺少 'dataset' 字段: {yaml_path}") + return name + + +def run_dataset(dataset: str) -> bool: + """运行指定数据集的评测脚本。""" + project_root = Path(__file__).resolve().parent run_script = Path(f"tasks/{dataset}/run.py") - if not run_script.exists(): + if not (project_root / run_script).exists(): print(f"[{dataset}] run.py not found, skipping.") return False @@ -34,11 +79,16 @@ def run_dataset(dataset: str) -> bool: print(f"Running: {dataset}") print(f"{'='*60}") + env = os.environ.copy() + env["PYTHONPATH"] = str(project_root) + os.pathsep + env.get("PYTHONPATH", "") + try: - result = subprocess.run( + subprocess.run( [sys.executable, str(run_script)], check=True, capture_output=False, + cwd=str(project_root), + env=env, ) return True except subprocess.CalledProcessError as e: @@ -52,14 +102,31 @@ def run_dataset(dataset: str) -> bool: return False +def run_datasets(datasets: Optional[List[str]] = None) -> Dict[str, bool]: + """依次运行指定数据集列表(默认全部),返回 {dataset: success}。""" + if datasets is None: + datasets = DATASETS + results = {} + for ds in datasets: + results[ds] = run_dataset(ds) + return results + + +# --------------------------------------------------------------------------- +# 默认入口:运行所有数据集 +# --------------------------------------------------------------------------- + + def main(): - for dataset in DATASETS: - run_dataset(dataset) + results = run_datasets() print(f"\n{'='*60}") print("All datasets completed.") + for ds, ok in results.items(): + status = "OK" if ok else "FAILED" + print(f" {ds}: {status}") print(f"{'='*60}") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/judge.py b/src/judge.py new file mode 100644 index 0000000..ad8c9bd --- /dev/null +++ b/src/judge.py @@ -0,0 +1,118 @@ +"""LLM 语义判断模块 + +当结构化输出提取失败(max_retry 耗尽)时,使用 LLM 判断模型的原始回答 +与标准答案在语义上是否一致。 +""" + +import logging +import re +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel +from tqdm import tqdm + +from src.llm import LLMClient + +logger = logging.getLogger(__name__) + +JUDGE_PROMPT_TEMPLATE = """\ +You are an impartial judge. Given a model's response and a gold (correct) answer, \ +determine whether the model's response contains an answer that is semantically \ +equivalent to the gold answer. + +Focus ONLY on whether the final answer matches in meaning. Ignore formatting, \ +extra explanation, or reasoning traces. + +## Model Response +{raw_response} + +## Gold Answer +{gold_answer} + +## Question (for context) +{question} + +Does the model's response contain an answer semantically equivalent to the gold answer?\ +""" + + +class JudgeVerdict(BaseModel): + """LLM judge 输出 schema — 布尔值约束""" + is_correct: bool + + +def _strip_think_tags(text: str) -> str: + """Remove ... blocks so the judge sees only the final answer.""" + return re.sub(r"[\s\S]*?", "", text).strip() + + +def judge_single( + client: LLMClient, + raw_response: str, + gold_answer: str, + question: str = "", +) -> bool: + """判断单条原始回答是否语义等价于标准答案。 + + Args: + client: 用于 judge 的 LLMClient(通常低温度) + raw_response: 被评测模型的完整原始输出 + gold_answer: 标准答案 + question: 原始问题(提供上下文,可为空) + + Returns: + True 表示语义一致 + """ + cleaned = _strip_think_tags(raw_response) + prompt = JUDGE_PROMPT_TEMPLATE.format( + raw_response=cleaned, + gold_answer=gold_answer, + question=question, + ) + result = client.generate_structure(prompt, JudgeVerdict, max_retry=3) + return getattr(result, "is_correct", False) + + +def batch_judge( + client: LLMClient, + items: List[Dict[str, Any]], +) -> List[bool]: + """批量语义判断。 + + Args: + client: judge LLMClient + items: 每个元素是 {"raw_response": str, "gold_answer": str, "question": str} + + Returns: + 与 items 等长的布尔列表 + """ + if not items: + return [] + + with ThreadPoolExecutor(client.max_workers) as executor: + futures = [ + executor.submit( + judge_single, + client, + item["raw_response"], + item["gold_answer"], + item.get("question", ""), + ) + for item in items + ] + + results = [] + for future in tqdm( + futures, + total=len(futures), + desc="LLM Judge", + miniters=100, + ): + try: + results.append(future.result()) + except Exception: + logger.warning("[Judge] single judge call failed, treating as incorrect") + results.append(False) + + return results diff --git a/src/llm/__init__.py b/src/llm/__init__.py index b1fdb82..0693070 100644 --- a/src/llm/__init__.py +++ b/src/llm/__init__.py @@ -6,9 +6,10 @@ - 结构化输出: generate_structure(), batch_generate_structure() """ -from .client import LLMClient, LLMUsage +from .client import LLMClient, LLMUsage, StructuredResult __all__ = [ "LLMClient", "LLMUsage", + "StructuredResult", ] diff --git a/src/llm/client.py b/src/llm/client.py index c8cdbfc..fe2cbf3 100644 --- a/src/llm/client.py +++ b/src/llm/client.py @@ -39,6 +39,27 @@ class LLMUsage: latency: float = 0.0 +@dataclass +class StructuredResult: + """Wraps a structured generation result with raw output metadata. + + Provides backward-compatible access to the parsed Pydantic fields + (e.g. ``r.answer``) while also exposing the raw LLM response for + bad-case analysis and LLM-judge fallback. + """ + parsed: BaseModel + raw_response: str = "" + reasoning_content: str = "" + extraction_success: bool = True + + @property + def answer(self): + return getattr(self.parsed, "answer", "") + + def __getattr__(self, name: str): + return getattr(self.parsed, name) + + # --------------------------------------------------------------------------- # LLM Client # --------------------------------------------------------------------------- @@ -271,8 +292,8 @@ def generate_structure( prompt: str, response_object: Type[BaseModel], max_retry: int = 5, - ) -> BaseModel: - """调用 LLM,返回 Pydantic 对象(自动适配不同模型)。 + ) -> "StructuredResult": + """调用 LLM,返回 StructuredResult(自动适配不同模型)。 两阶段降级策略: 1. 首选:chat.completions.parse() - 直接返回 Pydantic 对象,最佳体验 @@ -284,13 +305,12 @@ def generate_structure( max_retry: 最大重试次数 Returns: - response_object 的实例,失败时返回空实例 + StructuredResult, extraction_success=False when all retries exhausted """ # 首次检测:尝试使用 parse API if self._parse_supported is None: with self._parse_lock: if self._parse_supported is None: - # 尝试一次,成功则标记支持,失败则不支持 try: result = self._generate_with_parse(prompt, response_object, max_retry=1) self._parse_supported = True @@ -298,9 +318,7 @@ def generate_structure( except Exception: self._parse_supported = False logging.warning(f"[LLM] Model {self.model} parse API failed, switching to JSON object mode") - # 继续使用降级模式 - # 根据检测结果选择模式 if self._parse_supported: return self._generate_with_parse(prompt, response_object, max_retry) else: @@ -311,12 +329,15 @@ def _generate_with_parse( prompt: str, response_object: Type[BaseModel], max_retry: int = 5, - ) -> BaseModel: + ) -> "StructuredResult": """使用 parse API 的原生结构化输出。""" extra_body: Dict[str, Any] = {"top_k": self.top_k} if not self.enable_thinking: extra_body["chat_template_kwargs"] = {"enable_thinking": False} + last_raw = "" + last_reasoning = "" + for attempt in range(max_retry): try: start = time.time() @@ -336,30 +357,45 @@ def _generate_with_parse( usage.completion_tokens = response.usage.completion_tokens usage.total_tokens = response.usage.total_tokens - result = response.choices[0].message.parsed + msg = response.choices[0].message + last_raw = msg.content or "" + last_reasoning = ( + getattr(msg, "reasoning_content", "") + or getattr(msg, "reasoning", "") + or "" + ) + + result = msg.parsed self._track_usage(usage, success=True) - return result + return StructuredResult( + parsed=result, + raw_response=last_raw, + reasoning_content=last_reasoning, + extraction_success=True, + ) except Exception as e: - import traceback logging.warning(f"[LLM] parse mode attempt {attempt + 1}") logging.error(f"[LLM] parse mode all {max_retry} attempts exhausted") self._track_usage(LLMUsage(), success=False) - return response_object.model_construct() + return StructuredResult( + parsed=response_object.model_construct(), + raw_response=last_raw, + reasoning_content=last_reasoning, + extraction_success=False, + ) def _generate_with_json_object( self, prompt: str, response_object: Type[BaseModel], max_retry: int = 5, - ) -> BaseModel: + ) -> "StructuredResult": """降级模式:使用 json_object response_format + prompt 引导 + 解析验证。""" import json - # 构建 schema 描述 schema_desc = self._format_schema_for_prompt(response_object) - # 增强提示词 enhanced_prompt = f"""{prompt} --- @@ -373,6 +409,9 @@ def _generate_with_json_object( if not self.enable_thinking: extra_body["chat_template_kwargs"] = {"enable_thinking": False} + last_raw = "" + last_reasoning = "" + for attempt in range(max_retry): try: start = time.time() @@ -393,24 +432,40 @@ def _generate_with_json_object( usage.completion_tokens = response.usage.completion_tokens usage.total_tokens = response.usage.total_tokens - content = response.choices[0].message.content or "" - # 提取 JSON + msg = response.choices[0].message + content = msg.content or "" + reasoning = ( + getattr(msg, "reasoning_content", "") + or getattr(msg, "reasoning", "") + or "" + ) + last_raw = content + last_reasoning = reasoning + json_data = self._extract_json(content) if json_data is None: raise ValueError(f"Failed to extract valid JSON: {content[:200]}") - # 用 Pydantic 验证(不符合就重试) result = response_object.model_validate(json_data) self._track_usage(usage, success=True) - return result + return StructuredResult( + parsed=result, + raw_response=content, + reasoning_content=reasoning, + extraction_success=True, + ) except Exception as e: - import traceback logging.warning(f"[LLM] json_object mode attempt {attempt + 1}") logging.error(f"[LLM] json_object mode all {max_retry} attempts exhausted") self._track_usage(LLMUsage(), success=False) - return response_object.model_construct() + return StructuredResult( + parsed=response_object.model_construct(), + raw_response=last_raw, + reasoning_content=last_reasoning, + extraction_success=False, + ) def _extract_json(self, text: str) -> Optional[Dict[str, Any]]: """从文本中提取 JSON。""" @@ -471,9 +526,9 @@ def batch_generate_structure( self, prompts: List[str], response_object: Type[BaseModel], - ) -> List[BaseModel]: + ) -> List["StructuredResult"]: """ - 批量调用 LLM,返回 Pydantic 对象列表。 + 批量调用 LLM,返回 StructuredResult 列表。 支持并行调用。 @@ -482,7 +537,7 @@ def batch_generate_structure( response_object: 必须提供的 Pydantic 模型类 Returns: - response_object 实例的列表 + StructuredResult 实例的列表 """ with ThreadPoolExecutor(self.max_workers) as executor: futures = [ diff --git a/src/runner.py b/src/runner.py index c392dc2..cf743bf 100644 --- a/src/runner.py +++ b/src/runner.py @@ -6,12 +6,12 @@ import os import sys from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import yaml from src.dataloader import load_dataset -from src.llm import LLMClient +from src.llm import LLMClient, StructuredResult import logging # 将日志级别设置为 WARNING 或更高 logging.getLogger("urllib3").setLevel(logging.WARNING) @@ -65,13 +65,17 @@ def load_experiment_config(config_path: str) -> Dict[str, Any]: """ with open(config_path, encoding="utf-8") as f: config = yaml.safe_load(f) + judge_block = config.get("judge", {}) + badcase_block = config.get("badcase", {}) return { "llm_config": config.get("llm", {}), "repeats": config.get("repeats", 1), "max_samples": config.get("max_samples", 0), "datasets_path": config.get("datasets_path", "datasets"), "results_path": config.get("results_path", "results"), - "judge_config": config.get("judge", {}), # 覆盖数据集的 judge 配置 + "judge_config": judge_block, + "enable_llm_judge": judge_block.get("enable_llm_judge", False), + "badcase_enabled": badcase_block.get("enabled", False), } @@ -141,6 +145,8 @@ def save_common_results( metadata: Optional[Dict[str, Any]] = None, dataset_config: Optional[Dict[str, Any]] = None, experiment_config: Optional[Dict[str, Any]] = None, + badcases: Optional[List[Dict[str, Any]]] = None, + all_metrics_with_judge: Optional[List[Dict[str, Any]]] = None, ) -> Tuple[Path, Path, Path]: """保存评测结果 @@ -157,6 +163,8 @@ def save_common_results( metadata: 额外元数据(如 judge_model) dataset_config: 数据集配置字典(保存到 config.json) experiment_config: 实验配置字典(保存到 config.json,会过滤 api_key 和 api_url) + badcases: bad case 记录列表(保存到 badcases.jsonl) + all_metrics_with_judge: LLM judge 兜底后的 metrics 列表(保存到 metrics.json) Returns: (config_path, metrics_path, prediction_path) 元组 @@ -176,17 +184,14 @@ def save_common_results( "repeats": len(all_metrics), } - # 添加 dataset_config 内容(排除 schemas_module 等非 JSON 可序列化对象) if dataset_config: dataset_config_copy = dict(dataset_config) dataset_config_copy.pop("schemas_module", None) - dataset_config_copy.pop("schema", None) # schema 是类对象,不可序列化 + dataset_config_copy.pop("schema", None) config_data["dataset_config"] = dataset_config_copy - # 添加 experiment_config 内容(排除敏感信息) if experiment_config: experiment_config_copy = dict(experiment_config) - # 过滤敏感信息 if "llm_config" in experiment_config_copy: llm_config_copy = dict(experiment_config_copy["llm_config"]) llm_config_copy.pop("api_key", None) @@ -208,13 +213,18 @@ def save_common_results( encoding="utf-8", ) - # 2. 保存 metrics.json + # 2. 保存 metrics.json(含 strict 和可选的 judge 两套指标) avg_metrics = _compute_average_metrics(all_metrics) - metrics_data = { + metrics_data: Dict[str, Any] = { "avg_metrics": avg_metrics, "all_metrics": all_metrics, } + if all_metrics_with_judge: + avg_judge = _compute_average_metrics(all_metrics_with_judge) + metrics_data["avg_metrics_with_judge"] = avg_judge + metrics_data["all_metrics_with_judge"] = all_metrics_with_judge + metrics_path = output_dir / "metrics.json" metrics_path.write_text( json.dumps(metrics_data, ensure_ascii=False, indent=2), @@ -235,6 +245,14 @@ def save_common_results( } f.write(json.dumps(record, ensure_ascii=False) + "\n") + # 4. 保存 badcases.jsonl(可选) + if badcases: + badcases_path = output_dir / "badcases.jsonl" + with open(badcases_path, "w", encoding="utf-8") as f: + for bc in badcases: + f.write(json.dumps(bc, ensure_ascii=False) + "\n") + print(f" - badcases.jsonl ({len(badcases)} bad cases)") + print(f"Results saved to: {output_dir}") print(f" - config.json") print(f" - metrics.json") @@ -290,3 +308,95 @@ def load_and_limit_data( random.seed(seed) data = random.sample(data, min(max_samples, len(data))) return data + + +# --------------------------------------------------------------------------- +# Bad-case collection & LLM-judge helpers +# --------------------------------------------------------------------------- + + +def collect_badcases( + results: List[StructuredResult], + predictions: List[str], + gold_answers: List[str], + prompts: List[str], + dataset_name: str, + is_correct_fn: Callable[[str, str], bool], + repeat_idx: int = 0, + judge_verdicts: Optional[List[Optional[bool]]] = None, +) -> List[Dict[str, Any]]: + """收集 bad case 记录。 + + Args: + results: StructuredResult 列表 + predictions: 提取到的预测答案列表 + gold_answers: 标准答案列表 + prompts: 输入 prompt 列表 + dataset_name: 数据集名称 + is_correct_fn: (prediction, gold) -> bool 的判定函数 + repeat_idx: 当前 repeat 索引 + judge_verdicts: LLM judge 判定结果(与 results 等长,未启用时为 None) + + Returns: + bad case 字典列表 + """ + badcases: List[Dict[str, Any]] = [] + + for i, (r, pred, gold, prompt) in enumerate( + zip(results, predictions, gold_answers, prompts) + ): + if not r.extraction_success: + error_type = "extraction_failed" + elif not is_correct_fn(pred, gold): + error_type = "wrong_answer" + else: + continue + + jv = judge_verdicts[i] if judge_verdicts is not None else None + badcases.append({ + "repeat": repeat_idx, + "sample_idx": i, + "dataset": dataset_name, + "error_type": error_type, + "prompt": prompt, + "raw_response": r.raw_response, + "reasoning_content": r.reasoning_content, + "prediction": pred, + "gold_answer": gold, + "judge_result": jv, + }) + + return badcases + + +def build_corrected_predictions( + predictions: List[str], + results: List[StructuredResult], + judge_verdicts: List[bool], + gold_answers: List[str], +) -> List[str]: + """构建 LLM judge 兜底后的预测列表。 + + 对于 extraction 成功的样本,保持原始预测不变; + 对于 extraction 失败但 judge 判定语义正确的样本,替换为 gold_answer; + 其余保持原值(空字符串)。 + + Args: + predictions: 原始预测列表 + results: StructuredResult 列表 + judge_verdicts: 与 extraction_failed 样本对应的 judge 结果 + gold_answers: 标准答案列表 + + Returns: + 修正后的预测列表(与 predictions 等长) + """ + corrected = list(predictions) + judge_idx = 0 + + for i, r in enumerate(results): + if not r.extraction_success: + if judge_idx < len(judge_verdicts) and judge_verdicts[judge_idx]: + corrected[i] = gold_answers[i] + judge_idx += 1 + + return corrected diff --git a/tables/SUMMARY.md b/tables/SUMMARY.md index 684fb10..f39e318 100644 --- a/tables/SUMMARY.md +++ b/tables/SUMMARY.md @@ -1,7 +1,7 @@ ## 总览表格:Accuracy -| 数据集 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | -|---|-:|-:|-:|-:|-:| -| ToMBench | 0.6337 | 0.4104 | 0.6242 | 0.6688 | 0.6014 | -| ToMQA | - | 0.3549 | 0.5825 | 0.5611 | - | -| Tomato | 0.6844 | 0.4034 | 0.6453 | 0.6769 | 0.5632 | +| 数据集 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | Qwen3-8B-SIPColdStart | deepseek-chat | deepseek-r1 | +| --- | -: | -: | -: | -: | -: | -: | -: | -: | +| ToMBench | 0.6337 | 0.4104 | 0.6242 | 0.6688 | 0.6014 | 0.6198 | 0.7698 | 0.8097 | +| ToMQA | - | 0.3549 | 0.5825 | 0.5611 | - | - | - | - | +| Tomato | 0.6844 | 0.4034 | 0.6453 | 0.6769 | 0.5632 | 0.5885 | 0.7938 | 0.7894 | diff --git "a/tables/ToMBench/\345\205\266\344\273\226\346\214\207\346\240\207.md" "b/tables/ToMBench/\345\205\266\344\273\226\346\214\207\346\240\207.md" index 81f426b..46c2a32 100644 --- "a/tables/ToMBench/\345\205\266\344\273\226\346\214\207\346\240\207.md" +++ "b/tables/ToMBench/\345\205\266\344\273\226\346\214\207\346\240\207.md" @@ -1,37 +1,37 @@ # ToMBench - 其他指标 -| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | -|---|---|---|---|---|---| -| by_ability.Belief: Beliefs based action/emotions | 0.6620 | 0.4390 | 0.6385 | 0.6761 | 0.6150 | -| by_ability.Belief: Content false beliefs | 0.7017 | 0.4417 | 0.5267 | 0.5850 | 0.5817 | -| by_ability.Belief: Content false beliefs Belief: Second-order beliefs | 0.8133 | 0.3333 | 0.7933 | 0.8467 | 0.6667 | -| by_ability.Belief: Identity false beliefs | 0.8000 | 0.4333 | 0.6667 | 0.7833 | 0.7167 | -| by_ability.Belief: Location false beliefs | 0.7533 | 0.3933 | 0.7667 | 0.8000 | 0.8083 | -| by_ability.Belief: Location false beliefs Belief: Second-order beliefs | 0.4967 | 0.1933 | 0.1500 | 0.3167 | 0.3567 | -| by_ability.Belief: Sequence false beliefs | 0.5300 | 0.4433 | 0.5633 | 0.6100 | 0.5133 | -| by_ability.Desire: Desire-action contradiction | 0.6667 | 0.4833 | 0.6917 | 0.7500 | 0.6250 | -| by_ability.Desire: Desires influence on actions | 0.4912 | 0.3596 | 0.5175 | 0.4912 | 0.4781 | -| by_ability.Desire: Desires influence on emotions (beliefs) | 0.4722 | 0.3611 | 0.5278 | 0.5139 | 0.5417 | -| by_ability.Desire: Discrepant desires | 0.4667 | 0.2500 | 0.3333 | 0.4167 | 0.4000 | -| by_ability.Desire: Multiple desires | 0.6500 | 0.3833 | 0.6000 | 0.7167 | 0.6500 | -| by_ability.Emotion: Atypical emotional reactions | 0.4567 | 0.2433 | 0.5033 | 0.5733 | 0.5667 | -| by_ability.Emotion: Discrepant emotions | 0.5667 | 0.3833 | 0.6250 | 0.7000 | 0.5750 | -| by_ability.Emotion: Emotion regulation | 0.4333 | 0.3000 | 0.3667 | 0.4500 | 0.5000 | -| by_ability.Emotion: Hidden emotions | 0.5792 | 0.3250 | 0.6500 | 0.7250 | 0.5333 | -| by_ability.Emotion: Mixed emotions | 0.3417 | 0.5000 | 0.4500 | 0.4083 | 0.3250 | -| by_ability.Emotion: Moral emotions | 0.7333 | 0.4750 | 0.7000 | 0.7250 | 0.7500 | -| by_ability.Emotion: Typical emotional reactions | 0.8400 | 0.6533 | 0.8567 | 0.8333 | 0.8933 | -| by_ability.Intention: Completion of failed actions | 0.3500 | 0.3167 | 0.4333 | 0.4167 | 0.4500 | -| by_ability.Intention: Discrepant intentions | 0.8250 | 0.5250 | 0.7250 | 0.8000 | 0.6750 | -| by_ability.Intention: Intentions explanations | 0.7423 | 0.3795 | 0.6974 | 0.7449 | 0.6256 | -| by_ability.Intention: Prediction of actions | 0.6667 | 0.2667 | 0.5833 | 0.4667 | 0.5000 | -| by_ability.Knowledge: Information-knowledge links | 0.3333 | 0.2350 | 0.3917 | 0.4117 | 0.3083 | -| by_ability.Knowledge: Knowledge-attention links | 0.4000 | 0.2333 | 0.1667 | 0.3667 | 0.3000 | -| by_ability.Knowledge: Knowledge-pretend play links | 0.1111 | 0.2778 | 0.2444 | 0.2444 | 0.1000 | -| by_ability.Knowledge: Percepts-knowledge links | 0.7000 | 0.4083 | 0.6000 | 0.6917 | 0.6333 | -| by_ability.Non-Literal Communication: Faux pas | 0.6310 | 0.5435 | 0.7095 | 0.7393 | 0.6601 | -| by_ability.Non-Literal Communication: Involuntary lies | 0.7698 | 0.4127 | 0.7619 | 0.8889 | 0.6905 | -| by_ability.Non-Literal Communication: Irony/Sarcasm | 0.5897 | 0.2821 | 0.7436 | 0.7436 | 0.6154 | -| by_ability.Non-literal communication: Egocentric lies | 0.9167 | 0.4750 | 0.9083 | 0.9333 | 0.8500 | -| by_ability.Non-literal communication: Humor | 0.9250 | 0.4250 | 0.8167 | 0.9250 | 0.7500 | -| by_ability.Non-literal communication: White lies | 0.9083 | 0.3333 | 0.8167 | 0.8667 | 0.7250 | +| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | Qwen3-8B-SIPColdStart | +| --- | --- | --- | --- | --- | --- | --- | +| by_ability.Belief: Beliefs based action/emotions | 0.6620 | 0.4390 | 0.6385 | 0.6761 | 0.6150 | 0.6197 | +| by_ability.Belief: Content false beliefs | 0.7017 | 0.4417 | 0.5267 | 0.5850 | 0.5817 | 0.6417 | +| by_ability.Belief: Content false beliefs Belief: Second-order beliefs | 0.8133 | 0.3333 | 0.7933 | 0.8467 | 0.6667 | 0.3667 | +| by_ability.Belief: Identity false beliefs | 0.8000 | 0.4333 | 0.6667 | 0.7833 | 0.7167 | 0.7167 | +| by_ability.Belief: Location false beliefs | 0.7533 | 0.3933 | 0.7667 | 0.8000 | 0.8083 | 0.9333 | +| by_ability.Belief: Location false beliefs Belief: Second-order beliefs | 0.4967 | 0.1933 | 0.1500 | 0.3167 | 0.3567 | 0.3733 | +| by_ability.Belief: Sequence false beliefs | 0.5300 | 0.4433 | 0.5633 | 0.6100 | 0.5133 | 0.5100 | +| by_ability.Desire: Desire-action contradiction | 0.6667 | 0.4833 | 0.6917 | 0.7500 | 0.6250 | 0.6833 | +| by_ability.Desire: Desires influence on actions | 0.4912 | 0.3596 | 0.5175 | 0.4912 | 0.4781 | 0.4254 | +| by_ability.Desire: Desires influence on emotions (beliefs) | 0.4722 | 0.3611 | 0.5278 | 0.5139 | 0.5417 | 0.4583 | +| by_ability.Desire: Discrepant desires | 0.4667 | 0.2500 | 0.3333 | 0.4167 | 0.4000 | 0.3833 | +| by_ability.Desire: Multiple desires | 0.6500 | 0.3833 | 0.6000 | 0.7167 | 0.6500 | 0.6833 | +| by_ability.Emotion: Atypical emotional reactions | 0.4567 | 0.2433 | 0.5033 | 0.5733 | 0.5667 | 0.5000 | +| by_ability.Emotion: Discrepant emotions | 0.5667 | 0.3833 | 0.6250 | 0.7000 | 0.5750 | 0.5167 | +| by_ability.Emotion: Emotion regulation | 0.4333 | 0.3000 | 0.3667 | 0.4500 | 0.5000 | 0.3167 | +| by_ability.Emotion: Hidden emotions | 0.5792 | 0.3250 | 0.6500 | 0.7250 | 0.5333 | 0.6458 | +| by_ability.Emotion: Mixed emotions | 0.3417 | 0.5000 | 0.4500 | 0.4083 | 0.3250 | 0.4583 | +| by_ability.Emotion: Moral emotions | 0.7333 | 0.4750 | 0.7000 | 0.7250 | 0.7500 | 0.6833 | +| by_ability.Emotion: Typical emotional reactions | 0.8400 | 0.6533 | 0.8567 | 0.8333 | 0.8933 | 0.8367 | +| by_ability.Intention: Completion of failed actions | 0.3500 | 0.3167 | 0.4333 | 0.4167 | 0.4500 | 0.3500 | +| by_ability.Intention: Discrepant intentions | 0.8250 | 0.5250 | 0.7250 | 0.8000 | 0.6750 | 0.7167 | +| by_ability.Intention: Intentions explanations | 0.7423 | 0.3795 | 0.6974 | 0.7449 | 0.6256 | 0.6667 | +| by_ability.Intention: Prediction of actions | 0.6667 | 0.2667 | 0.5833 | 0.4667 | 0.5000 | 0.5667 | +| by_ability.Knowledge: Information-knowledge links | 0.3333 | 0.2350 | 0.3917 | 0.4117 | 0.3083 | 0.3550 | +| by_ability.Knowledge: Knowledge-attention links | 0.4000 | 0.2333 | 0.1667 | 0.3667 | 0.3000 | 0.2667 | +| by_ability.Knowledge: Knowledge-pretend play links | 0.1111 | 0.2778 | 0.2444 | 0.2444 | 0.1000 | 0.0556 | +| by_ability.Knowledge: Percepts-knowledge links | 0.7000 | 0.4083 | 0.6000 | 0.6917 | 0.6333 | 0.7667 | +| by_ability.Non-Literal Communication: Faux pas | 0.6310 | 0.5435 | 0.7095 | 0.7393 | 0.6601 | 0.6911 | +| by_ability.Non-Literal Communication: Involuntary lies | 0.7698 | 0.4127 | 0.7619 | 0.8889 | 0.6905 | 0.7143 | +| by_ability.Non-Literal Communication: Irony/Sarcasm | 0.5897 | 0.2821 | 0.7436 | 0.7436 | 0.6154 | 0.8077 | +| by_ability.Non-literal communication: Egocentric lies | 0.9167 | 0.4750 | 0.9083 | 0.9333 | 0.8500 | 0.8417 | +| by_ability.Non-literal communication: Humor | 0.9250 | 0.4250 | 0.8167 | 0.9250 | 0.7500 | 0.8583 | +| by_ability.Non-literal communication: White lies | 0.9083 | 0.3333 | 0.8167 | 0.8667 | 0.7250 | 0.7750 | diff --git "a/tables/ToMBench/\345\237\272\347\241\200\346\214\207\346\240\207.md" "b/tables/ToMBench/\345\237\272\347\241\200\346\214\207\346\240\207.md" index ed2e5d2..a181a19 100644 --- "a/tables/ToMBench/\345\237\272\347\241\200\346\214\207\346\240\207.md" +++ "b/tables/ToMBench/\345\237\272\347\241\200\346\214\207\346\240\207.md" @@ -1,7 +1,7 @@ # ToMBench - 基础指标 -| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | -|---|---|---|---|---|---| -| accuracy | 0.6337 | 0.4104 | 0.6242 | 0.6688 | 0.6014 | -| correct | 1812.3333 | 1173.6667 | 1785.3333 | 1912.6667 | 1720 | -| total | 2860 | 2860 | 2860 | 2860 | 2860 | +| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | Qwen3-8B-SIPColdStart | +| --- | --- | --- | --- | --- | --- | --- | +| accuracy | 0.6337 | 0.4104 | 0.6242 | 0.6688 | 0.6014 | 0.6198 | +| correct | 1812.3333 | 1173.6667 | 1785.3333 | 1912.6667 | 1720 | 1772.6667 | +| total | 2860 | 2860 | 2860 | 2860 | 2860 | 2860 | diff --git "a/tables/ToMQA/\345\205\266\344\273\226\346\214\207\346\240\207.md" "b/tables/ToMQA/\345\205\266\344\273\226\346\214\207\346\240\207.md" index ac0c22c..6874294 100644 --- "a/tables/ToMQA/\345\205\266\344\273\226\346\214\207\346\240\207.md" +++ "b/tables/ToMQA/\345\205\266\344\273\226\346\214\207\346\240\207.md" @@ -1,10 +1,10 @@ # ToMQA - 其他指标 -| 指标 \ 模型 | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | -|---|---|---|---| -| by_dimension.first_order_belief | 0.4174 | 0.6253 | 0.6226 | -| by_dimension.second_order_belief | 0.2923 | 0.5397 | 0.4997 | -| by_difficulty.easy | 0.3549 | 0.5825 | 0.5611 | -| by_order.1 | 0.4174 | 0.6253 | 0.6226 | -| by_order.2 | 0.2923 | 0.5397 | 0.4997 | -| by_task_type.qa | 0.3549 | 0.5825 | 0.5611 | +| 指标 \ 模型 | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | Qwen3-8B-SIPColdStart | +| --- | --- | --- | --- | --- | +| by_dimension.first_order_belief | 0.4174 | 0.6253 | 0.6226 | - | +| by_dimension.second_order_belief | 0.2923 | 0.5397 | 0.4997 | - | +| by_difficulty.easy | 0.3549 | 0.5825 | 0.5611 | - | +| by_order.1 | 0.4174 | 0.6253 | 0.6226 | - | +| by_order.2 | 0.2923 | 0.5397 | 0.4997 | - | +| by_task_type.qa | 0.3549 | 0.5825 | 0.5611 | - | diff --git "a/tables/ToMQA/\345\237\272\347\241\200\346\214\207\346\240\207.md" "b/tables/ToMQA/\345\237\272\347\241\200\346\214\207\346\240\207.md" index 1e67d53..e472092 100644 --- "a/tables/ToMQA/\345\237\272\347\241\200\346\214\207\346\240\207.md" +++ "b/tables/ToMQA/\345\237\272\347\241\200\346\214\207\346\240\207.md" @@ -1,7 +1,7 @@ # ToMQA - 基础指标 -| 指标 \ 模型 | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | -|---|---|---|---| -| accuracy | 0.3549 | 0.5825 | 0.5611 | -| correct | 4258.3333 | 6990.3333 | 6733.3333 | -| total | 12000 | 12000 | 12000 | +| 指标 \ 模型 | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | Qwen3-8B-SIPColdStart | +| --- | --- | --- | --- | --- | +| accuracy | 0.3549 | 0.5825 | 0.5611 | - | +| correct | 4258.3333 | 6990.3333 | 6733.3333 | - | +| total | 12000 | 12000 | 12000 | - | diff --git "a/tables/Tomato/\345\205\266\344\273\226\346\214\207\346\240\207.md" "b/tables/Tomato/\345\205\266\344\273\226\346\214\207\346\240\207.md" index 11e181f..2b23db1 100644 --- "a/tables/Tomato/\345\205\266\344\273\226\346\214\207\346\240\207.md" +++ "b/tables/Tomato/\345\205\266\344\273\226\346\214\207\346\240\207.md" @@ -1,13 +1,13 @@ # Tomato - 其他指标 -| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | -|---|---|---|---|---|---| -| by_dimension_1.belief | 0.6564 | 0.3883 | 0.6032 | 0.6240 | 0.5362 | -| by_dimension_1.desire | 0.7431 | 0.4309 | 0.7047 | 0.7379 | 0.5997 | -| by_dimension_1.emotion | 0.7001 | 0.3939 | 0.6754 | 0.6995 | 0.5699 | -| by_dimension_1.intention | 0.6892 | 0.4093 | 0.6649 | 0.6958 | 0.5556 | -| by_dimension_1.knowledge | 0.6411 | 0.3959 | 0.5883 | 0.6362 | 0.5581 | -| by_dimension_2.first_order | 0.7473 | 0.4155 | 0.7013 | 0.7322 | 0.5904 | -| by_dimension_2.second_order | 0.6226 | 0.3914 | 0.5902 | 0.6226 | 0.5364 | -| by_dimension_3.__none__ | 0.7106 | 0.4207 | 0.6693 | 0.7030 | 0.5854 | -| by_dimension_3.false_belief | 0.5347 | 0.3044 | 0.5083 | 0.5281 | 0.4363 | +| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | Qwen3-8B-SIPColdStart | +| --- | --- | --- | --- | --- | --- | --- | +| by_dimension_1.belief | 0.6564 | 0.3883 | 0.6032 | 0.6240 | 0.5362 | 0.5407 | +| by_dimension_1.desire | 0.7431 | 0.4309 | 0.7047 | 0.7379 | 0.5997 | 0.6263 | +| by_dimension_1.emotion | 0.7001 | 0.3939 | 0.6754 | 0.6995 | 0.5699 | 0.6195 | +| by_dimension_1.intention | 0.6892 | 0.4093 | 0.6649 | 0.6958 | 0.5556 | 0.6168 | +| by_dimension_1.knowledge | 0.6411 | 0.3959 | 0.5883 | 0.6362 | 0.5581 | 0.5469 | +| by_dimension_2.first_order | 0.7473 | 0.4155 | 0.7013 | 0.7322 | 0.5904 | 0.6486 | +| by_dimension_2.second_order | 0.6226 | 0.3914 | 0.5902 | 0.6226 | 0.5364 | 0.5294 | +| by_dimension_3.__none__ | 0.7106 | 0.4207 | 0.6693 | 0.7030 | 0.5854 | 0.6146 | +| by_dimension_3.false_belief | 0.5347 | 0.3044 | 0.5083 | 0.5281 | 0.4363 | 0.4396 | diff --git "a/tables/Tomato/\345\237\272\347\241\200\346\214\207\346\240\207.md" "b/tables/Tomato/\345\237\272\347\241\200\346\214\207\346\240\207.md" index 6579601..1101b47 100644 --- "a/tables/Tomato/\345\237\272\347\241\200\346\214\207\346\240\207.md" +++ "b/tables/Tomato/\345\237\272\347\241\200\346\214\207\346\240\207.md" @@ -1,7 +1,7 @@ # Tomato - 基础指标 -| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | -|---|---|---|---|---|---| -| accuracy | 0.6844 | 0.4034 | 0.6453 | 0.6769 | 0.5632 | -| correct | 3696.3333 | 2178.6667 | 3485 | 3656 | 3041.6667 | -| total | 5401 | 5401 | 5401 | 5401 | 5401 | +| 指标 \ 模型 | Meta-Llama-3.1-8B-Instruct | Qwen3-0.6B | Qwen3-4B | Qwen3-8B | gemma-3-4b-it | Qwen3-8B-SIPColdStart | +| --- | --- | --- | --- | --- | --- | --- | +| accuracy | 0.6844 | 0.4034 | 0.6453 | 0.6769 | 0.5632 | 0.5885 | +| correct | 3696.3333 | 2178.6667 | 3485 | 3656 | 3041.6667 | 3178.3333 | +| total | 5401 | 5401 | 5401 | 5401 | 5401 | 5401 | diff --git a/tasks/ToMBench/run.py b/tasks/ToMBench/run.py index 625dacd..32e7f8f 100644 --- a/tasks/ToMBench/run.py +++ b/tasks/ToMBench/run.py @@ -3,39 +3,40 @@ from pathlib import Path from typing import Any, Dict, List -# 添加父目录到路径以导入 src sys.path.insert(0, str(Path(__file__).parent.parent)) -from src.dataloader import load_dataset -from src.llm import LLMClient from src import runner +from src import judge as judge_module from ToMBench.prompts import get_template, build_prompt from ToMBench.metrics import compute_metrics import logging -# 彻底关闭 httpx 和 httpcore 的请求日志 logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) logging.getLogger("uvicorn.access").setLevel(logging.WARNING) + + +def _is_correct(pred: str, gold: str) -> bool: + return pred == gold + + def main(): - # 加载数据集配置 dataset_config = runner.load_dataset_config("tasks/ToMBench/config.yaml") - - # 加载实验配置 experiment_config = runner.load_experiment_config("experiment_config.yaml") schema = dataset_config["schema"] prompt_method = dataset_config["default_prompt"] - - # 获取 prompt 模板 template = get_template(prompt_method) - - # 创建 LLM 客户端 client = runner.create_llm_client(experiment_config["llm_config"]) - # 加载数据 + badcase_enabled = experiment_config["badcase_enabled"] + enable_judge = experiment_config["enable_llm_judge"] + judge_client = None + if enable_judge: + judge_client = runner.create_llm_client(experiment_config["judge_config"]) + data = runner.load_and_limit_data( subset=dataset_config["subset"], datasets_path=experiment_config["datasets_path"], @@ -46,17 +47,18 @@ def main(): print(f"Prompt method: {prompt_method}") print(f"Repeats: {experiment_config['repeats']}") - # 构建 prompts(每个 repeat 构建相同的 prompts) prompts = [build_prompt(template, row) for row in data] all_prompts = prompts * experiment_config["repeats"] - # 批量结构化推理 print(f"Running inference ({len(all_prompts)} prompts)...") results = client.batch_generate_structure(all_prompts, schema) - # 使用数据集的 metrics 函数计算 - all_predictions = [] - all_metrics = [] + gold_answers = [row['Answer']['Correct Answer'][0] for row in data] + + all_predictions: List[List[str]] = [] + all_metrics: List[Dict[str, Any]] = [] + all_metrics_with_judge: List[Dict[str, Any]] = [] + all_badcases: List[Dict[str, Any]] = [] for i in range(experiment_config["repeats"]): start = i * len(data) @@ -65,13 +67,55 @@ def main(): predictions = [r.answer for r in repeat_results] all_predictions.append(predictions) - # 调用数据集的 metrics 函数 metrics = compute_metrics(predictions, data) all_metrics.append(metrics) print(f"Run {i+1}: Accuracy={metrics['accuracy']:.4f}, Correct={metrics['correct']}/{metrics['total']}") - # 保存结果 - gold_answers = [row['Answer']['Correct Answer'][0] for row in data] + # --- LLM Judge 兜底 --- + judge_verdicts = None + if judge_client: + failed_items = [ + { + "raw_response": r.raw_response, + "gold_answer": gold, + "question": prompt, + } + for r, gold, prompt in zip(repeat_results, gold_answers, prompts) + if not r.extraction_success + ] + if failed_items: + judge_results = judge_module.batch_judge(judge_client, failed_items) + judge_verdicts_full: List[bool] = [] + ji = 0 + for r in repeat_results: + if not r.extraction_success: + judge_verdicts_full.append(judge_results[ji]) + ji += 1 + else: + judge_verdicts_full.append(False) + judge_verdicts = judge_verdicts_full + + corrected = runner.build_corrected_predictions( + predictions, repeat_results, judge_results, gold_answers, + ) + metrics_j = compute_metrics(corrected, data) + all_metrics_with_judge.append(metrics_j) + print( + f" [Judge] Accuracy={metrics_j['accuracy']:.4f}, " + f"Recovered={metrics_j['correct'] - metrics['correct']}" + ) + else: + all_metrics_with_judge.append(metrics) + + # --- Bad case 收集 --- + if badcase_enabled: + bcs = runner.collect_badcases( + repeat_results, predictions, gold_answers, prompts, + dataset_config["dataset"], _is_correct, + repeat_idx=i, judge_verdicts=judge_verdicts, + ) + all_badcases.extend(bcs) + runner.save_common_results( dataset_name=dataset_config["dataset"], model=experiment_config["llm_config"]["model_name"], @@ -82,11 +126,12 @@ def main(): results_path=experiment_config["results_path"], dataset_config=dataset_config, experiment_config=experiment_config, + badcases=all_badcases if badcase_enabled else None, + all_metrics_with_judge=all_metrics_with_judge if enable_judge else None, ) - # 打印统计摘要 runner.print_summary_stats(all_metrics, experiment_config["repeats"], len(gold_answers)) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tasks/ToMQA/run.py b/tasks/ToMQA/run.py index 4161b17..ef680f3 100644 --- a/tasks/ToMQA/run.py +++ b/tasks/ToMQA/run.py @@ -1,17 +1,18 @@ """ToMQA 评测脚本(基于结构化输出)""" import sys from pathlib import Path +from typing import Any, Dict, List -# 添加父目录到路径以导入 src sys.path.insert(0, str(Path(__file__).parent.parent)) from src import runner +from src import judge as judge_module + from ToMQA.prompts import get_template, build_prompt -from ToMQA.metrics import compute_metrics +from ToMQA.metrics import compute_metrics, normalize_answer import logging -# 关闭不必要日志 logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) logging.getLogger("uvicorn.access").setLevel(logging.WARNING) @@ -34,23 +35,27 @@ def extract_gold_answers(data): return golds +def _is_correct(pred: str, gold: str) -> bool: + p = normalize_answer(pred) + g = normalize_answer(gold) + return bool(p) and p == g + + def main(): - # 加载数据集配置 dataset_config = runner.load_dataset_config("tasks/ToMQA/config.yaml") - - # 加载实验配置 experiment_config = runner.load_experiment_config("experiment_config.yaml") schema = dataset_config["schema"] prompt_method = dataset_config["default_prompt"] - - # 获取 prompt 模板 template = get_template(prompt_method) - - # 创建 LLM 客户端 client = runner.create_llm_client(experiment_config["llm_config"]) - # 加载数据 + badcase_enabled = experiment_config["badcase_enabled"] + enable_judge = experiment_config["enable_llm_judge"] + judge_client = None + if enable_judge: + judge_client = runner.create_llm_client(experiment_config["judge_config"]) + data = runner.load_and_limit_data( subset=dataset_config["subset"], datasets_path=experiment_config["datasets_path"], @@ -61,17 +66,18 @@ def main(): print(f"Prompt method: {prompt_method}") print(f"Repeats: {experiment_config['repeats']}") - # 构建 prompts prompts = [build_prompt(template, row) for row in data] all_prompts = prompts * experiment_config["repeats"] - # 批量结构化推理 print(f"Running inference ({len(all_prompts)} prompts)...") results = client.batch_generate_structure(all_prompts, schema) - # 计算 metrics - all_predictions = [] - all_metrics = [] + gold_answers = extract_gold_answers(data) + + all_predictions: List[List[str]] = [] + all_metrics: List[Dict[str, Any]] = [] + all_metrics_with_judge: List[Dict[str, Any]] = [] + all_badcases: List[Dict[str, Any]] = [] for i in range(experiment_config["repeats"]): start = i * len(data) @@ -87,8 +93,51 @@ def main(): f"Correct={metrics['correct']}/{metrics['total']}" ) - # 保存结果 - gold_answers = extract_gold_answers(data) + # --- LLM Judge 兜底 --- + judge_verdicts = None + if judge_client: + failed_items = [ + { + "raw_response": r.raw_response, + "gold_answer": gold, + "question": prompt, + } + for r, gold, prompt in zip(repeat_results, gold_answers, prompts) + if not r.extraction_success + ] + if failed_items: + judge_results = judge_module.batch_judge(judge_client, failed_items) + judge_verdicts_full: List[bool] = [] + ji = 0 + for r in repeat_results: + if not r.extraction_success: + judge_verdicts_full.append(judge_results[ji]) + ji += 1 + else: + judge_verdicts_full.append(False) + judge_verdicts = judge_verdicts_full + + corrected = runner.build_corrected_predictions( + predictions, repeat_results, judge_results, gold_answers, + ) + metrics_j = compute_metrics(corrected, data) + all_metrics_with_judge.append(metrics_j) + print( + f" [Judge] Accuracy={metrics_j['accuracy']:.4f}, " + f"Recovered={metrics_j['correct'] - metrics['correct']}" + ) + else: + all_metrics_with_judge.append(metrics) + + # --- Bad case 收集 --- + if badcase_enabled: + bcs = runner.collect_badcases( + repeat_results, predictions, gold_answers, prompts, + dataset_config["dataset"], _is_correct, + repeat_idx=i, judge_verdicts=judge_verdicts, + ) + all_badcases.extend(bcs) + runner.save_common_results( dataset_name=dataset_config["dataset"], model=experiment_config["llm_config"]["model_name"], @@ -99,9 +148,10 @@ def main(): results_path=experiment_config["results_path"], dataset_config=dataset_config, experiment_config=experiment_config, + badcases=all_badcases if badcase_enabled else None, + all_metrics_with_judge=all_metrics_with_judge if enable_judge else None, ) - # 打印统计摘要 runner.print_summary_stats(all_metrics, experiment_config["repeats"], len(gold_answers)) diff --git a/tasks/ToMi/metrics.py b/tasks/ToMi/metrics.py index 9a7f9da..aa469dd 100644 --- a/tasks/ToMi/metrics.py +++ b/tasks/ToMi/metrics.py @@ -11,14 +11,18 @@ def _normalize_word(text: Any) -> str: def compute_metrics(predictions: List[str], data: List[Dict[str, Any]]) -> Dict[str, Any]: """计算 ToMi 的 metrics(单词答案精确匹配)""" - gold_answers = [_normalize_word(row.get("output", "")) for row in data] + gold_answers = [] + for row in data: + correct = row.get("Answer", {}).get("Correct_Answer", []) + gold_answers.append(_normalize_word(correct[0]) if correct else "") + pred_answers = [_normalize_word(p) for p in predictions] - correct = sum(1 for p, g in zip(pred_answers, gold_answers) if p == g) - accuracy = correct / len(pred_answers) if pred_answers else 0 + correct_count = sum(1 for p, g in zip(pred_answers, gold_answers) if p == g) + accuracy = correct_count / len(pred_answers) if pred_answers else 0 return { "accuracy": accuracy, - "correct": correct, + "correct": correct_count, "total": len(pred_answers), } diff --git a/tasks/ToMi/prompts.py b/tasks/ToMi/prompts.py index 7b6c9fe..c3e1184 100644 --- a/tasks/ToMi/prompts.py +++ b/tasks/ToMi/prompts.py @@ -23,8 +23,9 @@ def build_prompt(template: str, row: Dict[str, Any]) -> str: """构建 prompt""" - story = row.get("instruction", "") - question = row.get("input", "") + story_info = row.get("Story", {}) if isinstance(row.get("Story"), dict) else {} + story = story_info.get("full_story", "") or "" + question = row.get("Question", "") or "" return template.format(story=story, question=question) diff --git a/tasks/ToMi/run.py b/tasks/ToMi/run.py index 5d96c26..f0c8f0d 100644 --- a/tasks/ToMi/run.py +++ b/tasks/ToMi/run.py @@ -3,44 +3,49 @@ from pathlib import Path from typing import Any, Dict, List -# 添加父目录到路径以导入 src sys.path.insert(0, str(Path(__file__).parent.parent)) from src import runner +from src import judge as judge_module from ToMi.prompts import get_template, build_prompt from ToMi.metrics import compute_metrics import logging -# 关闭不必要日志 logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) logging.getLogger("uvicorn.access").setLevel(logging.WARNING) def extract_gold_answers(data: List[Dict[str, Any]]) -> List[str]: - """提取标准答案。""" - return [str(row.get("output", "")).strip().lower() for row in data] + """提取标准答案(取 Answer.Correct_Answer 列表的第一个元素)。""" + answers = [] + for row in data: + correct = row.get("Answer", {}).get("Correct_Answer", []) + answers.append(str(correct[0]).strip().lower() if correct else "") + return answers + + +def _is_correct(pred: str, gold: str) -> bool: + return str(pred).strip().lower() == str(gold).strip().lower() def main(): - # 加载数据集配置 dataset_config = runner.load_dataset_config("tasks/ToMi/config.yaml") - - # 加载实验配置 experiment_config = runner.load_experiment_config("experiment_config.yaml") schema = dataset_config["schema"] prompt_method = dataset_config["default_prompt"] - - # 获取 prompt 模板 template = get_template(prompt_method) - - # 创建 LLM 客户端 client = runner.create_llm_client(experiment_config["llm_config"]) - # 加载数据 + badcase_enabled = experiment_config["badcase_enabled"] + enable_judge = experiment_config["enable_llm_judge"] + judge_client = None + if enable_judge: + judge_client = runner.create_llm_client(experiment_config["judge_config"]) + data = runner.load_and_limit_data( subset=dataset_config["subset"], datasets_path=experiment_config["datasets_path"], @@ -51,17 +56,19 @@ def main(): print(f"Prompt method: {prompt_method}") print(f"Repeats: {experiment_config['repeats']}") - # 构建 prompts(每个 repeat 构建相同的 prompts) prompts = [build_prompt(template, row) for row in data] all_prompts = prompts * experiment_config["repeats"] - # 批量结构化推理 print(f"Running inference ({len(all_prompts)} prompts)...") results = client.batch_generate_structure(all_prompts, schema) - # 计算 metrics - all_predictions = [] - all_metrics = [] + gold_answers = extract_gold_answers(data) + + all_predictions: List[List[str]] = [] + all_metrics: List[Dict[str, Any]] = [] + all_metrics_with_judge: List[Dict[str, Any]] = [] + all_badcases: List[Dict[str, Any]] = [] + for i in range(experiment_config["repeats"]): start = i * len(data) end = start + len(data) @@ -73,8 +80,51 @@ def main(): all_metrics.append(metrics) print(f"Run {i+1}: Accuracy={metrics['accuracy']:.4f}, Correct={metrics['correct']}/{metrics['total']}") - # 保存结果 - gold_answers = extract_gold_answers(data) + # --- LLM Judge 兜底 --- + judge_verdicts = None + if judge_client: + failed_items = [ + { + "raw_response": r.raw_response, + "gold_answer": gold, + "question": prompt, + } + for r, gold, prompt in zip(repeat_results, gold_answers, prompts) + if not r.extraction_success + ] + if failed_items: + judge_results = judge_module.batch_judge(judge_client, failed_items) + judge_verdicts_full: List[bool] = [] + ji = 0 + for r in repeat_results: + if not r.extraction_success: + judge_verdicts_full.append(judge_results[ji]) + ji += 1 + else: + judge_verdicts_full.append(False) + judge_verdicts = judge_verdicts_full + + corrected = runner.build_corrected_predictions( + predictions, repeat_results, judge_results, gold_answers, + ) + metrics_j = compute_metrics(corrected, data) + all_metrics_with_judge.append(metrics_j) + print( + f" [Judge] Accuracy={metrics_j['accuracy']:.4f}, " + f"Recovered={metrics_j['correct'] - metrics['correct']}" + ) + else: + all_metrics_with_judge.append(metrics) + + # --- Bad case 收集 --- + if badcase_enabled: + bcs = runner.collect_badcases( + repeat_results, predictions, gold_answers, prompts, + dataset_config["dataset"], _is_correct, + repeat_idx=i, judge_verdicts=judge_verdicts, + ) + all_badcases.extend(bcs) + runner.save_common_results( dataset_name=dataset_config["dataset"], model=experiment_config["llm_config"]["model_name"], @@ -85,9 +135,10 @@ def main(): results_path=experiment_config["results_path"], dataset_config=dataset_config, experiment_config=experiment_config, + badcases=all_badcases if badcase_enabled else None, + all_metrics_with_judge=all_metrics_with_judge if enable_judge else None, ) - # 打印统计摘要 runner.print_summary_stats(all_metrics, experiment_config["repeats"], len(gold_answers)) diff --git a/tasks/Tomato/run.py b/tasks/Tomato/run.py index 6194152..45cb2b1 100644 --- a/tasks/Tomato/run.py +++ b/tasks/Tomato/run.py @@ -11,8 +11,11 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) from src import runner +from src import judge as judge_module + from Tomato.prompts import get_template, build_prompt from Tomato.metrics import compute_metrics + logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) logging.getLogger("uvicorn.access").setLevel(logging.WARNING) @@ -97,6 +100,10 @@ def shuffle_mcq_options(mcq: Dict[str, Any], seed: int) -> Dict[str, Any]: return {**mcq, "original_choices": new_choices, "gold_letter": new_gold} +def _is_correct(pred: str, gold: str) -> bool: + return bool(pred) and pred == gold + + def main() -> None: dataset_config = runner.load_dataset_config("tasks/Tomato/config.yaml") experiment_config = runner.load_experiment_config("experiment_config.yaml") @@ -106,6 +113,12 @@ def main() -> None: template = get_template(prompt_method) client = runner.create_llm_client(experiment_config["llm_config"]) + badcase_enabled = experiment_config["badcase_enabled"] + enable_judge = experiment_config["enable_llm_judge"] + judge_client = None + if enable_judge: + judge_client = runner.create_llm_client(experiment_config["judge_config"]) + data = runner.load_and_limit_data( subset=dataset_config["subset"], datasets_path=experiment_config["datasets_path"], @@ -122,16 +135,21 @@ def main() -> None: all_prompts: List[str] = [] repeat_data: List[List[Dict[str, Any]]] = [] + repeat_prompts_list: List[List[str]] = [] for i in range(repeats): shuffled_rows: List[Dict[str, Any]] = [] + cur_prompts: List[str] = [] for j, row in enumerate(data): shuffled_mcq = shuffle_mcq_options(row["_mcq"], seed=42 * (i + 1) + j) shuffled_row = dict(row) shuffled_row["_mcq"] = shuffled_mcq shuffled_rows.append(shuffled_row) - all_prompts.append(build_prompt(template, shuffled_row)) + p = build_prompt(template, shuffled_row) + all_prompts.append(p) + cur_prompts.append(p) repeat_data.append(shuffled_rows) + repeat_prompts_list.append(cur_prompts) print(f"Running inference ({len(all_prompts)} prompts)...") results = client.batch_generate_structure(all_prompts, schema) @@ -139,7 +157,9 @@ def main() -> None: n = len(data) all_predictions: List[List[str]] = [] all_metrics: List[Dict[str, Any]] = [] + all_metrics_with_judge: List[Dict[str, Any]] = [] all_gold: List[List[str]] = [] + all_badcases: List[Dict[str, Any]] = [] for i in range(repeats): start = i * n @@ -149,11 +169,59 @@ def main() -> None: predictions = [r.answer for r in repeat_results] all_predictions.append(predictions) + repeat_gold = [row["_mcq"]["gold_letter"] for row in rows] + repeat_pr = repeat_prompts_list[i] + metrics = compute_metrics(predictions, rows) all_metrics.append(metrics) - all_gold.append([row["_mcq"]["gold_letter"] for row in rows]) + all_gold.append(repeat_gold) print(f"Run {i+1}: Accuracy={metrics['accuracy']:.4f}, Correct={metrics['correct']}/{metrics['total']}") + # --- LLM Judge 兜底 --- + judge_verdicts = None + if judge_client: + failed_items = [ + { + "raw_response": r.raw_response, + "gold_answer": gold, + "question": prompt, + } + for r, gold, prompt in zip(repeat_results, repeat_gold, repeat_pr) + if not r.extraction_success + ] + if failed_items: + judge_results = judge_module.batch_judge(judge_client, failed_items) + judge_verdicts_full: List[bool] = [] + ji = 0 + for r in repeat_results: + if not r.extraction_success: + judge_verdicts_full.append(judge_results[ji]) + ji += 1 + else: + judge_verdicts_full.append(False) + judge_verdicts = judge_verdicts_full + + corrected = runner.build_corrected_predictions( + predictions, repeat_results, judge_results, repeat_gold, + ) + metrics_j = compute_metrics(corrected, rows) + all_metrics_with_judge.append(metrics_j) + print( + f" [Judge] Accuracy={metrics_j['accuracy']:.4f}, " + f"Recovered={metrics_j['correct'] - metrics['correct']}" + ) + else: + all_metrics_with_judge.append(metrics) + + # --- Bad case 收集 --- + if badcase_enabled: + bcs = runner.collect_badcases( + repeat_results, predictions, repeat_gold, repeat_pr, + dataset_config["dataset"], _is_correct, + repeat_idx=i, judge_verdicts=judge_verdicts, + ) + all_badcases.extend(bcs) + runner.save_common_results( dataset_name=dataset_config["dataset"], model=experiment_config["llm_config"]["model_name"], @@ -162,6 +230,10 @@ def main() -> None: gold_answers=all_gold, all_metrics=all_metrics, results_path=experiment_config["results_path"], + dataset_config=dataset_config, + experiment_config=experiment_config, + badcases=all_badcases if badcase_enabled else None, + all_metrics_with_judge=all_metrics_with_judge if enable_judge else None, ) runner.print_summary_stats(all_metrics, repeats, n)