diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index af02e83c..c8e9abce 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -29,8 +29,6 @@ NewBlueprint, NewBlueprintAck, ) -from brad.data_stats.estimator import Estimator -from brad.data_stats.postgres_estimator import PostgresEstimator from brad.front_end.brad_interface import BradInterface from brad.front_end.errors import QueryError from brad.front_end.grpc import BradGrpc @@ -118,7 +116,6 @@ def __init__( self._config, self._blueprint_mgr, self._schema_name ) self._daemon_messages_task: Optional[asyncio.Task[None]] = None - self._estimator: Optional[Estimator] = None # Number of transactions that completed. self._transaction_end_counter = Counter() # pylint: disable=global-statement @@ -173,17 +170,7 @@ async def _run_setup(self) -> None: self._monitor.set_up_metrics_sources() await self._monitor.fetch_latest() - if ( - self._routing_policy_override == RoutingPolicy.ForestTableSelectivity - or self._routing_policy_override == RoutingPolicy.Default - ): - self._estimator = await PostgresEstimator.connect( - self._schema_name, self._config - ) - await self._estimator.analyze(self._blueprint_mgr.get_blueprint()) - else: - self._estimator = None - await self._set_up_router(self._estimator) + await self._set_up_router() # Start the metrics reporting task. self._brad_metrics_reporting_task = asyncio.create_task( @@ -195,7 +182,7 @@ async def _run_setup(self) -> None: self._qlogger_refresh_task = asyncio.create_task(self._refresh_qlogger()) - async def _set_up_router(self, estimator: Optional[Estimator]) -> None: + async def _set_up_router(self) -> None: # We have different routing policies for performance evaluation and # testing purposes. blueprint = self._blueprint_mgr.get_blueprint() @@ -236,7 +223,6 @@ async def _set_up_router(self, estimator: Optional[Estimator]) -> None: definite_policy, blueprint.table_locations_bitmap() ) - await self._router.run_setup(estimator) self._router.log_policy() async def _run_teardown(self): @@ -258,10 +244,6 @@ async def _run_teardown(self): self._qlogger_refresh_task.cancel() self._qlogger_refresh_task = None - if self._estimator is not None: - await self._estimator.close() - self._estimator = None - async def start_session(self) -> SessionId: rand_backoff = None while True: diff --git a/src/brad/front_end/session.py b/src/brad/front_end/session.py index c44fcbda..bf99a8a3 100644 --- a/src/brad/front_end/session.py +++ b/src/brad/front_end/session.py @@ -3,11 +3,15 @@ import pytz from datetime import datetime from typing import Dict, Tuple, Optional + from brad.config.engine import Engine from brad.config.file import ConfigFile from brad.config.session import SessionId -from .engine_connections import EngineConnections from brad.blueprint.manager import BlueprintManager +from brad.front_end.engine_connections import EngineConnections +from brad.planner.estimator import Estimator +from brad.routing.policy import RoutingPolicy +from brad.data_stats.postgres_estimator import PostgresEstimator logger = logging.getLogger(__name__) @@ -19,12 +23,18 @@ class Session: `SessionManager`. """ - def __init__(self, session_id: SessionId, engines: EngineConnections): + def __init__( + self, + session_id: SessionId, + engines: EngineConnections, + estimator: Optional[Estimator], + ): self._session_id = session_id self._engines = engines self._in_txn = False self._closed = False self._txn_start_timestamp = datetime.now(tz=pytz.utc) + self._estimator = estimator @property def identifier(self) -> SessionId: @@ -38,6 +48,10 @@ def engines(self) -> EngineConnections: def in_transaction(self) -> bool: return self._in_txn + @property + def estimator(self) -> Optional[Estimator]: + return self._estimator + @property def closed(self) -> bool: return self._closed @@ -54,6 +68,7 @@ def txn_start_timestamp(self) -> datetime: async def close(self): self._closed = True await self._engines.close() + await self._estimator.close() class SessionManager: @@ -91,7 +106,20 @@ async def create_new_session(self) -> Tuple[SessionId, Session]: specific_engines=engines, connect_to_aurora_read_replicas=True, ) - session = Session(session_id, connections) + + # Create an estimator if needed. The estimator should be + # session-specific since it currently depends on a DB connection. + routing_policy_override = self._config.routing_policy + if ( + routing_policy_override == RoutingPolicy.ForestTableSelectivity + or routing_policy_override == RoutingPolicy.Default + ): + estimator = await PostgresEstimator.connect(self._schema_name, self._config) + await estimator.analyze(self._blueprint_mgr.get_blueprint()) + else: + estimator = None + + session = Session(session_id, connections, estimator) self._sessions[session_id] = session logger.debug("Established a new session: %s", session_id) return (session_id, session) diff --git a/src/brad/planner/beam/query_based.py b/src/brad/planner/beam/query_based.py index 1790654e..61ec6bb5 100644 --- a/src/brad/planner/beam/query_based.py +++ b/src/brad/planner/beam/query_based.py @@ -130,7 +130,7 @@ async def _run_replan_impl( self._planner_config, ) planning_router = Router.create_from_blueprint(self._current_blueprint) - await planning_router.run_setup( + await planning_router.run_setup_for_standalone( self._providers.estimator_provider.get_estimator() ) await ctx.simulate_current_workload_routing(planning_router) diff --git a/src/brad/planner/beam/query_based_legacy.py b/src/brad/planner/beam/query_based_legacy.py index 4b65ba02..f3d1f695 100644 --- a/src/brad/planner/beam/query_based_legacy.py +++ b/src/brad/planner/beam/query_based_legacy.py @@ -119,7 +119,7 @@ async def _run_replan_impl( self._planner_config, ) planning_router = Router.create_from_blueprint(self._current_blueprint) - await planning_router.run_setup( + await planning_router.run_setup_for_standalone( self._providers.estimator_provider.get_estimator() ) await ctx.simulate_current_workload_routing(planning_router) diff --git a/src/brad/planner/beam/table_based.py b/src/brad/planner/beam/table_based.py index 5a2903af..4d7aaf20 100644 --- a/src/brad/planner/beam/table_based.py +++ b/src/brad/planner/beam/table_based.py @@ -127,7 +127,7 @@ async def _run_replan_impl( self._planner_config, ) planning_router = Router.create_from_blueprint(self._current_blueprint) - await planning_router.run_setup( + await planning_router.run_setup_for_standalone( self._providers.estimator_provider.get_estimator() ) await ctx.simulate_current_workload_routing(planning_router) diff --git a/src/brad/planner/neighborhood/neighborhood.py b/src/brad/planner/neighborhood/neighborhood.py index 12a2322a..518abb9d 100644 --- a/src/brad/planner/neighborhood/neighborhood.py +++ b/src/brad/planner/neighborhood/neighborhood.py @@ -28,6 +28,7 @@ from brad.planner.strategy import PlanningStrategy from brad.planner.workload import Workload from brad.provisioning.directory import Directory +from brad.routing.context import RoutingContext from brad.routing.rule_based import RuleBased from brad.front_end.engine_connections import EngineConnections from brad.utils.table_sizer import TableSizer @@ -195,6 +196,7 @@ def _estimate_current_data_accessed( self, engines: EngineConnections, current_workload: Workload ) -> Dict[Engine, int]: current_router = RuleBased() + ctx = RoutingContext() total_accessed_mb: Dict[Engine, int] = {} total_accessed_mb[Engine.Aurora] = 0 @@ -204,7 +206,7 @@ def _estimate_current_data_accessed( # Compute the total amount of data accessed on each engine in the # current workload (used to weigh the workload assigned to each engine). for q in current_workload.analytical_queries(): - current_engines = current_router.engine_for_sync(q) + current_engines = current_router.engine_for_sync(q, ctx) assert len(current_engines) > 0 current_engine = current_engines[0] q.populate_data_accessed_mb( diff --git a/src/brad/planner/neighborhood/scaling_scorer.py b/src/brad/planner/neighborhood/scaling_scorer.py index 9b7b7b5b..ded835e2 100644 --- a/src/brad/planner/neighborhood/scaling_scorer.py +++ b/src/brad/planner/neighborhood/scaling_scorer.py @@ -13,6 +13,7 @@ from brad.blueprint.provisioning import Provisioning from brad.config.engine import Engine from brad.config.planner import PlannerConfig +from brad.routing.context import RoutingContext from brad.routing.rule_based import RuleBased logger = logging.getLogger(__name__) @@ -58,10 +59,11 @@ def _simulate_next_workload(self, ctx: ScoringContext) -> None: # NOTE: The routing policy should be included in the blueprint. We # currently hardcode it here for engineering convenience. router = RuleBased() + rctx = RoutingContext() # See where each analytical query gets routed. for q in ctx.next_workload.analytical_queries(): - next_engines = router.engine_for_sync(q) + next_engines = router.engine_for_sync(q, rctx) assert len(next_engines) > 0 next_engine = next_engines[0] ctx.next_dest[next_engine].append(q) diff --git a/src/brad/planner/triggers/variable_costs.py b/src/brad/planner/triggers/variable_costs.py index 8d00c60c..d2e61c9c 100644 --- a/src/brad/planner/triggers/variable_costs.py +++ b/src/brad/planner/triggers/variable_costs.py @@ -83,7 +83,8 @@ async def should_replan(self) -> bool: if ratio > self._change_ratio: logger.info( - "Triggering replanning due to variable costs changing. Previously estimated: %.4f. Current estimated: %.4f. Change ratio: %.4f", + "Triggering replanning due to variable costs changing. Previously " + "estimated: %.4f. Current estimated: %.4f. Change ratio: %.4f", estimated_hourly_cost, current_hourly_cost, self._change_ratio, @@ -123,7 +124,7 @@ async def _estimate_current_scan_hourly_cost(self) -> Tuple[float, float]: athena_query_indices: List[int] = [] athena_queries: List[Query] = [] router = Router.create_from_blueprint(self._current_blueprint) - await router.run_setup(self._estimator_provider.get_estimator()) + await router.run_setup_for_standalone(self._estimator_provider.get_estimator()) for idx, q in enumerate(workload.analytical_queries()): maybe_engine = q.primary_execution_location() diff --git a/src/brad/routing/abstract_policy.py b/src/brad/routing/abstract_policy.py index c0b2306a..0672399d 100644 --- a/src/brad/routing/abstract_policy.py +++ b/src/brad/routing/abstract_policy.py @@ -1,8 +1,8 @@ -from typing import List, Optional +from typing import List from brad.config.engine import Engine -from brad.planner.estimator import Estimator from brad.query_rep import QueryRep +from brad.routing.context import RoutingContext class AbstractRoutingPolicy: @@ -13,15 +13,9 @@ class AbstractRoutingPolicy: def name(self) -> str: raise NotImplementedError - async def run_setup(self, estimator: Optional[Estimator] = None) -> None: - """ - Should be called before using this policy. This is used to set up any - dynamic state. - - If this routing policy needs an estimator, one should be provided here. - """ - - async def engine_for(self, query_rep: QueryRep) -> List[Engine]: + async def engine_for( + self, query_rep: QueryRep, ctx: RoutingContext + ) -> List[Engine]: """ Produces a preference order for query routing (the first element in the list is the most preferred engine, and so on). @@ -33,9 +27,9 @@ async def engine_for(self, query_rep: QueryRep) -> List[Engine]: You should override this method if the routing policy needs to depend on any asynchronous methods. """ - return self.engine_for_sync(query_rep) + return self.engine_for_sync(query_rep, ctx) - def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]: + def engine_for_sync(self, query_rep: QueryRep, ctx: RoutingContext) -> List[Engine]: """ Produces a preference order for query routing (the first element in the list is the most preferred engine, and so on). @@ -62,17 +56,6 @@ def __init__( self.indefinite_policies = indefinite_policies self.definite_policy = definite_policy - async def run_setup(self, estimator: Optional[Estimator] = None) -> None: - """ - Should be called before using the policy. This is used to set up any - dynamic state. - - If this routing policy needs an estimator, one should be provided here. - """ - for policy in self.indefinite_policies: - await policy.run_setup(estimator) - await self.definite_policy.run_setup(estimator) - def __eq__(self, other: object): if not isinstance(other, FullRoutingPolicy): return False diff --git a/src/brad/routing/always_one.py b/src/brad/routing/always_one.py index cc51243b..6b117344 100644 --- a/src/brad/routing/always_one.py +++ b/src/brad/routing/always_one.py @@ -2,6 +2,7 @@ from brad.config.engine import Engine from brad.query_rep import QueryRep +from brad.routing.context import RoutingContext from brad.routing.abstract_policy import AbstractRoutingPolicy @@ -19,7 +20,7 @@ def __init__(self, db_type: Engine): def name(self) -> str: return f"AlwaysRouteTo({self._engine.name})" - def engine_for_sync(self, _query: QueryRep) -> List[Engine]: + def engine_for_sync(self, _query: QueryRep, _ctx: RoutingContext) -> List[Engine]: return self._always_route_to def __eq__(self, other: object) -> bool: diff --git a/src/brad/routing/cached.py b/src/brad/routing/cached.py index c9cb9a38..449e4ca9 100644 --- a/src/brad/routing/cached.py +++ b/src/brad/routing/cached.py @@ -4,6 +4,7 @@ from brad.planner.workload import Workload from brad.query_rep import QueryRep from brad.routing.abstract_policy import AbstractRoutingPolicy +from brad.routing.context import RoutingContext class CachedLocationPolicy(AbstractRoutingPolicy): @@ -30,7 +31,9 @@ def __init__(self, query_map: Dict[QueryRep, Engine]) -> None: def name(self) -> str: return "CachedLocationPolicy" - def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]: + def engine_for_sync( + self, query_rep: QueryRep, _ctx: RoutingContext + ) -> List[Engine]: try: return [self._query_map[query_rep]] except KeyError: diff --git a/src/brad/routing/context.py b/src/brad/routing/context.py new file mode 100644 index 00000000..035cb1ac --- /dev/null +++ b/src/brad/routing/context.py @@ -0,0 +1,12 @@ +from typing import Optional + +from brad.planner.estimator import Estimator + + +class RoutingContext: + """ + A wrapper class that holds state that should be used for routing. + """ + + def __init__(self) -> None: + self.estimator: Optional[Estimator] = None diff --git a/src/brad/routing/round_robin.py b/src/brad/routing/round_robin.py index 8ba7d304..bc64d243 100644 --- a/src/brad/routing/round_robin.py +++ b/src/brad/routing/round_robin.py @@ -3,6 +3,7 @@ from brad.config.engine import Engine from brad.query_rep import QueryRep from brad.routing.abstract_policy import AbstractRoutingPolicy +from brad.routing.context import RoutingContext class RoundRobin(AbstractRoutingPolicy): @@ -16,7 +17,7 @@ def __init__(self): def name(self) -> str: return "RoundRobin" - def engine_for_sync(self, _query: QueryRep) -> List[Engine]: + def engine_for_sync(self, _query: QueryRep, _ctx: RoutingContext) -> List[Engine]: tmp = self._ordering[0] self._ordering[0] = self._ordering[1] self._ordering[1] = self._ordering[2] diff --git a/src/brad/routing/router.py b/src/brad/routing/router.py index b27f06ab..2318cd7c 100644 --- a/src/brad/routing/router.py +++ b/src/brad/routing/router.py @@ -2,11 +2,12 @@ import logging from typing import Dict, Optional, TYPE_CHECKING from brad.front_end.session import Session -from brad.routing.functionality_catalog import Functionality from brad.data_stats.estimator import Estimator from brad.config.engine import Engine, EngineBitmapValues from brad.query_rep import QueryRep from brad.routing.abstract_policy import AbstractRoutingPolicy, FullRoutingPolicy +from brad.routing.context import RoutingContext +from brad.routing.functionality_catalog import Functionality if TYPE_CHECKING: from brad.blueprint import Blueprint @@ -44,6 +45,9 @@ def __init__( self._use_future_blueprint_policies = use_future_blueprint_policies self.functionality_catalog = Functionality() + # This should only be used when the router is being used in the planner. + self._shared_estimator: Optional[Estimator] = None + def log_policy(self) -> None: logger.info("Routing policy:") logger.info(" Indefinite policies:") @@ -51,14 +55,17 @@ def log_policy(self) -> None: logger.info(" - %s", p.name()) logger.info(" Definite policy: %s", self._full_policy.definite_policy.name()) - async def run_setup(self, estimator: Optional[Estimator] = None) -> None: + async def run_setup_for_standalone( + self, estimator: Optional[Estimator] = None + ) -> None: """ - Should be called before using the router. This is used to set up any - dynamic state. + Should be called before using the router "standalone" contexts (i.e., + outside the front end). This is used to set up any dynamic state that is + typically passed in via a `Session`. If the routing policy needs an estimator, one should be provided here. """ - await self._full_policy.run_setup(estimator) + self._shared_estimator = estimator def update_blueprint(self, blueprint: "Blueprint") -> None: """ @@ -114,16 +121,23 @@ async def engine_for( else: raise RuntimeError("Unsupported bitmap value " + str(valid_locations)) + # Right now, this context can be created once per session. But we may + # also want to include other shared state (e.g., metrics) that is not + # session-specific. + ctx = RoutingContext() + if session is not None: + ctx.estimator = session.estimator + # Go through the indefinite routing policies. These may not return a # routing location. for policy in self._full_policy.indefinite_policies: - locations = await policy.engine_for(query) + locations = await policy.engine_for(query, ctx) for loc in locations: if (EngineBitmapValues[loc] & valid_locations) != 0: return loc # Rely on the definite routing policy. - locations = await self._full_policy.definite_policy.engine_for(query) + locations = await self._full_policy.definite_policy.engine_for(query, ctx) for loc in locations: if (EngineBitmapValues[loc] & valid_locations) != 0: return loc diff --git a/src/brad/routing/rule_based.py b/src/brad/routing/rule_based.py index f0bf50ef..48d3fdb6 100644 --- a/src/brad/routing/rule_based.py +++ b/src/brad/routing/rule_based.py @@ -8,9 +8,10 @@ from brad.blueprint import Blueprint from brad.config.engine import Engine from brad.daemon.monitor import Monitor -from brad.routing.abstract_policy import AbstractRoutingPolicy -from brad.query_rep import QueryRep from brad.front_end.session import SessionManager +from brad.query_rep import QueryRep +from brad.routing.abstract_policy import AbstractRoutingPolicy +from brad.routing.context import RoutingContext logger = logging.getLogger(__name__) @@ -190,7 +191,9 @@ def check_engine_state( return True return not_overloaded - def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]: + def engine_for_sync( + self, query_rep: QueryRep, _ctx: RoutingContext + ) -> List[Engine]: touched_tables = query_rep.tables() if ( len(touched_tables) diff --git a/src/brad/routing/tree_based/forest_policy.py b/src/brad/routing/tree_based/forest_policy.py index 6b90bcf7..5f55f7d2 100644 --- a/src/brad/routing/tree_based/forest_policy.py +++ b/src/brad/routing/tree_based/forest_policy.py @@ -1,11 +1,11 @@ import asyncio -from typing import Optional, List, Dict, Any +from typing import List, Dict, Any from brad.asset_manager import AssetManager from brad.config.engine import Engine -from brad.data_stats.estimator import Estimator from brad.query_rep import QueryRep from brad.routing.abstract_policy import AbstractRoutingPolicy +from brad.routing.context import RoutingContext from brad.routing.policy import RoutingPolicy from brad.routing.tree_based.model_wrap import ModelWrap @@ -27,7 +27,6 @@ def from_loaded_model( def __init__(self, policy: RoutingPolicy, model: ModelWrap) -> None: self._policy = policy self._model = model - self._estimator: Optional[Estimator] = None def __getstate__(self) -> Dict[Any, Any]: return { @@ -38,7 +37,6 @@ def __getstate__(self) -> Dict[Any, Any]: def __setstate__(self, d: Dict[Any, Any]) -> None: self._policy = d["policy"] self._model = d["model"] - self._estimator = None def name(self) -> str: return f"ForestPolicy({self._policy.name})" @@ -48,14 +46,13 @@ def __eq__(self, other: object) -> bool: return False return self._policy == other._policy and self._model == other._model - async def run_setup(self, estimator: Optional[Estimator] = None) -> None: - self._estimator = estimator + async def engine_for( + self, query_rep: QueryRep, ctx: RoutingContext + ) -> List[Engine]: + return await self._model.engine_for(query_rep, ctx.estimator) - async def engine_for(self, query_rep: QueryRep) -> List[Engine]: - return await self._model.engine_for(query_rep, self._estimator) - - def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]: - return asyncio.run(self.engine_for(query_rep)) + def engine_for_sync(self, query_rep: QueryRep, ctx: RoutingContext) -> List[Engine]: + return asyncio.run(self.engine_for(query_rep, ctx)) # The methods below are used to save/load `ModelWrap` from S3. We # historically separated out the model's implementation details because the diff --git a/tests/test_always_routing.py b/tests/test_always_routing.py index 469de3b0..49891e9c 100644 --- a/tests/test_always_routing.py +++ b/tests/test_always_routing.py @@ -1,35 +1,39 @@ from brad.config.engine import Engine from brad.routing.always_one import AlwaysOneRouter +from brad.routing.context import RoutingContext def test_always_route_aurora(): db = Engine.Aurora router = AlwaysOneRouter(db) + ctx = RoutingContext() - pred_db = router.engine_for_sync("SELECT 1") + pred_db = router.engine_for_sync("SELECT 1", ctx) assert pred_db == [db] - pred_db = router.engine_for_sync("SELECT * FROM my_table") + pred_db = router.engine_for_sync("SELECT * FROM my_table", ctx) assert pred_db == [db] def test_always_route_athena(): db = Engine.Athena router = AlwaysOneRouter(db) + ctx = RoutingContext() - pred_db = router.engine_for_sync("SELECT 1") + pred_db = router.engine_for_sync("SELECT 1", ctx) assert pred_db == [db] - pred_db = router.engine_for_sync("SELECT * FROM my_table") + pred_db = router.engine_for_sync("SELECT * FROM my_table", ctx) assert pred_db == [db] def test_always_route_redshift(): db = Engine.Redshift router = AlwaysOneRouter(db) + ctx = RoutingContext() - pred_db = router.engine_for_sync("SELECT 1") + pred_db = router.engine_for_sync("SELECT 1", ctx) assert pred_db == [db] - pred_db = router.engine_for_sync("SELECT * FROM my_table") + pred_db = router.engine_for_sync("SELECT * FROM my_table", ctx) assert pred_db == [db] diff --git a/tests/test_forest_routing.py b/tests/test_forest_routing.py index 095dc810..2ab4752f 100644 --- a/tests/test_forest_routing.py +++ b/tests/test_forest_routing.py @@ -3,6 +3,7 @@ from sklearn.ensemble import RandomForestClassifier from brad.config.engine import Engine +from brad.routing.context import RoutingContext from brad.routing.tree_based.forest_policy import ForestPolicy from brad.routing.tree_based.model_wrap import ModelWrap from brad.routing.policy import RoutingPolicy @@ -21,9 +22,10 @@ def get_dummy_router(): def test_model_codepath_partial(): model = get_dummy_router() router = ForestPolicy.from_loaded_model(RoutingPolicy.ForestTablePresence, model) + ctx = RoutingContext() query = QueryRep("SELECT * FROM test1, test2") - loc = router.engine_for_sync(query) + loc = router.engine_for_sync(query, ctx) assert ( loc[0] == Engine.Aurora or loc[0] == Engine.Redshift or loc[0] == Engine.Athena ) @@ -32,9 +34,10 @@ def test_model_codepath_partial(): def test_model_codepath_all(): model = get_dummy_router() router = ForestPolicy.from_loaded_model(RoutingPolicy.ForestTablePresence, model) + ctx = RoutingContext() query = QueryRep("SELECT * FROM test1") - loc = router.engine_for_sync(query) + loc = router.engine_for_sync(query, ctx) assert ( loc[0] == Engine.Aurora or loc[0] == Engine.Redshift or loc[0] == Engine.Athena )