-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfeaturize.py
228 lines (175 loc) · 6.51 KB
/
featurize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import numpy as np
JOIN_TYPES = ["Nested Loop", "Hash Join", "Merge Join"]
LEAF_TYPES = ["Seq Scan", "Index Scan", "Index Only Scan", "Bitmap Index Scan"]
ALL_TYPES = JOIN_TYPES + LEAF_TYPES
class TreeBuilderError(Exception):
def __init__(self, msg):
self.__msg = msg
def is_join(node):
return node["Node Type"] in JOIN_TYPES
def is_scan(node):
return node["Node Type"] in LEAF_TYPES
class TreeBuilder:
def __init__(self, stats_extractor, relations):
self.__stats = stats_extractor
self.__relations = sorted(relations, key=lambda x: len(x), reverse=True)
def get_stats(self):
return self.__stats
def __relation_name(self, node):
if "Relation Name" in node:
return node["Relation Name"]
if node["Node Type"] == "Bitmap Index Scan":
# find the first (longest) relation name that appears in the index name
name_key = "Index Name" if "Index Name" in node else "Relation Name"
if name_key not in node:
print(node)
raise TreeBuilderError("Bitmap operator did not have an index name or a relation name")
for rel in self.__relations:
if rel in node[name_key]:
return rel
raise TreeBuilderError("Could not find relation name for bitmap index scan")
raise TreeBuilderError("Cannot extract relation type from node")
# 连接结点特征编码成vector
def __featurize_join(self, node):
assert is_join(node)
arr = np.zeros(len(ALL_TYPES))
arr[ALL_TYPES.index(node["Node Type"])] = 1
# self.__stats(node))表示buffer、cost、row信息
return np.concatenate((arr, self.__stats(node)))
# 叶子结点特征编码成vector
def __featurize_scan(self, node):
assert is_scan(node)
arr = np.zeros(len(ALL_TYPES))
arr[ALL_TYPES.index(node["Node Type"])] = 1
return (np.concatenate((arr, self.__stats(node))),
self.__relation_name(node))
# 通过递归遍历查询计划的树状结构,将每个节点转换为特征树的节点
def plan_to_feature_tree(self, plan):
children = plan["Plans"] if "Plans" in plan else []
if len(children) == 1:
return self.plan_to_feature_tree(children[0])
if is_join(plan):
assert len(children) == 2
my_vec = self.__featurize_join(plan)
left = self.plan_to_feature_tree(children[0])
right = self.plan_to_feature_tree(children[1])
return (my_vec, left, right)
if is_scan(plan):
assert not children
return self.__featurize_scan(plan)
raise TreeBuilderError("Node wasn't transparent, a join, or a scan: " + str(plan))
def norm(x, lo, hi):
return (np.log(x + 1) - lo) / (hi - lo)
def get_buffer_count_for_leaf(leaf, buffers):
total = 0
if "Relation Name" in leaf:
total += buffers.get(leaf["Relation Name"], 0)
if "Index Name" in leaf:
total += buffers.get(leaf["Index Name"], 0)
return total
class StatExtractor:
def __init__(self, fields, mins, maxs):
self.__fields = fields
self.__mins = mins
self.__maxs = maxs
def __call__(self, inp):
res = []
for f, lo, hi in zip(self.__fields, self.__mins, self.__maxs):
if f not in inp:
res.append(0)
else:
res.append(norm(inp[f], lo, hi))
return res
def get_min_max_values(self):
min_max_values = {}
for field, min_val, max_val in zip(self.__fields, self.__mins, self.__maxs):
min_max_values[field] = (min_val, max_val)
return min_max_values
def get_plan_stats(data):
costs = []
rows = []
bufs = []
def recurse(n, buffers=None):
costs.append(n["Total Cost"])
rows.append(n["Plan Rows"])
if "Buffers" in n:
bufs.append(n["Buffers"])
if "Plans" in n:
for child in n["Plans"]:
recurse(child)
for plan in data:
recurse(plan["Plan"], buffers=plan.get("Buffers", None))
costs = np.array(costs)
rows = np.array(rows)
bufs = np.array(bufs)
costs = np.log(costs + 1)
rows = np.log(rows + 1)
bufs = np.log(bufs + 1)
costs_min = np.min(costs)
costs_max = np.max(costs)
rows_min = np.min(rows)
rows_max = np.max(rows)
bufs_min = np.min(bufs) if len(bufs) != 0 else 0
bufs_max = np.max(bufs) if len(bufs) != 0 else 0
if len(bufs) != 0:
return StatExtractor(
["Buffers", "Total Cost", "Plan Rows"],
[bufs_min, costs_min, rows_min],
[bufs_max, costs_max, rows_max]
)
else:
return StatExtractor(
["Total Cost", "Plan Rows"],
[costs_min, rows_min],
[costs_max, rows_max]
)
def get_all_relations(data):
all_rels = []
def recurse(plan):
if "Relation Name" in plan:
yield plan["Relation Name"]
if "Plans" in plan:
for child in plan["Plans"]:
yield from recurse(child)
for plan in data:
all_rels.extend(list(recurse(plan["Plan"])))
return set(all_rels)
def get_featurized_trees(data):
all_rels = get_all_relations(data)
stats_extractor = get_plan_stats(data)
t = TreeBuilder(stats_extractor, all_rels)
trees = []
for plan in data:
tree = t.plan_to_feature_tree(plan)
trees.append(tree)
return trees
# 计算每个节点的buffer数量
def _attach_buf_data(tree):
if "Buffers" not in tree:
return
buffers = tree["Buffers"]
def recurse(n):
if "Plans" in n:
for child in n["Plans"]:
recurse(child)
return
# it is a leaf
n["Buffers"] = get_buffer_count_for_leaf(n, buffers)
recurse(tree["Plan"])
class TreeFeaturizer:
def __init__(self):
self.__tree_builder = None
def get_tree_builder(self):
return self.__tree_builder
def fit(self, trees):
for t in trees:
_attach_buf_data(t)
all_rels = get_all_relations(trees)
stats_extractor = get_plan_stats(trees)
self.__tree_builder = TreeBuilder(stats_extractor, all_rels)
def transform(self, trees):
for t in trees:
_attach_buf_data(t)
return [self.__tree_builder.plan_to_feature_tree(x["Plan"]) for x in trees]
def num_operators(self):
return len(ALL_TYPES)