Skip to content

Commit

Permalink
Functionality catalog (#328)
Browse files Browse the repository at this point in the history
* Geospatial queries

* add geospatial queries

* add homes table to imdb_extended.yml

* fix formatting

* Geoff's comments

* add geospatial query that touches ticket_orders

* minor

* Functionality catalog, initial

* Add transations to functionality catalog

* move eyword crawling script to tools

* minor merge main

* add session to engine_for tree_based router

* determine is query transaction in query rep, not routers

* minor fixes

* fixing tests

* fixing tests

* session in routing not query rep

* checks
  • Loading branch information
ferdiko authored Nov 10, 2023
1 parent cf28b57 commit e7732c2
Show file tree
Hide file tree
Showing 10 changed files with 534 additions and 44 deletions.
17 changes: 7 additions & 10 deletions src/brad/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,10 @@ async def _run_query_impl(

# 2. Select an engine for the query.
query_rep = QueryRep(query)
transactional_query = (
session.in_transaction or query_rep.is_data_modification_query()
)
if transactional_query:
engine_to_use = Engine.Aurora
elif self._route_redshift_only:
engine_to_use = Engine.Redshift
else:
assert self._router is not None
engine_to_use = await self._router.engine_for(query_rep)
if query_rep.is_transaction_start():
session.set_in_transaction(True)
assert self._router is not None
engine_to_use = await self._router.engine_for(query_rep, session)

log_verbose(
logger,
Expand All @@ -326,6 +320,9 @@ async def _run_query_impl(

# 3. Actually execute the query.
try:
transactional_query: bool = (
session.in_transaction or query_rep.is_data_modification_query()
)
if transactional_query:
connection = session.engines.get_connection(engine_to_use)
cursor = connection.cursor_sync()
Expand Down
5 changes: 2 additions & 3 deletions src/brad/front_end/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import pytz
from datetime import datetime
from typing import Dict, Tuple, Optional

from brad.config.engine import Engine
from brad.blueprint.manager import BlueprintManager
from brad.config.file import ConfigFile
from brad.config.session import SessionId
from .engine_connections import EngineConnections
from brad.blueprint.manager import BlueprintManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,7 +58,7 @@ async def close(self):

class SessionManager:
def __init__(
self, config: ConfigFile, blueprint_mgr: BlueprintManager, schema_name: str
self, config: ConfigFile, blueprint_mgr: "BlueprintManager", schema_name: str
) -> None:
self._config = config
self._blueprint_mgr = blueprint_mgr
Expand Down
1 change: 0 additions & 1 deletion src/brad/planner/workload/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from collections import Counter
from typing import Dict, List, Tuple, Optional

from brad.blueprint import Blueprint
from brad.config.engine import Engine
from brad.query_rep import QueryRep
Expand Down
26 changes: 25 additions & 1 deletion src/brad/query_rep.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import sqlglot
import sqlglot.expressions as exp

import yaml
from importlib.resources import files, as_file
import brad.routing as routing
from brad.routing.functionality_catalog import Functionality
from typing import List, Optional

_DATA_MODIFICATION_PREFIXES = [
Expand All @@ -17,6 +20,13 @@
"TRUNCATE",
]

# Load geospatial keywords used to detect if geospatial query
_GEOSPATIAL_KEYWORDS_PATH = files(routing).joinpath("geospatial_keywords.yml")
with as_file(_GEOSPATIAL_KEYWORDS_PATH) as file:
with open(file, "r", encoding="utf8") as f:
_GEOSPATIAL_KEYWORDS = yaml.safe_load(f)
_GEOSPATIAL_KEYWORDS = [k.upper() for k in _GEOSPATIAL_KEYWORDS]


class QueryRep:
"""
Expand Down Expand Up @@ -63,6 +73,20 @@ def is_transaction_end(self) -> bool:
raw_sql = self._raw_sql_query.upper()
return raw_sql == "COMMIT" or raw_sql == "ROLLBACK"

def is_geospatial(self) -> bool:
query = self._raw_sql_query.upper()
for keyword in _GEOSPATIAL_KEYWORDS:
if keyword in query:
return True
return False

def get_required_functionality(self) -> int:
req_functionality: List[str] = []
if self.is_geospatial():
req_functionality.append(Functionality.Geospatial)

return Functionality.to_bitmap(req_functionality)

def tables(self) -> List[str]:
if self._tables is None:
if self._ast is None:
Expand Down
12 changes: 12 additions & 0 deletions src/brad/routing/engine_functionality.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
database_engines:
- name: Aurora
functionalities:
- geospatial
- transactions

- name: Athena
functionalities:
- geospatial

- name: Redshift
functionalities: []
66 changes: 66 additions & 0 deletions src/brad/routing/functionality_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import List
import operator
import yaml
from functools import reduce
from typing import Dict
from importlib.resources import files, as_file
import brad.routing as routing


class Functionality:
Geospatial = "geospatial"
Transaction = "transactions"

def __init__(self):
# Read the YAML file
functionality_yaml = files(routing).joinpath("engine_functionality.yml")
with as_file(functionality_yaml) as file:
with open(file, "r", encoding="utf8") as yaml_file:
data = yaml.load(yaml_file, Loader=yaml.FullLoader)

# Initialize lists for each database engine's functionalities
aurora_functionalities = []
athena_functionalities = []
redshift_functionalities = []

# Parse the data into the respective lists
for engine in data["database_engines"]:
if engine["name"] == "Aurora":
aurora_functionalities = engine["functionalities"]
elif engine["name"] == "Athena":
athena_functionalities = engine["functionalities"]
elif engine["name"] == "Redshift":
redshift_functionalities = engine["functionalities"]

# Convert to bitmaps
engine_functionality_strings = [
athena_functionalities,
aurora_functionalities,
redshift_functionalities,
]
self.engine_functionalities = [
Functionality.to_bitmap(f) for f in engine_functionality_strings
]

@staticmethod
def to_bitmap(functionalities: List[str]) -> int:
if len(functionalities) == 0:
return 0
return reduce(
# Bitwise OR
operator.or_,
map(lambda f: FunctionalityBitmapValues[f], functionalities),
0,
)

def get_engine_functionalities(self) -> List[int]:
"""
Return a bitmap for each engine that states what functionalities the
engine supports
"""
return self.engine_functionalities


FunctionalityBitmapValues: Dict[str, int] = {}
FunctionalityBitmapValues[Functionality.Geospatial] = 0b01
FunctionalityBitmapValues[Functionality.Transaction] = 0b10
Loading

0 comments on commit e7732c2

Please sign in to comment.