Skip to content

ZSHYC/TrackNet-V3-based-Badminton

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

74 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TrackNet-V3-based-Badminton

基于 TrackNetV3 的羽毛球轨迹检测、轨迹修复、落点预测与关键事件检测项目。

本仓库以原始 TrackNetV3 为基础,保留“TrackNet 逐帧定位 + InpaintNet 轨迹修复”的主流程,并在此之上扩展了面向羽毛球比赛分析的后处理能力:落点预测、落地/击球事件检测、半自动标注工具、BounceNet 事件分类器、批量评估与可视化分析界面。

原论文:TrackNetV3: Enhancing Shuttlecock Tracking with Augmentations and Trajectory Rectification

论文链接:ACM Digital Library

TrackNetV3 network architecture

项目特点

  • TrackNetV3 轨迹检测主干:使用背景估计、热力图预测和时序集成输出羽毛球坐标。
  • InpaintNet 轨迹修复:对遮挡、漏检和不连续轨迹进行补全,使后续事件检测更稳定。
  • 坐标型落点预测:新增 landing_predictor.py,仅基于 (x, y) 轨迹判断落点候选,不把 visibility 作为检测依据。
  • 批量落点评估:新增 batch_predict_landings.py,支持对 train/test 等 split 批量生成 Top-K 落点候选,并在有标签时输出误差统计。
  • 落地/击球事件检测模块:新增 bounce_detection/,用运动学规则检测 landing、hit、out_of_frame 等关键事件。
  • BounceNet 二阶段分类器:在高召回规则候选基础上,用轨迹特征或视觉 patch 过滤误检、修正事件类型。
  • 半自动标注工具:新增 labeling_launcher.pybounce_detection/labeling/,支持 Phase 1 候选预填充、手动修正、批量标注、统计与导出。
  • 误差分析和可视化:保留并扩展 Dash 误差分析界面,同时新增落点预测视频可视化脚本。
Trajectory comparison

与原 TrackNetV3 的关系

原 TrackNetV3 主要解决“每一帧羽毛球在哪里”的问题,本项目进一步面向“比赛片段中关键事件在哪里发生”:

  1. 先用 TrackNet/InpaintNet 得到连续轨迹。
  2. 再用规则或轻量模型从轨迹中识别落地点、击球点、出画面等事件。
  3. 用标注工具快速修正候选结果,形成可训练的事件标签。
  4. 用 BounceNet 或坐标型落点预测器进行批量推理、评估和可视化。

这使仓库从单纯的轨迹检测模型,扩展为一个更完整的羽毛球轨迹分析与事件检测工具链。

环境安装

原始开发环境参考:

Ubuntu 16.04.7 LTS
Python 3.8.7
torch 1.10.0

安装依赖:

git clone https://github.com/ZSHYC/TrackNet-V3-based-Badminton.git
cd TrackNet-V3-based-Badminton
pip install -r requirements.txt

主要依赖包括 torchopencv-pythonnumpypandasdashplotlyPillow

数据准备

本项目沿用 Shuttlecock Trajectory Dataset 的组织方式。建议目录结构如下:

data/
  train/
    match1/
      csv/
      frame/
      video/
    match2/
    ...
  val/
  test/
    match1/
      csv/
      corrected_csv/
      frame/
      video/

预处理:

python preprocess.py

说明:

  • CSV 字段通常为 Frame, Visibility, X, Y
  • frame/val/ 可由预处理流程生成。
  • 背景中值图会保存到对应 match 或 rally 的 median 文件中。
  • 如果修改 dataset.py 中的数据索引逻辑,需要删除旧的 .npy 缓存后重新生成。

TrackNetV3 轨迹推理

下载原 TrackNetV3 检查点后放入 ckpts/

unzip TrackNetV3_ckpts.zip

视频推理并输出预测 CSV:

python predict.py \
  --video_file test.mp4 \
  --tracknet_file ckpts/TrackNet_best.pt \
  --inpaintnet_file ckpts/InpaintNet_best.pt \
  --save_dir prediction

同时输出带预测轨迹的视频:

python predict.py \
  --video_file test.mp4 \
  --tracknet_file ckpts/TrackNet_best.pt \
  --inpaintnet_file ckpts/InpaintNet_best.pt \
  --save_dir prediction \
  --output_video

长视频可使用 --large_video 降低内存压力:

