Skip to content

Commit

Permalink
adding small features to cost model
Browse files Browse the repository at this point in the history
  • Loading branch information
wuziniu committed Nov 2, 2023
1 parent 1b048f4 commit 2560ca2
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 17 deletions.
4 changes: 2 additions & 2 deletions readme_cost_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ python run_cost_model.py --augment_dataset --workload_runs ../data/imdb/parsed_q


## On Redshift
Provide connection details to Aurora.
Provide connection details to Redshift.
```angular2html
python run_cost_model.py \
--database redshift \
Expand All @@ -132,7 +132,7 @@ python run_cost_model.py \
```

## On Athena
Provide connection details to Aurora.
Provide connection details to Athena.
```angular2html
python run_cost_model.py \
--database athena \
Expand Down
11 changes: 10 additions & 1 deletion run_cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ def __call__(self, parser, namespace, values, option_string=None):


def parse_queries_wrapper(
database, source, source_aurora, target, cap_queries, db_name, is_brad
database: DatabaseSystem,
source: str,
source_aurora: str,
target: str,
cap_queries: int,
db_name: str,
is_brad: bool,
):
raw_plans = load_json(source)
if source_aurora is None or not os.path.exists(source_aurora):
Expand Down Expand Up @@ -163,6 +169,7 @@ def parse_queries_wrapper(
parser.add_argument("--lower_bound_runtime", type=int, default=None)
parser.add_argument("--gather_feature_statistics", action="store_true")
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--eval_on_test", action="store_true")
parser.add_argument("--save_best", action="store_true")
parser.add_argument("--train_model", action="store_true")
parser.add_argument("--is_query", action="store_true")
Expand Down Expand Up @@ -396,6 +403,7 @@ def parse_queries_wrapper(
skip_train=args.skip_train,
loss_class_name=args.loss_class_name,
save_best=args.save_best,
eval_on_test=args.eval_on_test,
)
else:
model = train_readout_hyperparams(
Expand All @@ -420,6 +428,7 @@ def parse_queries_wrapper(
skip_train=args.skip_train,
loss_class_name=args.loss_class_name,
save_best=args.save_best,
eval_on_test=args.eval_on_test,
)

if args.infer_brad:
Expand Down
92 changes: 78 additions & 14 deletions src/brad/cost_model/dataset/dataset_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def create_datasets(
lower_bound_runtime=None,
shuffle_before_split=True,
loss_class_name=None,
eval_on_test=False,
):
"""
Creating dataset of query featurization. Set read_plans=True for plan datasets
Expand All @@ -131,12 +132,16 @@ def create_datasets(

no_plans = len(data)
plan_idxs = list(range(no_plans))
if shuffle_before_split:
np.random.shuffle(plan_idxs)

train_ratio = 1 - val_ratio
split_train = int(no_plans * train_ratio)
train_idxs = plan_idxs[:split_train]
if eval_on_test:
# we don't need to create an evaluation dataset
train_idxs = plan_idxs
split_train = len(train_idxs)
else:
if shuffle_before_split:
np.random.shuffle(plan_idxs)
train_ratio = 1 - val_ratio
split_train = int(no_plans * train_ratio)
train_idxs = plan_idxs[:split_train]
# Limit number of training samples. To have comparable batch sizes, replicate remaining indexes.
if cap_training_samples is not None:
prev_train_length = len(train_idxs)
Expand All @@ -150,12 +155,13 @@ def create_datasets(
train_dataset = QueryDataset([data[i] for i in train_idxs], train_idxs)

val_dataset = None
if val_ratio > 0:
val_idxs = plan_idxs[split_train:]
if read_plans:
val_dataset = PlanDataset([data[i] for i in val_idxs], val_idxs)
else:
val_dataset = QueryDataset([data[i] for i in val_idxs], val_idxs)
if not eval_on_test:
if val_ratio > 0:
val_idxs = plan_idxs[split_train:]
if read_plans:
val_dataset = PlanDataset([data[i] for i in val_idxs], val_idxs)
else:
val_dataset = QueryDataset([data[i] for i in val_idxs], val_idxs)

# derive label normalization
runtimes = np.array([p.plan_runtime / 1000 for p in data])
Expand Down Expand Up @@ -199,6 +205,7 @@ def create_plan_dataloader(
lower_bound_runtime=None,
limit_runtime=None,
loss_class_name=None,
eval_on_test=False,
):
"""
Creates dataloaders that batches physical plans to train the model in a distributed fashion.
Expand All @@ -223,6 +230,7 @@ def create_plan_dataloader(
limit_runtime=limit_runtime,
lower_bound_num_tables=lower_bound_num_tables,
lower_bound_runtime=lower_bound_runtime,
eval_on_test=eval_on_test,
)

# postgres_plan_collator does the heavy lifting of creating the graphs and extracting the features and thus requires both
Expand All @@ -244,7 +252,10 @@ def create_plan_dataloader(
pin_memory=pin_memory,
)
train_loader = DataLoader(train_dataset, **dataloader_args)
val_loader = DataLoader(val_dataset, **dataloader_args)
if val_dataset is not None:
val_loader = DataLoader(val_dataset, **dataloader_args)
else:
val_loader = None

# for each test workoad run create a distinct test loader
test_loaders = None
Expand All @@ -269,6 +280,28 @@ def create_plan_dataloader(
dataloader_args.update(collate_fn=test_collate_fn)
test_loader = DataLoader(test_dataset, **dataloader_args)
test_loaders.append(test_loader)
if eval_on_test:
_, val_dataset, _, val_database_statistics = create_datasets(
test_workload_run_paths,
True,
loss_class_name=loss_class_name,
val_ratio=0.0,
shuffle_before_split=False,
)
val_collate_fn = functools.partial(
plan_collator,
db_statistics=val_database_statistics,
feature_statistics=feature_statistics,
plan_featurization_name=plan_featurization_name,
)
dataloader_args = dict(
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=val_collate_fn,
pin_memory=pin_memory,
)
val_loader = DataLoader(val_dataset, **dataloader_args)

return label_norm, feature_statistics, train_loader, val_loader, test_loaders

Expand All @@ -291,6 +324,7 @@ def create_query_dataloader(
lower_bound_runtime=None,
limit_runtime=None,
loss_class_name=None,
eval_on_test=False,
):
"""
Creates dataloaders that batches query featurization to train the model in a distributed fashion.
Expand All @@ -307,6 +341,7 @@ def create_query_dataloader(
limit_runtime=limit_runtime,
lower_bound_num_tables=lower_bound_num_tables,
lower_bound_runtime=lower_bound_runtime,
eval_on_test=eval_on_test,
)

# postgres_plan_collator does the heavy lifting of creating the graphs and extracting the features and thus requires both
Expand All @@ -328,7 +363,10 @@ def create_query_dataloader(
pin_memory=pin_memory,
)
train_loader = DataLoader(train_dataset, **dataloader_args)
val_loader = DataLoader(val_dataset, **dataloader_args)
if val_dataset is not None:
val_loader = DataLoader(val_dataset, **dataloader_args)
else:
val_loader = None

# for each test workoad run create a distinct test loader
test_loaders = None
Expand All @@ -355,6 +393,29 @@ def create_query_dataloader(
test_loader = DataLoader(test_dataset, **dataloader_args)
test_loaders.append(test_loader)

if eval_on_test:
_, val_dataset, _, val_database_statistics = create_datasets(
test_workload_run_paths,
False,
loss_class_name=loss_class_name,
val_ratio=0.0,
shuffle_before_split=False,
)
val_collate_fn = functools.partial(
query_collator,
db_statistics=val_database_statistics,
feature_statistics=feature_statistics,
plan_featurization_name=query_featurization_name,
)
dataloader_args = dict(
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=val_collate_fn,
pin_memory=pin_memory,
)
val_loader = DataLoader(val_dataset, **dataloader_args)

return label_norm, feature_statistics, train_loader, val_loader, test_loaders


Expand All @@ -377,6 +438,7 @@ def create_dataloader(
limit_runtime=None,
loss_class_name=None,
is_query=True,
eval_on_test=False,
):
if is_query:
return create_query_dataloader(
Expand All @@ -397,6 +459,7 @@ def create_dataloader(
lower_bound_runtime,
limit_runtime,
loss_class_name,
eval_on_test,
)
else:
return create_plan_dataloader(
Expand All @@ -417,6 +480,7 @@ def create_dataloader(
lower_bound_runtime,
limit_runtime,
loss_class_name,
eval_on_test,
)


Expand Down
6 changes: 6 additions & 0 deletions src/brad/cost_model/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def train_model(
skip_train=False,
seed=0,
save_best=False,
eval_on_test=False,
):
if model_kwargs is None:
model_kwargs = dict()
Expand Down Expand Up @@ -257,6 +258,7 @@ def train_model(
lower_bound_num_tables=lower_bound_num_tables,
lower_bound_runtime=lower_bound_runtime,
loss_class_name=loss_class_name,
eval_on_test=eval_on_test,
)

if loss_class_name == "QLoss":
Expand Down Expand Up @@ -459,6 +461,7 @@ def train_default(
max_no_epochs=None,
skip_train=False,
save_best=False,
eval_on_test=False,
):
"""
Sets default parameters and trains model
Expand Down Expand Up @@ -518,6 +521,7 @@ def train_default(
limit_queries_affected_wl=limit_queries_affected_wl,
skip_train=skip_train,
save_best=save_best,
eval_on_test=eval_on_test,
)
param_dict = flatten_dict(train_kwargs)

Expand Down Expand Up @@ -555,6 +559,7 @@ def train_readout_hyperparams(
max_no_epochs=None,
skip_train=False,
save_best=False,
eval_on_test=False,
):
"""
Reads out hyperparameters and trains model
Expand Down Expand Up @@ -623,6 +628,7 @@ def train_readout_hyperparams(
lower_bound_runtime=lower_bound_runtime,
skip_train=skip_train,
save_best=save_best,
eval_on_test=eval_on_test,
)

assert len(hyperparams) == 0, (
Expand Down

0 comments on commit 2560ca2

Please sign in to comment.