Skip to content

Commit

Permalink
Create an estimator per session to support multiple analytical client…
Browse files Browse the repository at this point in the history
…s per front end
  • Loading branch information
geoffxy committed Nov 14, 2023
1 parent a4aa1d5 commit 77c2f29
Show file tree
Hide file tree
Showing 17 changed files with 114 additions and 57 deletions.
20 changes: 2 additions & 18 deletions src/brad/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def __init__(
self._config, self._blueprint_mgr, self._schema_name
)
self._daemon_messages_task: Optional[asyncio.Task[None]] = None
self._estimator: Optional[Estimator] = None

# Number of transactions that completed.
self._transaction_end_counter = Counter() # pylint: disable=global-statement
Expand Down Expand Up @@ -173,17 +172,7 @@ async def _run_setup(self) -> None:
self._monitor.set_up_metrics_sources()
await self._monitor.fetch_latest()

if (
self._routing_policy_override == RoutingPolicy.ForestTableSelectivity
or self._routing_policy_override == RoutingPolicy.Default
):
self._estimator = await PostgresEstimator.connect(
self._schema_name, self._config
)
await self._estimator.analyze(self._blueprint_mgr.get_blueprint())
else:
self._estimator = None
await self._set_up_router(self._estimator)
await self._set_up_router()

# Start the metrics reporting task.
self._brad_metrics_reporting_task = asyncio.create_task(
Expand All @@ -195,7 +184,7 @@ async def _run_setup(self) -> None:

self._qlogger_refresh_task = asyncio.create_task(self._refresh_qlogger())

async def _set_up_router(self, estimator: Optional[Estimator]) -> None:
async def _set_up_router(self) -> None:
# We have different routing policies for performance evaluation and
# testing purposes.
blueprint = self._blueprint_mgr.get_blueprint()
Expand Down Expand Up @@ -236,7 +225,6 @@ async def _set_up_router(self, estimator: Optional[Estimator]) -> None:
definite_policy, blueprint.table_locations_bitmap()
)

await self._router.run_setup(estimator)
self._router.log_policy()

async def _run_teardown(self):
Expand All @@ -258,10 +246,6 @@ async def _run_teardown(self):
self._qlogger_refresh_task.cancel()
self._qlogger_refresh_task = None

if self._estimator is not None:
await self._estimator.close()
self._estimator = None

async def start_session(self) -> SessionId:
rand_backoff = None
while True:
Expand Down
34 changes: 31 additions & 3 deletions src/brad/front_end/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import pytz
from datetime import datetime
from typing import Dict, Tuple, Optional

from brad.config.engine import Engine
from brad.config.file import ConfigFile
from brad.config.session import SessionId
from .engine_connections import EngineConnections
from brad.blueprint.manager import BlueprintManager
from brad.front_end.engine_connections import EngineConnections
from brad.planner.estimator import Estimator
from brad.routing.policy import RoutingPolicy
from brad.data_stats.postgres_estimator import PostgresEstimator

logger = logging.getLogger(__name__)

