diff --git a/src/brad/cost_model/dataset/dataset_creation.py b/src/brad/cost_model/dataset/dataset_creation.py index e721432d..4ad6c56c 100644 --- a/src/brad/cost_model/dataset/dataset_creation.py +++ b/src/brad/cost_model/dataset/dataset_creation.py @@ -328,6 +328,7 @@ def create_query_dataloader( limit_runtime=None, loss_class_name=None, eval_on_test=False, + apply_constraint_on_test=False, ): """ Creates dataloaders that batches query featurization to train the model in a distributed fashion. @@ -376,13 +377,26 @@ def create_query_dataloader( if test_workload_run_paths is not None: test_loaders = [] for p in test_workload_run_paths: - _, test_dataset, _, test_database_statistics = create_datasets( - [p], - False, - loss_class_name=loss_class_name, - val_ratio=0.0, - shuffle_before_split=False, - ) + if apply_constraint_on_test: + _, test_dataset, _, test_database_statistics = create_datasets( + [p], + False, + loss_class_name=loss_class_name, + val_ratio=0.0, + shuffle_before_split=False, + limit_num_tables=limit_num_tables, + limit_runtime=limit_runtime, + lower_bound_num_tables=lower_bound_num_tables, + lower_bound_runtime=lower_bound_runtime, + ) + else: + _, test_dataset, _, test_database_statistics = create_datasets( + [p], + False, + loss_class_name=loss_class_name, + val_ratio=0.0, + shuffle_before_split=False, + ) # test dataset test_collate_fn = functools.partial( query_collator, @@ -443,6 +457,7 @@ def create_dataloader( loss_class_name=None, is_query=True, eval_on_test=False, + apply_constraint_on_test=False, ): if is_query: return create_query_dataloader( @@ -464,6 +479,7 @@ def create_dataloader( limit_runtime, loss_class_name, eval_on_test, + apply_constraint_on_test, ) else: return create_plan_dataloader( diff --git a/src/brad/cost_model/training/test.py b/src/brad/cost_model/training/test.py index f4c22e43..0d4fa6c6 100644 --- a/src/brad/cost_model/training/test.py +++ b/src/brad/cost_model/training/test.py @@ -40,6 +40,10 @@ def training_model_loader( database=None, limit_queries=None, limit_queries_affected_wl=None, + limit_num_tables=None, + lower_bound_num_tables=None, + lower_bound_runtime=None, + limit_runtime=None, skip_train=False, seed=0, ): @@ -80,6 +84,11 @@ def training_model_loader( limit_queries=limit_queries, limit_queries_affected_wl=limit_queries_affected_wl, loss_class_name=loss_class_name, + limit_num_tables=limit_num_tables, + lower_bound_num_tables=lower_bound_num_tables, + lower_bound_runtime=lower_bound_runtime, + limit_runtime=limit_runtime, + apply_constraint_on_test=True, ) if loss_class_name == "QLoss": @@ -152,6 +161,10 @@ def load_model( seed=0, limit_queries=None, limit_queries_affected_wl=None, + limit_num_tables=None, + lower_bound_num_tables=None, + lower_bound_runtime=None, + limit_runtime=None, max_no_epochs=None, skip_train=False, ): @@ -217,6 +230,10 @@ def load_model( limit_queries=limit_queries, limit_queries_affected_wl=limit_queries_affected_wl, skip_train=skip_train, + limit_num_tables=limit_num_tables, + lower_bound_num_tables=lower_bound_num_tables, + lower_bound_runtime=lower_bound_runtime, + limit_runtime=limit_runtime, ) assert len(hyperparams) == 0, ( @@ -312,6 +329,10 @@ def test_one_model( hyperparameter_path, test_workload_runs, statistics_file, + limit_num_tables=None, + lower_bound_num_tables=None, + lower_bound_runtime=None, + limit_runtime=None, ): test_loaders, model = load_model( test_workload_runs, @@ -321,6 +342,10 @@ def test_one_model( filename_model, hyperparameter_path, database=database, + limit_num_tables=limit_num_tables, + lower_bound_num_tables=lower_bound_num_tables, + lower_bound_runtime=lower_bound_runtime, + limit_runtime=limit_runtime, ) true, pred = validate_model(test_loaders[0], model) qerror = np.maximum(true / pred, pred / true) diff --git a/workloads/IMDB_extended/training_data_collection_telemetry.py b/workloads/IMDB_extended/training_data_collection_telemetry.py index ae711a77..666d0e42 100644 --- a/workloads/IMDB_extended/training_data_collection_telemetry.py +++ b/workloads/IMDB_extended/training_data_collection_telemetry.py @@ -1,4 +1,7 @@ +import copy import json +import os.path + from tqdm import tqdm import argparse from workloads.IMDB_extended.gen_telemetry_workload import generate_workload @@ -8,6 +11,7 @@ from workloads.cross_db_benchmark.benchmark_tools.redshift.database_connection import ( RedshiftDatabaseConnection, ) +from workloads.cross_db_benchmark.benchmark_tools.utils import dumper def reset_data_athena(conn): @@ -204,3 +208,41 @@ def collect_train_data( args.num_queries_per_template, args.save_path, ) + + +def simulate_query_on_larger_scale( + old_parsed_queries, current_scale, target_scale, target_path=None +): + scale_factor = target_scale / current_scale + parsed_queries = copy.deepcopy(old_parsed_queries) + db_stats = parsed_queries["database_stats"] + for column_stats in db_stats["column_stats"]: + column_stats["table_size"] = int(column_stats["table_size"] * scale_factor) + for table_stats in db_stats["table_stats"]: + table_stats["reltuples"] = int(table_stats["reltuples"] * scale_factor) + parsed_queries["database_stats"] = db_stats + parsed_queries["sql_queries"] = old_parsed_queries["sql_queries"] + parsed_queries["run_kwargs"] = old_parsed_queries["run_kwargs"] + parsed_queries["skipped"] = old_parsed_queries["skipped"] + parsed_queries["parsed_plans"] = [] + + for q in parsed_queries["parsed_queries"]: + for table_num in q["scan_nodes"]: + scan_node_param = q["scan_nodes"][table_num]["plan_parameters"] + scan_node_param["est_card"] = int( + scan_node_param["est_card"] * scale_factor + ) + scan_node_param["act_card"] = int( + scan_node_param["act_card"] * scale_factor + ) + scan_node_param["est_children_card"] = int( + scan_node_param["est_children_card"] * scale_factor + ) + scan_node_param["act_children_card"] = int( + scan_node_param["act_children_card"] * scale_factor + ) + + if target_path is not None: + with open(target_path + f"_epoch_{target_scale}.json", "w") as f: + json.dump(parsed_queries, f, default=dumper) + return parsed_queries