Skip to content

Commit

Permalink
minor fix to load test data
Browse files Browse the repository at this point in the history
  • Loading branch information
wuziniu committed Nov 28, 2023
1 parent 6f4f1b3 commit cbeca79
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 7 deletions.
30 changes: 23 additions & 7 deletions src/brad/cost_model/dataset/dataset_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def create_query_dataloader(
limit_runtime=None,
loss_class_name=None,
eval_on_test=False,
apply_constraint_on_test=False,
):
"""
Creates dataloaders that batches query featurization to train the model in a distributed fashion.
Expand Down Expand Up @@ -376,13 +377,26 @@ def create_query_dataloader(
if test_workload_run_paths is not None:
test_loaders = []
for p in test_workload_run_paths:
_, test_dataset, _, test_database_statistics = create_datasets(
[p],
False,
loss_class_name=loss_class_name,
val_ratio=0.0,
shuffle_before_split=False,
)
if apply_constraint_on_test:
_, test_dataset, _, test_database_statistics = create_datasets(
[p],
False,
loss_class_name=loss_class_name,
val_ratio=0.0,
shuffle_before_split=False,
limit_num_tables=limit_num_tables,
limit_runtime=limit_runtime,
lower_bound_num_tables=lower_bound_num_tables,
lower_bound_runtime=lower_bound_runtime,
)
else:
_, test_dataset, _, test_database_statistics = create_datasets(
[p],
False,
loss_class_name=loss_class_name,
val_ratio=0.0,
shuffle_before_split=False,
)
# test dataset
test_collate_fn = functools.partial(
query_collator,
Expand Down Expand Up @@ -443,6 +457,7 @@ def create_dataloader(
loss_class_name=None,
is_query=True,
eval_on_test=False,
apply_constraint_on_test=False,
):
if is_query:
return create_query_dataloader(
Expand All @@ -464,6 +479,7 @@ def create_dataloader(
limit_runtime,
loss_class_name,
eval_on_test,
apply_constraint_on_test,
)
else:
return create_plan_dataloader(
Expand Down
25 changes: 25 additions & 0 deletions src/brad/cost_model/training/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def training_model_loader(
database=None,
limit_queries=None,
limit_queries_affected_wl=None,
limit_num_tables=None,
lower_bound_num_tables=None,
lower_bound_runtime=None,
limit_runtime=None,
skip_train=False,
seed=0,
):
Expand Down Expand Up @@ -80,6 +84,11 @@ def training_model_loader(
limit_queries=limit_queries,
limit_queries_affected_wl=limit_queries_affected_wl,
loss_class_name=loss_class_name,
limit_num_tables=limit_num_tables,
lower_bound_num_tables=lower_bound_num_tables,
lower_bound_runtime=lower_bound_runtime,
limit_runtime=limit_runtime,
apply_constraint_on_test=True,
)

if loss_class_name == "QLoss":
Expand Down Expand Up @@ -152,6 +161,10 @@ def load_model(
seed=0,
limit_queries=None,
limit_queries_affected_wl=None,
limit_num_tables=None,
lower_bound_num_tables=None,
lower_bound_runtime=None,
limit_runtime=None,
max_no_epochs=None,
skip_train=False,
):
Expand Down Expand Up @@ -217,6 +230,10 @@ def load_model(
limit_queries=limit_queries,
limit_queries_affected_wl=limit_queries_affected_wl,
skip_train=skip_train,
limit_num_tables=limit_num_tables,
lower_bound_num_tables=lower_bound_num_tables,
lower_bound_runtime=lower_bound_runtime,
limit_runtime=limit_runtime,
)

assert len(hyperparams) == 0, (
Expand Down Expand Up @@ -312,6 +329,10 @@ def test_one_model(
hyperparameter_path,
test_workload_runs,
statistics_file,
limit_num_tables=None,
lower_bound_num_tables=None,
lower_bound_runtime=None,
limit_runtime=None,
):
test_loaders, model = load_model(
test_workload_runs,
Expand All @@ -321,6 +342,10 @@ def test_one_model(
filename_model,
hyperparameter_path,
database=database,
limit_num_tables=limit_num_tables,
lower_bound_num_tables=lower_bound_num_tables,
lower_bound_runtime=lower_bound_runtime,
limit_runtime=limit_runtime,
)
true, pred = validate_model(test_loaders[0], model)
qerror = np.maximum(true / pred, pred / true)
Expand Down
42 changes: 42 additions & 0 deletions workloads/IMDB_extended/training_data_collection_telemetry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import copy
import json
import os.path

from tqdm import tqdm
import argparse
from workloads.IMDB_extended.gen_telemetry_workload import generate_workload
Expand All @@ -8,6 +11,7 @@
from workloads.cross_db_benchmark.benchmark_tools.redshift.database_connection import (
RedshiftDatabaseConnection,
)
from workloads.cross_db_benchmark.benchmark_tools.utils import dumper


def reset_data_athena(conn):
Expand Down Expand Up @@ -204,3 +208,41 @@ def collect_train_data(
args.num_queries_per_template,
args.save_path,
)


def simulate_query_on_larger_scale(
old_parsed_queries, current_scale, target_scale, target_path=None
):
scale_factor = target_scale / current_scale
parsed_queries = copy.deepcopy(old_parsed_queries)
db_stats = parsed_queries["database_stats"]
for column_stats in db_stats["column_stats"]:
column_stats["table_size"] = int(column_stats["table_size"] * scale_factor)
for table_stats in db_stats["table_stats"]:
table_stats["reltuples"] = int(table_stats["reltuples"] * scale_factor)
parsed_queries["database_stats"] = db_stats
parsed_queries["sql_queries"] = old_parsed_queries["sql_queries"]
parsed_queries["run_kwargs"] = old_parsed_queries["run_kwargs"]
parsed_queries["skipped"] = old_parsed_queries["skipped"]
parsed_queries["parsed_plans"] = []

for q in parsed_queries["parsed_queries"]:
for table_num in q["scan_nodes"]:
scan_node_param = q["scan_nodes"][table_num]["plan_parameters"]
scan_node_param["est_card"] = int(
scan_node_param["est_card"] * scale_factor
)
scan_node_param["act_card"] = int(
scan_node_param["act_card"] * scale_factor
)
scan_node_param["est_children_card"] = int(
scan_node_param["est_children_card"] * scale_factor
)
scan_node_param["act_children_card"] = int(
scan_node_param["act_children_card"] * scale_factor
)

if target_path is not None:
with open(target_path + f"_epoch_{target_scale}.json", "w") as f:
json.dump(parsed_queries, f, default=dumper)
return parsed_queries

0 comments on commit cbeca79

Please sign in to comment.