From b5594b3e92d496d5e92da5124a9d23c9edb9ad9c Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sun, 5 Nov 2023 01:07:32 -0400 Subject: [PATCH] Store and use the routing policy in the blueprint (#350) * Implement unified routing policy abstraction * WIP: Update the routing pipeline to allow modifications during planning * WIP: Overhaul router usage * Fix tests * Fix other test * Remove dependence on legacy ForestRouter * Add ability to manually set the routing policy, remove unused parts of the rule routing policy * Actually use the serialized policy * Better transition support * Define routing policy equality * Fix type error --- src/brad/admin/drop_schema.py | 6 +- src/brad/admin/modify_blueprint.py | 30 +++- src/brad/admin/train_router.py | 4 +- src/brad/blueprint/blueprint.py | 42 ++++-- src/brad/blueprint/serde.py | 32 +++- src/brad/front_end/front_end.py | 101 +++++++------ .../planner/beam/query_based_candidate.py | 2 +- .../planner/beam/table_based_candidate.py | 2 +- src/brad/planner/data.py | 5 +- src/brad/planner/enumeration/blueprint.py | 15 +- src/brad/planner/neighborhood/neighborhood.py | 6 +- .../planner/neighborhood/scaling_scorer.py | 6 +- src/brad/planner/router_provider.py | 18 ++- src/brad/routing/abstract_policy.py | 81 ++++++++++ src/brad/routing/always_one.py | 17 ++- .../routing/location_aware_round_robin.py | 30 ---- src/brad/routing/policy.py | 7 +- src/brad/routing/round_robin.py | 27 ++++ src/brad/routing/router.py | 81 ++++++++-- src/brad/routing/rule_based.py | 86 +++-------- src/brad/routing/tree_based/forest_policy.py | 92 ++++++++++++ src/brad/routing/tree_based/forest_router.py | 140 ------------------ src/brad/routing/tree_based/model_wrap.py | 14 ++ tests/test_always_routing.py | 12 +- tests/test_blueprint_diff.py | 4 +- tests/test_enumeration.py | 4 +- tests/test_forest_routing.py | 51 ++----- tests/test_location_routing.py | 5 +- tests/test_planner_filters.py | 20 +-- 29 files changed, 550 insertions(+), 390 deletions(-) create mode 100644 src/brad/routing/abstract_policy.py delete mode 100644 src/brad/routing/location_aware_round_robin.py create mode 100644 src/brad/routing/round_robin.py create mode 100644 src/brad/routing/tree_based/forest_policy.py delete mode 100644 src/brad/routing/tree_based/forest_router.py diff --git a/src/brad/admin/drop_schema.py b/src/brad/admin/drop_schema.py index 9f0d7b71..a5a306db 100644 --- a/src/brad/admin/drop_schema.py +++ b/src/brad/admin/drop_schema.py @@ -8,7 +8,7 @@ from brad.front_end.engine_connections import EngineConnections from brad.provisioning.directory import Directory from brad.routing.policy import RoutingPolicy -from brad.routing.tree_based.forest_router import ForestRouter +from brad.routing.tree_based.forest_policy import ForestPolicy logger = logging.getLogger(__name__) @@ -56,10 +56,10 @@ def drop_schema(args): # 4. Drop any serialized routers. assets = AssetManager(config) - ForestRouter.static_drop_model_sync( + ForestPolicy.static_drop_model_sync( args.schema_name, RoutingPolicy.ForestTableSelectivity, assets ) - ForestRouter.static_drop_model_sync( + ForestPolicy.static_drop_model_sync( args.schema_name, RoutingPolicy.ForestTablePresence, assets ) diff --git a/src/brad/admin/modify_blueprint.py b/src/brad/admin/modify_blueprint.py index 72bec1bb..d8369879 100644 --- a/src/brad/admin/modify_blueprint.py +++ b/src/brad/admin/modify_blueprint.py @@ -16,6 +16,11 @@ from brad.daemon.transition_orchestrator import TransitionOrchestrator from brad.front_end.engine_connections import EngineConnections from brad.planner.enumeration.blueprint import EnumeratedBlueprint +from brad.routing.abstract_policy import AbstractRoutingPolicy, FullRoutingPolicy +from brad.routing.always_one import AlwaysOneRouter +from brad.routing.policy import RoutingPolicy +from brad.routing.tree_based.forest_policy import ForestPolicy +from brad.routing.rule_based import RuleBased logger = logging.getLogger(__name__) @@ -68,6 +73,11 @@ def register_admin_action(subparser) -> None: action="store_true", help="Updates the blueprint's table placement and places tables on all engines.", ) + parser.add_argument( + "--set-routing-policy", + type=str, + help="Sets the serialized routing policy to a preconfigured default: {always_redshift, df_selectivity, rule_based}", + ) parser.add_argument( "--add-indexes", action="store_true", @@ -184,7 +194,7 @@ async def run_transition( # This method is called by `brad.exec.admin.main`. -def modify_blueprint(args): +def modify_blueprint(args) -> None: # 1. Load the config. config = ConfigFile.load(args.config_file) @@ -253,6 +263,24 @@ def modify_blueprint(args): new_placement[tbl] = Engine.from_bitmap(Engine.bitmap_all()) enum_blueprint.set_table_locations(new_placement) + if args.set_routing_policy is not None: + if args.set_routing_policy == "always_redshift": + definite_policy: AbstractRoutingPolicy = AlwaysOneRouter(Engine.Redshift) + elif args.set_routing_policy == "df_selectivity": + definite_policy = asyncio.run( + ForestPolicy.from_assets( + args.schema_name, RoutingPolicy.ForestTableSelectivity, assets + ) + ) + elif args.set_routing_policy == "rule_based": + definite_policy = RuleBased() + else: + raise RuntimeError( + f"Unknown routing policy preset: {args.set_routing_policy}" + ) + full_policy = FullRoutingPolicy([], definite_policy) + enum_blueprint.set_routing_policy(full_policy) + # 3. Write the changes back. modified_blueprint = enum_blueprint.to_blueprint() if blueprint == modified_blueprint: diff --git a/src/brad/admin/train_router.py b/src/brad/admin/train_router.py index b05c1eec..29053b0c 100644 --- a/src/brad/admin/train_router.py +++ b/src/brad/admin/train_router.py @@ -9,7 +9,7 @@ from brad.data_stats.estimator import Estimator from brad.data_stats.postgres_estimator import PostgresEstimator from brad.routing.policy import RoutingPolicy -from brad.routing.tree_based.forest_router import ForestRouter +from brad.routing.tree_based.forest_policy import ForestPolicy from brad.routing.tree_based.trainer import ForestTrainer from brad.blueprint.manager import BlueprintManager @@ -163,7 +163,7 @@ def train_router(args): response = input("Do you want to persist this model? (y/n): ").lower() if response == "y": assets = AssetManager(config) - ForestRouter.static_persist_sync(model, schema_name, assets) + ForestPolicy.static_persist_sync(model, schema_name, assets) logger.info("Model persisted successfully.") break elif response == "n": diff --git a/src/brad/blueprint/blueprint.py b/src/brad/blueprint/blueprint.py index acefc006..1304ea7b 100644 --- a/src/brad/blueprint/blueprint.py +++ b/src/brad/blueprint/blueprint.py @@ -1,11 +1,9 @@ -from typing import Callable, Dict, List, Set, Optional, Tuple, Any +from typing import Dict, List, Set, Tuple, Any from brad.blueprint.provisioning import Provisioning from brad.blueprint.table import Table from brad.config.engine import Engine -from brad.routing.router import Router - -RouterProvider = Callable[[], Router] +from brad.routing.abstract_policy import FullRoutingPolicy class Blueprint: @@ -16,14 +14,14 @@ def __init__( table_locations: Dict[str, List[Engine]], aurora_provisioning: Provisioning, redshift_provisioning: Provisioning, - router_provider: Optional[RouterProvider], + full_routing_policy: FullRoutingPolicy, ): self._schema_name = schema_name self._table_schemas = table_schemas self._table_locations = table_locations self._aurora_provisioning = aurora_provisioning self._redshift_provisioning = redshift_provisioning - self._router_provider = router_provider + self._full_routing_policy = full_routing_policy # Derived properties used for the class' methods. self._tables_by_name = {tbl.name: tbl for tbl in self._table_schemas} @@ -57,11 +55,8 @@ def aurora_provisioning(self) -> Provisioning: def redshift_provisioning(self) -> Provisioning: return self._redshift_provisioning - def get_router(self) -> Optional[Router]: - return self._router_provider() if self._router_provider is not None else None - - def router_provider(self) -> Optional[RouterProvider]: - return self._router_provider + def get_routing_policy(self) -> FullRoutingPolicy: + return self._full_routing_policy def base_table_names(self) -> Set[str]: return self._base_table_names @@ -87,6 +82,7 @@ def __eq__(self, other: object) -> bool: and self.table_locations() == other.table_locations() and self.aurora_provisioning() == other.aurora_provisioning() and self.redshift_provisioning() == other.redshift_provisioning() + and self.get_routing_policy() == other.get_routing_policy() ) def _compute_base_tables(self) -> Set[str]: @@ -118,15 +114,33 @@ def visit_table(table: Table) -> None: def __repr__(self) -> str: # Useful for debugging purposes. - aurora = "Aurora: " + str(self.aurora_provisioning()) - redshift = "Redshift: " + str(self.redshift_provisioning()) + aurora = "Aurora: " + str(self.aurora_provisioning()) + redshift = "Redshift: " + str(self.redshift_provisioning()) tables = "\n ".join( map( lambda name_loc: "".join([name_loc[0], ": ", str(name_loc[1])]), self.table_locations().items(), ) ) - return "\n ".join(["Blueprint:", tables, aurora, redshift]) + routing_policy = self.get_routing_policy() + indefinite_policies = ( + f"Indefinite routing policies: {len(routing_policy.indefinite_policies)}" + ) + definite_policy = ( + f"Definite routing policy: {routing_policy.definite_policy.name()}" + ) + return "\n ".join( + [ + "Blueprint:", + tables, + "---", + aurora, + redshift, + "---", + indefinite_policies, + definite_policy, + ] + ) def as_dict(self) -> Dict[str, Any]: """ diff --git a/src/brad/blueprint/serde.py b/src/brad/blueprint/serde.py index f821624a..2f768e5c 100644 --- a/src/brad/blueprint/serde.py +++ b/src/brad/blueprint/serde.py @@ -1,12 +1,18 @@ +import pickle +import logging from typing import Tuple, List, Dict from brad.blueprint import Blueprint from brad.blueprint.provisioning import Provisioning from brad.blueprint.table import Column, Table from brad.config.engine import Engine +from brad.routing.abstract_policy import FullRoutingPolicy +from brad.routing.round_robin import RoundRobin import brad.proto_gen.blueprint_pb2 as b +logger = logging.getLogger(__name__) + # We define the data blueprint serialization/deserialization functions # separately from the blueprint classes to avoid mixing protobuf code (an # implementation detail) with the blueprint classes. @@ -24,7 +30,7 @@ def deserialize_blueprint(raw_data: bytes) -> Blueprint: table_locations=dict(map(_table_locations_from_proto, proto.tables)), aurora_provisioning=_provisioning_from_proto(proto.aurora), redshift_provisioning=_provisioning_from_proto(proto.redshift), - router_provider=None, + full_routing_policy=_policy_from_proto(proto.policy), ) @@ -34,7 +40,7 @@ def serialize_blueprint(blueprint: Blueprint) -> bytes: tables=map(_tables_with_locations_to_proto, blueprint.tables_with_locations()), aurora=_provisioning_to_proto(blueprint.aurora_provisioning()), redshift=_provisioning_to_proto(blueprint.redshift_provisioning()), - policy=None, + policy=_policy_to_proto(blueprint.get_routing_policy()), ) return proto.SerializeToString() @@ -88,6 +94,13 @@ def _indexed_columns_to_proto(indexed_columns: Tuple[Column, ...]) -> b.Index: return b.Index(column_name=map(lambda col: col.name, indexed_columns)) +def _policy_to_proto(full_routing_policy: FullRoutingPolicy) -> b.RoutingPolicy: + # We just use Python pickle serialization. In the future, this should be + # something more robust. + bytes_str = pickle.dumps(full_routing_policy) + return b.RoutingPolicy(policy=bytes_str) + + # Deserialization @@ -139,3 +152,18 @@ def _indexed_columns_from_proto( for col_name in indexed_columns.column_name: col_list.append(col_map[col_name]) return tuple(col_list) + + +def _policy_from_proto(policy: b.RoutingPolicy) -> FullRoutingPolicy: + if len(policy.policy) == 0: + logger.warning( + "Did not find a routing policy in the serialized blueprint. " + "This likely means you are running with an older blueprint. " + "Falling back to round robin routing. If you want to use a " + "different policy, use `brad admin modify_blueprint` to set " + "a policy." + ) + return FullRoutingPolicy(indefinite_policies=[], definite_policy=RoundRobin()) + # We just use Python pickle serialization. In the future, this should be + # something more robust. + return pickle.loads(policy.policy) diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index 5c39aa87..09b3b222 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -35,12 +35,12 @@ from brad.front_end.grpc import BradGrpc from brad.front_end.session import SessionManager, SessionId, Session from brad.query_rep import QueryRep +from brad.routing.abstract_policy import AbstractRoutingPolicy from brad.routing.always_one import AlwaysOneRouter from brad.routing.rule_based import RuleBased -from brad.routing.location_aware_round_robin import LocationAwareRoundRobin from brad.routing.policy import RoutingPolicy from brad.routing.router import Router -from brad.routing.tree_based.forest_router import ForestRouter +from brad.routing.tree_based.forest_policy import ForestPolicy from brad.row_list import RowList from brad.utils import log_verbose from brad.utils.counter import Counter @@ -110,36 +110,9 @@ def __init__( self._config.front_end_query_latency_buffer_size ) - # We have different routing policies for performance evaluation and - # testing purposes. - routing_policy = self._config.routing_policy - self._routing_policy = routing_policy - if routing_policy == RoutingPolicy.Default: - self._router: Router = LocationAwareRoundRobin(self._blueprint_mgr) - elif routing_policy == RoutingPolicy.AlwaysAthena: - self._router = AlwaysOneRouter(Engine.Athena) - elif routing_policy == RoutingPolicy.AlwaysAurora: - self._router = AlwaysOneRouter(Engine.Aurora) - elif routing_policy == RoutingPolicy.AlwaysRedshift: - self._router = AlwaysOneRouter(Engine.Redshift) - elif routing_policy == RoutingPolicy.RuleBased: - self._monitor = Monitor(config, self._blueprint_mgr) - self._router = RuleBased( - blueprint_mgr=self._blueprint_mgr, monitor=self._monitor - ) - elif ( - routing_policy == RoutingPolicy.ForestTablePresence - or routing_policy == RoutingPolicy.ForestTableSelectivity - ): - self._router = ForestRouter.for_server( - routing_policy, self._schema_name, self._assets, self._blueprint_mgr - ) - else: - raise RuntimeError( - "Unsupported routing policy: {}".format(str(routing_policy)) - ) - logger.info("Using routing policy: %s", routing_policy) - + self._routing_policy_override = self._config.routing_policy + # This is set up as the front end starts up. + self._router: Optional[Router] = None self._sessions = SessionManager( self._config, self._blueprint_mgr, self._schema_name ) @@ -181,15 +154,8 @@ async def serve_forever(self): logger.info("The BRAD front end has successfully started.") logger.info("Listening on port %d.", port_to_use) - if self._routing_policy == RoutingPolicy.RuleBased: - assert ( - self._monitor is not None - ), "require monitor running for rule-based router" - await asyncio.gather( - self._monitor.run_forever(), grpc_server.wait_for_termination() - ) - else: - await grpc_server.wait_for_termination() + # N.B. If we need the Monitor, we should call `run_forever()` here. + await grpc_server.wait_for_termination() finally: # Not ideal, but we need to manually call this method to ensure # gRPC's internal shutdown process completes before we return from @@ -206,14 +172,17 @@ async def _run_setup(self) -> None: self._monitor.set_up_metrics_sources() await self._monitor.fetch_latest() - if self._routing_policy == RoutingPolicy.ForestTableSelectivity: + if ( + self._routing_policy_override == RoutingPolicy.ForestTableSelectivity + or self._routing_policy_override == RoutingPolicy.Default + ): self._estimator = await PostgresEstimator.connect( self._schema_name, self._config ) await self._estimator.analyze(self._blueprint_mgr.get_blueprint()) else: self._estimator = None - await self._router.run_setup(self._estimator) + await self._set_up_router(self._estimator) # Start the metrics reporting task. self._brad_metrics_reporting_task = asyncio.create_task( @@ -225,6 +194,50 @@ async def _run_setup(self) -> None: self._qlogger_refresh_task = asyncio.create_task(self._refresh_qlogger()) + async def _set_up_router(self, estimator: Optional[Estimator]) -> None: + # We have different routing policies for performance evaluation and + # testing purposes. + blueprint = self._blueprint_mgr.get_blueprint() + + if self._routing_policy_override == RoutingPolicy.Default: + # No override - use the blueprint's policy. + self._router = Router.create_from_blueprint(blueprint) + logger.info("Using blueprint-provided routing policy.") + + else: + if self._routing_policy_override == RoutingPolicy.AlwaysAthena: + definite_policy: AbstractRoutingPolicy = AlwaysOneRouter(Engine.Athena) + elif self._routing_policy_override == RoutingPolicy.AlwaysAurora: + definite_policy = AlwaysOneRouter(Engine.Aurora) + elif self._routing_policy_override == RoutingPolicy.AlwaysRedshift: + definite_policy = AlwaysOneRouter(Engine.Redshift) + elif self._routing_policy_override == RoutingPolicy.RuleBased: + # TODO: If we need metrics, re-create the monitor here. It's + # easier to not have it created. + definite_policy = RuleBased() + elif ( + self._routing_policy_override == RoutingPolicy.ForestTablePresence + or self._routing_policy_override == RoutingPolicy.ForestTableSelectivity + ): + definite_policy = await ForestPolicy.from_assets( + self._schema_name, + self._routing_policy_override, + self._assets, + ) + else: + raise RuntimeError( + f"Unsupported routing policy override: {self._routing_policy_override}" + ) + logger.info( + "Using routing policy override: %s", self._routing_policy_override.name + ) + self._router = Router.create_from_definite_policy( + definite_policy, blueprint.table_locations_bitmap() + ) + + await self._router.run_setup(estimator) + self._router.log_policy() + async def _run_teardown(self): logger.debug("Starting BRAD front end _run_teardown()") await self._sessions.end_all_sessions() @@ -299,6 +312,7 @@ async def _run_query_impl( elif self._route_redshift_only: engine_to_use = Engine.Redshift else: + assert self._router is not None engine_to_use = await self._router.engine_for(query_rep) log_verbose( @@ -601,6 +615,7 @@ async def _run_blueprint_update(self, version: int) -> None: if self._monitor is not None: self._monitor.update_metrics_sources() await self._sessions.add_connections() + assert self._router is not None self._router.update_blueprint(blueprint) # NOTE: This will cause any pending queries on the to-be-removed # connections to be cancelled. We consider this behavior to be diff --git a/src/brad/planner/beam/query_based_candidate.py b/src/brad/planner/beam/query_based_candidate.py index 008cbd37..18a5e813 100644 --- a/src/brad/planner/beam/query_based_candidate.py +++ b/src/brad/planner/beam/query_based_candidate.py @@ -117,7 +117,7 @@ def to_blueprint(self) -> Blueprint: self.get_table_placement(), self.aurora_provisioning.clone(), self.redshift_provisioning.clone(), - self._source_blueprint.router_provider(), + self._source_blueprint.get_routing_policy(), # TODO: Use chosen policy. ) def to_score(self) -> Score: diff --git a/src/brad/planner/beam/table_based_candidate.py b/src/brad/planner/beam/table_based_candidate.py index 9ee32fbf..cdb2abac 100644 --- a/src/brad/planner/beam/table_based_candidate.py +++ b/src/brad/planner/beam/table_based_candidate.py @@ -111,7 +111,7 @@ def to_blueprint(self) -> Blueprint: self.get_table_placement(), self.aurora_provisioning.clone(), self.redshift_provisioning.clone(), - self._source_blueprint.router_provider(), + self._source_blueprint.get_routing_policy(), # TODO: Use chosen policy. ) def to_score(self) -> Score: diff --git a/src/brad/planner/data.py b/src/brad/planner/data.py index 7c6ff612..966e4c9f 100644 --- a/src/brad/planner/data.py +++ b/src/brad/planner/data.py @@ -4,6 +4,8 @@ from brad.blueprint.user import UserProvidedBlueprint from brad.blueprint.table import Table from brad.config.engine import Engine +from brad.routing.abstract_policy import FullRoutingPolicy +from brad.routing.round_robin import RoundRobin def bootstrap_blueprint(user: UserProvidedBlueprint) -> Blueprint: @@ -90,5 +92,6 @@ def process_table(table: Table, expect_standalone_base_table: bool): table_locations, user.aurora_provisioning(), user.redshift_provisioning(), - None, + # TODO: Replace the default definite policy. + FullRoutingPolicy(indefinite_policies=[], definite_policy=RoundRobin()), ) diff --git a/src/brad/planner/enumeration/blueprint.py b/src/brad/planner/enumeration/blueprint.py index 416ccc52..8fff27dd 100644 --- a/src/brad/planner/enumeration/blueprint.py +++ b/src/brad/planner/enumeration/blueprint.py @@ -3,6 +3,7 @@ from brad.blueprint.blueprint import Blueprint from brad.blueprint.provisioning import Provisioning from brad.config.engine import Engine +from brad.routing.abstract_policy import FullRoutingPolicy class EnumeratedBlueprint(Blueprint): @@ -21,12 +22,13 @@ def __init__(self, base_blueprint: Blueprint) -> None: base_blueprint.table_locations(), base_blueprint.aurora_provisioning(), base_blueprint.redshift_provisioning(), - base_blueprint.router_provider(), + base_blueprint.get_routing_policy(), ) self._current_locations = base_blueprint.table_locations() self._current_aurora_provisioning = base_blueprint.aurora_provisioning() self._current_redshift_provisioning = base_blueprint.redshift_provisioning() self._current_table_locations_bitmap: Optional[Dict[str, int]] = None + self._current_routing_policy = base_blueprint.get_routing_policy() def set_table_locations( self, locations: Dict[str, List[Engine]] @@ -43,6 +45,12 @@ def set_redshift_provisioning(self, prov: Provisioning) -> "EnumeratedBlueprint" self._current_redshift_provisioning = prov return self + def set_routing_policy( + self, routing_policy: FullRoutingPolicy + ) -> "EnumeratedBlueprint": + self._current_routing_policy = routing_policy + return self + def to_blueprint(self) -> Blueprint: """ Makes a copy of this object as a `Blueprint`. @@ -57,7 +65,7 @@ def to_blueprint(self) -> Blueprint: }, aurora_provisioning=self._current_aurora_provisioning.clone(), redshift_provisioning=self._current_redshift_provisioning.clone(), - router_provider=self.router_provider(), + full_routing_policy=self._current_routing_policy, ) # Overridden getters. @@ -84,3 +92,6 @@ def get_table_locations(self, table_name: str) -> List[Engine]: return self._current_locations[table_name] except KeyError as ex: raise ValueError from ex + + def get_routing_policy(self) -> FullRoutingPolicy: + return self._current_routing_policy diff --git a/src/brad/planner/neighborhood/neighborhood.py b/src/brad/planner/neighborhood/neighborhood.py index 113eecc0..12a2322a 100644 --- a/src/brad/planner/neighborhood/neighborhood.py +++ b/src/brad/planner/neighborhood/neighborhood.py @@ -194,7 +194,7 @@ def _check_if_metrics_warrant_replanning(self) -> bool: def _estimate_current_data_accessed( self, engines: EngineConnections, current_workload: Workload ) -> Dict[Engine, int]: - current_router = RuleBased(blueprint=self._current_blueprint) + current_router = RuleBased() total_accessed_mb: Dict[Engine, int] = {} total_accessed_mb[Engine.Aurora] = 0 @@ -204,7 +204,9 @@ def _estimate_current_data_accessed( # Compute the total amount of data accessed on each engine in the # current workload (used to weigh the workload assigned to each engine). for q in current_workload.analytical_queries(): - current_engine = current_router.engine_for_sync(q) + current_engines = current_router.engine_for_sync(q) + assert len(current_engines) > 0 + current_engine = current_engines[0] q.populate_data_accessed_mb( current_engine, engines, self._current_blueprint ) diff --git a/src/brad/planner/neighborhood/scaling_scorer.py b/src/brad/planner/neighborhood/scaling_scorer.py index 462dec66..9b7b7b5b 100644 --- a/src/brad/planner/neighborhood/scaling_scorer.py +++ b/src/brad/planner/neighborhood/scaling_scorer.py @@ -57,11 +57,13 @@ def score(self, ctx: ScoringContext) -> Score: def _simulate_next_workload(self, ctx: ScoringContext) -> None: # NOTE: The routing policy should be included in the blueprint. We # currently hardcode it here for engineering convenience. - router = RuleBased(blueprint=ctx.next_blueprint) + router = RuleBased() # See where each analytical query gets routed. for q in ctx.next_workload.analytical_queries(): - next_engine = router.engine_for_sync(q) + next_engines = router.engine_for_sync(q) + assert len(next_engines) > 0 + next_engine = next_engines[0] ctx.next_dest[next_engine].append(q) q.populate_data_accessed_mb(next_engine, ctx.engines, ctx.current_blueprint) diff --git a/src/brad/planner/router_provider.py b/src/brad/planner/router_provider.py index 6836e4ec..1cd7407c 100644 --- a/src/brad/planner/router_provider.py +++ b/src/brad/planner/router_provider.py @@ -3,10 +3,11 @@ from brad.asset_manager import AssetManager from brad.config.file import ConfigFile from brad.planner.estimator import EstimatorProvider +from brad.routing.abstract_policy import AbstractRoutingPolicy from brad.routing.policy import RoutingPolicy from brad.routing.router import Router from brad.routing.rule_based import RuleBased -from brad.routing.tree_based.forest_router import ForestRouter +from brad.routing.tree_based.forest_policy import ForestPolicy from brad.routing.tree_based.model_wrap import ModelWrap @@ -38,19 +39,17 @@ async def get_router(self, table_bitmap: Dict[str, int]) -> Router: or self._routing_policy == RoutingPolicy.ForestTableSelectivity ): if self._model is None: - self._model = ForestRouter.static_load_model_sync( + self._model = ForestPolicy.static_load_model_sync( self._schema_name, self._routing_policy, self._assets, ) - router = ForestRouter.for_planner( - self._routing_policy, self._schema_name, self._model, table_bitmap + definite_policy: AbstractRoutingPolicy = ForestPolicy.from_loaded_model( + self._routing_policy, self._model ) - await router.run_setup(self._estimator_provider.get_estimator()) - return router elif self._routing_policy == RoutingPolicy.RuleBased: - return RuleBased(table_placement_bitmap=table_bitmap) + definite_policy = RuleBased() else: raise RuntimeError( @@ -59,5 +58,10 @@ async def get_router(self, table_bitmap: Dict[str, int]) -> Router: ) ) + # This is temporary and will be removed. + router = Router.create_from_definite_policy(definite_policy, table_bitmap) + await router.run_setup(self._estimator_provider.get_estimator()) + return router + def clear_cached(self) -> None: self._model = None diff --git a/src/brad/routing/abstract_policy.py b/src/brad/routing/abstract_policy.py new file mode 100644 index 00000000..c0b2306a --- /dev/null +++ b/src/brad/routing/abstract_policy.py @@ -0,0 +1,81 @@ +from typing import List, Optional + +from brad.config.engine import Engine +from brad.planner.estimator import Estimator +from brad.query_rep import QueryRep + + +class AbstractRoutingPolicy: + """ + Note that implementers must be serializable. + """ + + def name(self) -> str: + raise NotImplementedError + + async def run_setup(self, estimator: Optional[Estimator] = None) -> None: + """ + Should be called before using this policy. This is used to set up any + dynamic state. + + If this routing policy needs an estimator, one should be provided here. + """ + + async def engine_for(self, query_rep: QueryRep) -> List[Engine]: + """ + Produces a preference order for query routing (the first element in the + list is the most preferred engine, and so on). + + NOTE: Implementers currently do not need to consider DML queries. BRAD + routes all DML queries to Aurora before consulting the router. Thus the + query passed to this method will always be a read-only query. + + You should override this method if the routing policy needs to depend on + any asynchronous methods. + """ + return self.engine_for_sync(query_rep) + + def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]: + """ + Produces a preference order for query routing (the first element in the + list is the most preferred engine, and so on). + + NOTE: Implementers currently do not need to consider DML queries. BRAD + routes all DML queries to Aurora before consulting the router. Thus the + query passed to this method will always be a read-only query. + """ + raise NotImplementedError + + +class FullRoutingPolicy: + """ + Captures a full routing policy for serialization purposes. Indefinite + policies are allowed to return empty preference lists (indicating no routing + decision). + """ + + def __init__( + self, + indefinite_policies: List[AbstractRoutingPolicy], + definite_policy: AbstractRoutingPolicy, + ) -> None: + self.indefinite_policies = indefinite_policies + self.definite_policy = definite_policy + + async def run_setup(self, estimator: Optional[Estimator] = None) -> None: + """ + Should be called before using the policy. This is used to set up any + dynamic state. + + If this routing policy needs an estimator, one should be provided here. + """ + for policy in self.indefinite_policies: + await policy.run_setup(estimator) + await self.definite_policy.run_setup(estimator) + + def __eq__(self, other: object): + if not isinstance(other, FullRoutingPolicy): + return False + return (self.indefinite_policies == other.indefinite_policies) and ( + self.definite_policy == other.definite_policy + ) diff --git a/src/brad/routing/always_one.py b/src/brad/routing/always_one.py index b52a79a4..cc51243b 100644 --- a/src/brad/routing/always_one.py +++ b/src/brad/routing/always_one.py @@ -1,9 +1,11 @@ +from typing import List + from brad.config.engine import Engine from brad.query_rep import QueryRep -from brad.routing.router import Router +from brad.routing.abstract_policy import AbstractRoutingPolicy -class AlwaysOneRouter(Router): +class AlwaysOneRouter(AbstractRoutingPolicy): """ This router always selects the same database engine for all queries. This router is useful for testing and benchmarking purposes. @@ -11,7 +13,14 @@ class AlwaysOneRouter(Router): def __init__(self, db_type: Engine): super().__init__() - self._always_route_to = db_type + self._engine = db_type + self._always_route_to = [db_type] + + def name(self) -> str: + return f"AlwaysRouteTo({self._engine.name})" - def engine_for_sync(self, _query: QueryRep) -> Engine: + def engine_for_sync(self, _query: QueryRep) -> List[Engine]: return self._always_route_to + + def __eq__(self, other: object) -> bool: + return isinstance(other, AlwaysOneRouter) and self._engine == other._engine diff --git a/src/brad/routing/location_aware_round_robin.py b/src/brad/routing/location_aware_round_robin.py deleted file mode 100644 index 5562f871..00000000 --- a/src/brad/routing/location_aware_round_robin.py +++ /dev/null @@ -1,30 +0,0 @@ -from brad.config.engine import Engine -from brad.blueprint.manager import BlueprintManager -from brad.routing.router import Router -from brad.query_rep import QueryRep - - -class LocationAwareRoundRobin(Router): - """ - Routes queries in a "roughly" round-robin fashion, taking into account the - locations of the tables referenced. - """ - - def __init__(self, blueprint_mgr: BlueprintManager): - self._blueprint_mgr = blueprint_mgr - self._curr_idx = 0 - - def engine_for_sync(self, query: QueryRep) -> Engine: - blueprint = self._blueprint_mgr.get_blueprint() - valid_locations, only_location = self._run_location_routing( - query, blueprint.table_locations_bitmap() - ) - if only_location is not None: - return only_location - - locations = Engine.from_bitmap(valid_locations) - self._curr_idx %= len(locations) - selected_location = locations[self._curr_idx] - self._curr_idx += 1 - - return selected_location diff --git a/src/brad/routing/policy.py b/src/brad/routing/policy.py index b2ba54ae..b621a089 100644 --- a/src/brad/routing/policy.py +++ b/src/brad/routing/policy.py @@ -2,6 +2,11 @@ class RoutingPolicy(str, enum.Enum): + """ + This is used to override the policy specified by the blueprint (usually for + testing purposes). + """ + Default = "default" AlwaysAthena = "always_athena" AlwaysAurora = "always_aurora" @@ -27,4 +32,4 @@ def from_str(candidate: str) -> "RoutingPolicy": elif candidate == RoutingPolicy.ForestTableSelectivity.value: return RoutingPolicy.ForestTableSelectivity else: - raise ValueError("Unrecognized DB type {}".format(candidate)) + raise ValueError("Unrecognized policy {}".format(candidate)) diff --git a/src/brad/routing/round_robin.py b/src/brad/routing/round_robin.py new file mode 100644 index 00000000..8ba7d304 --- /dev/null +++ b/src/brad/routing/round_robin.py @@ -0,0 +1,27 @@ +from typing import List + +from brad.config.engine import Engine +from brad.query_rep import QueryRep +from brad.routing.abstract_policy import AbstractRoutingPolicy + + +class RoundRobin(AbstractRoutingPolicy): + """ + Routes queries in a round-robin fashion. + """ + + def __init__(self): + self._ordering = [Engine.Athena, Engine.Aurora, Engine.Redshift] + + def name(self) -> str: + return "RoundRobin" + + def engine_for_sync(self, _query: QueryRep) -> List[Engine]: + tmp = self._ordering[0] + self._ordering[0] = self._ordering[1] + self._ordering[1] = self._ordering[2] + self._ordering[2] = tmp + return self._ordering + + def __eq__(self, other: object) -> bool: + return isinstance(other, RoundRobin) diff --git a/src/brad/routing/router.py b/src/brad/routing/router.py index 652b3d14..f94fc998 100644 --- a/src/brad/routing/router.py +++ b/src/brad/routing/router.py @@ -1,14 +1,54 @@ +import asyncio +import logging from typing import Dict, Tuple, Optional, TYPE_CHECKING from brad.data_stats.estimator import Estimator from brad.config.engine import Engine, EngineBitmapValues from brad.query_rep import QueryRep +from brad.routing.abstract_policy import AbstractRoutingPolicy, FullRoutingPolicy if TYPE_CHECKING: from brad.blueprint import Blueprint +logger = logging.getLogger(__name__) + class Router: + @classmethod + def create_from_blueprint(cls, blueprint: "Blueprint") -> "Router": + return cls( + blueprint.get_routing_policy(), + blueprint.table_locations_bitmap(), + use_future_blueprint_policies=True, + ) + + @classmethod + def create_from_definite_policy( + cls, policy: AbstractRoutingPolicy, table_placement_bitmap: Dict[str, int] + ) -> "Router": + return cls( + FullRoutingPolicy(indefinite_policies=[], definite_policy=policy), + table_placement_bitmap, + use_future_blueprint_policies=False, + ) + + def __init__( + self, + full_policy: FullRoutingPolicy, + table_placement_bitmap: Dict[str, int], + use_future_blueprint_policies: bool, + ) -> None: + self._full_policy = full_policy + self._table_placement_bitmap = table_placement_bitmap + self._use_future_blueprint_policies = use_future_blueprint_policies + + def log_policy(self) -> None: + logger.info("Routing policy:") + logger.info(" Indefinite policies:") + for p in self._full_policy.indefinite_policies: + logger.info(" - %s", p.name()) + logger.info(" Definite policy: %s") + async def run_setup(self, estimator: Optional[Estimator] = None) -> None: """ Should be called before using the router. This is used to set up any @@ -16,25 +56,47 @@ async def run_setup(self, estimator: Optional[Estimator] = None) -> None: If the routing policy needs an estimator, one should be provided here. """ + await self._full_policy.run_setup(estimator) def update_blueprint(self, blueprint: "Blueprint") -> None: """ Used to update any cached state that depends on the blueprint (e.g., location bitmaps). """ + self._table_placement_bitmap = blueprint.table_locations_bitmap() + if self._use_future_blueprint_policies: + self._full_policy = blueprint.get_routing_policy() async def engine_for(self, query: QueryRep) -> Engine: """ Selects an engine for the provided SQL query. - - NOTE: Implementers currently do not need to consider DML queries. BRAD - routes all DML queries to Aurora before consulting the router. Thus the - query passed to this method will always be a read-only query. - - You should override this method if the routing policy needs to depend on - any asynchronous methods. """ - return self.engine_for_sync(query) + + # Table placement constraints. + assert self._table_placement_bitmap is not None + valid_locations, only_location = self._run_location_routing( + query, self._table_placement_bitmap + ) + if only_location is not None: + return only_location + + # Go through the indefinite routing policies. These may not return a + # routing location. + for policy in self._full_policy.indefinite_policies: + locations = await policy.engine_for(query) + for loc in locations: + if (EngineBitmapValues[loc] & valid_locations) != 0: + return loc + + # Rely on the definite routing policy. + locations = await self._full_policy.definite_policy.engine_for(query) + for loc in locations: + if (EngineBitmapValues[loc] & valid_locations) != 0: + return loc + + # This should be unreachable. The definite policy must rank all engines, + # and we know >= 2 engines can support this query. + raise AssertionError def engine_for_sync(self, query: QueryRep) -> Engine: """ @@ -44,7 +106,8 @@ def engine_for_sync(self, query: QueryRep) -> Engine: routes all DML queries to Aurora before consulting the router. Thus the query passed to this method will always be a read-only query. """ - raise NotImplementedError + # Ideally we re-implement a sync version. + return asyncio.run(self.engine_for(query)) def _run_location_routing( self, query: QueryRep, location_bitmap: Dict[str, int] diff --git a/src/brad/routing/rule_based.py b/src/brad/routing/rule_based.py index 25169c88..f0bf50ef 100644 --- a/src/brad/routing/rule_based.py +++ b/src/brad/routing/rule_based.py @@ -1,15 +1,14 @@ import os.path import json import logging -from typing import List, Optional, Mapping, MutableMapping, Any, Dict +from typing import List, Optional, Mapping, MutableMapping, Any from importlib.resources import files, as_file import brad.routing from brad.blueprint import Blueprint from brad.config.engine import Engine -from brad.blueprint.manager import BlueprintManager from brad.daemon.monitor import Monitor -from brad.routing.router import Router +from brad.routing.abstract_policy import AbstractRoutingPolicy from brad.query_rep import QueryRep from brad.front_end.session import SessionManager @@ -54,21 +53,12 @@ def __init__(self) -> None: self.redshift_parameters_lower_limit = redshift_parameters_lower_limit -class RuleBased(Router): +class RuleBased(AbstractRoutingPolicy): def __init__( self, - # One of `blueprint_mgr` and `blueprint` must not be `None`. - blueprint_mgr: Optional[BlueprintManager] = None, - blueprint: Optional[Blueprint] = None, - table_placement_bitmap: Optional[Dict[str, int]] = None, monitor: Optional[Monitor] = None, catalog: Optional[MutableMapping[str, MutableMapping[str, Any]]] = None, - use_decision_tree: bool = False, - deterministic: bool = True, ): - self._blueprint_mgr = blueprint_mgr - self._blueprint = blueprint - self._table_placement_bitmap = table_placement_bitmap self._monitor = monitor # catalog contains all tables' number of rows and columns self._catalog = catalog @@ -78,18 +68,17 @@ def __init__( if os.path.exists(file): with open(file, "r", encoding="utf8") as f: self._catalog = json.load(f) - # use decision tree instead of rules - self._use_decision_tree = use_decision_tree - # deterministic routing guarantees the same decision for the same query and should be used online - # non-determinism will be used for offline training data exploration (not implemented) - self._deterministic = deterministic self._params = RuleBasedParams() - def update_blueprint(self, blueprint: Blueprint) -> None: - self._blueprint = blueprint - self._table_placement_bitmap = blueprint.table_locations_bitmap() + def name(self) -> str: + return "RuleBased" - async def recollect_catalog(self, sessions: SessionManager) -> None: + def __eq__(self, other: object) -> bool: + return isinstance(other, RuleBased) + + async def recollect_catalog( + self, sessions: SessionManager, blueprint: Blueprint + ) -> None: # recollect catalog stats; happens every maintenance window if self._catalog is None: self._catalog = dict() @@ -102,12 +91,6 @@ async def recollect_catalog(self, sessions: SessionManager) -> None: connection = session.engines.get_connection(Engine.Aurora) cursor = await connection.cursor() - if self._blueprint is not None: - blueprint = self._blueprint - else: - assert self._blueprint_mgr is not None - blueprint = self._blueprint_mgr.get_blueprint() - indexes_sql = ( "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' " "ORDER BY tablename, indexname;" @@ -207,25 +190,8 @@ def check_engine_state( return True return not_overloaded - def engine_for_sync(self, query: QueryRep) -> Engine: - if self._table_placement_bitmap is None: - if self._blueprint is not None: - blueprint = self._blueprint - else: - assert self._blueprint_mgr is not None - blueprint = self._blueprint_mgr.get_blueprint() - self._table_placement_bitmap = blueprint.table_locations_bitmap() - - valid_locations, only_location = self._run_location_routing( - query, self._table_placement_bitmap - ) - if only_location is not None: - return only_location - - locations = Engine.from_bitmap(valid_locations) - assert len(locations) > 1 - ideal_location_rank: List[Engine] = [] - touched_tables = query.tables() + def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]: + touched_tables = query_rep.tables() if ( len(touched_tables) < self._params.ideal_location_lower_limit["redshift_num_table"] @@ -239,7 +205,7 @@ def engine_for_sync(self, query: QueryRep) -> Engine: if self._catalog: n_rows = [] n_cols = [] - for table_name in query.tables(): + for table_name in query_rep.tables(): if table_name in self._catalog: n_rows.append(self._catalog[table_name]["nrow"]) n_cols.append(self._catalog[table_name]["ncol"]) @@ -269,7 +235,7 @@ def engine_for_sync(self, query: QueryRep) -> Engine: ideal_location_rank = [Engine.Redshift, Engine.Athena, Engine.Aurora] if self._catalog: n_rows = [] - for table_name in query.tables(): + for table_name in query_rep.tables(): if table_name in self._catalog: n_rows.append(self._catalog[table_name]["nrow"]) if ( @@ -287,11 +253,8 @@ def engine_for_sync(self, query: QueryRep) -> Engine: ideal_location_rank = [Engine.Athena, Engine.Redshift, Engine.Aurora] if self._monitor is None: - for loc in ideal_location_rank: - if loc in locations: - return loc - # This should be unreachable since len(locations) > 0. - assert False + return ideal_location_rank + else: # Todo(Ziniu): this can be stored in this class to reduce latency raw_aurora_metrics = ( @@ -304,21 +267,14 @@ def engine_for_sync(self, query: QueryRep) -> Engine: logger.warning( "Routing without system metrics when we expect to have metrics." ) - return ideal_location_rank[0] + return ideal_location_rank aurora_metrics = raw_aurora_metrics.iloc[0].to_dict() redshift_metrics = raw_redshift_metrics.iloc[0].to_dict() for loc in ideal_location_rank: - if loc in locations and self.check_engine_state( - loc, aurora_metrics, redshift_metrics - ): - return loc + if self.check_engine_state(loc, aurora_metrics, redshift_metrics): + return [loc, *[il for il in ideal_location_rank if il != loc]] # In the case of all system are overloaded (time to trigger replan), # we assign it to the optimal one. But Athena should not be overloaded at any time - for loc in ideal_location_rank: - if loc in locations: - return loc - - # Should be unreachable since len(locations) > 0. - assert False + return ideal_location_rank diff --git a/src/brad/routing/tree_based/forest_policy.py b/src/brad/routing/tree_based/forest_policy.py new file mode 100644 index 00000000..6b90bcf7 --- /dev/null +++ b/src/brad/routing/tree_based/forest_policy.py @@ -0,0 +1,92 @@ +import asyncio +from typing import Optional, List, Dict, Any + +from brad.asset_manager import AssetManager +from brad.config.engine import Engine +from brad.data_stats.estimator import Estimator +from brad.query_rep import QueryRep +from brad.routing.abstract_policy import AbstractRoutingPolicy +from brad.routing.policy import RoutingPolicy +from brad.routing.tree_based.model_wrap import ModelWrap + + +class ForestPolicy(AbstractRoutingPolicy): + @classmethod + async def from_assets( + cls, schema_name: str, policy: RoutingPolicy, assets: AssetManager + ) -> "ForestPolicy": + model = cls.static_load_model_sync(schema_name, policy, assets) + return cls(policy, model) + + @classmethod + def from_loaded_model( + cls, policy: RoutingPolicy, model: ModelWrap + ) -> "ForestPolicy": + return cls(policy, model) + + def __init__(self, policy: RoutingPolicy, model: ModelWrap) -> None: + self._policy = policy + self._model = model + self._estimator: Optional[Estimator] = None + + def __getstate__(self) -> Dict[Any, Any]: + return { + "policy": self._policy, + "model": self._model, + } + + def __setstate__(self, d: Dict[Any, Any]) -> None: + self._policy = d["policy"] + self._model = d["model"] + self._estimator = None + + def name(self) -> str: + return f"ForestPolicy({self._policy.name})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ForestPolicy): + return False + return self._policy == other._policy and self._model == other._model + + async def run_setup(self, estimator: Optional[Estimator] = None) -> None: + self._estimator = estimator + + async def engine_for(self, query_rep: QueryRep) -> List[Engine]: + return await self._model.engine_for(query_rep, self._estimator) + + def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]: + return asyncio.run(self.engine_for(query_rep)) + + # The methods below are used to save/load `ModelWrap` from S3. We + # historically separated out the model's implementation details because the + # router contained state that was not serializable. This separation is kept + # around for legacy reasons now (this policy class should be directly + # serialized). + + @staticmethod + def static_persist_sync( + model: ModelWrap, schema_name: str, assets: AssetManager + ) -> None: + key = _SERIALIZED_KEY.format( + schema_name=schema_name, policy=model.policy().value + ) + serialized = model.to_pickle() + assets.persist_sync(key, serialized) + + @staticmethod + def static_load_model_sync( + schema_name: str, policy: RoutingPolicy, assets: AssetManager + ) -> ModelWrap: + key = _SERIALIZED_KEY.format(schema_name=schema_name, policy=policy.value) + serialized = assets.load_sync(key) + return ModelWrap.from_pickle_bytes(serialized) + + @staticmethod + def static_drop_model_sync( + schema_name: str, policy: RoutingPolicy, assets: AssetManager + ) -> None: + key = _SERIALIZED_KEY.format(schema_name=schema_name, policy=policy.value) + assets.delete_sync(key) + + +_SERIALIZED_KEY = "{schema_name}/{policy}-router.pickle" diff --git a/src/brad/routing/tree_based/forest_router.py b/src/brad/routing/tree_based/forest_router.py deleted file mode 100644 index aae48877..00000000 --- a/src/brad/routing/tree_based/forest_router.py +++ /dev/null @@ -1,140 +0,0 @@ -import asyncio -from typing import Optional, Dict - -from .model_wrap import ModelWrap -from brad.asset_manager import AssetManager -from brad.blueprint import Blueprint -from brad.config.engine import Engine, EngineBitmapValues -from brad.data_stats.estimator import Estimator -from brad.query_rep import QueryRep -from brad.routing.policy import RoutingPolicy -from brad.routing.router import Router -from brad.blueprint.manager import BlueprintManager - - -class ForestRouter(Router): - @classmethod - def for_server( - cls, - policy: RoutingPolicy, - schema_name: str, - assets: AssetManager, - blueprint_mgr: BlueprintManager, - ) -> "ForestRouter": - return cls(policy, schema_name, assets=assets, blueprint_mgr=blueprint_mgr) - - @classmethod - def for_planner( - cls, - policy: RoutingPolicy, - schema_name: str, - model: ModelWrap, - table_bitmap: Dict[str, int], - ) -> "ForestRouter": - return cls( - policy, schema_name, model=model, table_placement_bitmap=table_bitmap - ) - - def __init__( - self, - policy: RoutingPolicy, - schema_name: str, - # One of `assets` and `model` most not be `None`. - assets: Optional[AssetManager] = None, - model: Optional[ModelWrap] = None, - # One of `blueprint_mgr`, `blueprint`, and `table_placement_bitmap` must not be `None`. - blueprint_mgr: Optional[BlueprintManager] = None, - blueprint: Optional[Blueprint] = None, - table_placement_bitmap: Optional[Dict[str, int]] = None, - ) -> None: - self._policy = policy - self._schema_name = schema_name - self._model = model - self._assets = assets - - self._blueprint_mgr = blueprint_mgr - self._blueprint = blueprint - self._table_placement_bitmap = table_placement_bitmap - self._estimator: Optional[Estimator] = None - - async def run_setup(self, estimator: Optional[Estimator] = None) -> None: - self._estimator = estimator - - # Load the model. - if self._model is None: - assert self._assets is not None - serialized_model = await self._assets.load( - _SERIALIZED_KEY.format( - schema_name=self._schema_name, policy=self._policy.value - ) - ) - self._model = ModelWrap.from_pickle_bytes(serialized_model) - - # Load the table placement if it was not provided. - if self._table_placement_bitmap is None: - if self._blueprint is None: - assert self._blueprint_mgr is not None - self._blueprint = self._blueprint_mgr.get_blueprint() - - self._table_placement_bitmap = self._blueprint.table_locations_bitmap() - - def update_blueprint(self, blueprint: Blueprint) -> None: - self._blueprint = blueprint - self._table_placement_bitmap = blueprint.table_locations_bitmap() - - async def engine_for(self, query: QueryRep) -> Engine: - # Compute valid locations. - assert self._table_placement_bitmap is not None - valid_locations, only_location = self._run_location_routing( - query, self._table_placement_bitmap - ) - if only_location is not None: - return only_location - - # Multiple locations possible. Use the model to figure out which location to use. - assert self._model is not None - preferred_locations = await self._model.engine_for(query, self._estimator) - - for loc in preferred_locations: - if (EngineBitmapValues[loc] & valid_locations) != 0: - return loc - - # This should be unreachable. The model must rank all engines, and we - # know >= 2 engines can support this query. - raise AssertionError - - def engine_for_sync(self, query: QueryRep) -> Engine: - return asyncio.run(self.engine_for(query)) - - def persist_sync(self) -> None: - assert self._assets is not None - assert self._model is not None - self.static_persist_sync(self._model, self._schema_name, self._assets) - - @staticmethod - def static_persist_sync( - model: ModelWrap, schema_name: str, assets: AssetManager - ) -> None: - key = _SERIALIZED_KEY.format( - schema_name=schema_name, policy=model.policy().value - ) - serialized = model.to_pickle() - assets.persist_sync(key, serialized) - - @staticmethod - def static_load_model_sync( - schema_name: str, policy: RoutingPolicy, assets: AssetManager - ) -> ModelWrap: - key = _SERIALIZED_KEY.format(schema_name=schema_name, policy=policy.value) - serialized = assets.load_sync(key) - return ModelWrap.from_pickle_bytes(serialized) - - @staticmethod - def static_drop_model_sync( - schema_name: str, policy: RoutingPolicy, assets: AssetManager - ) -> None: - key = _SERIALIZED_KEY.format(schema_name=schema_name, policy=policy.value) - assets.delete_sync(key) - - -_SERIALIZED_KEY = "{schema_name}/{policy}-router.pickle" diff --git a/src/brad/routing/tree_based/model_wrap.py b/src/brad/routing/tree_based/model_wrap.py index 2ff068fa..42a496ad 100644 --- a/src/brad/routing/tree_based/model_wrap.py +++ b/src/brad/routing/tree_based/model_wrap.py @@ -49,6 +49,20 @@ async def engine_for( low_to_high = np.argsort(preds) return [ENGINE_LABELS[label] for label in reversed(low_to_high)] + def __eq__(self, other: object) -> bool: + if not isinstance(other, ModelWrap): + return False + if self._policy != other._policy or self._table_order != other._table_order: + return False + + if id(self._model) == id(other._model): + return True + + # Not very ideal, but this will check for identical copies. + serialized = pickle.dumps(self._model) + other_serialized = pickle.dumps(other._model) + return serialized == other_serialized + def to_pickle(self) -> bytes: # TODO: Pickling might not be the best option. return pickle.dumps(self) diff --git a/tests/test_always_routing.py b/tests/test_always_routing.py index 7b5a6288..469de3b0 100644 --- a/tests/test_always_routing.py +++ b/tests/test_always_routing.py @@ -7,10 +7,10 @@ def test_always_route_aurora(): router = AlwaysOneRouter(db) pred_db = router.engine_for_sync("SELECT 1") - assert pred_db == db + assert pred_db == [db] pred_db = router.engine_for_sync("SELECT * FROM my_table") - assert pred_db == db + assert pred_db == [db] def test_always_route_athena(): @@ -18,10 +18,10 @@ def test_always_route_athena(): router = AlwaysOneRouter(db) pred_db = router.engine_for_sync("SELECT 1") - assert pred_db == db + assert pred_db == [db] pred_db = router.engine_for_sync("SELECT * FROM my_table") - assert pred_db == db + assert pred_db == [db] def test_always_route_redshift(): @@ -29,7 +29,7 @@ def test_always_route_redshift(): router = AlwaysOneRouter(db) pred_db = router.engine_for_sync("SELECT 1") - assert pred_db == db + assert pred_db == [db] pred_db = router.engine_for_sync("SELECT * FROM my_table") - assert pred_db == db + assert pred_db == [db] diff --git a/tests/test_blueprint_diff.py b/tests/test_blueprint_diff.py index a4215121..8ae9dc67 100644 --- a/tests/test_blueprint_diff.py +++ b/tests/test_blueprint_diff.py @@ -5,6 +5,8 @@ from brad.blueprint.user import UserProvidedBlueprint from brad.config.engine import Engine from brad.planner.data import bootstrap_blueprint +from brad.routing.abstract_policy import FullRoutingPolicy +from brad.routing.always_one import AlwaysOneRouter def test_no_diff(): @@ -72,7 +74,7 @@ def test_provisioning_change(): initial.table_locations(), initial.aurora_provisioning(), Provisioning(instance_type="dc2.large", num_nodes=4), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) diff = BlueprintDiff.of(initial, changed) assert diff is not None diff --git a/tests/test_enumeration.py b/tests/test_enumeration.py index 028310b5..fd65334c 100644 --- a/tests/test_enumeration.py +++ b/tests/test_enumeration.py @@ -4,6 +4,8 @@ from brad.planner.enumeration.provisioning import ProvisioningEnumerator from brad.planner.enumeration.table_locations import TableLocationEnumerator from brad.planner.enumeration.neighborhood import NeighborhoodBlueprintEnumerator +from brad.routing.abstract_policy import FullRoutingPolicy +from brad.routing.always_one import AlwaysOneRouter def test_provisioning_enumerate_aurora(): @@ -65,7 +67,7 @@ def test_blueprint_enumerate(): {"table1": [Engine.Aurora], "table2": [Engine.Redshift]}, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) # Simple sanity check only. diff --git a/tests/test_forest_routing.py b/tests/test_forest_routing.py index 4e03cabb..095dc810 100644 --- a/tests/test_forest_routing.py +++ b/tests/test_forest_routing.py @@ -2,8 +2,8 @@ from sklearn.ensemble import RandomForestClassifier -from brad.config.engine import Engine, EngineBitmapValues -from brad.routing.tree_based.forest_router import ForestRouter +from brad.config.engine import Engine +from brad.routing.tree_based.forest_policy import ForestPolicy from brad.routing.tree_based.model_wrap import ModelWrap from brad.routing.policy import RoutingPolicy from brad.query_rep import QueryRep @@ -18,54 +18,23 @@ def get_dummy_router(): return ModelWrap(RoutingPolicy.ForestTablePresence, ["test1", "test2"], model) -def test_location_constraints(): - model = get_dummy_router() - bitmap = { - "test1": EngineBitmapValues[Engine.Aurora], - "test2": EngineBitmapValues[Engine.Aurora] - | EngineBitmapValues[Engine.Redshift], - } - router = ForestRouter.for_planner( - RoutingPolicy.ForestTablePresence, "test_schema", model, bitmap - ) - - query1 = QueryRep("SELECT * FROM test1") - loc = router.engine_for_sync(query1) - assert loc == Engine.Aurora - - query2 = QueryRep("SELECT * FROM test1, test2") - loc = router.engine_for_sync(query2) - assert loc == Engine.Aurora - - def test_model_codepath_partial(): model = get_dummy_router() - bitmap = { - "test1": EngineBitmapValues[Engine.Aurora] - | EngineBitmapValues[Engine.Redshift], - "test2": EngineBitmapValues[Engine.Aurora] - | EngineBitmapValues[Engine.Redshift], - } - router = ForestRouter.for_planner( - RoutingPolicy.ForestTablePresence, "test_schema", model, bitmap - ) + router = ForestPolicy.from_loaded_model(RoutingPolicy.ForestTablePresence, model) query = QueryRep("SELECT * FROM test1, test2") loc = router.engine_for_sync(query) - assert loc == Engine.Aurora or loc == Engine.Redshift + assert ( + loc[0] == Engine.Aurora or loc[0] == Engine.Redshift or loc[0] == Engine.Athena + ) def test_model_codepath_all(): model = get_dummy_router() - bitmap = { - "test1": Engine.bitmap_all(), - "test2": EngineBitmapValues[Engine.Aurora] - | EngineBitmapValues[Engine.Redshift], - } - router = ForestRouter.for_planner( - RoutingPolicy.ForestTablePresence, "test_schema", model, bitmap - ) + router = ForestPolicy.from_loaded_model(RoutingPolicy.ForestTablePresence, model) query = QueryRep("SELECT * FROM test1") loc = router.engine_for_sync(query) - assert loc == Engine.Aurora or loc == Engine.Redshift or loc == Engine.Athena + assert ( + loc[0] == Engine.Aurora or loc[0] == Engine.Redshift or loc[0] == Engine.Athena + ) diff --git a/tests/test_location_routing.py b/tests/test_location_routing.py index 4f9b9eaa..eacfa16c 100644 --- a/tests/test_location_routing.py +++ b/tests/test_location_routing.py @@ -1,12 +1,13 @@ from brad.config.engine import Engine, EngineBitmapValues from brad.routing.router import Router +from brad.routing.round_robin import RoundRobin from brad.query_rep import QueryRep def test_only_one_location(): query = QueryRep("SELECT * FROM test") bitmap = {"test": EngineBitmapValues[Engine.Aurora]} - r = Router() + r = Router.create_from_definite_policy(RoundRobin(), bitmap) # pylint: disable-next=protected-access valid_locations, only_location = r._run_location_routing(query, bitmap) assert only_location is not None @@ -22,7 +23,7 @@ def test_multiple_locations(): EngineBitmapValues[Engine.Redshift] | EngineBitmapValues[Engine.Athena] ), } - r = Router() + r = Router.create_from_definite_policy(RoundRobin(), bitmap) # pylint: disable-next=protected-access valid_locations, only_location = r._run_location_routing(query, bitmap) assert only_location is None diff --git a/tests/test_planner_filters.py b/tests/test_planner_filters.py index 1c3d712f..be001d2a 100644 --- a/tests/test_planner_filters.py +++ b/tests/test_planner_filters.py @@ -13,6 +13,8 @@ from brad.planner.neighborhood.filters.table_on_engine import TableOnEngine from brad.planner.workload import Workload from brad.planner.workload.query import Query +from brad.routing.abstract_policy import FullRoutingPolicy +from brad.routing.always_one import AlwaysOneRouter def workload_from_queries(query_list: List[str]) -> Workload: @@ -42,7 +44,7 @@ def test_aurora_transactions(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp2 = Blueprint( "schema", @@ -56,7 +58,7 @@ def test_aurora_transactions(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp_filter1 = AuroraTransactions(workload1) @@ -90,7 +92,7 @@ def test_single_engine_execution(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp2 = Blueprint( "schema", @@ -106,7 +108,7 @@ def test_single_engine_execution(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp_filter1 = SingleEngineExecution(workload1) @@ -136,7 +138,7 @@ def test_table_on_engine(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp2 = Blueprint( "schema", @@ -152,7 +154,7 @@ def test_table_on_engine(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 0), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp3 = Blueprint( "schema", @@ -168,7 +170,7 @@ def test_table_on_engine(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 0), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) assert bp_filter.is_valid(bp1) @@ -192,7 +194,7 @@ def test_no_data_loss(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp2 = Blueprint( "schema", @@ -208,7 +210,7 @@ def test_no_data_loss(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) assert not ndl_filter.is_valid(bp1)