From 7085e7545b5617cd6eed17e4c9f9b7b6c7ef3a1f Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Fri, 8 Nov 2024 17:10:19 +0800 Subject: [PATCH 01/12] add execution evaluator --- .../dbgpt_hub_gql/eval/evaluation.py | 27 ++++- .../dbgpt_hub_gql/eval/evaluation_old.py | 101 +++++++++++++++++ .../dbgpt_hub_gql/eval/evaluator/evaluator.py | 2 +- .../tugraph-analytics/grammar_evaluator.py | 2 +- .../impl/tugraph-db/execution_evaluator.py | 106 ++++++++++++++++++ .../impl/tugraph-db/grammar_evaluator.py | 2 +- .../server/server_gold/lgraph_standalone.json | 14 +++ .../tugraph-db/server/server_gold/start.sh | 3 + .../server_predict/lgraph_standalone.json | 14 +++ .../tugraph-db/server/server_predict/start.sh | 3 + .../eval/evaluator/similarity_evaluator.py | 2 +- src/dbgpt-hub-gql/setup.py | 1 + 12 files changed, 267 insertions(+), 10 deletions(-) create mode 100644 src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation_old.py create mode 100644 src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py create mode 100644 src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/lgraph_standalone.json create mode 100644 src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/start.sh create mode 100644 src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/lgraph_standalone.json create mode 100644 src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/start.sh diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py index 8b172b6..c8a7523 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation.py @@ -7,19 +7,26 @@ from evaluator.evaluator import Evaluator from evaluator.similarity_evaluator import SimilarityEvaluator +# print(f"{os.path.dirname(os.path.abspath(__file__))}/evaluator/impl/tugraph-db") +# sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/evaluator/impl/tugraph-db") + def evaluate(gold, predict, etype, impl): log_file = open(f"{os.path.dirname(__file__)}/../output/logs/eval.log", "w") log_lines = [] with open(gold) as f: + content = f.read() + gold_list = json.loads(content) gseq_one = [] - for l in f.readlines(): - if len(l.strip()) == 0: + db_id_list = [] + for gold_dic in gold_list: + if len(gold_dic["output"].strip()) == 0: # when some predict is none, support it can continue work gseq_one.append("no out") else: - gseq_one.append(l.strip()) + gseq_one.append(gold_dic["output"].strip()) + db_id_list.append(gold_dic["db_id"].strip()) with open(predict) as f: plist = [] @@ -28,7 +35,6 @@ def evaluate(gold, predict, etype, impl): if len(l.strip()) == 0: # when some predict is none, support it can continue work pseq_one.append("no out") - else: pseq_one.append(l.strip()) @@ -38,16 +44,24 @@ def evaluate(gold, predict, etype, impl): score_total = 0 if etype == "similarity": + # jaro-winkler distance score evaluator = SimilarityEvaluator() elif etype == "grammar": + # grammar check result, 1 if pass, 0 if fail model_path = f"evaluator.impl.{impl}.grammar_evaluator" m = importlib.import_module(model_path) GrammarEvaluator = getattr(m, "GrammarEvaluator") evaluator = GrammarEvaluator() + elif etype == "execution": + # excution result, 1 if same, 0 if not same + model_path = f"evaluator.impl.{impl}.execution_evaluator" + m = importlib.import_module(model_path) + ExecutionEvaluator = getattr(m, "ExecutionEvaluator") + evaluator = ExecutionEvaluator() total = 0 for i in range(len(gseq_one)): - score = evaluator.evaluate(pseq_one[i], gseq_one[i]) + score = evaluator.evaluate(pseq_one[i], gseq_one[i], db_id_list[i]) if score != -1: score_total += score total += 1 @@ -57,6 +71,7 @@ def evaluate(gold, predict, etype, impl): tmp_log["score"] = score log_lines.append(tmp_log) + json.dump(log_lines, log_file, ensure_ascii=False, indent=4) tb = pt.PrettyTable() @@ -83,7 +98,7 @@ def evaluate(gold, predict, etype, impl): type=str, default="similarity", help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy", - choices=("similarity", "grammar"), + choices=("similarity", "grammar", "execution"), ) parser.add_argument( "--impl", diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation_old.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation_old.py new file mode 100644 index 0000000..8b172b6 --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation_old.py @@ -0,0 +1,101 @@ +import os +import sys +import argparse +import importlib +import json +import prettytable as pt +from evaluator.evaluator import Evaluator +from evaluator.similarity_evaluator import SimilarityEvaluator + + +def evaluate(gold, predict, etype, impl): + log_file = open(f"{os.path.dirname(__file__)}/../output/logs/eval.log", "w") + log_lines = [] + + with open(gold) as f: + gseq_one = [] + for l in f.readlines(): + if len(l.strip()) == 0: + # when some predict is none, support it can continue work + gseq_one.append("no out") + else: + gseq_one.append(l.strip()) + + with open(predict) as f: + plist = [] + pseq_one = [] + for l in f.readlines(): + if len(l.strip()) == 0: + # when some predict is none, support it can continue work + pseq_one.append("no out") + + else: + pseq_one.append(l.strip()) + + assert len(gseq_one) == len( + pseq_one + ), "number of predicted queries and gold standard queries must equal" + + score_total = 0 + if etype == "similarity": + evaluator = SimilarityEvaluator() + elif etype == "grammar": + model_path = f"evaluator.impl.{impl}.grammar_evaluator" + m = importlib.import_module(model_path) + GrammarEvaluator = getattr(m, "GrammarEvaluator") + evaluator = GrammarEvaluator() + + total = 0 + for i in range(len(gseq_one)): + score = evaluator.evaluate(pseq_one[i], gseq_one[i]) + if score != -1: + score_total += score + total += 1 + tmp_log = {} + tmp_log["pred"] = pseq_one[i] + tmp_log["gold"] = gseq_one[i] + tmp_log["score"] = score + log_lines.append(tmp_log) + + json.dump(log_lines, log_file, ensure_ascii=False, indent=4) + + tb = pt.PrettyTable() + tb.field_names = ["Evaluation Type", "Total Count", "Accuracy"] + tb.add_row([etype, len(gseq_one), "{:.3f}".format(score_total / total)]) + print(tb) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--input", + dest="input", + type=str, + help="the path to the input file", + required=True, + ) + parser.add_argument( + "--gold", dest="gold", type=str, help="the path to the gold queries", default="" + ) + parser.add_argument( + "--etype", + dest="etype", + type=str, + default="similarity", + help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy", + choices=("similarity", "grammar"), + ) + parser.add_argument( + "--impl", + dest="impl", + type=str, + default="tugraph-analytics", + help="implementation folder for grammar evaluator", + ) + args = parser.parse_args() + + # Print args + print(f"params as fllows \n {args}") + + # Second, evaluate the predicted GQL queries + evaluate(args.gold, args.input, args.etype, args.impl) diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py index 8b97bab..4a0c24d 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/evaluator.py @@ -1,3 +1,3 @@ class Evaluator: - def evaluate(self, query_predict, query_gold): + def evaluate(self, query_predict, query_gold, db_id): return 1 diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py index 9d474c1..9c52e90 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-analytics/grammar_evaluator.py @@ -16,7 +16,7 @@ def __init__(self): JDClass = jpype.JClass("com.antgroup.geaflow.dsl.parser.GeaFlowDSLParser") self.jd = JDClass() - def evaluate(self, query_predict, query_gold): + def evaluate(self, query_predict, query_gold, db_id): try: result_gold = self.jd.parseStatement(query_gold) try: diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py new file mode 100644 index 0000000..ff6416c --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py @@ -0,0 +1,106 @@ +import sys +import logging +import os.path +import os +import ctypes +import subprocess +import time +import json +from neo4j import GraphDatabase + +current_dir = os.path.dirname(__file__) + +class ExecutionEvaluator: + def __init__(self): + self.log = open('./exc_eval.log', 'w+') + # import datasets to 2 different data folder + dataset_list = os.listdir(f"{current_dir}/datasets") + try: + # iterate through all dataset folder under ./datasets + for dataset in dataset_list: + # import data in ci command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/datasets/{dataset}/import.sh', + f'{current_dir}/server/server_gold/lgraph_db', f'{dataset}' + ], stdout=self.log, stderr=self.log, close_fds=True) + + self.process = subprocess.run([ + 'sh', + f'{current_dir}/datasets/{dataset}/import.sh', + f'{current_dir}/server/server_predict/lgraph_db', f'{dataset}' + ], stdout=self.log, stderr=self.log, close_fds=True) + except Exception as e: + # in dev environment, start server before run tests + logging.debug(e) + + # python start 2 seperate tugraph-db server + try: + # start server in ci command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/server/server_gold/start.sh' + ], stdout=self.log, stderr=self.log, close_fds=True) + time.sleep(10) + + self.process = subprocess.run([ + 'sh', + f'{current_dir}/server/server_predict/start.sh' + ], stdout=self.log, stderr=self.log, close_fds=True) + time.sleep(10) + except Exception as e: + # in dev environment, start server before run tests + logging.debug(e) + + self.url_gold = f"bolt://localhost:9092" + self.url_predict = f"bolt://localhost:9094" + # driver for ground truth + self.driver_gold = GraphDatabase.driver(self.url_gold, auth=("admin", "73@TuGraph")) + self.driver_gold.verify_connectivity() + # driver for predict result + self.driver_predict = GraphDatabase.driver(self.url_predict, auth=("admin", "73@TuGraph")) + self.driver_predict.verify_connectivity() + + self.session_pool = {} + for dataset in dataset_list: + self.session_pool[dataset] = [] + self.session_pool[dataset].append(self.driver_gold.session(database=dataset)) + self.session_pool[dataset].append(self.driver_predict.session(database=dataset)) + + + def evaluate(self, query_predict, query_gold, db_id): + if db_id not in self.session_pool.keys(): + return 0 + + # run cypher on the server for ground truth + ret_gold = True + try: + res_gold = self.session_pool[db_id][0].run(query_gold).data() + except Exception as e: + ret_gold = False + res_gold = e + + # run cypher on the server for predict result + ret_predict = True + try: + res_predict = self.session_pool[db_id][1].run(query_predict).data() + except Exception as e: + ret_predict = False + res_predict = e + + if ret_gold == False: + return 0 + else: + if ret_predict == True: + for i in range(len(res_gold)): + res_gold[i] = str(res_gold[i]) + res_gold.sort() + for i in range(len(res_predict)): + res_predict[i] = str(res_predict[i]) + res_predict.sort() + if res_predict == res_gold: + return 1 + else: + return 0 + else: + return 0 diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py index 7c6331d..f1bbd7b 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/grammar_evaluator.py @@ -18,7 +18,7 @@ def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e): class GrammarEvaluator: - def evaluate(self, query_predict, query_gold): + def evaluate(self, query_predict, query_gold, db_id): error_listener = MyErrorListener() try: input_stream = InputStream(query_gold) diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/lgraph_standalone.json b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/lgraph_standalone.json new file mode 100644 index 0000000..890404a --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/lgraph_standalone.json @@ -0,0 +1,14 @@ +{ + "host": "0.0.0.0", + "port": 7073, + "enable_rpc": true, + "rpc_port": 9093, + "verbose": 2, + "log_dir": "./log", + "directory": "./lgraph_db", + "bolt_port": "9092", + "web": "./output/resource", + "ssl_auth": false, + "server_key": "./server-key.pem", + "server_cert": "./server-cert.pem" +} diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/start.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/start.sh new file mode 100644 index 0000000..4bd358e --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/start.sh @@ -0,0 +1,3 @@ +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")" +lgraph_server -d stop +lgraph_server -c ./lgraph_standalone.json -d start \ No newline at end of file diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/lgraph_standalone.json b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/lgraph_standalone.json new file mode 100644 index 0000000..1c523f2 --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/lgraph_standalone.json @@ -0,0 +1,14 @@ +{ + "host": "0.0.0.0", + "port": 7075, + "enable_rpc": true, + "rpc_port": 9095, + "verbose": 2, + "log_dir": "./log", + "directory": "./lgraph_db", + "bolt_port": "9094", + "web": "./output/resource", + "ssl_auth": false, + "server_key": "./server-key.pem", + "server_cert": "./server-cert.pem" +} diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/start.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/start.sh new file mode 100644 index 0000000..4bd358e --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/start.sh @@ -0,0 +1,3 @@ +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")" +lgraph_server -d stop +lgraph_server -c ./lgraph_standalone.json -d start \ No newline at end of file diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py index 60f1f64..6185418 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py @@ -3,4 +3,4 @@ class SimilarityEvaluator: def evaluate(self, query_predict, query_gold): - return jaro.jaro_winkler_metric(query_predict, query_gold) + return jaro.jaro_winkler_metric(query_predict, query_gold, db_id) diff --git a/src/dbgpt-hub-gql/setup.py b/src/dbgpt-hub-gql/setup.py index 4034de6..a770637 100644 --- a/src/dbgpt-hub-gql/setup.py +++ b/src/dbgpt-hub-gql/setup.py @@ -76,6 +76,7 @@ def core_dependencies(): "jaro-winkler==2.0.3", "antlr4-python3-runtime==4.13.2", "JPype1==1.5.0", + "neo4j>=5.26.0", ] From 66f44e8bf6a7c4bb7f429bb992da07eced5177fa Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Fri, 8 Nov 2024 17:36:42 +0800 Subject: [PATCH 02/12] compare result length when use SKIP or LIMIT --- .../impl/tugraph-db/execution_evaluator.py | 40 ++++++++++++++----- .../tugraph-db/server/server_gold/stop.sh | 3 ++ .../tugraph-db/server/server_predict/stop.sh | 3 ++ 3 files changed, 37 insertions(+), 9 deletions(-) create mode 100644 src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh create mode 100644 src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py index ff6416c..d2fac86 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py @@ -67,6 +67,22 @@ def __init__(self): self.session_pool[dataset].append(self.driver_gold.session(database=dataset)) self.session_pool[dataset].append(self.driver_predict.session(database=dataset)) + def __del__(self): + # python stop 2 seperate tugraph-db server + try: + # stop server in ci command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/server/server_gold/stop.sh' + ], stdout=self.log, stderr=self.log, close_fds=True) + + self.process = subprocess.run([ + 'sh', + f'{current_dir}/server/server_predict/stop.sh' + ], stdout=self.log, stderr=self.log, close_fds=True) + except Exception as e: + # in dev environment, start server before run tests + logging.debug(e) def evaluate(self, query_predict, query_gold, db_id): if db_id not in self.session_pool.keys(): @@ -92,15 +108,21 @@ def evaluate(self, query_predict, query_gold, db_id): return 0 else: if ret_predict == True: - for i in range(len(res_gold)): - res_gold[i] = str(res_gold[i]) - res_gold.sort() - for i in range(len(res_predict)): - res_predict[i] = str(res_predict[i]) - res_predict.sort() - if res_predict == res_gold: - return 1 + if "SKIP" in query_gold or "LIMIT" in query_gold: + if len(res_gold) == len(res_predict): + return 1 + else: + return 0 else: - return 0 + for i in range(len(res_gold)): + res_gold[i] = str(res_gold[i]) + res_gold.sort() + for i in range(len(res_predict)): + res_predict[i] = str(res_predict[i]) + res_predict.sort() + if res_predict == res_gold: + return 1 + else: + return 0 else: return 0 diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh new file mode 100644 index 0000000..170018a --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh @@ -0,0 +1,3 @@ +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")" +lgraph_server -d stop +rm -rf ./lgraph_db \ No newline at end of file diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh new file mode 100644 index 0000000..170018a --- /dev/null +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh @@ -0,0 +1,3 @@ +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")" +lgraph_server -d stop +rm -rf ./lgraph_db \ No newline at end of file From ab1a3ea4a5c9b9676f7ec7c50075355359a338a3 Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Thu, 9 Jan 2025 11:20:56 +0800 Subject: [PATCH 03/12] fix bug and add some comments for execution evaluator --- .gitignore | 2 + .../dbgpt_hub_gql/eval/evaluation_old.py | 101 ------------------ .../impl/tugraph-db/execution_evaluator.py | 78 +++++++++----- .../tugraph-db/server/server_gold/stop.sh | 4 +- .../tugraph-db/server/server_predict/stop.sh | 4 +- .../eval/evaluator/similarity_evaluator.py | 4 +- 6 files changed, 60 insertions(+), 133 deletions(-) delete mode 100644 src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation_old.py diff --git a/.gitignore b/.gitignore index 65bcdef..19840f8 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,8 @@ src/dbgpt-hub-gql/wandb/* !src/dbgpt-hub-gql/dbgpt_hub_gql/data/tugraph-db-example !src/dbgpt-hub-gql/dbgpt_hub_gql/data/dataset_info.json !src/dbgpt-hub-gql/dbgpt_hub_gql/data/example_text2sql.json +# Ignore all server's dataset folder under eval/evaluator/impl +src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/datasets # Ignore everything under dbgpt_hub_sql/ouput/ except the adapter directory src/dbgpt-hub-sql/dbgpt_hub_sql/output/ diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation_old.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation_old.py deleted file mode 100644 index 8b172b6..0000000 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluation_old.py +++ /dev/null @@ -1,101 +0,0 @@ -import os -import sys -import argparse -import importlib -import json -import prettytable as pt -from evaluator.evaluator import Evaluator -from evaluator.similarity_evaluator import SimilarityEvaluator - - -def evaluate(gold, predict, etype, impl): - log_file = open(f"{os.path.dirname(__file__)}/../output/logs/eval.log", "w") - log_lines = [] - - with open(gold) as f: - gseq_one = [] - for l in f.readlines(): - if len(l.strip()) == 0: - # when some predict is none, support it can continue work - gseq_one.append("no out") - else: - gseq_one.append(l.strip()) - - with open(predict) as f: - plist = [] - pseq_one = [] - for l in f.readlines(): - if len(l.strip()) == 0: - # when some predict is none, support it can continue work - pseq_one.append("no out") - - else: - pseq_one.append(l.strip()) - - assert len(gseq_one) == len( - pseq_one - ), "number of predicted queries and gold standard queries must equal" - - score_total = 0 - if etype == "similarity": - evaluator = SimilarityEvaluator() - elif etype == "grammar": - model_path = f"evaluator.impl.{impl}.grammar_evaluator" - m = importlib.import_module(model_path) - GrammarEvaluator = getattr(m, "GrammarEvaluator") - evaluator = GrammarEvaluator() - - total = 0 - for i in range(len(gseq_one)): - score = evaluator.evaluate(pseq_one[i], gseq_one[i]) - if score != -1: - score_total += score - total += 1 - tmp_log = {} - tmp_log["pred"] = pseq_one[i] - tmp_log["gold"] = gseq_one[i] - tmp_log["score"] = score - log_lines.append(tmp_log) - - json.dump(log_lines, log_file, ensure_ascii=False, indent=4) - - tb = pt.PrettyTable() - tb.field_names = ["Evaluation Type", "Total Count", "Accuracy"] - tb.add_row([etype, len(gseq_one), "{:.3f}".format(score_total / total)]) - print(tb) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--input", - dest="input", - type=str, - help="the path to the input file", - required=True, - ) - parser.add_argument( - "--gold", dest="gold", type=str, help="the path to the gold queries", default="" - ) - parser.add_argument( - "--etype", - dest="etype", - type=str, - default="similarity", - help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy", - choices=("similarity", "grammar"), - ) - parser.add_argument( - "--impl", - dest="impl", - type=str, - default="tugraph-analytics", - help="implementation folder for grammar evaluator", - ) - args = parser.parse_args() - - # Print args - print(f"params as fllows \n {args}") - - # Second, evaluate the predicted GQL queries - evaluate(args.gold, args.input, args.etype, args.impl) diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py index d2fac86..e747c62 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py @@ -6,82 +6,108 @@ import subprocess import time import json +import signal +import jaro from neo4j import GraphDatabase current_dir = os.path.dirname(__file__) +def handle_timeout(sig, frame): + raise TimeoutError('took too long') + +signal.signal(signal.SIGALRM, handle_timeout) + class ExecutionEvaluator: def __init__(self): self.log = open('./exc_eval.log', 'w+') # import datasets to 2 different data folder dataset_list = os.listdir(f"{current_dir}/datasets") + self.dataset_list = dataset_list try: # iterate through all dataset folder under ./datasets for dataset in dataset_list: - # import data in ci command + # import data to data folder for ground truth with cli command self.process = subprocess.run([ 'sh', f'{current_dir}/datasets/{dataset}/import.sh', f'{current_dir}/server/server_gold/lgraph_db', f'{dataset}' ], stdout=self.log, stderr=self.log, close_fds=True) + # import data to data folder for predicted result with cli command self.process = subprocess.run([ 'sh', f'{current_dir}/datasets/{dataset}/import.sh', f'{current_dir}/server/server_predict/lgraph_db', f'{dataset}' ], stdout=self.log, stderr=self.log, close_fds=True) except Exception as e: - # in dev environment, start server before run tests logging.debug(e) - # python start 2 seperate tugraph-db server + # start 2 seperate tugraph-db server try: - # start server in ci command + # start server for ground truth with cli command self.process = subprocess.run([ 'sh', f'{current_dir}/server/server_gold/start.sh' ], stdout=self.log, stderr=self.log, close_fds=True) time.sleep(10) + # start server for predcited result with cli command self.process = subprocess.run([ 'sh', f'{current_dir}/server/server_predict/start.sh' ], stdout=self.log, stderr=self.log, close_fds=True) time.sleep(10) except Exception as e: - # in dev environment, start server before run tests logging.debug(e) + # setup driver for ground truth self.url_gold = f"bolt://localhost:9092" - self.url_predict = f"bolt://localhost:9094" - # driver for ground truth self.driver_gold = GraphDatabase.driver(self.url_gold, auth=("admin", "73@TuGraph")) self.driver_gold.verify_connectivity() - # driver for predict result + + # setup driver for predicted result + self.url_predict = f"bolt://localhost:9094" self.driver_predict = GraphDatabase.driver(self.url_predict, auth=("admin", "73@TuGraph")) self.driver_predict.verify_connectivity() self.session_pool = {} - for dataset in dataset_list: + for dataset in self.dataset_list: self.session_pool[dataset] = [] self.session_pool[dataset].append(self.driver_gold.session(database=dataset)) self.session_pool[dataset].append(self.driver_predict.session(database=dataset)) def __del__(self): - # python stop 2 seperate tugraph-db server + # stop 2 seperate tugraph-db server try: - # stop server in ci command + # stop server for ground truth with cli command self.process = subprocess.run([ 'sh', f'{current_dir}/server/server_gold/stop.sh' ], stdout=self.log, stderr=self.log, close_fds=True) + # stop server for predicted result with cli command self.process = subprocess.run([ 'sh', f'{current_dir}/server/server_predict/stop.sh' ], stdout=self.log, stderr=self.log, close_fds=True) except Exception as e: - # in dev environment, start server before run tests + logging.debug(e) + + def restart_predict_server(self): + try: + # restart server for predicted result with cli command + self.process = subprocess.run([ + 'sh', + f'{current_dir}/server/server_predict/start.sh' + ], stdout=self.log, stderr=self.log, close_fds=True) + time.sleep(10) + + # setup driver for predicted result + self.driver_predict = GraphDatabase.driver(self.url_predict, auth=("admin", "73@TuGraph")) + self.driver_predict.verify_connectivity() + for dataset in self.dataset_list: + self.session_pool[dataset][1] = self.driver_predict.session(database=dataset) + except Exception as e: logging.debug(e) def evaluate(self, query_predict, query_gold, db_id): @@ -99,30 +125,26 @@ def evaluate(self, query_predict, query_gold, db_id): # run cypher on the server for predict result ret_predict = True try: + signal.alarm(10) res_predict = self.session_pool[db_id][1].run(query_predict).data() + signal.alarm(0) + except TimeoutError as e: + ret_predict = False + res_predict = e + self.restart_predict_server() except Exception as e: ret_predict = False res_predict = e + if "Couldn't connect to localhost:9094 (resolved to ()):" in str(e): + self.restart_predict_server() if ret_gold == False: - return 0 + return -1 else: if ret_predict == True: - if "SKIP" in query_gold or "LIMIT" in query_gold: - if len(res_gold) == len(res_predict): - return 1 - else: - return 0 + if len(res_gold) == len(res_predict): + return 1 else: - for i in range(len(res_gold)): - res_gold[i] = str(res_gold[i]) - res_gold.sort() - for i in range(len(res_predict)): - res_predict[i] = str(res_predict[i]) - res_predict.sort() - if res_predict == res_gold: - return 1 - else: - return 0 + return 0 else: return 0 diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh index 170018a..bf10a3d 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_gold/stop.sh @@ -1,3 +1,5 @@ cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")" lgraph_server -d stop -rm -rf ./lgraph_db \ No newline at end of file +rm -rf ./lgraph_db +rm -rf ./log +rm -rf ./core.* \ No newline at end of file diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh index 170018a..bf10a3d 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/server/server_predict/stop.sh @@ -1,3 +1,5 @@ cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")" lgraph_server -d stop -rm -rf ./lgraph_db \ No newline at end of file +rm -rf ./lgraph_db +rm -rf ./log +rm -rf ./core.* \ No newline at end of file diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py index 6185418..67c3e60 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/similarity_evaluator.py @@ -2,5 +2,5 @@ class SimilarityEvaluator: - def evaluate(self, query_predict, query_gold): - return jaro.jaro_winkler_metric(query_predict, query_gold, db_id) + def evaluate(self, query_predict, query_gold, db_id): + return jaro.jaro_winkler_metric(query_predict, query_gold) From fc1038fc7f470ed8adcbc9a5bd7e1a477589097a Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Thu, 9 Jan 2025 11:32:50 +0800 Subject: [PATCH 04/12] only compare length when SKIP or LIMIT exist --- .../impl/tugraph-db/execution_evaluator.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py index e747c62..2da3fbc 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py @@ -142,9 +142,23 @@ def evaluate(self, query_predict, query_gold, db_id): return -1 else: if ret_predict == True: - if len(res_gold) == len(res_predict): - return 1 + if "SKIP" in query_gold or "LIMIT" in query_gold: + # if SKIP or LIMIT in cypher, only compare the size of query result + if len(res_gold) == len(res_predict): + return 1 + else: + return 0 else: - return 0 + # else, sort all query results then compare if two results are same + for i in range(len(res_gold)): + res_gold[i] = str(res_gold[i]) + res_gold.sort() + for i in range(len(res_predict)): + res_predict[i] = str(res_predict[i]) + res_predict.sort() + if res_predict == res_gold: + return 1 + else: + return 0 else: return 0 From 6d6171c49569e50475f969bf973d4a50a9570cea Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Thu, 9 Jan 2025 15:33:59 +0800 Subject: [PATCH 05/12] BUGFIX: datasets 3.0.0 and above incompatible with huggingface, and adapter.bin not generated with higher version of peft --- src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py | 4 ++-- src/dbgpt-hub-gql/setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py index 322ad44..dbf0b96 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/adapter.py @@ -3,7 +3,7 @@ import torch from peft import LoraConfig, PeftModel, TaskType, get_peft_model -from peft.utils import CONFIG_NAME, WEIGHTS_NAME +from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME from .config_parser import load_trainable_params from .loggings import get_logger @@ -59,7 +59,7 @@ def init_adapter( if model_args.checkpoint_dir is not None: assert os.path.exists( - os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME) + os.path.join(model_args.checkpoint_dir[0], SAFETENSORS_WEIGHTS_NAME) ), "Provided path ({}) does not contain a LoRA weight.".format( model_args.checkpoint_dir[0] ) diff --git a/src/dbgpt-hub-gql/setup.py b/src/dbgpt-hub-gql/setup.py index a770637..0e8a496 100644 --- a/src/dbgpt-hub-gql/setup.py +++ b/src/dbgpt-hub-gql/setup.py @@ -22,7 +22,7 @@ def unique_extras(self) -> dict[str, list[str]]: def core_dependencies(): setup_spec.extras["core"] = [ "transformers>=4.41.2", - "datasets>=2.14.6", + "datasets>=2.14.7", "tiktoken>=0.7.0", "torch>=2.0.0", "peft>=0.4.0", From 658cc4a228e24fda1439380ba895b68d71cbf3b1 Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Thu, 9 Jan 2025 15:48:58 +0800 Subject: [PATCH 06/12] BUGFIX: higher version of protobuf incompatible --- src/dbgpt-hub-gql/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dbgpt-hub-gql/setup.py b/src/dbgpt-hub-gql/setup.py index 0e8a496..cd05c88 100644 --- a/src/dbgpt-hub-gql/setup.py +++ b/src/dbgpt-hub-gql/setup.py @@ -77,6 +77,7 @@ def core_dependencies(): "antlr4-python3-runtime==4.13.2", "JPype1==1.5.0", "neo4j>=5.26.0", + "protobuf==3.20.*", ] From f4b36b86c74c988b24b8a51b8835c8d40a7db955 Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Thu, 9 Jan 2025 15:49:25 +0800 Subject: [PATCH 07/12] BUGFIX: deepseed cannot be imported directly --- src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py index 5bd16fc..38a414d 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/llm_base/load_tokenizer.py @@ -15,7 +15,7 @@ PreTrainedTokenizer, PreTrainedTokenizerBase, ) -from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from transformers.utils import cached_file, check_min_version from transformers.utils.versions import require_version From bf02050ca3151aa606b0f78418334ad538f0f9b6 Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Thu, 9 Jan 2025 15:50:19 +0800 Subject: [PATCH 08/12] BUGFIX: can not run scripts if /output/logs folder is not created. --- src/dbgpt-hub-gql/README.zh.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/dbgpt-hub-gql/README.zh.md b/src/dbgpt-hub-gql/README.zh.md index c10cc04..aa22795 100644 --- a/src/dbgpt-hub-gql/README.zh.md +++ b/src/dbgpt-hub-gql/README.zh.md @@ -134,6 +134,12 @@ cd src/dbgpt-hub-gql pip install -e . ``` +创建输出及日志目录 +```bash +mkdir dbgpt_hub_gql/output +mkdir dbgpt_hub_gql/output/logs +``` + ### 3.2、模型准备 创建并进入codellama模型存放目录 ```bash From 8ae969584fc3d1de4476a1299d26f5de0afd147b Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Thu, 9 Jan 2025 16:35:49 +0800 Subject: [PATCH 09/12] BUGFIX: return -1 when db_id is not in db --- src/dbgpt-hub-gql/README.zh.md | 5 +++-- .../eval/evaluator/impl/tugraph-db/execution_evaluator.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/dbgpt-hub-gql/README.zh.md b/src/dbgpt-hub-gql/README.zh.md index aa22795..f92f8ce 100644 --- a/src/dbgpt-hub-gql/README.zh.md +++ b/src/dbgpt-hub-gql/README.zh.md @@ -138,6 +138,7 @@ pip install -e . ```bash mkdir dbgpt_hub_gql/output mkdir dbgpt_hub_gql/output/logs +mkdir dbgpt_hub_gql/output/pred ``` ### 3.2、模型准备 @@ -222,7 +223,7 @@ sh dbgpt_hub_gql/scripts/predict_sft.sh 文本相似度评估直接统计预测结果与标准结果的Jaro–Winkler distance ```bash -python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/gold_dev.txt --etype similarity +python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype similarity ``` #### 3.5.2、语法正确性评估 @@ -230,7 +231,7 @@ python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tug `tugraph-db-example`是符合`tugraph-db`的LCypher语法规则的语料数据集,语法正确性评估使用ANTLR4工具,基于`./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/Lcypher.g4`文件生成了语法解析器,用于评估模型预测结果的语法正确性。 ```bash -python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/gold_dev.txt --etype grammar --impl tugraph-db +python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype grammar --impl tugraph-db ``` ### 3.6、模型权重合并 diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py index 2da3fbc..b81fad6 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/execution_evaluator.py @@ -112,7 +112,7 @@ def restart_predict_server(self): def evaluate(self, query_predict, query_gold, db_id): if db_id not in self.session_pool.keys(): - return 0 + return -1 # run cypher on the server for ground truth ret_gold = True From 5bc719ee933bc3da8b14609dd19d35c73177c95b Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Thu, 9 Jan 2025 16:44:40 +0800 Subject: [PATCH 10/12] add tugraph-db predict config --- src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh b/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh index 7b1b2db..974f8ef 100644 --- a/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh +++ b/src/dbgpt-hub-gql/dbgpt_hub_gql/scripts/predict_sft.sh @@ -13,6 +13,14 @@ echo " Pred Start time: $(date -d @$start_time +'%Y-%m-%d %H:%M:%S')" >>${pred_l # --checkpoint_dir dbgpt_hub_gql/output/adapter/CodeLlama-7b-gql-lora \ # --predicted_out_filename dbgpt_hub_gql/output/pred/tugraph_analytics_dev.txt >> ${pred_log} +# CUDA_VISIBLE_DEVICES=0,1 python dbgpt_hub_gql/predict/predict.py \ +# --model_name_or_path codellama/CodeLlama-7b-Instruct-hf \ +# --template llama2 \ +# --finetuning_type lora \ +# --predicted_input_filename dbgpt_hub_gql/data/tugraph-db/dev.json \ +# --checkpoint_dir dbgpt_hub_gql/output/adapter/CodeLlama-7b-gql-lora \ +# --predicted_out_filename dbgpt_hub_gql/output/pred/tugraph_analytics_dev.txt >> ${pred_log} + CUDA_VISIBLE_DEVICES=0,1 python dbgpt_hub_gql/predict/predict.py \ --model_name_or_path codellama/CodeLlama-7b-Instruct-hf \ --template llama2 \ From 674c014ac6634b0d4b506f77ade43d8025d9a346 Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Wed, 15 Jan 2025 16:16:14 +0800 Subject: [PATCH 11/12] update test result on new dataset --- src/dbgpt-hub-gql/README.zh.md | 106 +++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 26 deletions(-) diff --git a/src/dbgpt-hub-gql/README.zh.md b/src/dbgpt-hub-gql/README.zh.md index f92f8ce..8e8cc8a 100644 --- a/src/dbgpt-hub-gql/README.zh.md +++ b/src/dbgpt-hub-gql/README.zh.md @@ -12,22 +12,25 @@ Method Similarity Grammar + Execution base - 0.769 - 0.703 + 0.674 + 0.653 + 0.037 Cypher (tugraph-db) TuGraph-DB Cypher数据集 CodeLlama-7B-Instruct lora - 0.928 - 0.946 + 0.922 + 0.987 + 0.507 @@ -36,6 +39,7 @@ base 0.493 0.002 + none GQL(tugraph-analytics) @@ -44,6 +48,25 @@ lora 0.935 0.984 + none + + + + + + base + 0.769 + 0.703 + 0.000 + + + Cypher (tugraph-db-example) + TuGraph-DB Cypher数据集 + CodeLlama-7B-Instruct + lora + 0.928 + 0.946 + 0.476 @@ -63,6 +86,8 @@ - [3.5、模型评估](#35模型评估) - [3.5.1、文本相似度评估](#351文本相似度评估) - [3.5.2、语法正确性评估](#352语法正确性评估) + - [3.5.3、执行结果一致性评估](#353执行结果一致性评估) + - [3.5.3.1、tugraph-db](#3531tugraph-db) - [3.6、模型权重合并](#36模型权重合并) # 一、简介 @@ -79,7 +104,7 @@ DB-GPT-GQL不仅支持了基于多个大模型的微调、预测流程,在翻 ### 2.1、数据集 -本项目样例数据集为`Cypher(tugraph-db)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,当前可使用的数据集如下: +本项目样例数据集为`Cypher(tugraph-db-example)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,当前可使用的数据集如下: - [Cypher(tugraph-db)](https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db/tugraph-db.zip): 符合tugraph-db的Cypher语法的数据集,采用“ [语法制导的语料生成策略](https://mp.weixin.qq.com/s/rZdj8TEoHZg_f4C-V4lq2A)”,将查询语言模板结合多样化的schema生成查询语言,并使用大模型泛化与之对应的自然语言问题描述后生成的语料。[语料生成框架](https://github.com/TuGraph-contrib/Awesome-Text2GQL)现已开源,欢迎参与共建。 @@ -120,50 +145,39 @@ DB-GPT-GQL目前已经支持的base模型有: ### 3.1、环境准备 -克隆项目并创建 conda 环境, ```bash +# 克隆项目并创建 conda 环境 git clone https://github.com/eosphoros-ai/DB-GPT-Hub.git cd DB-GPT-Hub conda create -n dbgpt_hub_gql python=3.10 conda activate dbgpt_hub_gql -``` -进入DB-GPT-GQL项目目录,并使用poetry安装依赖 -```bash +# 进入DB-GPT-GQL项目目录,并使用poetry安装依赖 cd src/dbgpt-hub-gql pip install -e . -``` -创建输出及日志目录 -```bash +# 创建输出及日志目录 mkdir dbgpt_hub_gql/output mkdir dbgpt_hub_gql/output/logs mkdir dbgpt_hub_gql/output/pred ``` ### 3.2、模型准备 -创建并进入codellama模型存放目录 -```bash -mkdir codellama -cd ./codellama -``` -在`codellama`文件夹下创建`download.py`文件并将如下内容复制进入python文件中 +创建`download.py`文件并将如下内容复制进入python文件中 ```python from modelscope import snapshot_download -model_dir = snapshot_download("AI-ModelScope/CodeLlama-7b-Instruct-hf") +model_dir = snapshot_download("AI-ModelScope/CodeLlama-7b-Instruct-hf", cache_dir='./') ``` - -安装python依赖并下载模型 +使用download.py下载模型 ```bash +# 安装python依赖并下载模型 pip install modelscope python3 download.py -``` -下载完成后,将模型文件软链接到`codellama`目录下 -```bash -ln -s /root/.cache/modelscope/hub/AI-ModelScope/CodeLlama-7b-Instruct-hf ./ +# 下载完成后,将AI-ModelScope/CodeLlama-7b-Instruct-hf重命名为codellama/CodeLlama-7b-Instruct-hf +mv ./AI-ModelScope ./codellama ``` ### 3.3、模型微调 @@ -216,7 +230,7 @@ sh dbgpt_hub_gql/scripts/predict_sft.sh ### 3.5、模型评估 -目前版本支持两种预测结果评估方法,第一种是基于Jaro–Winkler distance的文本相似度评估,第二种是基于`.g4`语法文件或图数据库现有语法解析器的语法正确性评估。 +目前版本支持三种预测结果评估方法,第一种是基于Jaro–Winkler distance的文本相似度评估,第二种是基于`.g4`语法文件或图数据库现有语法解析器的语法正确性评估,第三种则是基于查询语句返回结果比较大的执行一致性评估。 #### 3.5.1、文本相似度评估 @@ -234,6 +248,46 @@ python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tug python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype grammar --impl tugraph-db ``` +#### 3.5.3、执行结果一致性评估 + +当前版本仅支持tugraph-db上的执行结果一致性评估,暂未支持tugraph-analytics上的执行结果一致性评估。 + +##### 3.5.3.1、tugraph-db + +执行结果一致性评估需要实际运行tugraph-db,方便起见可以下载tugraph官方提供的runtime镜像。 + +```bash +# 下载并解压tugraph执行测试数据集 +wget -P ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db-server/datasets.zip + +unzip -d ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db ./dbgpt_hub_gql/eval/evaluator/impl/tugraph-db/datasets.zip + +# 下载并启动tugraph的runtime镜像 +docker pull tugraph/tugraph-runtime-centos7 + +docker run -it -v ./:/root/dbgpt-hub-gql --name=tugraph-db_evaluation tugraph/tugraph-runtime-centos7 /bin/bash + +cd /root + +# 安装 miniconda +mkdir -p ~/miniconda3 +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh --no-check-certificate +bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 +rm ~/miniconda3/miniconda.sh +source ~/miniconda3/bin/activate + +# 准备运行环境 +cd /root/dbgpt-hub-gql/ +conda create -n dbgpt_hub_gql python=3.10 +conda activate dbgpt_hub_gql +pip install -e . + +# 执行评测 +python dbgpt_hub_gql/eval/evaluation.py --input ./dbgpt_hub_gql/output/pred/tugraph_db_example_dev.txt --gold ./dbgpt_hub_gql/data/tugraph-db-example/dev.json --etype execution --impl tugraph-db +``` + + + ### 3.6、模型权重合并 如果你需要将训练的基础模型和微调的Peft模块的权重合并,导出一个完整的模型。则运行如下模型导出脚本: From b8595dd3b504f92e64a21456104ecd3932ce7347 Mon Sep 17 00:00:00 2001 From: Lyu Songlin Date: Wed, 22 Jan 2025 15:06:49 +0800 Subject: [PATCH 12/12] update new model link --- src/dbgpt-hub-gql/README.zh.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dbgpt-hub-gql/README.zh.md b/src/dbgpt-hub-gql/README.zh.md index 8e8cc8a..36dec99 100644 --- a/src/dbgpt-hub-gql/README.zh.md +++ b/src/dbgpt-hub-gql/README.zh.md @@ -26,7 +26,7 @@ Cypher (tugraph-db) TuGraph-DB Cypher数据集 - CodeLlama-7B-Instruct + CodeLlama-7B-Instruct lora 0.922 0.987 @@ -61,7 +61,7 @@ Cypher (tugraph-db-example) - TuGraph-DB Cypher数据集 + TuGraph-DB Cypher example数据集 CodeLlama-7B-Instruct lora 0.928 @@ -104,7 +104,7 @@ DB-GPT-GQL不仅支持了基于多个大模型的微调、预测流程,在翻 ### 2.1、数据集 -本项目样例数据集为`Cypher(tugraph-db-example)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,当前可使用的数据集如下: +本项目样例数据集为`Cypher(tugraph-db-example)`,其中包含tugraph-db提供的,可在tugraph-db上可执行的185条语料,存放在`/dbgpt_hub_gql/data/tugraph-db-example`文件夹中,仅作测试运行使用,当前可使用的完整数据集如下: - [Cypher(tugraph-db)](https://tugraph-web.oss-cn-beijing.aliyuncs.com/tugraph/datasets/text2gql/tugraph-db/tugraph-db.zip): 符合tugraph-db的Cypher语法的数据集,采用“ [语法制导的语料生成策略](https://mp.weixin.qq.com/s/rZdj8TEoHZg_f4C-V4lq2A)”,将查询语言模板结合多样化的schema生成查询语言,并使用大模型泛化与之对应的自然语言问题描述后生成的语料。[语料生成框架](https://github.com/TuGraph-contrib/Awesome-Text2GQL)现已开源,欢迎参与共建。