Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ async def is_sql_valid(
timeout: float = TIMEOUT_SECONDS,
) -> Tuple[bool, str]:
sql = sql.rstrip(";") if sql.endswith(";") else sql
quoted_sql, no_error = add_quotes(sql)
assert no_error, f"Error in quoting SQL: {sql}"
quoted_sql, error = add_quotes(sql)
assert not error, f"Error in quoting SQL: {sql}, error: {error}"

if data_source == "duckdb":
async with aiohttp.request(
Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/eval/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def parse_args() -> Tuple[str, str]:
db_name = parse_db_name(path)
if "spider_" in path:
settings.eval_data_db_path = "etc/spider1.0/database"
load_eval_data_db_to_postgres(db_name, settings.eval_data_db_path)
elif "bird_" in path:
settings.eval_data_db_path = "etc/bird/minidev/MINIDEV/dev_databases"
load_eval_data_db_to_postgres(db_name, settings.eval_data_db_path)
Expand Down
12 changes: 2 additions & 10 deletions wren-ai-service/eval/preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _get_columns_by_table_index(columns, table_index):
"primaryKey": (
tables_info["column_names_original"][primary_key_column_index][-1]
if primary_key_column_index
else "",
else ""
),
"columns": _build_mdl_columns(
tables_info, i, database_info.get(table, None)
Expand Down Expand Up @@ -410,7 +410,7 @@ def get_mdls_and_question_sql_pairs_by_common_db(mdl_by_db, question_sql_pairs_b
get_contexts_from_sql(
ground_truth["sql"],
values["mdl"],
WREN_ENGINE_API_URL,
api_endpoint=WREN_ENGINE_API_URL,
)
)

Expand Down Expand Up @@ -442,14 +442,6 @@ def get_mdls_and_question_sql_pairs_by_common_db(mdl_by_db, question_sql_pairs_b
"instructions": instructions,
}
)
# else:
# print(
# "Warning: context is empty, ignore this question sql pair as of now..."
# )
# print(f"database: {db}")
# print(f'question: {ground_truth["question"]}')
# print(f'sql: {ground_truth["sql"]}')
# print()

# save eval dataset
if candidate_eval_dataset:
Expand Down
14 changes: 8 additions & 6 deletions wren-ai-service/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def get_data_from_wren_engine(
limit: Optional[int] = None,
):
quoted_sql, error = add_quotes(sql)
assert not error, f"Error in quoting SQL: {sql}"
assert not error, f"Error in quoting SQL: {sql}, error: {error}"

if data_source == "duckdb":
async with aiohttp.request(
Expand Down Expand Up @@ -157,9 +157,9 @@ async def _get_sql_analysis(
timeout: float = 300,
) -> List[dict]:
sql = sql.rstrip(";") if sql.endswith(";") else sql
quoted_sql, no_error = add_quotes(sql)
if not no_error:
print(f"Error in quoting SQL: {sql}")
quoted_sql, error = add_quotes(sql)
if error:
print(f"Error in quoting SQL: {sql}, error: {error}")
quoted_sql = sql

manifest_str = base64.b64encode(orjson.dumps(mdl_json)).decode()
Expand All @@ -175,7 +175,9 @@ async def _get_sql_analysis(
) as response:
return await response.json()

sql_analysis_results = await _get_sql_analysis(sql, mdl_json, api_endpoint, timeout)
sql_analysis_results = await _get_sql_analysis(
sql, mdl_json, api_endpoint, timeout=timeout
)
contexts = _get_contexts_from_sql_analysis_results(sql_analysis_results)
return contexts

Expand All @@ -190,7 +192,7 @@ def parse_db_name(path: str) -> str:
r"bird_(.+?)_eval_dataset\.toml|spider_(.+?)_eval_dataset\.toml", path
)
if match:
return match.group(1)
return match.group(1) or match.group(2)
else:
raise ValueError(
f"Invalid path format: {path}. Expected format: bird_<db_name>_eval_dataset.toml or spider_<db_name>_eval_dataset.toml"
Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/src/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def remove_limit_statement(sql: str) -> str:

def add_quotes(sql: str) -> Tuple[str, str]:
try:
sql = sql.replace("`", '"')
quoted_sql = sqlglot.transpile(
sql,
read=None,
Expand Down
8 changes: 4 additions & 4 deletions wren-ai-service/tools/dev/.env
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ IBIS_SERVER_PORT=8000
# version
# CHANGE THIS TO THE LATEST VERSION
WREN_PRODUCT_VERSION=development
WREN_ENGINE_VERSION=0.17.1
WREN_AI_SERVICE_VERSION=0.24.3
IBIS_SERVER_VERSION=0.17.1
WREN_UI_VERSION=0.30.0
WREN_ENGINE_VERSION=0.20.2
WREN_AI_SERVICE_VERSION=0.27.14
IBIS_SERVER_VERSION=0.20.2
WREN_UI_VERSION=0.31.2
WREN_BOOTSTRAP_VERSION=0.1.5

LAUNCH_CLI_PATH=./launch-cli.sh
Expand Down
Loading