Expand All @@ -19,12 +23,18 @@ class Session:
`SessionManager`.
"""

def __init__(self, session_id: SessionId, engines: EngineConnections):
def __init__(
self,
session_id: SessionId,
engines: EngineConnections,
estimator: Optional[Estimator],
):
self._session_id = session_id
self._engines = engines
self._in_txn = False
self._closed = False
self._txn_start_timestamp = datetime.now(tz=pytz.utc)
self._estimator = estimator

@property
def identifier(self) -> SessionId:
Expand All @@ -38,6 +48,10 @@ def engines(self) -> EngineConnections:
def in_transaction(self) -> bool:
return self._in_txn

@property
def estimator(self) -> Optional[Estimator]:
return self._estimator

@property
def closed(self) -> bool:
return self._closed
Expand All @@ -54,6 +68,7 @@ def txn_start_timestamp(self) -> datetime:
async def close(self):
self._closed = True
await self._engines.close()
await self._estimator.close()


class SessionManager:
Expand Down Expand Up @@ -91,7 +106,20 @@ async def create_new_session(self) -> Tuple[SessionId, Session]:
specific_engines=engines,
connect_to_aurora_read_replicas=True,
)
session = Session(session_id, connections)

# Create an estimator if needed. The estimator should be
# session-specific since it currently depends on a DB connection.
routing_policy_override = self._config.routing_policy
if (
routing_policy_override == RoutingPolicy.ForestTableSelectivity
or routing_policy_override == RoutingPolicy.Default
):
estimator = await PostgresEstimator.connect(self._schema_name, self._config)
await estimator.analyze(self._blueprint_mgr.get_blueprint())
else:
estimator = None

session = Session(session_id, connections, estimator)
self._sessions[session_id] = session
logger.debug("Established a new session: %s", session_id)
return (session_id, session)
Expand Down
2 changes: 1 addition & 1 deletion src/brad/planner/beam/query_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def _run_replan_impl(
self._planner_config,
)
planning_router = Router.create_from_blueprint(self._current_blueprint)
await planning_router.run_setup(
await planning_router.run_setup_for_standalone(
self._providers.estimator_provider.get_estimator()
)
await ctx.simulate_current_workload_routing(planning_router)
Expand Down
2 changes: 1 addition & 1 deletion src/brad/planner/beam/query_based_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def _run_replan_impl(
self._planner_config,
)
planning_router = Router.create_from_blueprint(self._current_blueprint)
await planning_router.run_setup(
await planning_router.run_setup_for_standalone(
self._providers.estimator_provider.get_estimator()
)
await ctx.simulate_current_workload_routing(planning_router)
Expand Down
2 changes: 1 addition & 1 deletion src/brad/planner/beam/table_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def _run_replan_impl(
self._planner_config,
)
planning_router = Router.create_from_blueprint(self._current_blueprint)
await planning_router.run_setup(
await planning_router.run_setup_for_standalone(
self._providers.estimator_provider.get_estimator()
)
await ctx.simulate_current_workload_routing(planning_router)
Expand Down
4 changes: 3 additions & 1 deletion src/brad/planner/neighborhood/neighborhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from brad.planner.strategy import PlanningStrategy
from brad.planner.workload import Workload
from brad.provisioning.directory import Directory
from brad.routing.context import RoutingContext
from brad.routing.rule_based import RuleBased
from brad.front_end.engine_connections import EngineConnections
from brad.utils.table_sizer import TableSizer
Expand Down Expand Up @@ -195,6 +196,7 @@ def _estimate_current_data_accessed(
self, engines: EngineConnections, current_workload: Workload
) -> Dict[Engine, int]:
current_router = RuleBased()
ctx = RoutingContext()

total_accessed_mb: Dict[Engine, int] = {}
total_accessed_mb[Engine.Aurora] = 0
Expand All @@ -204,7 +206,7 @@ def _estimate_current_data_accessed(
# Compute the total amount of data accessed on each engine in the
# current workload (used to weigh the workload assigned to each engine).
for q in current_workload.analytical_queries():
current_engines = current_router.engine_for_sync(q)
current_engines = current_router.engine_for_sync(q, ctx)
assert len(current_engines) > 0
current_engine = current_engines[0]
q.populate_data_accessed_mb(
Expand Down
4 changes: 3 additions & 1 deletion src/brad/planner/neighborhood/scaling_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brad.blueprint.provisioning import Provisioning
from brad.config.engine import Engine
from brad.config.planner import PlannerConfig
from brad.routing.context import RoutingContext
from brad.routing.rule_based import RuleBased

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,10 +59,11 @@ def _simulate_next_workload(self, ctx: ScoringContext) -> None:
# NOTE: The routing policy should be included in the blueprint. We
# currently hardcode it here for engineering convenience.
router = RuleBased()
rctx = RoutingContext()

# See where each analytical query gets routed.
for q in ctx.next_workload.analytical_queries():
next_engines = router.engine_for_sync(q)
next_engines = router.engine_for_sync(q, rctx)
assert len(next_engines) > 0
next_engine = next_engines[0]
ctx.next_dest[next_engine].append(q)
Expand Down
5 changes: 3 additions & 2 deletions src/brad/planner/triggers/variable_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ async def should_replan(self) -> bool:

if ratio > self._change_ratio:
logger.info(
"Triggering replanning due to variable costs changing. Previously estimated: %.4f. Current estimated: %.4f. Change ratio: %.4f",
"Triggering replanning due to variable costs changing. Previously "
"estimated: %.4f. Current estimated: %.4f. Change ratio: %.4f",
estimated_hourly_cost,
current_hourly_cost,
self._change_ratio,
Expand Down Expand Up @@ -123,7 +124,7 @@ async def _estimate_current_scan_hourly_cost(self) -> Tuple[float, float]:
athena_query_indices: List[int] = []
athena_queries: List[Query] = []
router = Router.create_from_blueprint(self._current_blueprint)
await router.run_setup(self._estimator_provider.get_estimator())
await router.run_setup_for_standalone(self._estimator_provider.get_estimator())

for idx, q in enumerate(workload.analytical_queries()):
maybe_engine = q.primary_execution_location()
Expand Down
9 changes: 6 additions & 3 deletions src/brad/routing/abstract_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from brad.config.engine import Engine
from brad.planner.estimator import Estimator
from brad.query_rep import QueryRep
from brad.routing.context import RoutingContext


class AbstractRoutingPolicy:
Expand All @@ -21,7 +22,9 @@ async def run_setup(self, estimator: Optional[Estimator] = None) -> None:
If this routing policy needs an estimator, one should be provided here.
"""

async def engine_for(self, query_rep: QueryRep) -> List[Engine]:
async def engine_for(
self, query_rep: QueryRep, ctx: RoutingContext
) -> List[Engine]:
"""
Produces a preference order for query routing (the first element in the
list is the most preferred engine, and so on).
Expand All @@ -33,9 +36,9 @@ async def engine_for(self, query_rep: QueryRep) -> List[Engine]:
You should override this method if the routing policy needs to depend on
any asynchronous methods.
"""
return self.engine_for_sync(query_rep)
return self.engine_for_sync(query_rep, ctx)

def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]:
def engine_for_sync(self, query_rep: QueryRep, ctx: RoutingContext) -> List[Engine]:
"""
Produces a preference order for query routing (the first element in the
list is the most preferred engine, and so on).
Expand Down
3 changes: 2 additions & 1 deletion src/brad/routing/always_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from brad.config.engine import Engine
from brad.query_rep import QueryRep
from brad.routing.context import RoutingContext
from brad.routing.abstract_policy import AbstractRoutingPolicy


Expand All @@ -19,7 +20,7 @@ def __init__(self, db_type: Engine):
def name(self) -> str:
return f"AlwaysRouteTo({self._engine.name})"

def engine_for_sync(self, _query: QueryRep) -> List[Engine]:
def engine_for_sync(self, _query: QueryRep, _ctx: RoutingContext) -> List[Engine]:
return self._always_route_to

def __eq__(self, other: object) -> bool:
Expand Down
5 changes: 4 additions & 1 deletion src/brad/routing/cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from brad.planner.workload import Workload
from brad.query_rep import QueryRep
from brad.routing.abstract_policy import AbstractRoutingPolicy
from brad.routing.context import RoutingContext


class CachedLocationPolicy(AbstractRoutingPolicy):
Expand All @@ -30,7 +31,9 @@ def __init__(self, query_map: Dict[QueryRep, Engine]) -> None:
def name(self) -> str:
return "CachedLocationPolicy"

def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]:
def engine_for_sync(
self, query_rep: QueryRep, _ctx: RoutingContext
) -> List[Engine]:
try:
return [self._query_map[query_rep]]
except KeyError:
Expand Down
12 changes: 12 additions & 0 deletions src/brad/routing/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Optional

from brad.planner.estimator import Estimator


class RoutingContext:
"""
A wrapper class that holds state that should be used for routing.
"""

def __init__(self) -> None:
self.estimator: Optional[Estimator] = None
3 changes: 2 additions & 1 deletion src/brad/routing/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from brad.config.engine import Engine
from brad.query_rep import QueryRep
from brad.routing.abstract_policy import AbstractRoutingPolicy
from brad.routing.context import RoutingContext


class RoundRobin(AbstractRoutingPolicy):
Expand All @@ -16,7 +17,7 @@ def __init__(self):
def name(self) -> str:
return "RoundRobin"

def engine_for_sync(self, _query: QueryRep) -> List[Engine]:
def engine_for_sync(self, _query: QueryRep, _ctx: RoutingContext) -> List[Engine]:
tmp = self._ordering[0]
self._ordering[0] = self._ordering[1]
self._ordering[1] = self._ordering[2]
Expand Down
26 changes: 19 additions & 7 deletions src/brad/routing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import logging
from typing import Dict, Optional, TYPE_CHECKING
from brad.front_end.session import Session
from brad.routing.functionality_catalog import Functionality
from brad.data_stats.estimator import Estimator
from brad.config.engine import Engine, EngineBitmapValues
from brad.query_rep import QueryRep
from brad.routing.abstract_policy import AbstractRoutingPolicy, FullRoutingPolicy
from brad.routing.context import RoutingContext
from brad.routing.functionality_catalog import Functionality

if TYPE_CHECKING:
from brad.blueprint import Blueprint
Expand Down Expand Up @@ -44,21 +45,25 @@ def __init__(
self._use_future_blueprint_policies = use_future_blueprint_policies
self.functionality_catalog = Functionality()

# This should only be used when the router is being used in the planner.
self._shared_estimator: Optional[Estimator] = None

def log_policy(self) -> None:
logger.info("Routing policy:")
logger.info(" Indefinite policies:")
for p in self._full_policy.indefinite_policies:
logger.info(" - %s", p.name())
logger.info(" Definite policy: %s", self._full_policy.definite_policy.name())

async def run_setup(self, estimator: Optional[Estimator] = None) -> None:
async def run_setup_for_standalone(self, estimator: Optional[Estimator] = None) -> None:
"""
Should be called before using the router. This is used to set up any
dynamic state.
Should be called before using the router "standalone" contexts (i.e.,
outside the front end). This is used to set up any dynamic state that is
typically passed in via a `Session`.
If the routing policy needs an estimator, one should be provided here.
"""
await self._full_policy.run_setup(estimator)
self._shared_estimator = estimator

def update_blueprint(self, blueprint: "Blueprint") -> None:
"""
Expand Down Expand Up @@ -114,16 +119,23 @@ async def engine_for(
else:
raise RuntimeError("Unsupported bitmap value " + str(valid_locations))

# Right now, this context can be created once per session. But we may
# also want to include other shared state (e.g., metrics) that is not
# session-specific.
ctx = RoutingContext()
if session is not None:
ctx.estimator = session.estimator

# Go through the indefinite routing policies. These may not return a
# routing location.
for policy in self._full_policy.indefinite_policies:
locations = await policy.engine_for(query)
locations = await policy.engine_for(query, ctx)
for loc in locations:
if (EngineBitmapValues[loc] & valid_locations) != 0:
return loc

# Rely on the definite routing policy.
locations = await self._full_policy.definite_policy.engine_for(query)
locations = await self._full_policy.definite_policy.engine_for(query, ctx)
for loc in locations:
if (EngineBitmapValues[loc] & valid_locations) != 0:
return loc
Expand Down
Loading

0 comments on commit 77c2f29

Please sign in to comment.