python predict.py \
  --video_file test.mp4 \
  --tracknet_file ckpts/TrackNet_best.pt \
  --inpaintnet_file ckpts/InpaintNet_best.pt \
  --save_dir prediction \
  --large_video \
  --video_range 324,330

坐标型落点预测

新增的落点预测器位于:

  • landing_predictor.py:单条轨迹的 Top-K 落点候选生成。
  • batch_predict_landings.py:按 split/match 批量推理和评估。
  • visualize_predictions.py:把预测落点画回视频。

核心策略:

  • 不依赖 visibility 判断落点。
  • 使用稳定窗口、低速度、Y 坐标稳定性和转向信号筛选候选。
  • 输出 Top-K 候选,按转向、平均速度和 Y 方向波动排序。
  • 有标注时计算帧误差、坐标误差和命中率。

训练集评估并保存预测:

python batch_predict_landings.py --split train --evaluate --save_preds

测试集批量预测:

python batch_predict_landings.py --split test --save_preds

常用参数:

python batch_predict_landings.py \
  --split test \
  --window 5 \
  --dy 10 \
  --v_th 5 \
  --top_k 5 \
  --save_preds

输出示例:

data/<split>/<match>/pred_landing/
  *_pred.json
  metrics.json   # evaluate 模式下生成
  detail.json    # evaluate 模式下生成

预测结果可视化:

python visualize_predictions.py \
  --match_dir data/train/match1 \
  --pred_dir pred_landing \
  --output_dir pred_landing_vis

落地/击球事件检测

bounce_detection/ 是本项目最主要的扩展模块,用于从 TrackNetV3 输出轨迹中检测关键事件。

bounce_detection/
  kinematics.py              # 速度、加速度、方向、曲率等运动学特征
  candidate_generator.py     # 规则候选生成
  detector.py                # 统一检测接口
  visual_features.py         # 视频 patch 与运动历史图特征
  bouncenet.py               # BounceNet 分类网络
  dataset.py                 # BounceNet 训练数据集
  labeling/                  # 半自动标注工具

快速使用:

from bounce_detection import BounceDetector

detector = BounceDetector()
events = detector.detect_from_csv("data/test/match1/csv/1_05_02_ball.csv")

for event in events:
    print(event["frame"], event["event_type"], event["rule"])

事件类型:

事件 标识 说明
落地点 landing 球落地、停止或轨迹结束
击球点 hit 球被击打后方向或速度突变
出画面 out_of_frame 球飞出画面边缘
非事件 none BounceNet 过滤后的误检

规则检测关注高召回率,典型规则包括:

  • speed_drop:速度骤降。
  • trajectory_end:轨迹结束。
  • visibility_drop:可见性消失。
  • vy_reversal / vx_reversal:速度方向反转。
  • acceleration_peak:加速度峰值。
  • y_local_max / speed_local_max:局部极值辅助规则。

详细说明见:

半自动标注工具

标注工具用于把规则检测结果快速转成可训练标签。默认会在新文件上运行 Phase 1 检测进行候选预填充,也可以关闭自动检测进行纯手工标注。

单文件标注:

python labeling_launcher.py --csv data/test/match1/csv/1_05_02_ball.csv

指定视频:

python labeling_launcher.py \
  --csv data/test/match1/csv/1_05_02_ball.csv \
  --video data/test/match1/video/1_05_02.mp4

批量标注一个 match:

python labeling_launcher.py --match_dir data/test/match1

关闭 Phase 1 自动预填充:

python labeling_launcher.py \
  --csv data/test/match1/csv/1_05_02_ball.csv \
  --no-auto-detect

导出和统计:

python labeling_launcher.py --export data/test/match1/labels --output training_events.csv
python labeling_launcher.py --stats data/test/match1/labels

标注结果默认保存为:

