diff --git a/.gitignore b/.gitignore index 65bcdef8..19840f80 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,8 @@ src/dbgpt-hub-gql/wandb/* !src/dbgpt-hub-gql/dbgpt_hub_gql/data/tugraph-db-example !src/dbgpt-hub-gql/dbgpt_hub_gql/data/dataset_info.json !src/dbgpt-hub-gql/dbgpt_hub_gql/data/example_text2sql.json +# Ignore all server's dataset folder under eval/evaluator/impl +src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/datasets # Ignore everything under dbgpt_hub_sql/ouput/ except the adapter directory src/dbgpt-hub-sql/dbgpt_hub_sql/output/ diff --git a/src/dbgpt-hub-gql/README.zh.md b/src/dbgpt-hub-gql/README.zh.md index c10cc04b..36dec994 100644 --- a/src/dbgpt-hub-gql/README.zh.md +++ b/src/dbgpt-hub-gql/README.zh.md @@ -12,22 +12,25 @@ Method Similarity Grammar + Execution base - 0.769 - 0.703 + 0.674 + 0.653 + 0.037 Cypher (tugraph-db) TuGraph-DB Cypher数据集 - CodeLlama-7B-Instruct + CodeLlama-7B-Instruct lora - 0.928 - 0.946 + 0.922 + 0.987 + 0.507 @@ -36,6 +39,7 @@ base 0.493 0.002 + none GQL(tugraph-analytics) @@ -44,6 +48,25 @@ lora 0.935 0.984 + none + + + + + + base + 0.769 + 0.703 + 0.000 + + + Cypher (tugraph-db-example) + TuGraph-DB Cypher example数据集 + CodeLlama-7B-Instruct + lora + 0.928 + 0.946 + 0.476 @@ -63,6 +86,8 @@ - [3.5、模型评估](#35模型评估) - [3.5.1、文本相似度评估](#351文本相似度评估) - [3.5.2、语法正确性评估](#352语法正确性评估) + - [3.5.3、执行结果一致性评估](#353执行结果一致性评估) + - [3.5.3.1、tugraph-db](#3531tugraph-db) - [3.6、模型权重合并](#36模型权重合并) # 一、简介 @@ -79,7 +104,7 @@ DB-GPT-GQL不仅支持了基于多个大模型的微调、预测流程,在翻 ### 2.1、数据集 -本项目样例数据集为`Cypher(tugraph-db)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,当前可使用的数据集如下: +本项目样例数据集为`Cypher(tugraph-db-example)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,仅作测试运行使用,当前可使用的完整数据集如下: - [Cypher(tugraph-db)](https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db/tugraph-db.zip): 符合tugraph-db的Cypher语法的数据集,采用“ [语法制导的语料生成策略](https://mp.weixin.qq.com/s/rZdj8TEoHZg_f4C-V4lq2A)”,将查询语言模板结合多样化的schema生成查询语言,并使用大模型泛化与之对应的自然语言问题描述后生成的语料。[语料生成框架](https://github.com/TuGraph-contrib/Awesome-Text2GQL)现已开源,欢迎参与共建。 @@ -120,43 +145,39 @@ DB-GPT-GQL目前已经支持的base模型有: ### 3.1、环境准备 -克隆项目并创建 conda 环境, ```bash +# 克隆项目并创建 conda 环境 git clone https://github.com/eosphoros-ai/DB-GPT-Hub.git cd DB-GPT-Hub conda create -n dbgpt_hub_gql python=3.10 conda activate dbgpt_hub_gql -``` -进入DB-GPT-GQL项目目录,并使用poetry安装依赖 -```bash +# 进入DB-GPT-GQL项目目录,并使用poetry安装依赖 cd src/dbgpt-hub-gql pip install -e . + +# 创建输出及日志目录 +mkdir dbgpt_hub_gql/output +mkdir dbgpt_hub_gql/output/logs +mkdir dbgpt_hub_gql/output/pred ``` ### 3.2、模型准备 -创建并进入codellama模型存放目录 -```bash -mkdir codellama -cd ./codellama -``` -在`codellama`文件夹下创建`download.py`文件并将如下内容复制进入python文件中 +创建`download.py`文件并将如下内容复制进入python文件中 ```python from modelscope import snapshot_download -model_dir = snapshot_download("AI-ModelScope/CodeLlama-7b-Instruct-hf") +model_dir = snapshot_download("AI-ModelScope/CodeLlama-7b-Instruct-hf", cache_dir='./') ``` - -安装python依赖并下载模型 +使用download.py下载模型 ```bash +# 安装python依赖并下载模型 pip install modelscope python3 download.py -``` -下载完成后,将模型文件软链接到`codellama`目录下 -```bash -ln -s /root/.cache/modelscope/hub/AI-ModelScope/CodeLlama-7b-Instruct-hf ./ +# 下载完成后,将AI-ModelScope/CodeLlama-7b-Instruct-hf重命名为codellama/CodeLlama-7b-Instruct-hf +mv ./AI-ModelScope ./codellama ``` ### 3.3、模型微调 @@ -209,14 +230,14 @@ sh dbgpt_hub_gql/scripts/predict_sft.sh ### 3.5、模型评估 -目前版本支持两种预测结果评估方法,第一种是基于Jaro–Winkler distance的文本相似度评估,第二种是基于`.g4`语法文件或图数据库现有语法解析器的语法正确性评估。 +目前版本支持三种预测结果评估方法,第一种是基于Jaro–Winkler distance的文本相似度评估,第二种是基于`.g4`语法文件或图数据库现有语法解析器的语法正确性评估,第三种则是基于查询语句返回结果比较大的执行一致性评估。 #### 3.5.1、文本相似度评估 文本相似度评估直接统计预测结果与标准结果的Jaro–Winkler distance ```bash -python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/gold_dev.txt --etype similarity +python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype similarity ``` #### 3.5.2、语法正确性评估 @@ -224,9 +245,49 @@ python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tug `tugraph-db-example`是符合`tugraph-db`的LCypher语法规则的语料数据集,语法正确性评估使用ANTLR4工具,基于`./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/Lcypher.g4`文件生成了语法解析器,用于评估模型预测结果的语法正确性。 ```bash -python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/gold_dev.txt --etype grammar --impl tugraph-db +python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype grammar --impl tugraph-db ``` +#### 3.5.3、执行结果一致性评估 + +当前版本仅支持tugraph-db上的执行结果一致性评估,暂未支持tugraph-analytics上的执行结果一致性评估。 + +##### 3.5.3.1、tugraph-db + +执行结果一致性评估需要实际运行tugraph-db,方便起见可以下载tugraph官方提供的runtime镜像。 + +```bash +# 下载并解压tugraph执行测试数据集 +wget -P ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db-server/datasets.zip + +unzip -d ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/datasets.zip + +# 下载并启动tugraph的runtime镜像 +docker pull tugraph/tugraph-runtime-centos7 + +docker run -it -v ./:/root/dbgpt-hub-gql --name=tugraph-db_evaluation tugraph/tugraph-runtime-centos7 /bin/bash + +cd /root + +# 安装 miniconda +mkdir -p ~/miniconda3 +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh --no-check-certificate +bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 +rm ~/miniconda3/miniconda.sh +source ~/miniconda3/bin/activate + +# 准备运行环境 +cd /root/dbgpt-hub-gql/ +conda create -n dbgpt_hub_gql python=3.10 +conda activate dbgpt_hub_gql +pip install -e . + +# 执行评测 +python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype execution --impl tugraph-db +``` + + + ### 3.6、模型权重合并 如果你需要将训练的基础模型和微调的Peft模块的权重合并,导出一个完整的模型。则运行如下模型导出脚本: diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py index 8b172b6a..c8a75238 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py @@ -7,19 +7,26 @@ from evaluator.evaluator import Evaluator from evaluator.similarity_evaluator import SimilarityEvaluator +# print(f"{os.path.dirname(os.path.abspath(__file__))}/evaluator/impl/tugraph-db") +# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/evaluator/impl/tugraph-db") + def evaluate(gold, predict, etype, impl): log_file = open(f"{os.path.dirname(__file__)}/../output/logs/eval.log", "w") log_lines = [] with open(gold) as f: + content = f.read() + gold_list = json.loads(content) gseq_one = [] - for l in f.readlines(): - if len(l.strip()) == 0: + db_id_list = [] + for gold_dic in gold_list: + if len(gold_dic["output"].strip()) == 0: # when some predict is none, support it can continue work gseq_one.append("no out") else: - gseq_one.append(l.strip()) + gseq_one.append(gold_dic["output"].strip()) + db_id_list.append(gold_dic["db_id"].strip()) with open(predict) as f: plist = [] @@ -28,7 +35,6 @@ def evaluate(gold, predict, etype, impl): if len(l.strip()) == 0: # when some predict is none, support it can continue work pseq_one.append("no out") - else: pseq_one.append(l.strip()) @@ -38,16 +44,24 @@ def evaluate(gold, predict, etype, impl): score_total = 0 if etype == "similarity": + # jaro-winkler distance score evaluator = SimilarityEvaluator() elif etype == "grammar": + # grammar check result, 1 if pass, 0 if fail model_path = f"evaluator.impl.{impl}.grammar_evaluator" m = importlib.import_module(model_path) GrammarEvaluator = getattr(m, "GrammarEvaluator") evaluator = GrammarEvaluator() + elif etype == "execution": + # excution result, 1 if same, 0 if not same + model_path = f"evaluator.impl.{impl}.execution_evaluator" + m = importlib.import_module(model_path) + ExecutionEvaluator = getattr(m, "ExecutionEvaluator") + evaluator = ExecutionEvaluator() total = 0 for i in range(len(gseq_one)): - score = evaluator.evaluate(pseq_one[i], gseq_one[i]) + score = evaluator.evaluate(pseq_one[i], gseq_one[i], db_id_list[i]) if score != -1: score_total += score total += 1 @@ -57,6 +71,7 @@ def evaluate(gold, predict, etype, impl): tmp_log["score"] = score log_lines.append(tmp_log) + json.dump(log_lines, log_file, ensure_ascii=False, indent=4) tb = pt.PrettyTable() @@ -83,7 +98,7 @@ def evaluate(gold, predict, etype, impl): type=str, default="similarity", help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy", - choices=("similarity", "grammar"), + choices=("similarity", "grammar", "execution"), ) parser.add_argument( "--impl", diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py index 8b97babf..4a0c24dc 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py @@ -1,3 +1,3 @@ class Evaluator: - def evaluate(self, query_predict, query_gold): + def evaluate(self, query_predict, query_gold, db_id): return 1 diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py index 9d474c1a..9c52e90e 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py @@ -16,7 +16,7 @@ def __init__(self): JDClass = jpype.JClass("com.antgroup.geaflow.dsl.parser.GeaFlowDSLParser") self.jd = JDClass() - def evaluate(self, query_predict, query_gold): + def evaluate(self, query_predict, query_gold, db_id): try: result_gold = self.jd.parseStatement(query_gold) try: diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py new file mode 100644 index 00000000..b81fad6e --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py @@ -0,0 +1,164 @@ +import sys +import logging +import os.path +import os +import ctypes +import subprocess +import time +import json +import signal +import jaro +from neo4j import GraphDatabase + +current_dir = os.path.dirname(__file__) + +def handle_timeout(sig, frame): + raise TimeoutError('took too long') + +signal.signal(signal.SIGALRM, handle_timeout) + +class ExecutionEvaluator: + def __init__(self): + self.log = open('./exc_eval.log', 'w+') + # import datasets to 2 different data folder + dataset_list = os.listdir(f"{current_dir}/datasets") + self.dataset_list = dataset_list + try: + # iterate through all dataset folder under ./datasets + for dataset in dataset_list: + # import data to data folder for ground truth with cli command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/datasets/{dataset}/import.sh', + f'{current_dir}/server/server_gold/lgraph_db', f'{dataset}' + ], stdout=self.log, stderr=self.log, close_fds=True) + + # import data to data folder for predicted result with cli command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/datasets/{dataset}/import.sh', + f'{current_dir}/server/server_predict/lgraph_db', f'{dataset}' + ], stdout=self.log, stderr=self.log, close_fds=True) + except Exception as e: + logging.debug(e) + + # start 2 seperate tugraph-db server + try: + # start server for ground truth with cli command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/server/server_gold/start.sh' + ], stdout=self.log, stderr=self.log, close_fds=True) + time.sleep(10) + + # start server for predcited result with cli command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/server/server_predict/start.sh' + ], stdout=self.log, stderr=self.log, close_fds=True) + time.sleep(10) + except Exception as e: + logging.debug(e) + + # setup driver for ground truth + self.url_gold = f"bolt://localhost:9092" + self.driver_gold = GraphDatabase.driver(self.url_gold, auth=("admin", "73@TuGraph")) + self.driver_gold.verify_connectivity() + + # setup driver for predicted result + self.url_predict = f"bolt://localhost:9094" + self.driver_predict = GraphDatabase.driver(self.url_predict, auth=("admin", "73@TuGraph")) + self.driver_predict.verify_connectivity() + + self.session_pool = {} + for dataset in self.dataset_list: + self.session_pool[dataset] = [] + self.session_pool[dataset].append(self.driver_gold.session(database=dataset)) + self.session_pool[dataset].append(self.driver_predict.session(database=dataset)) + + def __del__(self): + # stop 2 seperate tugraph-db server + try: + # stop server for ground truth with cli command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/server/server_gold/stop.sh' + ], stdout=self.log, stderr=self.log, close_fds=True) + + # stop server for predicted result with cli command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/server/server_predict/stop.sh' + ], stdout=self.log, stderr=self.log, close_fds=True) + except Exception as e: + logging.debug(e) + + def restart_predict_server(self): + try: + # restart server for predicted result with cli command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/server/server_predict/start.sh' + ], stdout=self.log, stderr=self.log, close_fds=True) + time.sleep(10) + + # setup driver for predicted result + self.driver_predict = GraphDatabase.driver(self.url_predict, auth=("admin", "73@TuGraph")) + self.driver_predict.verify_connectivity() + for dataset in self.dataset_list: + self.session_pool[dataset][1] = self.driver_predict.session(database=dataset) + except Exception as e: + logging.debug(e) + + def evaluate(self, query_predict, query_gold, db_id): + if db_id not in self.session_pool.keys(): + return -1 + + # run cypher on the server for ground truth + ret_gold = True + try: + res_gold = self.session_pool[db_id][0].run(query_gold).data() + except Exception as e: + ret_gold = False + res_gold = e + + # run cypher on the server for predict result + ret_predict = True + try: + signal.alarm(10) + res_predict = self.session_pool[db_id][1].run(query_predict).data() + signal.alarm(0) + except TimeoutError as e: + ret_predict = False + res_predict = e + self.restart_predict_server() + except Exception as e: + ret_predict = False + res_predict = e + if "Couldn't connect to localhost:9094 (resolved to ()):" in str(e): + self.restart_predict_server() + + if ret_gold == False: + return -1 + else: + if ret_predict == True: + if "SKIP" in query_gold or "LIMIT" in query_gold: + # if SKIP or LIMIT in cypher, only compare the size of query result + if len(res_gold) == len(res_predict): + return 1 + else: + return 0 + else: + # else, sort all query results then compare if two results are same + for i in range(len(res_gold)): + res_gold[i] = str(res_gold[i]) + res_gold.sort() + for i in range(len(res_predict)): + res_predict[i] = str(res_predict[i]) + res_predict.sort() + if res_predict == res_gold: + return 1 + else: + return 0 + else: + return 0 diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py index 7c6331da..f1bbd7ba 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py @@ -18,7 +18,7 @@ def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): class GrammarEvaluator: - def evaluate(self, query_predict, query_gold): + def evaluate(self, query_predict, query_gold, db_id): error_listener = MyErrorListener() try: input_stream = InputStream(query_gold) diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/lgraph_standalone.json b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/lgraph_standalone.json new file mode 100644 index 00000000..890404aa --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/lgraph_standalone.json @@ -0,0 +1,14 @@ +{ + "host": "0.0.0.0", + "port": 7073, + "enable_rpc": true, + "rpc_port": 9093, + "verbose": 2, + "log_dir": "./log", + "directory": "./lgraph_db", + "bolt_port": "9092", + "web": "./output/resource", + "ssl_auth": false, + "server_key": "./server-key.pem", + "server_cert": "./server-cert.pem" +} diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/start.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/start.sh new file mode 100644 index 00000000..4bd358e7 --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/start.sh @@ -0,0 +1,3 @@ +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")" +lgraph_server -d stop +lgraph_server -c ./lgraph_standalone.json -d start \ No newline at end of file diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh new file mode 100644 index 00000000..bf10a3d0 --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh @@ -0,0 +1,5 @@ +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")" +lgraph_server -d stop +rm -rf ./lgraph_db +rm -rf ./log +rm -rf ./core.* \ No newline at end of file diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/lgraph_standalone.json b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/lgraph_standalone.json new file mode 100644 index 00000000..1c523f24 --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/lgraph_standalone.json @@ -0,0 +1,14 @@ +{ + "host": "0.0.0.0", + "port": 7075, + "enable_rpc": true, + "rpc_port": 9095, + "verbose": 2, + "log_dir": "./log", + "directory": "./lgraph_db", + "bolt_port": "9094", + "web": "./output/resource", + "ssl_auth": false, + "server_key": "./server-key.pem", + "server_cert": "./server-cert.pem" +} diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/start.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/start.sh new file mode 100644 index 00000000..4bd358e7 --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/start.sh @@ -0,0 +1,3 @@ +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")" +lgraph_server -d stop +lgraph_server -c ./lgraph_standalone.json -d start \ No newline at end of file diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh new file mode 100644 index 00000000..bf10a3d0 --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh @@ -0,0 +1,5 @@ +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")" +lgraph_server -d stop +rm -rf ./lgraph_db +rm -rf ./log +rm -rf ./core.* \ No newline at end of file diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py index 60f1f64d..67c3e602 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py @@ -2,5 +2,5 @@ class SimilarityEvaluator: - def evaluate(self, query_predict, query_gold): + def evaluate(self, query_predict, query_gold, db_id): return jaro.jaro_winkler_metric(query_predict, query_gold) diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py index 322ad44c..dbf0b963 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py @@ -3,7 +3,7 @@ import torch from peft import LoraConfig, PeftModel, TaskType, get_peft_model -from peft.utils import CONFIG_NAME, WEIGHTS_NAME +from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME from .config_parser import load_trainable_params from .loggings import get_logger @@ -59,7 +59,7 @@ def init_adapter( if model_args.checkpoint_dir is not None: assert os.path.exists( - os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME) + os.path.join(model_args.checkpoint_dir[0], SAFETENSORS_WEIGHTS_NAME) ), "Provided path ({}) does not contain a LoRA weight.".format( model_args.checkpoint_dir[0] ) diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py index 5bd16fc3..38a414d3 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py @@ -15,7 +15,7 @@ PreTrainedTokenizer, PreTrainedTokenizerBase, ) -from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from transformers.utils import cached_file, check_min_version from transformers.utils.versions import require_version diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh index 7b1b2db9..974f8ef4 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh @@ -13,6 +13,14 @@ echo " Pred Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${pred_l # --checkpoint_dir dbgpt_hub_gql/output/adapter/CodeLlama-7b-gql-lora \ # --predicted_out_filename dbgpt_hub_gql/output/pred/tugraph_analytics_dev.txt >> ${pred_log} +# CUDA_VISIBLE_DEVICES=0,1 python dbgpt_hub_gql/predict/predict.py \ +# --model_name_or_path codellama/CodeLlama-7b-Instruct-hf \ +# --template llama2 \ +# --finetuning_type lora \ +# --predicted_input_filename dbgpt_hub_gql/data/tugraph-db/dev.json \ +# --checkpoint_dir dbgpt_hub_gql/output/adapter/CodeLlama-7b-gql-lora \ +# --predicted_out_filename dbgpt_hub_gql/output/pred/tugraph_analytics_dev.txt >> ${pred_log} + CUDA_VISIBLE_DEVICES=0,1 python dbgpt_hub_gql/predict/predict.py \ --model_name_or_path codellama/CodeLlama-7b-Instruct-hf \ --template llama2 \ diff --git a/src/dbgpt-hub-gql/setup.py b/src/dbgpt-hub-gql/setup.py index 4034de6e..cd05c88c 100644 --- a/src/dbgpt-hub-gql/setup.py +++ b/src/dbgpt-hub-gql/setup.py @@ -22,7 +22,7 @@ def unique_extras(self) -> dict[str, list[str]]: def core_dependencies(): setup_spec.extras["core"] = [ "transformers>=4.41.2", - "datasets>=2.14.6", + "datasets>=2.14.7", "tiktoken>=0.7.0", "torch>=2.0.0", "peft>=0.4.0", @@ -76,6 +76,8 @@ def core_dependencies(): "jaro-winkler==2.0.3", "antlr4-python3-runtime==4.13.2", "JPype1==1.5.0", + "neo4j>=5.26.0", + "protobuf==3.20.*", ]