Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support execution result evaluation for text2gql #302

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ src/dbgpt-hub-gql/wandb/*
!src/dbgpt-hub-gql/dbgpt_hub_gql/data/tugraph-db-example
!src/dbgpt-hub-gql/dbgpt_hub_gql/data/dataset_info.json
!src/dbgpt-hub-gql/dbgpt_hub_gql/data/example_text2sql.json
# Ignore all server's dataset folder under eval/evaluator/impl
src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/datasets

# Ignore everything under dbgpt_hub_sql/ouput/ except the adapter directory
src/dbgpt-hub-sql/dbgpt_hub_sql/output/
Expand Down
113 changes: 87 additions & 26 deletions src/dbgpt-hub-gql/README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,25 @@
<th>Method</th>
<th>Similarity</th>
<th>Grammar</th>
<th>Execution</th>
</tr>
<tr >
<td></td>
<td></td>
<td></td>
<td>base</td>
<td>0.769</td>
<td>0.703</td>
<td>0.674</td>
<td>0.653</td>
<td>0.037</td>
</tr>
<tr>
<td>Cypher <a href="https://github.com/TuGraph-family/tugraph-db">(tugraph-db)</a></td>
<td><a href="https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db/tugraph-db.zip">TuGraph-DB Cypher数据集</a></td>
<td><a href="https://huggingface.co/tugraph/CodeLlama-7b-Cypher-hf/tree/1.0">CodeLlama-7B-Instruct</a></td>
<td><a href="https://huggingface.co/tugraph/CodeLlama-7b-Cypher-hf/tree/1.1">CodeLlama-7B-Instruct</a></td>
<td>lora</td>
<td>0.928</td>
<td>0.946</td>
<td>0.922</td>
<td>0.987</td>
<td>0.507</td>
</tr>
<tr >
<td></td>
Expand All @@ -36,6 +39,7 @@
<td>base</td>
<td>0.493</td>
<td>0.002</td>
<td>none</td>
</tr>
<tr>
<td>GQL<a href="https://github.com/TuGraph-family/tugraph-analytics">(tugraph-analytics)</a></td>
Expand All @@ -44,6 +48,25 @@
<td>lora</td>
<td>0.935</td>
<td>0.984</td>
<td>none</td>
</tr>
<tr >
<td></td>
<td></td>
<td></td>
<td>base</td>
<td>0.769</td>
<td>0.703</td>
<td>0.000</td>
</tr>
<tr>
<td>Cypher <a href="https://github.com/TuGraph-family/tugraph-db">(tugraph-db-example)</a></td>
<td><a href="https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db-example/tugraph-db-example.zip">TuGraph-DB Cypher example数据集</a></td>
<td><a href="https://huggingface.co/tugraph/CodeLlama-7b-Cypher-hf/tree/1.0">CodeLlama-7B-Instruct</a></td>
<td>lora</td>
<td>0.928</td>
<td>0.946</td>
<td>0.476</td>
</tr>
</table>

Expand All @@ -63,6 +86,8 @@
- [3.5、模型评估](#35模型评估)
- [3.5.1、文本相似度评估](#351文本相似度评估)
- [3.5.2、语法正确性评估](#352语法正确性评估)
- [3.5.3、执行结果一致性评估](#353执行结果一致性评估)
- [3.5.3.1、tugraph-db](#3531tugraph-db)
- [3.6、模型权重合并](#36模型权重合并)

# 一、简介
Expand All @@ -79,7 +104,7 @@ DB-GPT-GQL不仅支持了基于多个大模型的微调、预测流程,在翻

### 2.1、数据集

本项目样例数据集为`Cypher(tugraph-db)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,当前可使用的数据集如下
本项目样例数据集为`Cypher(tugraph-db-example)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,仅作测试运行使用,当前可使用的完整数据集如下

- [Cypher(tugraph-db)](https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db/tugraph-db.zip): 符合tugraph-db的Cypher语法的数据集,采用“ [语法制导的语料生成策略](https://mp.weixin.qq.com/s/rZdj8TEoHZg_f4C-V4lq2A)”,将查询语言模板结合多样化的schema生成查询语言,并使用大模型泛化与之对应的自然语言问题描述后生成的语料。[语料生成框架](https://github.com/TuGraph-contrib/Awesome-Text2GQL)现已开源,欢迎参与共建。

Expand Down Expand Up @@ -120,43 +145,39 @@ DB-GPT-GQL目前已经支持的base模型有:

### 3.1、环境准备

克隆项目并创建 conda 环境,
```bash
# 克隆项目并创建 conda 环境
git clone https://github.com/eosphoros-ai/DB-GPT-Hub.git
cd DB-GPT-Hub
conda create -n dbgpt_hub_gql python=3.10
conda activate dbgpt_hub_gql
```

进入DB-GPT-GQL项目目录,并使用poetry安装依赖
```bash
# 进入DB-GPT-GQL项目目录,并使用poetry安装依赖
cd src/dbgpt-hub-gql
pip install -e .

# 创建输出及日志目录
mkdir dbgpt_hub_gql/output
mkdir dbgpt_hub_gql/output/logs
mkdir dbgpt_hub_gql/output/pred
```

### 3.2、模型准备
创建并进入codellama模型存放目录
```bash
mkdir codellama
cd ./codellama
```

在`codellama`文件夹下创建`download.py`文件并将如下内容复制进入python文件中
创建`download.py`文件并将如下内容复制进入python文件中
```python
from modelscope import snapshot_download

model_dir = snapshot_download("AI-ModelScope/CodeLlama-7b-Instruct-hf")
model_dir = snapshot_download("AI-ModelScope/CodeLlama-7b-Instruct-hf", cache_dir='./')
```

安装python依赖并下载模型
使用download.py下载模型
```bash
# 安装python依赖并下载模型
pip install modelscope
python3 download.py
```

下载完成后,将模型文件软链接到`codellama`目录下
```bash
ln -s /root/.cache/modelscope/hub/AI-ModelScope/CodeLlama-7b-Instruct-hf ./
# 下载完成后,将AI-ModelScope/CodeLlama-7b-Instruct-hf重命名为codellama/CodeLlama-7b-Instruct-hf
mv ./AI-ModelScope ./codellama
```

### 3.3、模型微调
Expand Down Expand Up @@ -209,24 +230,64 @@ sh dbgpt_hub_gql/scripts/predict_sft.sh

### 3.5、模型评估

目前版本支持两种预测结果评估方法,第一种是基于Jaro–Winkler distance的文本相似度评估,第二种是基于`.g4`语法文件或图数据库现有语法解析器的语法正确性评估。
目前版本支持三种预测结果评估方法,第一种是基于Jaro–Winkler distance的文本相似度评估,第二种是基于`.g4`语法文件或图数据库现有语法解析器的语法正确性评估,第三种则是基于查询语句返回结果比较大的执行一致性评估

#### 3.5.1、文本相似度评估

文本相似度评估直接统计预测结果与标准结果的Jaro–Winkler distance

```bash
python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/gold_dev.txt --etype similarity
python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype similarity
```

#### 3.5.2、语法正确性评估

`tugraph-db-example`是符合`tugraph-db`的LCypher语法规则的语料数据集,语法正确性评估使用ANTLR4工具,基于`./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/Lcypher.g4`文件生成了语法解析器,用于评估模型预测结果的语法正确性。

```bash
python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/gold_dev.txt --etype grammar --impl tugraph-db
python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype grammar --impl tugraph-db
```

#### 3.5.3、执行结果一致性评估

当前版本仅支持tugraph-db上的执行结果一致性评估,暂未支持tugraph-analytics上的执行结果一致性评估。

##### 3.5.3.1、tugraph-db

执行结果一致性评估需要实际运行tugraph-db,方便起见可以下载tugraph官方提供的runtime镜像。

```bash
# 下载并解压tugraph执行测试数据集
wget -P ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db-server/datasets.zip

unzip -d ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/datasets.zip

# 下载并启动tugraph的runtime镜像
docker pull tugraph/tugraph-runtime-centos7

docker run -it -v ./:/root/dbgpt-hub-gql --name=tugraph-db_evaluation tugraph/tugraph-runtime-centos7 /bin/bash

cd /root

# 安装 miniconda
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh --no-check-certificate
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
source ~/miniconda3/bin/activate

# 准备运行环境
cd /root/dbgpt-hub-gql/
conda create -n dbgpt_hub_gql python=3.10
conda activate dbgpt_hub_gql
pip install -e .

# 执行评测
python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype execution --impl tugraph-db
```



### 3.6、模型权重合并

如果你需要将训练的基础模型和微调的Peft模块的权重合并,导出一个完整的模型。则运行如下模型导出脚本:
Expand Down
27 changes: 21 additions & 6 deletions src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,26 @@
from evaluator.evaluator import Evaluator
from evaluator.similarity_evaluator import SimilarityEvaluator

# print(f"{os.path.dirname(os.path.abspath(__file__))}/evaluator/impl/tugraph-db")
# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/evaluator/impl/tugraph-db")


def evaluate(gold, predict, etype, impl):
log_file = open(f"{os.path.dirname(__file__)}/../output/logs/eval.log", "w")
log_lines = []

with open(gold) as f:
content = f.read()
gold_list = json.loads(content)
gseq_one = []
for l in f.readlines():
if len(l.strip()) == 0:
db_id_list = []
for gold_dic in gold_list:
if len(gold_dic["output"].strip()) == 0:
# when some predict is none, support it can continue work
gseq_one.append("no out")
else:
gseq_one.append(l.strip())
gseq_one.append(gold_dic["output"].strip())
db_id_list.append(gold_dic["db_id"].strip())

with open(predict) as f:
plist = []
Expand All @@ -28,7 +35,6 @@ def evaluate(gold, predict, etype, impl):
if len(l.strip()) == 0:
# when some predict is none, support it can continue work
pseq_one.append("no out")

else:
pseq_one.append(l.strip())

Expand All @@ -38,16 +44,24 @@ def evaluate(gold, predict, etype, impl):

score_total = 0
if etype == "similarity":
# jaro-winkler distance score
evaluator = SimilarityEvaluator()
elif etype == "grammar":
# grammar check result, 1 if pass, 0 if fail
model_path = f"evaluator.impl.{impl}.grammar_evaluator"
m = importlib.import_module(model_path)
GrammarEvaluator = getattr(m, "GrammarEvaluator")
evaluator = GrammarEvaluator()
elif etype == "execution":
# excution result, 1 if same, 0 if not same
model_path = f"evaluator.impl.{impl}.execution_evaluator"
m = importlib.import_module(model_path)
ExecutionEvaluator = getattr(m, "ExecutionEvaluator")
evaluator = ExecutionEvaluator()

total = 0
for i in range(len(gseq_one)):
score = evaluator.evaluate(pseq_one[i], gseq_one[i])
score = evaluator.evaluate(pseq_one[i], gseq_one[i], db_id_list[i])
if score != -1:
score_total += score
total += 1
Expand All @@ -57,6 +71,7 @@ def evaluate(gold, predict, etype, impl):
tmp_log["score"] = score
log_lines.append(tmp_log)


json.dump(log_lines, log_file, ensure_ascii=False, indent=4)

tb = pt.PrettyTable()
Expand All @@ -83,7 +98,7 @@ def evaluate(gold, predict, etype, impl):
type=str,
default="similarity",
help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy",
choices=("similarity", "grammar"),
choices=("similarity", "grammar", "execution"),
)
parser.add_argument(
"--impl",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class Evaluator:
def evaluate(self, query_predict, query_gold):
def evaluate(self, query_predict, query_gold, db_id):
return 1
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):
JDClass = jpype.JClass("com.antgroup.geaflow.dsl.parser.GeaFlowDSLParser")
self.jd = JDClass()

def evaluate(self, query_predict, query_gold):
def evaluate(self, query_predict, query_gold, db_id):
try:
result_gold = self.jd.parseStatement(query_gold)
try:
Expand Down
Loading