基于贪吃蛇游戏,创建行为识别 + 行为正确性检测的 DEMO,用于生成带标注的视频素材数据。
整体 Pipeline 分为两条可选路径:纯网格路径(基于 scene 坐标)和 视觉路径(YOLO 检测 + 跟踪)。
flowchart TB
A["1. 数据生成<br/>data_generator, batch JSON"]
B["2. 渲染与导出<br/>render_and_export, 640×640 图像+YOLO label"]
C["3. 目标检测+跟踪<br/>YOLO+ByteTrack 可选"]
A --> B --> C
A -->|纯网格路径| D
B -->|视觉路径| D
C --> D
D["4. 序列构建 run_track_and_prepare<br/>每蛇每帧18维+蛇头前方格(4类), 输入label或YOLO track"]
D --> E
E["5. 行为正确性 train_behavior 双向LSTM+注意力<br/>correct/incorrect + reason 共7类"]
E --> F
F["6. 实战演示 demo_video<br/>对局渲染, YOLO框+行为标注, 输出MP4"]
| 阶段 | 脚本 | 输入 | 输出 |
|---|---|---|---|
| 1. 数据生成 | data_generator.py |
随机种子、规则参数 | batches/batch_*.json |
| 2. 渲染导出 | render_and_export.py |
batch JSON | dataset/ (640×640 images + labels + metadata) |
| 3. 目标检测 | YOLO train / track |
渲染图像 | 检测框 + track_id |
| 4. 序列构建 | run_track_and_prepare.py |
dataset(含 train/val) | sequences/track_sequences.json(每条序列带 split,与 dataset 一致) |
| 5. 行为训练 | train_behavior.py |
sequences(按 split 划分 train/val)或 grid |
checkpoints/behavior/best.pt(best 按 7类(P+R+AP)均值+阈值搜索选取) |
| 6. 视频演示 | demo_video.py |
batch + 模型权重 | 带标注的 MP4 视频 |
特征约束:仅使用 YOLO 能从图像检测到的信息(head/food/x2 位置),grid 与 track 两种路径完全一致。
每帧每蛇连续特征(18 维):[head_x, head_y, vel_x, vel_y, food_x, food_y, x2_x, x2_y, has_x2, dist_to_food, dist_to_x2, moving_towards_food, ate_food, ate_x2, is_dead, steps_since_food, ate_food_while_x2, about_to_timeout]
ate_food/ate_x2:由连续帧 head/food/x2 位置推导,YOLO 推理时同样可计算is_dead:蛇是否已撞击死亡(YOLO 5 类时 class 4=snake_head_dead表示圆形蛇头,label 路径从解析结果读取)steps_since_food:自上次吃食物以来的帧数,归一化min(count/80, 1.0)about_to_timeout:若下一步再没吃到果子就超时(80 步)则为 1,否则为 0- 蛇头前方一格(离散,Embedding 输入):0=空,1=自己身体,2=其他蛇身体,3=其他蛇头;用于区分 self_collision / snake_collision 及碰撞部位
- 训练时默认将前后共 3 帧(1 前 + 当前 + 1 后)合并为 1 帧输入(3×18=54 维 + head_forward 嵌入)
- 无墙壁,蛇撞自己即结束
- 吃 1 个食物 +1 格长度、+1 分
- 每当食物被吃掉时,50% 概率同时生成一个「x2」
- 先吃 x2 再吃食物 → 该食物得 2 分;否则得 1 分
- x2 每波只生效一次,生成新食物时自动失效
| 标注 | 含义 |
|---|---|
correct / ate_x2_then_food |
先吃 x2 再吃食物 ✓ |
correct / ate_food_no_x2 |
无 x2 时吃食物 ✓ |
in_progress |
进行中(未结束) |
incorrect / self_collision |
撞自己 ✗ |
incorrect / snake_collision |
蛇间碰撞 ✗ |
incorrect / x2_wasted |
先吃食物导致 x2 浪费 ✗ |
incorrect / timeout |
超时未吃食物 ✗ |
batch JSON 结构(支持多蛇):
{
"episodes": [
{
"scenes": [
{
"snakes": [
{ "body": [[x,y],...], "food": [x,y], "x2": [x,y]|null, "score": 0, "color_id": 0 }
],
"step": 0
}
],
"label": "correct",
"reason": "ate_food_no_x2",
"snake_annotations": [
{ "label": "correct", "reason": "ate_food_no_x2" }
]
}
]
}每个 scene 为一帧状态;snake_annotations[i] 为第 i 条蛇的波级标注。
# 默认:10 个 batch,每 batch 10 局,输出到 batches/
python data_generator.py
# 自定义参数(支持多进程加速)
python data_generator.py --batches 100 --batch-size 100 --output my_data --mistake-rate 0.2
# 使用 12 个进程并行生成
python data_generator.py -b 100 -s 100 -w 12| 参数 | 默认 | 说明 |
|---|---|---|
-b, --batches |
10 | batch 数量 |
-s, --batch-size |
10 | 每 batch 局数 |
-o, --output |
batches | 输出目录 |
-w, --workers |
CPU 核心数 | 并行进程数 |
-m, --mistake-rate |
0.15 | AI 犯错概率(先吃食物浪费 x2) |
-f, --max-foods |
12 | 每局需吃完的食物波数 |
每个 batch 保存为独立 JSON,如 batches/batch_00000.json。每局存储完整场景序列。
渲染输出 640×640 图像;蛇头视觉:开局菱形 → 三角(尖指方向) → 撞击圆形。
| 步骤 | 命令 | 说明 |
|---|---|---|
| 1 | 数据生成 | 多进程 -w 12 加速 |
| 2 | 渲染导出 | 640×640,全帧不跳帧 |
| 3 | YOLO 训练 | 可选,5 类含 snake_head_dead |
| 4 | 序列准备 | 路径 A 纯 label / 路径 B YOLO 跟踪 |
| 5 | 行为训练 | 双向 LSTM + 注意力;best 按 7类(P+R+AP)均值 选取(每轮 0.15~0.85 搜最优阈值,F1≥0.25 约束);正负样本都评;--patience 早停;每轮打印 (P+R+AP)、thresh |
| 6 | 评估 | 默认复用训练 val 流程(DataLoader),正负都评,指标准确一致;阈值 0.15~0.85 内按 (P+R+AP) 均值搜索;--no-threshold-search 时用 --incorrect-threshold |
| 7 | 演示视频 | -d dataset 与训练帧一致;--draw-boxes 绘制 YOLO 框 |
# 1. 数据生成(-w 12 多进程)
python data_generator.py --batches 10 --batch-size 100 -w 12
# 2. 渲染与导出 YOLO 数据集(640×640,全帧;-w 12 多进程)
python scripts/render_and_export.py --batches batches/ --output dataset --val-ratio 0.2 -w 12
# 3. (可选) YOLO 训练(5 类:snake_head, snake_body, food, x2, snake_head_dead;imgsz 与渲染一致)
yolo train model=yolov8n.pt data=dataset/data.yaml epochs=100 imgsz=640 batch=128
# 4. 序列准备(二选一)
# 路径 A:纯 label,无需 YOLO(-w 12 多进程)
python scripts/run_track_and_prepare.py --from-labels -d dataset -o sequences -w 12
# 路径 B:YOLO 跟踪(需先完成步骤 3)
python scripts/run_track_and_prepare.py -m runs/detect/train/weights/best.pt -d dataset -o sequences
# 5. 行为模型训练(best 按 7类(P+R+AP)均值 选取,每轮搜最优阈值并打印;(P+R+AP)、thresh;50 epoch 无提升早停)
python scripts/train_behavior.py --data sequences/track_sequences.json -o checkpoints/behavior \
--boost-incorrect --aug-multiscale --aug-frame-drop 0.1 --aug-noise 0.02 --epochs 1000 --patience 50
# 6. 评估(默认复用训练 val 流程,正负都评;阈值 0.15~0.85 内按 (P+R+AP) 搜索;--batch-size 256 提速)
python scripts/eval_behavior.py -c checkpoints/behavior/best.pt -d sequences/track_sequences.json --batch-size 256
# 7. 演示视频(-d dataset 保证帧与训练一致;路径 B 可加 -m YOLO权重 --draw-boxes)
python scripts/demo_video.py -b batches/batch_00000.json -e 0 -c checkpoints/behavior/best.pt -d dataset -o demo.mp4pip install pygame
python replay_ui.py回放控制:
- 打开文件 (O):点击按钮或按 O 键选择 JSON 数据文件
- 空格:播放/暂停
- ←/→:上一帧/下一帧
- A/D:上一局/下一局
- Home/End:跳到开头/结尾
# 640×640 全帧导出(不跳帧),多进程:-w 12
python scripts/render_and_export.py --batches batches/ --output dataset --val-ratio 0.2 -w 12输出 dataset/train、dataset/val(640×640 images + labels + metadata.json)。划分按局(episode):约 --val-ratio 比例的局整局进入 val,保证验证集样本量合理。run_track_and_prepare 会据此为每条序列写入 split,训练与评估的 train/val 与 dataset 一致。蛇头视觉:开局菱形、前进三角(尖指方向)、撞击圆形。
# 纯 label:直接从 YOLO label 提取蛇头,无需运行 YOLO(-w 12 多进程)
python scripts/run_track_and_prepare.py --from-labels -d dataset -o sequences -w 12
# YOLO 跟踪:需先训练 YOLO,再跑 track(GPU 时默认单进程)
python scripts/run_track_and_prepare.py -m yolov8n.pt -d dataset -o sequences模型结构:双向 LSTM + 自注意力(可用 --no-bidirectional --no-attention 禁用以兼容旧版)。best.pt 按 7类(P+R+AP)均值 选取:每轮在阈值 0.15~0.85 内搜索(F1≥0.25 约束下最大化该均值),正负样本都参与验证;每轮打印 (P+R+AP) 与 thresh。--patience 50 早停。
# 纯网格(推荐先验证)
python scripts/train_behavior.py --data grid --batches batches/ -o checkpoints/behavior
# 序列数据(--boost-incorrect 提升错误检测;best 按 7类(P+R+AP)均值 选取)
python scripts/train_behavior.py --data sequences/track_sequences.json -o checkpoints/behavior \
--boost-incorrect --patience 50 --epochs 1000
# 可选:--patience 0 禁用早停
python scripts/train_behavior.py -d sequences/track_sequences.json -o checkpoints/behavior --patience 30
# 禁用新结构(与旧 checkpoint 一致)
python scripts/train_behavior.py -d sequences/track_sequences.json -o checkpoints/behavior --no-bidirectional --no-attention# 行为标注(推荐 -d dataset 保证帧与训练一致)
python scripts/demo_video.py -b batches/batch_00000.json -e 0 -c checkpoints/behavior/best.pt -d dataset -o demo.mp4
# YOLO 框 + 行为联合标注(路径 B)
python scripts/demo_video.py -b batches/batch_00000.json -e 0 -m yolov8n.pt -c checkpoints/behavior/best.pt -d dataset -o demo.mp4 --draw-boxes默认仅评估验证集(--split val),且复用训练 val 流程(同一 DataLoader),正负样本都评,指标准确一致。阈值在 0.15~0.85 内按 7类(P+R+AP)均值 最大搜索(F1≥0.25 约束);输出表格(P/R/mAP,每类 + all)。
# 默认评估 val 集(与训练同一 val 流程,mAP 一致)
python scripts/eval_behavior.py -c checkpoints/behavior/best.pt -d sequences/track_sequences.json
# 评估全部数据:--split all(旧版无 split 的 JSON 也请用此选项)
python scripts/eval_behavior.py -c best.pt -d sequences/track_sequences.json --split all
# 使用固定阈值(不搜索):--no-threshold-search --incorrect-threshold 0.15
python scripts/eval_behavior.py -c best.pt -d sequences/track_sequences.json --no-threshold-search --incorrect-threshold 0.15
# 提升错误召回率:--reason-override
python scripts/eval_behavior.py -c best.pt -d sequences/track_sequences.json --reason-override
# 大批量评估提速:--batch-size 256(显存允许时)
python scripts/eval_behavior.py -c best.pt -d sequences/track_sequences.json --batch-size 256from data_generator import generate_dataset, run_episode
# 生成数据:100 个 batch,每 batch 100 局
generate_dataset(num_batches=100, batch_size=100, output_dir="batches")
# 单局运行(含 AI 犯错概率)
ep = run_episode(seed=42, max_foods=12, ai_mistake_rate=0.15)
print(ep["label"], ep["reason"], len(ep["scenes"]))| 用途 | 包 |
|---|---|
| 游戏 / 回放 / 渲染 | pygame |
| 目标检测 / 跟踪 | ultralytics (YOLO) |
| 行为模型 | torch |
| 视频输出 | opencv-python |
behavior_detection/
├── data_generator.py # 数据生成
├── game.py, ai.py # 游戏逻辑与 AI
├── replay_ui.py # 回放演示
├── scripts/
│ ├── render_and_export.py # 渲染 → YOLO 数据集
│ ├── preview_labels.py # 标注预览
│ ├── run_track_and_prepare.py # 序列构建
│ ├── train_behavior.py # 行为模型训练
│ ├── infer_behavior.py # 行为推理
│ └── demo_video.py # 实战演示视频
├── models/
│ └── behavior_correctness.py # 双向LSTM+注意力 行为/正确性模型
├── batches/ # 生成的对局数据
├── dataset/ # 渲染输出 (640×640 YOLO 数据集)
├── sequences/ # 序列特征
└── checkpoints/behavior/ # 模型权重