diff --git a/.gitignore b/.gitignore
index 65bcdef8..19840f80 100644
--- a/.gitignore
+++ b/.gitignore
@@ -27,6 +27,8 @@ src/dbgpt-hub-gql/wandb/*
!src/dbgpt-hub-gql/dbgpt_hub_gql/data/tugraph-db-example
!src/dbgpt-hub-gql/dbgpt_hub_gql/data/dataset_info.json
!src/dbgpt-hub-gql/dbgpt_hub_gql/data/example_text2sql.json
+# Ignore all server's dataset folder under eval/evaluator/impl
+src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/datasets
# Ignore everything under dbgpt_hub_sql/ouput/ except the adapter directory
src/dbgpt-hub-sql/dbgpt_hub_sql/output/
diff --git a/src/dbgpt-hub-gql/README.zh.md b/src/dbgpt-hub-gql/README.zh.md
index c10cc04b..36dec994 100644
--- a/src/dbgpt-hub-gql/README.zh.md
+++ b/src/dbgpt-hub-gql/README.zh.md
@@ -12,22 +12,25 @@
Method |
Similarity |
Grammar |
+ Execution |
|
|
|
base |
- 0.769 |
- 0.703 |
+ 0.674 |
+ 0.653 |
+ 0.037 |
Cypher (tugraph-db) |
TuGraph-DB Cypher数据集 |
- CodeLlama-7B-Instruct |
+ CodeLlama-7B-Instruct |
lora |
- 0.928 |
- 0.946 |
+ 0.922 |
+ 0.987 |
+ 0.507 |
|
@@ -36,6 +39,7 @@
base |
0.493 |
0.002 |
+ none |
GQL(tugraph-analytics) |
@@ -44,6 +48,25 @@
lora |
0.935 |
0.984 |
+ none |
+
+
+ |
+ |
+ |
+ base |
+ 0.769 |
+ 0.703 |
+ 0.000 |
+
+
+ Cypher (tugraph-db-example) |
+ TuGraph-DB Cypher example数据集 |
+ CodeLlama-7B-Instruct |
+ lora |
+ 0.928 |
+ 0.946 |
+ 0.476 |
@@ -63,6 +86,8 @@
- [3.5、模型评估](#35模型评估)
- [3.5.1、文本相似度评估](#351文本相似度评估)
- [3.5.2、语法正确性评估](#352语法正确性评估)
+ - [3.5.3、执行结果一致性评估](#353执行结果一致性评估)
+ - [3.5.3.1、tugraph-db](#3531tugraph-db)
- [3.6、模型权重合并](#36模型权重合并)
# 一、简介
@@ -79,7 +104,7 @@ DB-GPT-GQL不仅支持了基于多个大模型的微调、预测流程,在翻
### 2.1、数据集
-本项目样例数据集为`Cypher(tugraph-db)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,当前可使用的数据集如下:
+本项目样例数据集为`Cypher(tugraph-db-example)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,仅作测试运行使用,当前可使用的完整数据集如下:
- [Cypher(tugraph-db)](https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db/tugraph-db.zip): 符合tugraph-db的Cypher语法的数据集,采用“ [语法制导的语料生成策略](https://mp.weixin.qq.com/s/rZdj8TEoHZg_f4C-V4lq2A)”,将查询语言模板结合多样化的schema生成查询语言,并使用大模型泛化与之对应的自然语言问题描述后生成的语料。[语料生成框架](https://github.com/TuGraph-contrib/Awesome-Text2GQL)现已开源,欢迎参与共建。
@@ -120,43 +145,39 @@ DB-GPT-GQL目前已经支持的base模型有:
### 3.1、环境准备
-克隆项目并创建 conda 环境,
```bash
+# 克隆项目并创建 conda 环境
git clone https://github.com/eosphoros-ai/DB-GPT-Hub.git
cd DB-GPT-Hub
conda create -n dbgpt_hub_gql python=3.10
conda activate dbgpt_hub_gql
-```
-进入DB-GPT-GQL项目目录,并使用poetry安装依赖
-```bash
+# 进入DB-GPT-GQL项目目录,并使用poetry安装依赖
cd src/dbgpt-hub-gql
pip install -e .
+
+# 创建输出及日志目录
+mkdir dbgpt_hub_gql/output
+mkdir dbgpt_hub_gql/output/logs
+mkdir dbgpt_hub_gql/output/pred
```
### 3.2、模型准备
-创建并进入codellama模型存放目录
-```bash
-mkdir codellama
-cd ./codellama
-```
-在`codellama`文件夹下创建`download.py`文件并将如下内容复制进入python文件中
+创建`download.py`文件并将如下内容复制进入python文件中
```python
from modelscope import snapshot_download
-model_dir = snapshot_download("AI-ModelScope/CodeLlama-7b-Instruct-hf")
+model_dir = snapshot_download("AI-ModelScope/CodeLlama-7b-Instruct-hf", cache_dir='./')
```
-
-安装python依赖并下载模型
+使用download.py下载模型
```bash
+# 安装python依赖并下载模型
pip install modelscope
python3 download.py
-```
-下载完成后,将模型文件软链接到`codellama`目录下
-```bash
-ln -s /root/.cache/modelscope/hub/AI-ModelScope/CodeLlama-7b-Instruct-hf ./
+# 下载完成后,将AI-ModelScope/CodeLlama-7b-Instruct-hf重命名为codellama/CodeLlama-7b-Instruct-hf
+mv ./AI-ModelScope ./codellama
```
### 3.3、模型微调
@@ -209,14 +230,14 @@ sh dbgpt_hub_gql/scripts/predict_sft.sh
### 3.5、模型评估
-目前版本支持两种预测结果评估方法,第一种是基于Jaro–Winkler distance的文本相似度评估,第二种是基于`.g4`语法文件或图数据库现有语法解析器的语法正确性评估。
+目前版本支持三种预测结果评估方法,第一种是基于Jaro–Winkler distance的文本相似度评估,第二种是基于`.g4`语法文件或图数据库现有语法解析器的语法正确性评估,第三种则是基于查询语句返回结果比较大的执行一致性评估。
#### 3.5.1、文本相似度评估
文本相似度评估直接统计预测结果与标准结果的Jaro–Winkler distance
```bash
-python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/gold_dev.txt --etype similarity
+python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype similarity
```
#### 3.5.2、语法正确性评估
@@ -224,9 +245,49 @@ python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tug
`tugraph-db-example`是符合`tugraph-db`的LCypher语法规则的语料数据集,语法正确性评估使用ANTLR4工具,基于`./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/Lcypher.g4`文件生成了语法解析器,用于评估模型预测结果的语法正确性。
```bash
-python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/gold_dev.txt --etype grammar --impl tugraph-db
+python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype grammar --impl tugraph-db
```
+#### 3.5.3、执行结果一致性评估
+
+当前版本仅支持tugraph-db上的执行结果一致性评估,暂未支持tugraph-analytics上的执行结果一致性评估。
+
+##### 3.5.3.1、tugraph-db
+
+执行结果一致性评估需要实际运行tugraph-db,方便起见可以下载tugraph官方提供的runtime镜像。
+
+```bash
+# 下载并解压tugraph执行测试数据集
+wget -P ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db-server/datasets.zip
+
+unzip -d ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/datasets.zip
+
+# 下载并启动tugraph的runtime镜像
+docker pull tugraph/tugraph-runtime-centos7
+
+docker run -it -v ./:/root/dbgpt-hub-gql --name=tugraph-db_evaluation tugraph/tugraph-runtime-centos7 /bin/bash
+
+cd /root
+
+# 安装 miniconda
+mkdir -p ~/miniconda3
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh --no-check-certificate
+bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
+rm ~/miniconda3/miniconda.sh
+source ~/miniconda3/bin/activate
+
+# 准备运行环境
+cd /root/dbgpt-hub-gql/
+conda create -n dbgpt_hub_gql python=3.10
+conda activate dbgpt_hub_gql
+pip install -e .
+
+# 执行评测
+python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype execution --impl tugraph-db
+```
+
+
+
### 3.6、模型权重合并
如果你需要将训练的基础模型和微调的Peft模块的权重合并,导出一个完整的模型。则运行如下模型导出脚本:
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py
index 8b172b6a..c8a75238 100644
--- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py
@@ -7,19 +7,26 @@
from evaluator.evaluator import Evaluator
from evaluator.similarity_evaluator import SimilarityEvaluator
+# print(f"{os.path.dirname(os.path.abspath(__file__))}/evaluator/impl/tugraph-db")
+# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/evaluator/impl/tugraph-db")
+
def evaluate(gold, predict, etype, impl):
log_file = open(f"{os.path.dirname(__file__)}/../output/logs/eval.log", "w")
log_lines = []
with open(gold) as f:
+ content = f.read()
+ gold_list = json.loads(content)
gseq_one = []
- for l in f.readlines():
- if len(l.strip()) == 0:
+ db_id_list = []
+ for gold_dic in gold_list:
+ if len(gold_dic["output"].strip()) == 0:
# when some predict is none, support it can continue work
gseq_one.append("no out")
else:
- gseq_one.append(l.strip())
+ gseq_one.append(gold_dic["output"].strip())
+ db_id_list.append(gold_dic["db_id"].strip())
with open(predict) as f:
plist = []
@@ -28,7 +35,6 @@ def evaluate(gold, predict, etype, impl):
if len(l.strip()) == 0:
# when some predict is none, support it can continue work
pseq_one.append("no out")
-
else:
pseq_one.append(l.strip())
@@ -38,16 +44,24 @@ def evaluate(gold, predict, etype, impl):
score_total = 0
if etype == "similarity":
+ # jaro-winkler distance score
evaluator = SimilarityEvaluator()
elif etype == "grammar":
+ # grammar check result, 1 if pass, 0 if fail
model_path = f"evaluator.impl.{impl}.grammar_evaluator"
m = importlib.import_module(model_path)
GrammarEvaluator = getattr(m, "GrammarEvaluator")
evaluator = GrammarEvaluator()
+ elif etype == "execution":
+ # excution result, 1 if same, 0 if not same
+ model_path = f"evaluator.impl.{impl}.execution_evaluator"
+ m = importlib.import_module(model_path)
+ ExecutionEvaluator = getattr(m, "ExecutionEvaluator")
+ evaluator = ExecutionEvaluator()
total = 0
for i in range(len(gseq_one)):
- score = evaluator.evaluate(pseq_one[i], gseq_one[i])
+ score = evaluator.evaluate(pseq_one[i], gseq_one[i], db_id_list[i])
if score != -1:
score_total += score
total += 1
@@ -57,6 +71,7 @@ def evaluate(gold, predict, etype, impl):
tmp_log["score"] = score
log_lines.append(tmp_log)
+
json.dump(log_lines, log_file, ensure_ascii=False, indent=4)
tb = pt.PrettyTable()
@@ -83,7 +98,7 @@ def evaluate(gold, predict, etype, impl):
type=str,
default="similarity",
help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy",
- choices=("similarity", "grammar"),
+ choices=("similarity", "grammar", "execution"),
)
parser.add_argument(
"--impl",
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py
index 8b97babf..4a0c24dc 100644
--- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py
@@ -1,3 +1,3 @@
class Evaluator:
- def evaluate(self, query_predict, query_gold):
+ def evaluate(self, query_predict, query_gold, db_id):
return 1
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py
index 9d474c1a..9c52e90e 100644
--- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py
@@ -16,7 +16,7 @@ def __init__(self):
JDClass = jpype.JClass("com.antgroup.geaflow.dsl.parser.GeaFlowDSLParser")
self.jd = JDClass()
- def evaluate(self, query_predict, query_gold):
+ def evaluate(self, query_predict, query_gold, db_id):
try:
result_gold = self.jd.parseStatement(query_gold)
try:
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py
new file mode 100644
index 00000000..b81fad6e
--- /dev/null
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py
@@ -0,0 +1,164 @@
+import sys
+import logging
+import os.path
+import os
+import ctypes
+import subprocess
+import time
+import json
+import signal
+import jaro
+from neo4j import GraphDatabase
+
+current_dir = os.path.dirname(__file__)
+
+def handle_timeout(sig, frame):
+ raise TimeoutError('took too long')
+
+signal.signal(signal.SIGALRM, handle_timeout)
+
+class ExecutionEvaluator:
+ def __init__(self):
+ self.log = open('./exc_eval.log', 'w+')
+ # import datasets to 2 different data folder
+ dataset_list = os.listdir(f"{current_dir}/datasets")
+ self.dataset_list = dataset_list
+ try:
+ # iterate through all dataset folder under ./datasets
+ for dataset in dataset_list:
+ # import data to data folder for ground truth with cli command
+ self.process = subprocess.run([
+ 'sh',
+ f'{current_dir}/datasets/{dataset}/import.sh',
+ f'{current_dir}/server/server_gold/lgraph_db', f'{dataset}'
+ ], stdout=self.log, stderr=self.log, close_fds=True)
+
+ # import data to data folder for predicted result with cli command
+ self.process = subprocess.run([
+ 'sh',
+ f'{current_dir}/datasets/{dataset}/import.sh',
+ f'{current_dir}/server/server_predict/lgraph_db', f'{dataset}'
+ ], stdout=self.log, stderr=self.log, close_fds=True)
+ except Exception as e:
+ logging.debug(e)
+
+ # start 2 seperate tugraph-db server
+ try:
+ # start server for ground truth with cli command
+ self.process = subprocess.run([
+ 'sh',
+ f'{current_dir}/server/server_gold/start.sh'
+ ], stdout=self.log, stderr=self.log, close_fds=True)
+ time.sleep(10)
+
+ # start server for predcited result with cli command
+ self.process = subprocess.run([
+ 'sh',
+ f'{current_dir}/server/server_predict/start.sh'
+ ], stdout=self.log, stderr=self.log, close_fds=True)
+ time.sleep(10)
+ except Exception as e:
+ logging.debug(e)
+
+ # setup driver for ground truth
+ self.url_gold = f"bolt://localhost:9092"
+ self.driver_gold = GraphDatabase.driver(self.url_gold, auth=("admin", "73@TuGraph"))
+ self.driver_gold.verify_connectivity()
+
+ # setup driver for predicted result
+ self.url_predict = f"bolt://localhost:9094"
+ self.driver_predict = GraphDatabase.driver(self.url_predict, auth=("admin", "73@TuGraph"))
+ self.driver_predict.verify_connectivity()
+
+ self.session_pool = {}
+ for dataset in self.dataset_list:
+ self.session_pool[dataset] = []
+ self.session_pool[dataset].append(self.driver_gold.session(database=dataset))
+ self.session_pool[dataset].append(self.driver_predict.session(database=dataset))
+
+ def __del__(self):
+ # stop 2 seperate tugraph-db server
+ try:
+ # stop server for ground truth with cli command
+ self.process = subprocess.run([
+ 'sh',
+ f'{current_dir}/server/server_gold/stop.sh'
+ ], stdout=self.log, stderr=self.log, close_fds=True)
+
+ # stop server for predicted result with cli command
+ self.process = subprocess.run([
+ 'sh',
+ f'{current_dir}/server/server_predict/stop.sh'
+ ], stdout=self.log, stderr=self.log, close_fds=True)
+ except Exception as e:
+ logging.debug(e)
+
+ def restart_predict_server(self):
+ try:
+ # restart server for predicted result with cli command
+ self.process = subprocess.run([
+ 'sh',
+ f'{current_dir}/server/server_predict/start.sh'
+ ], stdout=self.log, stderr=self.log, close_fds=True)
+ time.sleep(10)
+
+ # setup driver for predicted result
+ self.driver_predict = GraphDatabase.driver(self.url_predict, auth=("admin", "73@TuGraph"))
+ self.driver_predict.verify_connectivity()
+ for dataset in self.dataset_list:
+ self.session_pool[dataset][1] = self.driver_predict.session(database=dataset)
+ except Exception as e:
+ logging.debug(e)
+
+ def evaluate(self, query_predict, query_gold, db_id):
+ if db_id not in self.session_pool.keys():
+ return -1
+
+ # run cypher on the server for ground truth
+ ret_gold = True
+ try:
+ res_gold = self.session_pool[db_id][0].run(query_gold).data()
+ except Exception as e:
+ ret_gold = False
+ res_gold = e
+
+ # run cypher on the server for predict result
+ ret_predict = True
+ try:
+ signal.alarm(10)
+ res_predict = self.session_pool[db_id][1].run(query_predict).data()
+ signal.alarm(0)
+ except TimeoutError as e:
+ ret_predict = False
+ res_predict = e
+ self.restart_predict_server()
+ except Exception as e:
+ ret_predict = False
+ res_predict = e
+ if "Couldn't connect to localhost:9094 (resolved to ()):" in str(e):
+ self.restart_predict_server()
+
+ if ret_gold == False:
+ return -1
+ else:
+ if ret_predict == True:
+ if "SKIP" in query_gold or "LIMIT" in query_gold:
+ # if SKIP or LIMIT in cypher, only compare the size of query result
+ if len(res_gold) == len(res_predict):
+ return 1
+ else:
+ return 0
+ else:
+ # else, sort all query results then compare if two results are same
+ for i in range(len(res_gold)):
+ res_gold[i] = str(res_gold[i])
+ res_gold.sort()
+ for i in range(len(res_predict)):
+ res_predict[i] = str(res_predict[i])
+ res_predict.sort()
+ if res_predict == res_gold:
+ return 1
+ else:
+ return 0
+ else:
+ return 0
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py
index 7c6331da..f1bbd7ba 100644
--- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py
@@ -18,7 +18,7 @@ def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
class GrammarEvaluator:
- def evaluate(self, query_predict, query_gold):
+ def evaluate(self, query_predict, query_gold, db_id):
error_listener = MyErrorListener()
try:
input_stream = InputStream(query_gold)
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/lgraph_standalone.json b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/lgraph_standalone.json
new file mode 100644
index 00000000..890404aa
--- /dev/null
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/lgraph_standalone.json
@@ -0,0 +1,14 @@
+{
+ "host": "0.0.0.0",
+ "port": 7073,
+ "enable_rpc": true,
+ "rpc_port": 9093,
+ "verbose": 2,
+ "log_dir": "./log",
+ "directory": "./lgraph_db",
+ "bolt_port": "9092",
+ "web": "./output/resource",
+ "ssl_auth": false,
+ "server_key": "./server-key.pem",
+ "server_cert": "./server-cert.pem"
+}
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/start.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/start.sh
new file mode 100644
index 00000000..4bd358e7
--- /dev/null
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/start.sh
@@ -0,0 +1,3 @@
+cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"
+lgraph_server -d stop
+lgraph_server -c ./lgraph_standalone.json -d start
\ No newline at end of file
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh
new file mode 100644
index 00000000..bf10a3d0
--- /dev/null
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh
@@ -0,0 +1,5 @@
+cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"
+lgraph_server -d stop
+rm -rf ./lgraph_db
+rm -rf ./log
+rm -rf ./core.*
\ No newline at end of file
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/lgraph_standalone.json b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/lgraph_standalone.json
new file mode 100644
index 00000000..1c523f24
--- /dev/null
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/lgraph_standalone.json
@@ -0,0 +1,14 @@
+{
+ "host": "0.0.0.0",
+ "port": 7075,
+ "enable_rpc": true,
+ "rpc_port": 9095,
+ "verbose": 2,
+ "log_dir": "./log",
+ "directory": "./lgraph_db",
+ "bolt_port": "9094",
+ "web": "./output/resource",
+ "ssl_auth": false,
+ "server_key": "./server-key.pem",
+ "server_cert": "./server-cert.pem"
+}
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/start.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/start.sh
new file mode 100644
index 00000000..4bd358e7
--- /dev/null
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/start.sh
@@ -0,0 +1,3 @@
+cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"
+lgraph_server -d stop
+lgraph_server -c ./lgraph_standalone.json -d start
\ No newline at end of file
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh
new file mode 100644
index 00000000..bf10a3d0
--- /dev/null
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh
@@ -0,0 +1,5 @@
+cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"
+lgraph_server -d stop
+rm -rf ./lgraph_db
+rm -rf ./log
+rm -rf ./core.*
\ No newline at end of file
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py
index 60f1f64d..67c3e602 100644
--- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py
@@ -2,5 +2,5 @@
class SimilarityEvaluator:
- def evaluate(self, query_predict, query_gold):
+ def evaluate(self, query_predict, query_gold, db_id):
return jaro.jaro_winkler_metric(query_predict, query_gold)
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py
index 322ad44c..dbf0b963 100644
--- a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py
@@ -3,7 +3,7 @@
import torch
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
-from peft.utils import CONFIG_NAME, WEIGHTS_NAME
+from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
from .config_parser import load_trainable_params
from .loggings import get_logger
@@ -59,7 +59,7 @@ def init_adapter(
if model_args.checkpoint_dir is not None:
assert os.path.exists(
- os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)
+ os.path.join(model_args.checkpoint_dir[0], SAFETENSORS_WEIGHTS_NAME)
), "Provided path ({}) does not contain a LoRA weight.".format(
model_args.checkpoint_dir[0]
)
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py
index 5bd16fc3..38a414d3 100644
--- a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py
@@ -15,7 +15,7 @@
PreTrainedTokenizer,
PreTrainedTokenizerBase,
)
-from transformers.deepspeed import is_deepspeed_zero3_enabled
+from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from transformers.utils import cached_file, check_min_version
from transformers.utils.versions import require_version
diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh
index 7b1b2db9..974f8ef4 100644
--- a/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh
+++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh
@@ -13,6 +13,14 @@ echo " Pred Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${pred_l
# --checkpoint_dir dbgpt_hub_gql/output/adapter/CodeLlama-7b-gql-lora \
# --predicted_out_filename dbgpt_hub_gql/output/pred/tugraph_analytics_dev.txt >> ${pred_log}
+# CUDA_VISIBLE_DEVICES=0,1 python dbgpt_hub_gql/predict/predict.py \
+# --model_name_or_path codellama/CodeLlama-7b-Instruct-hf \
+# --template llama2 \
+# --finetuning_type lora \
+# --predicted_input_filename dbgpt_hub_gql/data/tugraph-db/dev.json \
+# --checkpoint_dir dbgpt_hub_gql/output/adapter/CodeLlama-7b-gql-lora \
+# --predicted_out_filename dbgpt_hub_gql/output/pred/tugraph_analytics_dev.txt >> ${pred_log}
+
CUDA_VISIBLE_DEVICES=0,1 python dbgpt_hub_gql/predict/predict.py \
--model_name_or_path codellama/CodeLlama-7b-Instruct-hf \
--template llama2 \
diff --git a/src/dbgpt-hub-gql/setup.py b/src/dbgpt-hub-gql/setup.py
index 4034de6e..cd05c88c 100644
--- a/src/dbgpt-hub-gql/setup.py
+++ b/src/dbgpt-hub-gql/setup.py
@@ -22,7 +22,7 @@ def unique_extras(self) -> dict[str, list[str]]:
def core_dependencies():
setup_spec.extras["core"] = [
"transformers>=4.41.2",
- "datasets>=2.14.6",
+ "datasets>=2.14.7",
"tiktoken>=0.7.0",
"torch>=2.0.0",
"peft>=0.4.0",
@@ -76,6 +76,8 @@ def core_dependencies():
"jaro-winkler==2.0.3",
"antlr4-python3-runtime==4.13.2",
"JPype1==1.5.0",
+ "neo4j>=5.26.0",
+ "protobuf==3.20.*",
]