Skip to content

Commit

Permalink
Add support for data accessed predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffxy committed Nov 28, 2023
1 parent cd153fe commit c8b87fb
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 41 deletions.
17 changes: 8 additions & 9 deletions src/brad/daemon/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import logging
import queue
import os
import pathlib
import multiprocessing as mp
import numpy as np
from typing import Optional, List, Set
from typing import Optional, List, Set, Tuple

from brad.asset_manager import AssetManager
from brad.blueprint import Blueprint
Expand Down Expand Up @@ -163,18 +164,16 @@ async def _run_setup(self) -> None:
if self._temp_config is not None:
# TODO: Actually call into the models. We avoid doing so for now to
# avoid having to implement model loading, etc.
std_dataset_path = self._temp_config.std_dataset_path()
std_datasets = self._temp_config.std_datasets()
if len(std_datasets) > 0:
datasets: List[Tuple[str, str | pathlib.Path]] = [
(dataset["name"], dataset["path"]) for dataset in std_datasets
]
latency_scorer: AnalyticsLatencyScorer = (
PrecomputedPredictions.load_from_standard_dataset(
[(dataset["name"], dataset["path"]) for dataset in std_datasets]
)
PrecomputedPredictions.load_from_standard_dataset(datasets)
)
data_access_provider = (
PrecomputedDataAccessProvider.load_from_standard_dataset(
dataset_path=std_dataset_path,
)
data_access_provider: DataAccessProvider = (
PrecomputedDataAccessProvider.load_from_standard_dataset(datasets)
)
else:
latency_scorer = PrecomputedPredictions.load(
Expand Down
115 changes: 83 additions & 32 deletions src/brad/planner/scoring/data_access/precomputed_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,21 @@
import pathlib
import numpy as np
import numpy.typing as npt
from typing import Dict
from typing import Dict, Tuple, List

from .provider import DataAccessProvider
from brad.planner.workload import Workload

logger = logging.getLogger(__name__)


class PrecomputedDataAccessProvider(DataAccessProvider):
"""
Provides predictions for a fixed workload using precomputed values.
Used for debugging purposes.
"""

class QueryMap:
@classmethod
def load_from_standard_dataset(
cls,
name: str,
dataset_path: str | pathlib.Path,
):
) -> "QueryMap":
if isinstance(dataset_path, pathlib.Path):
dsp = dataset_path
else:
Expand All @@ -46,7 +42,60 @@ def load_from_standard_dataset(
assert len(aurora.shape) == 1
assert len(athena.shape) == 1

return cls(queries_map, aurora, athena)
return cls(name, queries_map, aurora, athena)

def __init__(
self,
name: str,
queries_map: Dict[str, int],
aurora_accessed_pages: npt.NDArray,
athena_accessed_bytes: npt.NDArray,
) -> None:
self.name = name
self.queries_map = queries_map
self.aurora_accessed_pages = aurora_accessed_pages
self.athena_accessed_bytes = athena_accessed_bytes

def extract_access_statistics(
self, workload: Workload
) -> Tuple[List[int], List[int], npt.NDArray, npt.NDArray]:
workload_query_index = []
indices_in_dataset = []
for wqi, query in enumerate(workload.analytical_queries()):
try:
query_str = query.raw_query.strip()
if query_str.endswith(";"):
query_str = query_str[:-1]
indices_in_dataset.append(self.queries_map[query_str])
workload_query_index.append(wqi)
except KeyError:
continue

return (
workload_query_index,
indices_in_dataset,
self.aurora_accessed_pages[indices_in_dataset],
self.athena_accessed_bytes[indices_in_dataset],
)


class PrecomputedDataAccessProvider(DataAccessProvider):
"""
Provides predictions for a fixed workload using precomputed values.
Used for debugging purposes.
"""

@classmethod
def load_from_standard_dataset(
cls,
datasets: List[Tuple[str, str | pathlib.Path]],
) -> "PrecomputedDataAccessProvider":
return cls(
[
QueryMap.load_from_standard_dataset(name, dataset_path)
for name, dataset_path in datasets
]
)

@classmethod
def load(
Expand All @@ -72,35 +121,37 @@ def load(
assert len(athena.shape) == 1
assert aurora.shape[0] == athena.shape[0]

return cls(queries_map, aurora, athena)
return cls([QueryMap("custom", queries_map, aurora, athena)])

def __init__(
self,
queries_map: Dict[str, int],
aurora_accessed_pages: npt.NDArray,
athena_accessed_bytes: npt.NDArray,
predictions: List[QueryMap],
) -> None:
self._queries_map = queries_map
self._aurora_accessed_pages = aurora_accessed_pages
self._athena_accessed_bytes = athena_accessed_bytes
self._predictions = predictions

def apply_access_statistics(self, workload: Workload) -> None:
query_indices = []
has_unmatched = False
for query in workload.analytical_queries():
try:
query_str = query.raw_query.strip()
if query_str.endswith(";"):
query_str = query_str[:-1]
query_indices.append(self._queries_map[query_str])
except KeyError:
logger.warning("Cannot match query:\n%s", query.raw_query.strip())
query_indices.append(-1)
has_unmatched = True
all_queries = workload.analytical_queries()
applied_aurora = np.ones(len(all_queries)) * np.nan
applied_athena = np.ones(len(all_queries)) * np.nan

for qm in self._predictions:
workload_indices, _, aurora, athena = qm.extract_access_statistics(workload)
applied_aurora[workload_indices] = aurora
applied_athena[workload_indices] = athena

# Special case: vector similarity queries.
special_vector_queries = []
for wqi, q in enumerate(all_queries):
if "<=>" in q.raw_query:
special_vector_queries.append(wqi)
applied_athena[special_vector_queries] = 0.0

# Check for unmatched queries.
num_unmatched_athena = np.isnan(applied_athena).sum()
if num_unmatched_athena > 0:
raise RuntimeError("Unmatched queries: " + num_unmatched_athena)

if has_unmatched:
raise RuntimeError("Workload contains unmatched queries.")
workload.set_predicted_data_access_statistics(
aurora_pages=self._aurora_accessed_pages[query_indices],
athena_bytes=self._athena_accessed_bytes[query_indices],
aurora_pages=applied_aurora,
athena_bytes=applied_athena,
)

0 comments on commit c8b87fb

Please sign in to comment.