data/<split>/<match>/labels/*_labels.json

详见 bounce_detection/labeling/README_labeling_tool.md

BounceNet 事件分类器

BounceNet 是二阶段事件分类器,用于过滤规则候选中的误检,并在 landing/hit/none 之间重新分类。

支持三种模式:

模式 输入 适用场景
trajectory_only 轨迹窗口 快速、无需视频,默认推荐
visual_only 视频 patch 依赖局部视觉线索
fusion 轨迹 + 视频 信息最完整,训练成本更高

训练:

python train_bouncenet.py \
  --label_dir labels/ \
  --mode trajectory_only \
  --epochs 100 \
  --batch_size 32 \
  --lr 1e-3 \
  --early_stopping 15 \
  --save_dir ckpts/bouncenet

推理集成:

from bounce_detection import BounceDetector

detector = BounceDetector(bouncenet_ckpt="ckpts/bouncenet/best.pt")
events = detector.detect(x, y, visibility, frames=frames, use_bouncenet=True)

训练细节见 README_train_bouncenet.md

模型训练与评估

训练 TrackNet:

python train.py \
  --model_name TrackNet \
  --seq_len 8 \
  --epochs 30 \
  --batch_size 10 \
  --bg_mode concat \
  --alpha 0.5 \
  --save_dir exp \
  --verbose

生成 InpaintNet 训练用轨迹和 mask:

python generate_mask_data.py \
  --tracknet_file ckpts/TrackNet_best.pt \
  --batch_size 16

训练 InpaintNet:

python train.py \
  --model_name InpaintNet \
  --seq_len 16 \
  --epochs 300 \
  --batch_size 32 \
  --lr_scheduler StepLR \
  --mask_ratio 0.3 \
  --save_dir exp \
  --verbose

评估 TrackNetV3:

python generate_mask_data.py --tracknet_file ckpts/TrackNet_best.pt --split_list test
python test.py --tracknet_file ckpts/TrackNet_best.pt --inpaintnet_file ckpts/InpaintNet_best.pt --save_dir eval

仅评估 TrackNet:

python test.py --tracknet_file ckpts/TrackNet_best.pt --save_dir eval

生成用于误差分析界面的详细预测:

python test.py \
  --tracknet_file ckpts/TrackNet_best.pt \
  --inpaintnet_file ckpts/InpaintNet_best.pt \
  --save_dir eval \
  --output_pred

误差分析界面

error_analysis.py 提供 Dash 可视化界面,用于比较不同模型或不同结果文件的逐帧误差。

python error_analysis.py --split test --host 127.0.0.1

界面支持:

  • rally 级误差分布查看。
  • 逐帧预测与标签对比。
  • 不同结果文件之间的可视化比较。
  • test.py --output_pred 生成的 JSON 联动。
Error analysis UI

项目结构

.
├── model.py                         # TrackNet / InpaintNet 模型
├── dataset.py                       # 数据集与序列构造
├── train.py                         # TrackNet / InpaintNet 训练
├── predict.py                       # 视频轨迹推理
├── test.py                          # 评估与误差分析数据导出
├── landing_predictor.py             # 坐标型落点预测
├── batch_predict_landings.py        # 批量落点预测与评估
├── visualize_predictions.py         # 落点预测视频可视化
├── predict_bounce.py                # 带事件结果的视频输出
├── labeling_launcher.py             # 标注工具入口
├── train_bouncenet.py               # BounceNet 训练入口
├── error_analysis.py                # Dash 误差分析界面
├── bounce_detection/                # 落地/击球事件检测模块
├── utils/                           # 通用函数、指标、可视化
└── figure/                          # README 与论文图示

原始 TrackNetV3 性能参考

以下为原 TrackNetV3 在 Shuttlecock Trajectory Dataset test split 上报告的结果,用作轨迹检测主干的背景参考:

Model Accuracy Precision Recall F1 FPS
YOLOv7 57.82% 78.53% 59.96% 68.00% 34.77
TrackNetV2 94.98% 99.64% 94.56% 97.03% 27.70
TrackNetV3 97.51% 97.79% 99.33% 98.56% 25.11

本项目新增的落点预测和事件检测模块是后处理扩展,评估方式与原表中的逐帧轨迹检测指标不同,应分别查看对应输出的 metrics.jsondetail.json 或标注评估结果。

参考

License

本仓库保留原项目许可证。详见 LICENSE

About

TrackNet-V3-based-Badminton是一个基于深度学习的羽毛球轨迹检测与事件分析系统,旨在通过增强的轨迹预测和后处理模块,实现对羽毛球比赛中关键事件(如落地点和击球点)的精准检测。该项目在原始 TrackNetV3 的基础上,新增了多个功能模块,包括 BounceNet 事件分类器和半自动标注工具,进一步提升了系统的实用性和智能化水平。

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages