Skip to content

Commit

Permalink
Create an estimator per session to support multiple analytical client…
Browse files Browse the repository at this point in the history
…s per front end (#368)
  • Loading branch information
geoffxy authored Nov 15, 2023
1 parent cacd841 commit a8c8e15
Show file tree
Hide file tree
Showing 18 changed files with 122 additions and 86 deletions.
22 changes: 2 additions & 20 deletions src/brad/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
NewBlueprint,
NewBlueprintAck,
)
from brad.data_stats.estimator import Estimator
from brad.data_stats.postgres_estimator import PostgresEstimator
from brad.front_end.brad_interface import BradInterface
from brad.front_end.errors import QueryError
from brad.front_end.grpc import BradGrpc
Expand Down Expand Up @@ -118,7 +116,6 @@ def __init__(
self._config, self._blueprint_mgr, self._schema_name
)
self._daemon_messages_task: Optional[asyncio.Task[None]] = None
self._estimator: Optional[Estimator] = None

# Number of transactions that completed.
self._transaction_end_counter = Counter() # pylint: disable=global-statement
Expand Down Expand Up @@ -173,17 +170,7 @@ async def _run_setup(self) -> None:
self._monitor.set_up_metrics_sources()
await self._monitor.fetch_latest()

if (
self._routing_policy_override == RoutingPolicy.ForestTableSelectivity
or self._routing_policy_override == RoutingPolicy.Default
):
self._estimator = await PostgresEstimator.connect(
self._schema_name, self._config
)
await self._estimator.analyze(self._blueprint_mgr.get_blueprint())
else:
self._estimator = None
await self._set_up_router(self._estimator)
await self._set_up_router()

# Start the metrics reporting task.
self._brad_metrics_reporting_task = asyncio.create_task(
Expand All @@ -195,7 +182,7 @@ async def _run_setup(self) -> None:

self._qlogger_refresh_task = asyncio.create_task(self._refresh_qlogger())

async def _set_up_router(self, estimator: Optional[Estimator]) -> None:
async def _set_up_router(self) -> None:
# We have different routing policies for performance evaluation and
# testing purposes.
blueprint = self._blueprint_mgr.get_blueprint()
Expand Down Expand Up @@ -236,7 +223,6 @@ async def _set_up_router(self, estimator: Optional[Estimator]) -> None:
definite_policy, blueprint.table_locations_bitmap()
)

await self._router.run_setup(estimator)
self._router.log_policy()

async def _run_teardown(self):
Expand All @@ -258,10 +244,6 @@ async def _run_teardown(self):
self._qlogger_refresh_task.cancel()
self._qlogger_refresh_task = None

if self._estimator is not None:
await self._estimator.close()
self._estimator = None

async def start_session(self) -> SessionId:
rand_backoff = None
while True:
Expand Down
34 changes: 31 additions & 3 deletions src/brad/front_end/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import pytz
from datetime import datetime
from typing import Dict, Tuple, Optional

from brad.config.engine import Engine
from brad.config.file import ConfigFile
from brad.config.session import SessionId
from .engine_connections import EngineConnections
from brad.blueprint.manager import BlueprintManager
from brad.front_end.engine_connections import EngineConnections
from brad.planner.estimator import Estimator
from brad.routing.policy import RoutingPolicy
from brad.data_stats.postgres_estimator import PostgresEstimator

logger = logging.getLogger(__name__)

Expand All @@ -19,12 +23,18 @@ class Session:
`SessionManager`.
"""

def __init__(self, session_id: SessionId, engines: EngineConnections):
def __init__(
self,
session_id: SessionId,
engines: EngineConnections,
estimator: Optional[Estimator],
):
self._session_id = session_id
self._engines = engines
self._in_txn = False
self._closed = False
self._txn_start_timestamp = datetime.now(tz=pytz.utc)
self._estimator = estimator

@property
def identifier(self) -> SessionId:
Expand All @@ -38,6 +48,10 @@ def engines(self) -> EngineConnections:
def in_transaction(self) -> bool:
return self._in_txn

@property
def estimator(self) -> Optional[Estimator]:
return self._estimator

@property
def closed(self) -> bool:
return self._closed
Expand All @@ -54,6 +68,7 @@ def txn_start_timestamp(self) -> datetime:
async def close(self):
self._closed = True
await self._engines.close()
await self._estimator.close()


class SessionManager:
Expand Down Expand Up @@ -91,7 +106,20 @@ async def create_new_session(self) -> Tuple[SessionId, Session]:
specific_engines=engines,
connect_to_aurora_read_replicas=True,
)
session = Session(session_id, connections)

# Create an estimator if needed. The estimator should be
# session-specific since it currently depends on a DB connection.
routing_policy_override = self._config.routing_policy
if (
routing_policy_override == RoutingPolicy.ForestTableSelectivity
or routing_policy_override == RoutingPolicy.Default
):
estimator = await PostgresEstimator.connect(self._schema_name, self._config)
await estimator.analyze(self._blueprint_mgr.get_blueprint())
else:
estimator = None

session = Session(session_id, connections, estimator)
self._sessions[session_id] = session
logger.debug("Established a new session: %s", session_id)
return (session_id, session)
Expand Down
2 changes: 1 addition & 1 deletion src/brad/planner/beam/query_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def _run_replan_impl(
self._planner_config,
)
planning_router = Router.create_from_blueprint(self._current_blueprint)
await planning_router.run_setup(
await planning_router.run_setup_for_standalone(
self._providers.estimator_provider.get_estimator()
)
await ctx.simulate_current_workload_routing(planning_router)
Expand Down
2 changes: 1 addition & 1 deletion src/brad/planner/beam/query_based_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def _run_replan_impl(
self._planner_config,
)
planning_router = Router.create_from_blueprint(self._current_blueprint)
await planning_router.run_setup(
await planning_router.run_setup_for_standalone(
self._providers.estimator_provider.get_estimator()
)
await ctx.simulate_current_workload_routing(planning_router)
Expand Down
2 changes: 1 addition & 1 deletion src/brad/planner/beam/table_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def _run_replan_impl(
self._planner_config,
)
planning_router = Router.create_from_blueprint(self._current_blueprint)
await planning_router.run_setup(
await planning_router.run_setup_for_standalone(
self._providers.estimator_provider.get_estimator()
)
await ctx.simulate_current_workload_routing(planning_router)
Expand Down
4 changes: 3 additions & 1 deletion src/brad/planner/neighborhood/neighborhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from brad.planner.strategy import PlanningStrategy
from brad.planner.workload import Workload
from brad.provisioning.directory import Directory
from brad.routing.context import RoutingContext
from brad.routing.rule_based import RuleBased
from brad.front_end.engine_connections import EngineConnections
from brad.utils.table_sizer import TableSizer
Expand Down Expand Up @@ -195,6 +196,7 @@ def _estimate_current_data_accessed(
self, engines: EngineConnections, current_workload: Workload
) -> Dict[Engine, int]:
current_router = RuleBased()
ctx = RoutingContext()

total_accessed_mb: Dict[Engine, int] = {}
total_accessed_mb[Engine.Aurora] = 0
Expand All @@ -204,7 +206,7 @@ def _estimate_current_data_accessed(
# Compute the total amount of data accessed on each engine in the
# current workload (used to weigh the workload assigned to each engine).
for q in current_workload.analytical_queries():
current_engines = current_router.engine_for_sync(q)
current_engines = current_router.engine_for_sync(q, ctx)
assert len(current_engines) > 0
current_engine = current_engines[0]
q.populate_data_accessed_mb(
Expand Down
4 changes: 3 additions & 1 deletion src/brad/planner/neighborhood/scaling_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from brad.blueprint.provisioning import Provisioning
from brad.config.engine import Engine
from brad.config.planner import PlannerConfig
from brad.routing.context import RoutingContext
from brad.routing.rule_based import RuleBased

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,10 +59,11 @@ def _simulate_next_workload(self, ctx: ScoringContext) -> None:
# NOTE: The routing policy should be included in the blueprint. We
# currently hardcode it here for engineering convenience.
router = RuleBased()
rctx = RoutingContext()

# See where each analytical query gets routed.
for q in ctx.next_workload.analytical_queries():
next_engines = router.engine_for_sync(q)
next_engines = router.engine_for_sync(q, rctx)
assert len(next_engines) > 0
next_engine = next_engines[0]
ctx.next_dest[next_engine].append(q)
Expand Down
5 changes: 3 additions & 2 deletions src/brad/planner/triggers/variable_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ async def should_replan(self) -> bool:

if ratio > self._change_ratio:
logger.info(
"Triggering replanning due to variable costs changing. Previously estimated: %.4f. Current estimated: %.4f. Change ratio: %.4f",
"Triggering replanning due to variable costs changing. Previously "
"estimated: %.4f. Current estimated: %.4f. Change ratio: %.4f",
estimated_hourly_cost,
current_hourly_cost,
self._change_ratio,
Expand Down Expand Up @@ -123,7 +124,7 @@ async def _estimate_current_scan_hourly_cost(self) -> Tuple[float, float]:
athena_query_indices: List[int] = []
athena_queries: List[Query] = []
router = Router.create_from_blueprint(self._current_blueprint)
await router.run_setup(self._estimator_provider.get_estimator())
await router.run_setup_for_standalone(self._estimator_provider.get_estimator())

for idx, q in enumerate(workload.analytical_queries()):
maybe_engine = q.primary_execution_location()
Expand Down
31 changes: 7 additions & 24 deletions src/brad/routing/abstract_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Optional
from typing import List

from brad.config.engine import Engine
from brad.planner.estimator import Estimator
from brad.query_rep import QueryRep
from brad.routing.context import RoutingContext


class AbstractRoutingPolicy:
Expand All @@ -13,15 +13,9 @@ class AbstractRoutingPolicy:
def name(self) -> str:
raise NotImplementedError

async def run_setup(self, estimator: Optional[Estimator] = None) -> None:
"""
Should be called before using this policy. This is used to set up any
dynamic state.
If this routing policy needs an estimator, one should be provided here.
"""

async def engine_for(self, query_rep: QueryRep) -> List[Engine]:
async def engine_for(
self, query_rep: QueryRep, ctx: RoutingContext
) -> List[Engine]:
"""
Produces a preference order for query routing (the first element in the
list is the most preferred engine, and so on).
Expand All @@ -33,9 +27,9 @@ async def engine_for(self, query_rep: QueryRep) -> List[Engine]:
You should override this method if the routing policy needs to depend on
any asynchronous methods.
"""
return self.engine_for_sync(query_rep)
return self.engine_for_sync(query_rep, ctx)

def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]:
def engine_for_sync(self, query_rep: QueryRep, ctx: RoutingContext) -> List[Engine]:
"""
Produces a preference order for query routing (the first element in the
list is the most preferred engine, and so on).
Expand All @@ -62,17 +56,6 @@ def __init__(
self.indefinite_policies = indefinite_policies
self.definite_policy = definite_policy

async def run_setup(self, estimator: Optional[Estimator] = None) -> None:
"""
Should be called before using the policy. This is used to set up any
dynamic state.
If this routing policy needs an estimator, one should be provided here.
"""
for policy in self.indefinite_policies:
await policy.run_setup(estimator)
await self.definite_policy.run_setup(estimator)

def __eq__(self, other: object):
if not isinstance(other, FullRoutingPolicy):
return False
Expand Down
3 changes: 2 additions & 1 deletion src/brad/routing/always_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from brad.config.engine import Engine
from brad.query_rep import QueryRep
from brad.routing.context import RoutingContext
from brad.routing.abstract_policy import AbstractRoutingPolicy


Expand All @@ -19,7 +20,7 @@ def __init__(self, db_type: Engine):
def name(self) -> str:
return f"AlwaysRouteTo({self._engine.name})"

def engine_for_sync(self, _query: QueryRep) -> List[Engine]:
def engine_for_sync(self, _query: QueryRep, _ctx: RoutingContext) -> List[Engine]:
return self._always_route_to

def __eq__(self, other: object) -> bool:
Expand Down
5 changes: 4 additions & 1 deletion src/brad/routing/cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from brad.planner.workload import Workload
from brad.query_rep import QueryRep
from brad.routing.abstract_policy import AbstractRoutingPolicy
from brad.routing.context import RoutingContext


class CachedLocationPolicy(AbstractRoutingPolicy):
Expand All @@ -30,7 +31,9 @@ def __init__(self, query_map: Dict[QueryRep, Engine]) -> None:
def name(self) -> str:
return "CachedLocationPolicy"

def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]:
def engine_for_sync(
self, query_rep: QueryRep, _ctx: RoutingContext
) -> List[Engine]:
try:
return [self._query_map[query_rep]]
except KeyError:
Expand Down
12 changes: 12 additions & 0 deletions src/brad/routing/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Optional

from brad.planner.estimator import Estimator


class RoutingContext:
"""
A wrapper class that holds state that should be used for routing.
"""

def __init__(self) -> None:
self.estimator: Optional[Estimator] = None
3 changes: 2 additions & 1 deletion src/brad/routing/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from brad.config.engine import Engine
from brad.query_rep import QueryRep
from brad.routing.abstract_policy import AbstractRoutingPolicy
from brad.routing.context import RoutingContext


class RoundRobin(AbstractRoutingPolicy):
Expand All @@ -16,7 +17,7 @@ def __init__(self):
def name(self) -> str:
return "RoundRobin"

def engine_for_sync(self, _query: QueryRep) -> List[Engine]:
def engine_for_sync(self, _query: QueryRep, _ctx: RoutingContext) -> List[Engine]:
tmp = self._ordering[0]
self._ordering[0] = self._ordering[1]
self._ordering[1] = self._ordering[2]
Expand Down
Loading

0 comments on commit a8c8e15

Please sign in to comment.