From 2560ca2b93d9c7d1473359f8764811b2b6a81c12 Mon Sep 17 00:00:00 2001 From: wuziniu Date: Thu, 2 Nov 2023 16:03:19 -0400 Subject: [PATCH] adding small features to cost model --- readme_cost_model.md | 4 +- run_cost_model.py | 11 ++- .../cost_model/dataset/dataset_creation.py | 92 ++++++++++++++++--- src/brad/cost_model/training/train.py | 6 ++ 4 files changed, 96 insertions(+), 17 deletions(-) diff --git a/readme_cost_model.md b/readme_cost_model.md index c8ba8ca2..87d3cb06 100644 --- a/readme_cost_model.md +++ b/readme_cost_model.md @@ -113,7 +113,7 @@ python run_cost_model.py --augment_dataset --workload_runs ../data/imdb/parsed_q ## On Redshift -Provide connection details to Aurora. +Provide connection details to Redshift. ```angular2html python run_cost_model.py \ --database redshift \ @@ -132,7 +132,7 @@ python run_cost_model.py \ ``` ## On Athena -Provide connection details to Aurora. +Provide connection details to Athena. ```angular2html python run_cost_model.py \ --database athena \ diff --git a/run_cost_model.py b/run_cost_model.py index 329327e3..f3dc35a1 100644 --- a/run_cost_model.py +++ b/run_cost_model.py @@ -37,7 +37,13 @@ def __call__(self, parser, namespace, values, option_string=None): def parse_queries_wrapper( - database, source, source_aurora, target, cap_queries, db_name, is_brad + database: DatabaseSystem, + source: str, + source_aurora: str, + target: str, + cap_queries: int, + db_name: str, + is_brad: bool, ): raw_plans = load_json(source) if source_aurora is None or not os.path.exists(source_aurora): @@ -163,6 +169,7 @@ def parse_queries_wrapper( parser.add_argument("--lower_bound_runtime", type=int, default=None) parser.add_argument("--gather_feature_statistics", action="store_true") parser.add_argument("--skip_train", action="store_true") + parser.add_argument("--eval_on_test", action="store_true") parser.add_argument("--save_best", action="store_true") parser.add_argument("--train_model", action="store_true") parser.add_argument("--is_query", action="store_true") @@ -396,6 +403,7 @@ def parse_queries_wrapper( skip_train=args.skip_train, loss_class_name=args.loss_class_name, save_best=args.save_best, + eval_on_test=args.eval_on_test, ) else: model = train_readout_hyperparams( @@ -420,6 +428,7 @@ def parse_queries_wrapper( skip_train=args.skip_train, loss_class_name=args.loss_class_name, save_best=args.save_best, + eval_on_test=args.eval_on_test, ) if args.infer_brad: diff --git a/src/brad/cost_model/dataset/dataset_creation.py b/src/brad/cost_model/dataset/dataset_creation.py index eb4bca91..09c8990f 100644 --- a/src/brad/cost_model/dataset/dataset_creation.py +++ b/src/brad/cost_model/dataset/dataset_creation.py @@ -114,6 +114,7 @@ def create_datasets( lower_bound_runtime=None, shuffle_before_split=True, loss_class_name=None, + eval_on_test=False, ): """ Creating dataset of query featurization. Set read_plans=True for plan datasets @@ -131,12 +132,16 @@ def create_datasets( no_plans = len(data) plan_idxs = list(range(no_plans)) - if shuffle_before_split: - np.random.shuffle(plan_idxs) - - train_ratio = 1 - val_ratio - split_train = int(no_plans * train_ratio) - train_idxs = plan_idxs[:split_train] + if eval_on_test: + # we don't need to create an evaluation dataset + train_idxs = plan_idxs + split_train = len(train_idxs) + else: + if shuffle_before_split: + np.random.shuffle(plan_idxs) + train_ratio = 1 - val_ratio + split_train = int(no_plans * train_ratio) + train_idxs = plan_idxs[:split_train] # Limit number of training samples. To have comparable batch sizes, replicate remaining indexes. if cap_training_samples is not None: prev_train_length = len(train_idxs) @@ -150,12 +155,13 @@ def create_datasets( train_dataset = QueryDataset([data[i] for i in train_idxs], train_idxs) val_dataset = None - if val_ratio > 0: - val_idxs = plan_idxs[split_train:] - if read_plans: - val_dataset = PlanDataset([data[i] for i in val_idxs], val_idxs) - else: - val_dataset = QueryDataset([data[i] for i in val_idxs], val_idxs) + if not eval_on_test: + if val_ratio > 0: + val_idxs = plan_idxs[split_train:] + if read_plans: + val_dataset = PlanDataset([data[i] for i in val_idxs], val_idxs) + else: + val_dataset = QueryDataset([data[i] for i in val_idxs], val_idxs) # derive label normalization runtimes = np.array([p.plan_runtime / 1000 for p in data]) @@ -199,6 +205,7 @@ def create_plan_dataloader( lower_bound_runtime=None, limit_runtime=None, loss_class_name=None, + eval_on_test=False, ): """ Creates dataloaders that batches physical plans to train the model in a distributed fashion. @@ -223,6 +230,7 @@ def create_plan_dataloader( limit_runtime=limit_runtime, lower_bound_num_tables=lower_bound_num_tables, lower_bound_runtime=lower_bound_runtime, + eval_on_test=eval_on_test, ) # postgres_plan_collator does the heavy lifting of creating the graphs and extracting the features and thus requires both @@ -244,7 +252,10 @@ def create_plan_dataloader( pin_memory=pin_memory, ) train_loader = DataLoader(train_dataset, **dataloader_args) - val_loader = DataLoader(val_dataset, **dataloader_args) + if val_dataset is not None: + val_loader = DataLoader(val_dataset, **dataloader_args) + else: + val_loader = None # for each test workoad run create a distinct test loader test_loaders = None @@ -269,6 +280,28 @@ def create_plan_dataloader( dataloader_args.update(collate_fn=test_collate_fn) test_loader = DataLoader(test_dataset, **dataloader_args) test_loaders.append(test_loader) + if eval_on_test: + _, val_dataset, _, val_database_statistics = create_datasets( + test_workload_run_paths, + True, + loss_class_name=loss_class_name, + val_ratio=0.0, + shuffle_before_split=False, + ) + val_collate_fn = functools.partial( + plan_collator, + db_statistics=val_database_statistics, + feature_statistics=feature_statistics, + plan_featurization_name=plan_featurization_name, + ) + dataloader_args = dict( + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=val_collate_fn, + pin_memory=pin_memory, + ) + val_loader = DataLoader(val_dataset, **dataloader_args) return label_norm, feature_statistics, train_loader, val_loader, test_loaders @@ -291,6 +324,7 @@ def create_query_dataloader( lower_bound_runtime=None, limit_runtime=None, loss_class_name=None, + eval_on_test=False, ): """ Creates dataloaders that batches query featurization to train the model in a distributed fashion. @@ -307,6 +341,7 @@ def create_query_dataloader( limit_runtime=limit_runtime, lower_bound_num_tables=lower_bound_num_tables, lower_bound_runtime=lower_bound_runtime, + eval_on_test=eval_on_test, ) # postgres_plan_collator does the heavy lifting of creating the graphs and extracting the features and thus requires both @@ -328,7 +363,10 @@ def create_query_dataloader( pin_memory=pin_memory, ) train_loader = DataLoader(train_dataset, **dataloader_args) - val_loader = DataLoader(val_dataset, **dataloader_args) + if val_dataset is not None: + val_loader = DataLoader(val_dataset, **dataloader_args) + else: + val_loader = None # for each test workoad run create a distinct test loader test_loaders = None @@ -355,6 +393,29 @@ def create_query_dataloader( test_loader = DataLoader(test_dataset, **dataloader_args) test_loaders.append(test_loader) + if eval_on_test: + _, val_dataset, _, val_database_statistics = create_datasets( + test_workload_run_paths, + False, + loss_class_name=loss_class_name, + val_ratio=0.0, + shuffle_before_split=False, + ) + val_collate_fn = functools.partial( + query_collator, + db_statistics=val_database_statistics, + feature_statistics=feature_statistics, + plan_featurization_name=query_featurization_name, + ) + dataloader_args = dict( + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=val_collate_fn, + pin_memory=pin_memory, + ) + val_loader = DataLoader(val_dataset, **dataloader_args) + return label_norm, feature_statistics, train_loader, val_loader, test_loaders @@ -377,6 +438,7 @@ def create_dataloader( limit_runtime=None, loss_class_name=None, is_query=True, + eval_on_test=False, ): if is_query: return create_query_dataloader( @@ -397,6 +459,7 @@ def create_dataloader( lower_bound_runtime, limit_runtime, loss_class_name, + eval_on_test, ) else: return create_plan_dataloader( @@ -417,6 +480,7 @@ def create_dataloader( lower_bound_runtime, limit_runtime, loss_class_name, + eval_on_test, ) diff --git a/src/brad/cost_model/training/train.py b/src/brad/cost_model/training/train.py index fbf325c0..e7d49a4e 100644 --- a/src/brad/cost_model/training/train.py +++ b/src/brad/cost_model/training/train.py @@ -212,6 +212,7 @@ def train_model( skip_train=False, seed=0, save_best=False, + eval_on_test=False, ): if model_kwargs is None: model_kwargs = dict() @@ -257,6 +258,7 @@ def train_model( lower_bound_num_tables=lower_bound_num_tables, lower_bound_runtime=lower_bound_runtime, loss_class_name=loss_class_name, + eval_on_test=eval_on_test, ) if loss_class_name == "QLoss": @@ -459,6 +461,7 @@ def train_default( max_no_epochs=None, skip_train=False, save_best=False, + eval_on_test=False, ): """ Sets default parameters and trains model @@ -518,6 +521,7 @@ def train_default( limit_queries_affected_wl=limit_queries_affected_wl, skip_train=skip_train, save_best=save_best, + eval_on_test=eval_on_test, ) param_dict = flatten_dict(train_kwargs) @@ -555,6 +559,7 @@ def train_readout_hyperparams( max_no_epochs=None, skip_train=False, save_best=False, + eval_on_test=False, ): """ Reads out hyperparameters and trains model @@ -623,6 +628,7 @@ def train_readout_hyperparams( lower_bound_runtime=lower_bound_runtime, skip_train=skip_train, save_best=save_best, + eval_on_test=eval_on_test, ) assert len(hyperparams) == 0, (