From a6c6f0b0cc0606f195b480411a7aa0763c60550d Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sat, 4 Nov 2023 14:01:06 -0400 Subject: [PATCH 01/11] Implement unified routing policy abstraction --- src/brad/routing/abstract_policy.py | 74 +++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 src/brad/routing/abstract_policy.py diff --git a/src/brad/routing/abstract_policy.py b/src/brad/routing/abstract_policy.py new file mode 100644 index 00000000..66979abe --- /dev/null +++ b/src/brad/routing/abstract_policy.py @@ -0,0 +1,74 @@ +from typing import List, Optional + +from brad.config.engine import Engine +from brad.planner.estimator import Estimator +from brad.query_rep import QueryRep + + +class AbstractRoutingPolicy: + """ + Note that implementers must be serializable. + """ + + def name(self) -> str: + raise NotImplementedError + + async def run_setup(self, estimator: Optional[Estimator] = None) -> None: + """ + Should be called before using this policy. This is used to set up any + dynamic state. + + If this routing policy needs an estimator, one should be provided here. + """ + + async def engine_for(self, query_rep: QueryRep) -> List[Engine]: + """ + Produces a preference order for query routing (the first element in the + list is the most preferred engine, and so on). + + NOTE: Implementers currently do not need to consider DML queries. BRAD + routes all DML queries to Aurora before consulting the router. Thus the + query passed to this method will always be a read-only query. + + You should override this method if the routing policy needs to depend on + any asynchronous methods. + """ + return self.engine_for_sync(query_rep) + + def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]: + """ + Produces a preference order for query routing (the first element in the + list is the most preferred engine, and so on). + + NOTE: Implementers currently do not need to consider DML queries. BRAD + routes all DML queries to Aurora before consulting the router. Thus the + query passed to this method will always be a read-only query. + """ + raise NotImplementedError + + +class FullRoutingPolicy: + """ + Captures a full routing policy for serialization purposes. Indefinite + policies are allowed to return empty preference lists (indicating no routing + decision). + """ + + def __init__( + self, + indefinite_policies: List[AbstractRoutingPolicy], + definite_policy: AbstractRoutingPolicy, + ) -> None: + self.indefinite_policies = indefinite_policies + self.definite_policy = definite_policy + + async def run_setup(self, estimator: Optional[Estimator] = None) -> None: + """ + Should be called before using the policy. This is used to set up any + dynamic state. + + If this routing policy needs an estimator, one should be provided here. + """ + for policy in self.indefinite_policies: + await policy.run_setup(estimator) + await self.definite_policy.run_setup(estimator) From abe5f56e668b7dae440048a1331f3a7b0098f52e Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sat, 4 Nov 2023 15:10:49 -0400 Subject: [PATCH 02/11] WIP: Update the routing pipeline to allow modifications during planning --- src/brad/blueprint/blueprint.py | 17 ++--- src/brad/blueprint/serde.py | 19 ++++- .../planner/beam/query_based_candidate.py | 2 +- .../planner/beam/table_based_candidate.py | 2 +- src/brad/planner/data.py | 5 +- src/brad/planner/enumeration/blueprint.py | 4 +- .../routing/location_aware_round_robin.py | 30 -------- src/brad/routing/policy.py | 7 +- src/brad/routing/round_robin.py | 21 ++++++ src/brad/routing/router.py | 72 ++++++++++++++++--- 10 files changed, 121 insertions(+), 58 deletions(-) delete mode 100644 src/brad/routing/location_aware_round_robin.py create mode 100644 src/brad/routing/round_robin.py diff --git a/src/brad/blueprint/blueprint.py b/src/brad/blueprint/blueprint.py index acefc006..48b2dfe5 100644 --- a/src/brad/blueprint/blueprint.py +++ b/src/brad/blueprint/blueprint.py @@ -1,11 +1,9 @@ -from typing import Callable, Dict, List, Set, Optional, Tuple, Any +from typing import Dict, List, Set, Optional, Tuple, Any from brad.blueprint.provisioning import Provisioning from brad.blueprint.table import Table from brad.config.engine import Engine -from brad.routing.router import Router - -RouterProvider = Callable[[], Router] +from brad.routing.abstract_policy import FullRoutingPolicy class Blueprint: @@ -16,14 +14,14 @@ def __init__( table_locations: Dict[str, List[Engine]], aurora_provisioning: Provisioning, redshift_provisioning: Provisioning, - router_provider: Optional[RouterProvider], + full_routing_policy: FullRoutingPolicy, ): self._schema_name = schema_name self._table_schemas = table_schemas self._table_locations = table_locations self._aurora_provisioning = aurora_provisioning self._redshift_provisioning = redshift_provisioning - self._router_provider = router_provider + self._full_routing_policy = full_routing_policy # Derived properties used for the class' methods. self._tables_by_name = {tbl.name: tbl for tbl in self._table_schemas} @@ -57,11 +55,8 @@ def aurora_provisioning(self) -> Provisioning: def redshift_provisioning(self) -> Provisioning: return self._redshift_provisioning - def get_router(self) -> Optional[Router]: - return self._router_provider() if self._router_provider is not None else None - - def router_provider(self) -> Optional[RouterProvider]: - return self._router_provider + def get_routing_policy(self) -> FullRoutingPolicy: + return self._full_routing_policy def base_table_names(self) -> Set[str]: return self._base_table_names diff --git a/src/brad/blueprint/serde.py b/src/brad/blueprint/serde.py index f821624a..60ebf002 100644 --- a/src/brad/blueprint/serde.py +++ b/src/brad/blueprint/serde.py @@ -1,9 +1,11 @@ +import pickle from typing import Tuple, List, Dict from brad.blueprint import Blueprint from brad.blueprint.provisioning import Provisioning from brad.blueprint.table import Column, Table from brad.config.engine import Engine +from brad.routing.abstract_policy import FullRoutingPolicy import brad.proto_gen.blueprint_pb2 as b @@ -24,7 +26,7 @@ def deserialize_blueprint(raw_data: bytes) -> Blueprint: table_locations=dict(map(_table_locations_from_proto, proto.tables)), aurora_provisioning=_provisioning_from_proto(proto.aurora), redshift_provisioning=_provisioning_from_proto(proto.redshift), - router_provider=None, + full_routing_policy=_policy_from_proto(proto.policy), ) @@ -34,7 +36,7 @@ def serialize_blueprint(blueprint: Blueprint) -> bytes: tables=map(_tables_with_locations_to_proto, blueprint.tables_with_locations()), aurora=_provisioning_to_proto(blueprint.aurora_provisioning()), redshift=_provisioning_to_proto(blueprint.redshift_provisioning()), - policy=None, + policy=_policy_to_proto(blueprint.get_routing_policy()), ) return proto.SerializeToString() @@ -88,6 +90,13 @@ def _indexed_columns_to_proto(indexed_columns: Tuple[Column, ...]) -> b.Index: return b.Index(column_name=map(lambda col: col.name, indexed_columns)) +def _policy_to_proto(full_routing_policy: FullRoutingPolicy) -> b.RoutingPolicy: + # We just use Python pickle serialization. In the future, this should be + # something more robust. + bytes_str = pickle.dumps(full_routing_policy) + return b.RoutingPolicy(policy=bytes_str) + + # Deserialization @@ -139,3 +148,9 @@ def _indexed_columns_from_proto( for col_name in indexed_columns.column_name: col_list.append(col_map[col_name]) return tuple(col_list) + + +def _policy_from_proto(policy: b.RoutingPolicy) -> FullRoutingPolicy: + # We just use Python pickle serialization. In the future, this should be + # something more robust. + return pickle.loads(policy.policy) diff --git a/src/brad/planner/beam/query_based_candidate.py b/src/brad/planner/beam/query_based_candidate.py index 008cbd37..18a5e813 100644 --- a/src/brad/planner/beam/query_based_candidate.py +++ b/src/brad/planner/beam/query_based_candidate.py @@ -117,7 +117,7 @@ def to_blueprint(self) -> Blueprint: self.get_table_placement(), self.aurora_provisioning.clone(), self.redshift_provisioning.clone(), - self._source_blueprint.router_provider(), + self._source_blueprint.get_routing_policy(), # TODO: Use chosen policy. ) def to_score(self) -> Score: diff --git a/src/brad/planner/beam/table_based_candidate.py b/src/brad/planner/beam/table_based_candidate.py index 9ee32fbf..cdb2abac 100644 --- a/src/brad/planner/beam/table_based_candidate.py +++ b/src/brad/planner/beam/table_based_candidate.py @@ -111,7 +111,7 @@ def to_blueprint(self) -> Blueprint: self.get_table_placement(), self.aurora_provisioning.clone(), self.redshift_provisioning.clone(), - self._source_blueprint.router_provider(), + self._source_blueprint.get_routing_policy(), # TODO: Use chosen policy. ) def to_score(self) -> Score: diff --git a/src/brad/planner/data.py b/src/brad/planner/data.py index 7c6ff612..966e4c9f 100644 --- a/src/brad/planner/data.py +++ b/src/brad/planner/data.py @@ -4,6 +4,8 @@ from brad.blueprint.user import UserProvidedBlueprint from brad.blueprint.table import Table from brad.config.engine import Engine +from brad.routing.abstract_policy import FullRoutingPolicy +from brad.routing.round_robin import RoundRobin def bootstrap_blueprint(user: UserProvidedBlueprint) -> Blueprint: @@ -90,5 +92,6 @@ def process_table(table: Table, expect_standalone_base_table: bool): table_locations, user.aurora_provisioning(), user.redshift_provisioning(), - None, + # TODO: Replace the default definite policy. + FullRoutingPolicy(indefinite_policies=[], definite_policy=RoundRobin()), ) diff --git a/src/brad/planner/enumeration/blueprint.py b/src/brad/planner/enumeration/blueprint.py index 416ccc52..29b10e9a 100644 --- a/src/brad/planner/enumeration/blueprint.py +++ b/src/brad/planner/enumeration/blueprint.py @@ -21,7 +21,7 @@ def __init__(self, base_blueprint: Blueprint) -> None: base_blueprint.table_locations(), base_blueprint.aurora_provisioning(), base_blueprint.redshift_provisioning(), - base_blueprint.router_provider(), + base_blueprint.get_routing_policy(), ) self._current_locations = base_blueprint.table_locations() self._current_aurora_provisioning = base_blueprint.aurora_provisioning() @@ -57,7 +57,7 @@ def to_blueprint(self) -> Blueprint: }, aurora_provisioning=self._current_aurora_provisioning.clone(), redshift_provisioning=self._current_redshift_provisioning.clone(), - router_provider=self.router_provider(), + full_routing_policy=self.get_routing_policy(), ) # Overridden getters. diff --git a/src/brad/routing/location_aware_round_robin.py b/src/brad/routing/location_aware_round_robin.py deleted file mode 100644 index 5562f871..00000000 --- a/src/brad/routing/location_aware_round_robin.py +++ /dev/null @@ -1,30 +0,0 @@ -from brad.config.engine import Engine -from brad.blueprint.manager import BlueprintManager -from brad.routing.router import Router -from brad.query_rep import QueryRep - - -class LocationAwareRoundRobin(Router): - """ - Routes queries in a "roughly" round-robin fashion, taking into account the - locations of the tables referenced. - """ - - def __init__(self, blueprint_mgr: BlueprintManager): - self._blueprint_mgr = blueprint_mgr - self._curr_idx = 0 - - def engine_for_sync(self, query: QueryRep) -> Engine: - blueprint = self._blueprint_mgr.get_blueprint() - valid_locations, only_location = self._run_location_routing( - query, blueprint.table_locations_bitmap() - ) - if only_location is not None: - return only_location - - locations = Engine.from_bitmap(valid_locations) - self._curr_idx %= len(locations) - selected_location = locations[self._curr_idx] - self._curr_idx += 1 - - return selected_location diff --git a/src/brad/routing/policy.py b/src/brad/routing/policy.py index b2ba54ae..b621a089 100644 --- a/src/brad/routing/policy.py +++ b/src/brad/routing/policy.py @@ -2,6 +2,11 @@ class RoutingPolicy(str, enum.Enum): + """ + This is used to override the policy specified by the blueprint (usually for + testing purposes). + """ + Default = "default" AlwaysAthena = "always_athena" AlwaysAurora = "always_aurora" @@ -27,4 +32,4 @@ def from_str(candidate: str) -> "RoutingPolicy": elif candidate == RoutingPolicy.ForestTableSelectivity.value: return RoutingPolicy.ForestTableSelectivity else: - raise ValueError("Unrecognized DB type {}".format(candidate)) + raise ValueError("Unrecognized policy {}".format(candidate)) diff --git a/src/brad/routing/round_robin.py b/src/brad/routing/round_robin.py new file mode 100644 index 00000000..b780eed9 --- /dev/null +++ b/src/brad/routing/round_robin.py @@ -0,0 +1,21 @@ +from typing import List + +from brad.config.engine import Engine +from brad.query_rep import QueryRep +from brad.routing.abstract_policy import AbstractRoutingPolicy + + +class RoundRobin(AbstractRoutingPolicy): + """ + Routes queries in a round-robin fashion. + """ + + def __init__(self): + self._ordering = [Engine.Athena, Engine.Aurora, Engine.Redshift] + + def engine_for_sync(self, _query: QueryRep) -> List[Engine]: + tmp = self._ordering[0] + self._ordering[0] = self._ordering[1] + self._ordering[1] = self._ordering[2] + self._ordering[2] = tmp + return self._ordering diff --git a/src/brad/routing/router.py b/src/brad/routing/router.py index 652b3d14..f2f6af23 100644 --- a/src/brad/routing/router.py +++ b/src/brad/routing/router.py @@ -1,14 +1,44 @@ +import asyncio from typing import Dict, Tuple, Optional, TYPE_CHECKING from brad.data_stats.estimator import Estimator from brad.config.engine import Engine, EngineBitmapValues from brad.query_rep import QueryRep +from brad.routing.abstract_policy import AbstractRoutingPolicy, FullRoutingPolicy if TYPE_CHECKING: from brad.blueprint import Blueprint class Router: + @classmethod + def create_from_blueprint(cls, blueprint: "Blueprint") -> "Router": + return cls( + blueprint.get_routing_policy(), + blueprint.table_locations_bitmap(), + use_future_blueprint_policies=True, + ) + + @classmethod + def create_from_definite_policy( + cls, policy: AbstractRoutingPolicy, table_placement_bitmap: Dict[str, int] + ) -> "Router": + return cls( + FullRoutingPolicy(indefinite_policies=[], definite_policy=policy), + table_placement_bitmap, + use_future_blueprint_policies=False, + ) + + def __init__( + self, + full_policy: FullRoutingPolicy, + table_placement_bitmap: Dict[str, int], + use_future_blueprint_policies: bool, + ) -> None: + self._full_policy = full_policy + self._table_placement_bitmap = table_placement_bitmap + self._use_future_blueprint_policies = use_future_blueprint_policies + async def run_setup(self, estimator: Optional[Estimator] = None) -> None: """ Should be called before using the router. This is used to set up any @@ -16,25 +46,48 @@ async def run_setup(self, estimator: Optional[Estimator] = None) -> None: If the routing policy needs an estimator, one should be provided here. """ + await self._full_policy.run_setup(estimator) def update_blueprint(self, blueprint: "Blueprint") -> None: """ Used to update any cached state that depends on the blueprint (e.g., location bitmaps). """ + self._table_placement_bitmap = blueprint.table_locations_bitmap() + if self._use_future_blueprint_policies: + # TODO: Deserialize from the blueprint. + pass async def engine_for(self, query: QueryRep) -> Engine: """ Selects an engine for the provided SQL query. - - NOTE: Implementers currently do not need to consider DML queries. BRAD - routes all DML queries to Aurora before consulting the router. Thus the - query passed to this method will always be a read-only query. - - You should override this method if the routing policy needs to depend on - any asynchronous methods. """ - return self.engine_for_sync(query) + + # Table placement constraints. + assert self._table_placement_bitmap is not None + valid_locations, only_location = self._run_location_routing( + query, self._table_placement_bitmap + ) + if only_location is not None: + return only_location + + # Go through the indefinite routing policies. These may not return a + # routing location. + for policy in self._full_policy.indefinite_policies: + locations = await policy.engine_for(query) + for loc in locations: + if (EngineBitmapValues[loc] & valid_locations) != 0: + return loc + + # Rely on the definite routing policy. + locations = await self._full_policy.definite_policy.engine_for(query) + for loc in locations: + if (EngineBitmapValues[loc] & valid_locations) != 0: + return loc + + # This should be unreachable. The definite policy must rank all engines, + # and we know >= 2 engines can support this query. + raise AssertionError def engine_for_sync(self, query: QueryRep) -> Engine: """ @@ -44,7 +97,8 @@ def engine_for_sync(self, query: QueryRep) -> Engine: routes all DML queries to Aurora before consulting the router. Thus the query passed to this method will always be a read-only query. """ - raise NotImplementedError + # Ideally we re-implement a sync version. + return asyncio.run(self.engine_for(query)) def _run_location_routing( self, query: QueryRep, location_bitmap: Dict[str, int] From 1cb885e7b4b6c05129e90c8a70f7df44e5e6a7b7 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sat, 4 Nov 2023 17:31:01 -0400 Subject: [PATCH 03/11] WIP: Overhaul router usage --- src/brad/blueprint/blueprint.py | 2 +- src/brad/front_end/front_end.py | 101 +++++++++++-------- src/brad/planner/router_provider.py | 18 ++-- src/brad/routing/always_one.py | 14 ++- src/brad/routing/round_robin.py | 3 + src/brad/routing/router.py | 10 ++ src/brad/routing/rule_based.py | 54 +++------- src/brad/routing/tree_based/forest_policy.py | 87 ++++++++++++++++ 8 files changed, 195 insertions(+), 94 deletions(-) create mode 100644 src/brad/routing/tree_based/forest_policy.py diff --git a/src/brad/blueprint/blueprint.py b/src/brad/blueprint/blueprint.py index 48b2dfe5..b5a57902 100644 --- a/src/brad/blueprint/blueprint.py +++ b/src/brad/blueprint/blueprint.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set, Optional, Tuple, Any +from typing import Dict, List, Set, Tuple, Any from brad.blueprint.provisioning import Provisioning from brad.blueprint.table import Table diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index 5c39aa87..f0dd199c 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -35,12 +35,12 @@ from brad.front_end.grpc import BradGrpc from brad.front_end.session import SessionManager, SessionId, Session from brad.query_rep import QueryRep +from brad.routing.abstract_policy import AbstractRoutingPolicy from brad.routing.always_one import AlwaysOneRouter from brad.routing.rule_based import RuleBased -from brad.routing.location_aware_round_robin import LocationAwareRoundRobin from brad.routing.policy import RoutingPolicy from brad.routing.router import Router -from brad.routing.tree_based.forest_router import ForestRouter +from brad.routing.tree_based.forest_policy import ForestPolicy from brad.row_list import RowList from brad.utils import log_verbose from brad.utils.counter import Counter @@ -110,36 +110,9 @@ def __init__( self._config.front_end_query_latency_buffer_size ) - # We have different routing policies for performance evaluation and - # testing purposes. - routing_policy = self._config.routing_policy - self._routing_policy = routing_policy - if routing_policy == RoutingPolicy.Default: - self._router: Router = LocationAwareRoundRobin(self._blueprint_mgr) - elif routing_policy == RoutingPolicy.AlwaysAthena: - self._router = AlwaysOneRouter(Engine.Athena) - elif routing_policy == RoutingPolicy.AlwaysAurora: - self._router = AlwaysOneRouter(Engine.Aurora) - elif routing_policy == RoutingPolicy.AlwaysRedshift: - self._router = AlwaysOneRouter(Engine.Redshift) - elif routing_policy == RoutingPolicy.RuleBased: - self._monitor = Monitor(config, self._blueprint_mgr) - self._router = RuleBased( - blueprint_mgr=self._blueprint_mgr, monitor=self._monitor - ) - elif ( - routing_policy == RoutingPolicy.ForestTablePresence - or routing_policy == RoutingPolicy.ForestTableSelectivity - ): - self._router = ForestRouter.for_server( - routing_policy, self._schema_name, self._assets, self._blueprint_mgr - ) - else: - raise RuntimeError( - "Unsupported routing policy: {}".format(str(routing_policy)) - ) - logger.info("Using routing policy: %s", routing_policy) - + self._routing_policy_override = self._config.routing_policy + # This is set up as the front end starts up. + self._router: Optional[Router] = None self._sessions = SessionManager( self._config, self._blueprint_mgr, self._schema_name ) @@ -181,15 +154,8 @@ async def serve_forever(self): logger.info("The BRAD front end has successfully started.") logger.info("Listening on port %d.", port_to_use) - if self._routing_policy == RoutingPolicy.RuleBased: - assert ( - self._monitor is not None - ), "require monitor running for rule-based router" - await asyncio.gather( - self._monitor.run_forever(), grpc_server.wait_for_termination() - ) - else: - await grpc_server.wait_for_termination() + # N.B. If we need the Monitor, we should call `run_forever()` here. + await grpc_server.wait_for_termination() finally: # Not ideal, but we need to manually call this method to ensure # gRPC's internal shutdown process completes before we return from @@ -206,14 +172,17 @@ async def _run_setup(self) -> None: self._monitor.set_up_metrics_sources() await self._monitor.fetch_latest() - if self._routing_policy == RoutingPolicy.ForestTableSelectivity: + if ( + self._routing_policy_override == RoutingPolicy.ForestTableSelectivity + or self._routing_policy_override == RoutingPolicy.Default + ): self._estimator = await PostgresEstimator.connect( self._schema_name, self._config ) await self._estimator.analyze(self._blueprint_mgr.get_blueprint()) else: self._estimator = None - await self._router.run_setup(self._estimator) + await self._set_up_router(self._estimator) # Start the metrics reporting task. self._brad_metrics_reporting_task = asyncio.create_task( @@ -225,6 +194,50 @@ async def _run_setup(self) -> None: self._qlogger_refresh_task = asyncio.create_task(self._refresh_qlogger()) + async def _set_up_router(self, estimator: Optional[Estimator]) -> None: + # We have different routing policies for performance evaluation and + # testing purposes. + blueprint = self._blueprint_mgr.get_blueprint() + + if self._routing_policy_override == RoutingPolicy.Default: + # No override - use the blueprint's policy. + self._router = Router.create_from_blueprint(blueprint) + logger.info("Using blueprint-provided routing policy.") + + else: + if self._routing_policy_override == RoutingPolicy.AlwaysAthena: + definite_policy: AbstractRoutingPolicy = AlwaysOneRouter(Engine.Athena) + elif self._routing_policy_override == RoutingPolicy.AlwaysAurora: + definite_policy = AlwaysOneRouter(Engine.Aurora) + elif self._routing_policy_override == RoutingPolicy.AlwaysRedshift: + definite_policy = AlwaysOneRouter(Engine.Redshift) + elif self._routing_policy_override == RoutingPolicy.RuleBased: + # TODO: If we need metrics, re-create the monitor here. It's + # easier to not have it created. + definite_policy = RuleBased(blueprint=blueprint, monitor=None) + elif ( + self._routing_policy_override == RoutingPolicy.ForestTablePresence + or self._routing_policy_override == RoutingPolicy.ForestTableSelectivity + ): + definite_policy = await ForestPolicy.from_assets( + self._schema_name, + self._routing_policy_override, + self._assets, + ) + else: + raise RuntimeError( + f"Unsupported routing policy override: {self._routing_policy_override}" + ) + logger.info( + "Using routing policy override: %s", self._routing_policy_override.name + ) + self._router = Router.create_from_definite_policy( + definite_policy, blueprint.table_locations_bitmap() + ) + + await self._router.run_setup(estimator) + self._router.log_policy() + async def _run_teardown(self): logger.debug("Starting BRAD front end _run_teardown()") await self._sessions.end_all_sessions() @@ -299,6 +312,7 @@ async def _run_query_impl( elif self._route_redshift_only: engine_to_use = Engine.Redshift else: + assert self._router is not None engine_to_use = await self._router.engine_for(query_rep) log_verbose( @@ -601,6 +615,7 @@ async def _run_blueprint_update(self, version: int) -> None: if self._monitor is not None: self._monitor.update_metrics_sources() await self._sessions.add_connections() + assert self._router is not None self._router.update_blueprint(blueprint) # NOTE: This will cause any pending queries on the to-be-removed # connections to be cancelled. We consider this behavior to be diff --git a/src/brad/planner/router_provider.py b/src/brad/planner/router_provider.py index 6836e4ec..87dc19bd 100644 --- a/src/brad/planner/router_provider.py +++ b/src/brad/planner/router_provider.py @@ -3,10 +3,11 @@ from brad.asset_manager import AssetManager from brad.config.file import ConfigFile from brad.planner.estimator import EstimatorProvider +from brad.routing.abstract_policy import AbstractRoutingPolicy from brad.routing.policy import RoutingPolicy from brad.routing.router import Router from brad.routing.rule_based import RuleBased -from brad.routing.tree_based.forest_router import ForestRouter +from brad.routing.tree_based.forest_policy import ForestPolicy from brad.routing.tree_based.model_wrap import ModelWrap @@ -38,19 +39,17 @@ async def get_router(self, table_bitmap: Dict[str, int]) -> Router: or self._routing_policy == RoutingPolicy.ForestTableSelectivity ): if self._model is None: - self._model = ForestRouter.static_load_model_sync( + self._model = ForestPolicy.static_load_model_sync( self._schema_name, self._routing_policy, self._assets, ) - router = ForestRouter.for_planner( - self._routing_policy, self._schema_name, self._model, table_bitmap + definite_policy: AbstractRoutingPolicy = ForestPolicy.from_loaded_model( + self._routing_policy, self._model ) - await router.run_setup(self._estimator_provider.get_estimator()) - return router elif self._routing_policy == RoutingPolicy.RuleBased: - return RuleBased(table_placement_bitmap=table_bitmap) + definite_policy = RuleBased(table_placement_bitmap=table_bitmap) else: raise RuntimeError( @@ -59,5 +58,10 @@ async def get_router(self, table_bitmap: Dict[str, int]) -> Router: ) ) + # This is temporary and will be removed. + router = Router.create_from_definite_policy(definite_policy, table_bitmap) + await router.run_setup(self._estimator_provider.get_estimator()) + return router + def clear_cached(self) -> None: self._model = None diff --git a/src/brad/routing/always_one.py b/src/brad/routing/always_one.py index b52a79a4..8fba5653 100644 --- a/src/brad/routing/always_one.py +++ b/src/brad/routing/always_one.py @@ -1,9 +1,11 @@ +from typing import List + from brad.config.engine import Engine from brad.query_rep import QueryRep -from brad.routing.router import Router +from brad.routing.abstract_policy import AbstractRoutingPolicy -class AlwaysOneRouter(Router): +class AlwaysOneRouter(AbstractRoutingPolicy): """ This router always selects the same database engine for all queries. This router is useful for testing and benchmarking purposes. @@ -11,7 +13,11 @@ class AlwaysOneRouter(Router): def __init__(self, db_type: Engine): super().__init__() - self._always_route_to = db_type + self._engine = db_type + self._always_route_to = [db_type] + + def name(self) -> str: + return f"AlwaysRouteTo({self._engine.name})" - def engine_for_sync(self, _query: QueryRep) -> Engine: + def engine_for_sync(self, _query: QueryRep) -> List[Engine]: return self._always_route_to diff --git a/src/brad/routing/round_robin.py b/src/brad/routing/round_robin.py index b780eed9..11ae7e13 100644 --- a/src/brad/routing/round_robin.py +++ b/src/brad/routing/round_robin.py @@ -13,6 +13,9 @@ class RoundRobin(AbstractRoutingPolicy): def __init__(self): self._ordering = [Engine.Athena, Engine.Aurora, Engine.Redshift] + def name(self) -> str: + return "RoundRobin" + def engine_for_sync(self, _query: QueryRep) -> List[Engine]: tmp = self._ordering[0] self._ordering[0] = self._ordering[1] diff --git a/src/brad/routing/router.py b/src/brad/routing/router.py index f2f6af23..8b6d10d2 100644 --- a/src/brad/routing/router.py +++ b/src/brad/routing/router.py @@ -1,4 +1,5 @@ import asyncio +import logging from typing import Dict, Tuple, Optional, TYPE_CHECKING from brad.data_stats.estimator import Estimator @@ -9,6 +10,8 @@ if TYPE_CHECKING: from brad.blueprint import Blueprint +logger = logging.getLogger(__name__) + class Router: @classmethod @@ -39,6 +42,13 @@ def __init__( self._table_placement_bitmap = table_placement_bitmap self._use_future_blueprint_policies = use_future_blueprint_policies + def log_policy(self) -> None: + logger.info("Routing policy:") + logger.info(" Indefinite policies:") + for p in self._full_policy.indefinite_policies: + logger.info(" - %s", p.name()) + logger.info(" Definite policy: %s") + async def run_setup(self, estimator: Optional[Estimator] = None) -> None: """ Should be called before using the router. This is used to set up any diff --git a/src/brad/routing/rule_based.py b/src/brad/routing/rule_based.py index 25169c88..0c401901 100644 --- a/src/brad/routing/rule_based.py +++ b/src/brad/routing/rule_based.py @@ -9,7 +9,7 @@ from brad.config.engine import Engine from brad.blueprint.manager import BlueprintManager from brad.daemon.monitor import Monitor -from brad.routing.router import Router +from brad.routing.abstract_policy import AbstractRoutingPolicy from brad.query_rep import QueryRep from brad.front_end.session import SessionManager @@ -54,7 +54,7 @@ def __init__(self) -> None: self.redshift_parameters_lower_limit = redshift_parameters_lower_limit -class RuleBased(Router): +class RuleBased(AbstractRoutingPolicy): def __init__( self, # One of `blueprint_mgr` and `blueprint` must not be `None`. @@ -85,6 +85,9 @@ def __init__( self._deterministic = deterministic self._params = RuleBasedParams() + def name(self) -> str: + return "RuleBased" + def update_blueprint(self, blueprint: Blueprint) -> None: self._blueprint = blueprint self._table_placement_bitmap = blueprint.table_locations_bitmap() @@ -207,25 +210,8 @@ def check_engine_state( return True return not_overloaded - def engine_for_sync(self, query: QueryRep) -> Engine: - if self._table_placement_bitmap is None: - if self._blueprint is not None: - blueprint = self._blueprint - else: - assert self._blueprint_mgr is not None - blueprint = self._blueprint_mgr.get_blueprint() - self._table_placement_bitmap = blueprint.table_locations_bitmap() - - valid_locations, only_location = self._run_location_routing( - query, self._table_placement_bitmap - ) - if only_location is not None: - return only_location - - locations = Engine.from_bitmap(valid_locations) - assert len(locations) > 1 - ideal_location_rank: List[Engine] = [] - touched_tables = query.tables() + def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]: + touched_tables = query_rep.tables() if ( len(touched_tables) < self._params.ideal_location_lower_limit["redshift_num_table"] @@ -239,7 +225,7 @@ def engine_for_sync(self, query: QueryRep) -> Engine: if self._catalog: n_rows = [] n_cols = [] - for table_name in query.tables(): + for table_name in query_rep.tables(): if table_name in self._catalog: n_rows.append(self._catalog[table_name]["nrow"]) n_cols.append(self._catalog[table_name]["ncol"]) @@ -269,7 +255,7 @@ def engine_for_sync(self, query: QueryRep) -> Engine: ideal_location_rank = [Engine.Redshift, Engine.Athena, Engine.Aurora] if self._catalog: n_rows = [] - for table_name in query.tables(): + for table_name in query_rep.tables(): if table_name in self._catalog: n_rows.append(self._catalog[table_name]["nrow"]) if ( @@ -287,11 +273,8 @@ def engine_for_sync(self, query: QueryRep) -> Engine: ideal_location_rank = [Engine.Athena, Engine.Redshift, Engine.Aurora] if self._monitor is None: - for loc in ideal_location_rank: - if loc in locations: - return loc - # This should be unreachable since len(locations) > 0. - assert False + return ideal_location_rank + else: # Todo(Ziniu): this can be stored in this class to reduce latency raw_aurora_metrics = ( @@ -304,21 +287,14 @@ def engine_for_sync(self, query: QueryRep) -> Engine: logger.warning( "Routing without system metrics when we expect to have metrics." ) - return ideal_location_rank[0] + return ideal_location_rank aurora_metrics = raw_aurora_metrics.iloc[0].to_dict() redshift_metrics = raw_redshift_metrics.iloc[0].to_dict() for loc in ideal_location_rank: - if loc in locations and self.check_engine_state( - loc, aurora_metrics, redshift_metrics - ): - return loc + if self.check_engine_state(loc, aurora_metrics, redshift_metrics): + return [loc, *[il for il in ideal_location_rank if il != loc]] # In the case of all system are overloaded (time to trigger replan), # we assign it to the optimal one. But Athena should not be overloaded at any time - for loc in ideal_location_rank: - if loc in locations: - return loc - - # Should be unreachable since len(locations) > 0. - assert False + return ideal_location_rank diff --git a/src/brad/routing/tree_based/forest_policy.py b/src/brad/routing/tree_based/forest_policy.py new file mode 100644 index 00000000..5a64de30 --- /dev/null +++ b/src/brad/routing/tree_based/forest_policy.py @@ -0,0 +1,87 @@ +import asyncio +from typing import Optional, List, Dict, Any + +from brad.asset_manager import AssetManager +from brad.config.engine import Engine +from brad.data_stats.estimator import Estimator +from brad.query_rep import QueryRep +from brad.routing.abstract_policy import AbstractRoutingPolicy +from brad.routing.policy import RoutingPolicy +from brad.routing.tree_based.model_wrap import ModelWrap + + +class ForestPolicy(AbstractRoutingPolicy): + @classmethod + async def from_assets( + cls, schema_name: str, policy: RoutingPolicy, assets: AssetManager + ) -> "ForestPolicy": + model = cls.static_load_model_sync(schema_name, policy, assets) + return cls(policy, model) + + @classmethod + def from_loaded_model( + cls, policy: RoutingPolicy, model: ModelWrap + ) -> "ForestPolicy": + return cls(policy, model) + + def __init__(self, policy: RoutingPolicy, model: ModelWrap) -> None: + self._policy = policy + self._model = model + self._estimator: Optional[Estimator] = None + + def __getstate__(self) -> Dict[Any, Any]: + return { + "policy": self._policy, + "model": self._model, + } + + def __setstate__(self, d: Dict[Any, Any]) -> None: + self._policy = d["policy"] + self._model = d["model"] + self._estimator = None + + def name(self) -> str: + return f"ForestPolicy({self._policy.name})" + + async def run_setup(self, estimator: Optional[Estimator] = None) -> None: + self._estimator = estimator + + async def engine_for(self, query: QueryRep) -> List[Engine]: + return await self._model.engine_for(query, self._estimator) + + def engine_for_sync(self, query: QueryRep) -> List[Engine]: + return asyncio.run(self.engine_for(query)) + + # The methods below are used to save/load `ModelWrap` from S3. We + # historically separated out the model's implementation details because the + # router contained state that was not serializable. This separation is kept + # around for legacy reasons now (this policy class should be directly + # serialized). + + @staticmethod + def static_persist_sync( + model: ModelWrap, schema_name: str, assets: AssetManager + ) -> None: + key = _SERIALIZED_KEY.format( + schema_name=schema_name, policy=model.policy().value + ) + serialized = model.to_pickle() + assets.persist_sync(key, serialized) + + @staticmethod + def static_load_model_sync( + schema_name: str, policy: RoutingPolicy, assets: AssetManager + ) -> ModelWrap: + key = _SERIALIZED_KEY.format(schema_name=schema_name, policy=policy.value) + serialized = assets.load_sync(key) + return ModelWrap.from_pickle_bytes(serialized) + + @staticmethod + def static_drop_model_sync( + schema_name: str, policy: RoutingPolicy, assets: AssetManager + ) -> None: + key = _SERIALIZED_KEY.format(schema_name=schema_name, policy=policy.value) + assets.delete_sync(key) + + +_SERIALIZED_KEY = "{schema_name}/{policy}-router.pickle" From 70f30d25d6f624af553a5efd174726af91007360 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sat, 4 Nov 2023 17:45:46 -0400 Subject: [PATCH 04/11] Fix tests --- src/brad/planner/neighborhood/neighborhood.py | 4 +++- .../planner/neighborhood/scaling_scorer.py | 4 +++- src/brad/routing/tree_based/forest_policy.py | 8 ++++---- tests/test_blueprint_diff.py | 4 +++- tests/test_enumeration.py | 4 +++- tests/test_location_routing.py | 5 +++-- tests/test_planner_filters.py | 20 ++++++++++--------- 7 files changed, 30 insertions(+), 19 deletions(-) diff --git a/src/brad/planner/neighborhood/neighborhood.py b/src/brad/planner/neighborhood/neighborhood.py index 113eecc0..07ab41b3 100644 --- a/src/brad/planner/neighborhood/neighborhood.py +++ b/src/brad/planner/neighborhood/neighborhood.py @@ -204,7 +204,9 @@ def _estimate_current_data_accessed( # Compute the total amount of data accessed on each engine in the # current workload (used to weigh the workload assigned to each engine). for q in current_workload.analytical_queries(): - current_engine = current_router.engine_for_sync(q) + current_engines = current_router.engine_for_sync(q) + assert len(current_engines) > 0 + current_engine = current_engines[0] q.populate_data_accessed_mb( current_engine, engines, self._current_blueprint ) diff --git a/src/brad/planner/neighborhood/scaling_scorer.py b/src/brad/planner/neighborhood/scaling_scorer.py index 462dec66..eb4360da 100644 --- a/src/brad/planner/neighborhood/scaling_scorer.py +++ b/src/brad/planner/neighborhood/scaling_scorer.py @@ -61,7 +61,9 @@ def _simulate_next_workload(self, ctx: ScoringContext) -> None: # See where each analytical query gets routed. for q in ctx.next_workload.analytical_queries(): - next_engine = router.engine_for_sync(q) + next_engines = router.engine_for_sync(q) + assert len(next_engines) > 0 + next_engine = next_engines[0] ctx.next_dest[next_engine].append(q) q.populate_data_accessed_mb(next_engine, ctx.engines, ctx.current_blueprint) diff --git a/src/brad/routing/tree_based/forest_policy.py b/src/brad/routing/tree_based/forest_policy.py index 5a64de30..049f59b9 100644 --- a/src/brad/routing/tree_based/forest_policy.py +++ b/src/brad/routing/tree_based/forest_policy.py @@ -46,11 +46,11 @@ def name(self) -> str: async def run_setup(self, estimator: Optional[Estimator] = None) -> None: self._estimator = estimator - async def engine_for(self, query: QueryRep) -> List[Engine]: - return await self._model.engine_for(query, self._estimator) + async def engine_for(self, query_rep: QueryRep) -> List[Engine]: + return await self._model.engine_for(query_rep, self._estimator) - def engine_for_sync(self, query: QueryRep) -> List[Engine]: - return asyncio.run(self.engine_for(query)) + def engine_for_sync(self, query_rep: QueryRep) -> List[Engine]: + return asyncio.run(self.engine_for(query_rep)) # The methods below are used to save/load `ModelWrap` from S3. We # historically separated out the model's implementation details because the diff --git a/tests/test_blueprint_diff.py b/tests/test_blueprint_diff.py index a4215121..8ae9dc67 100644 --- a/tests/test_blueprint_diff.py +++ b/tests/test_blueprint_diff.py @@ -5,6 +5,8 @@ from brad.blueprint.user import UserProvidedBlueprint from brad.config.engine import Engine from brad.planner.data import bootstrap_blueprint +from brad.routing.abstract_policy import FullRoutingPolicy +from brad.routing.always_one import AlwaysOneRouter def test_no_diff(): @@ -72,7 +74,7 @@ def test_provisioning_change(): initial.table_locations(), initial.aurora_provisioning(), Provisioning(instance_type="dc2.large", num_nodes=4), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) diff = BlueprintDiff.of(initial, changed) assert diff is not None diff --git a/tests/test_enumeration.py b/tests/test_enumeration.py index 028310b5..fd65334c 100644 --- a/tests/test_enumeration.py +++ b/tests/test_enumeration.py @@ -4,6 +4,8 @@ from brad.planner.enumeration.provisioning import ProvisioningEnumerator from brad.planner.enumeration.table_locations import TableLocationEnumerator from brad.planner.enumeration.neighborhood import NeighborhoodBlueprintEnumerator +from brad.routing.abstract_policy import FullRoutingPolicy +from brad.routing.always_one import AlwaysOneRouter def test_provisioning_enumerate_aurora(): @@ -65,7 +67,7 @@ def test_blueprint_enumerate(): {"table1": [Engine.Aurora], "table2": [Engine.Redshift]}, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) # Simple sanity check only. diff --git a/tests/test_location_routing.py b/tests/test_location_routing.py index 4f9b9eaa..eacfa16c 100644 --- a/tests/test_location_routing.py +++ b/tests/test_location_routing.py @@ -1,12 +1,13 @@ from brad.config.engine import Engine, EngineBitmapValues from brad.routing.router import Router +from brad.routing.round_robin import RoundRobin from brad.query_rep import QueryRep def test_only_one_location(): query = QueryRep("SELECT * FROM test") bitmap = {"test": EngineBitmapValues[Engine.Aurora]} - r = Router() + r = Router.create_from_definite_policy(RoundRobin(), bitmap) # pylint: disable-next=protected-access valid_locations, only_location = r._run_location_routing(query, bitmap) assert only_location is not None @@ -22,7 +23,7 @@ def test_multiple_locations(): EngineBitmapValues[Engine.Redshift] | EngineBitmapValues[Engine.Athena] ), } - r = Router() + r = Router.create_from_definite_policy(RoundRobin(), bitmap) # pylint: disable-next=protected-access valid_locations, only_location = r._run_location_routing(query, bitmap) assert only_location is None diff --git a/tests/test_planner_filters.py b/tests/test_planner_filters.py index 1c3d712f..be001d2a 100644 --- a/tests/test_planner_filters.py +++ b/tests/test_planner_filters.py @@ -13,6 +13,8 @@ from brad.planner.neighborhood.filters.table_on_engine import TableOnEngine from brad.planner.workload import Workload from brad.planner.workload.query import Query +from brad.routing.abstract_policy import FullRoutingPolicy +from brad.routing.always_one import AlwaysOneRouter def workload_from_queries(query_list: List[str]) -> Workload: @@ -42,7 +44,7 @@ def test_aurora_transactions(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp2 = Blueprint( "schema", @@ -56,7 +58,7 @@ def test_aurora_transactions(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp_filter1 = AuroraTransactions(workload1) @@ -90,7 +92,7 @@ def test_single_engine_execution(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp2 = Blueprint( "schema", @@ -106,7 +108,7 @@ def test_single_engine_execution(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp_filter1 = SingleEngineExecution(workload1) @@ -136,7 +138,7 @@ def test_table_on_engine(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp2 = Blueprint( "schema", @@ -152,7 +154,7 @@ def test_table_on_engine(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 0), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp3 = Blueprint( "schema", @@ -168,7 +170,7 @@ def test_table_on_engine(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 0), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) assert bp_filter.is_valid(bp1) @@ -192,7 +194,7 @@ def test_no_data_loss(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) bp2 = Blueprint( "schema", @@ -208,7 +210,7 @@ def test_no_data_loss(): }, aurora_provisioning=Provisioning("db.r6g.large", 1), redshift_provisioning=Provisioning("dc2.large", 1), - router_provider=None, + full_routing_policy=FullRoutingPolicy([], AlwaysOneRouter(Engine.Aurora)), ) assert not ndl_filter.is_valid(bp1) From 73da9b5b5a773570da791ca1d1b2c8a060c133f8 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sat, 4 Nov 2023 17:47:27 -0400 Subject: [PATCH 05/11] Fix other test --- tests/test_always_routing.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_always_routing.py b/tests/test_always_routing.py index 7b5a6288..469de3b0 100644 --- a/tests/test_always_routing.py +++ b/tests/test_always_routing.py @@ -7,10 +7,10 @@ def test_always_route_aurora(): router = AlwaysOneRouter(db) pred_db = router.engine_for_sync("SELECT 1") - assert pred_db == db + assert pred_db == [db] pred_db = router.engine_for_sync("SELECT * FROM my_table") - assert pred_db == db + assert pred_db == [db] def test_always_route_athena(): @@ -18,10 +18,10 @@ def test_always_route_athena(): router = AlwaysOneRouter(db) pred_db = router.engine_for_sync("SELECT 1") - assert pred_db == db + assert pred_db == [db] pred_db = router.engine_for_sync("SELECT * FROM my_table") - assert pred_db == db + assert pred_db == [db] def test_always_route_redshift(): @@ -29,7 +29,7 @@ def test_always_route_redshift(): router = AlwaysOneRouter(db) pred_db = router.engine_for_sync("SELECT 1") - assert pred_db == db + assert pred_db == [db] pred_db = router.engine_for_sync("SELECT * FROM my_table") - assert pred_db == db + assert pred_db == [db] From 754947081f040b5f6d1c946d599fd246ef2f53ce Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sat, 4 Nov 2023 17:58:03 -0400 Subject: [PATCH 06/11] Remove dependence on legacy ForestRouter --- src/brad/admin/drop_schema.py | 6 +- src/brad/admin/train_router.py | 4 +- src/brad/routing/tree_based/forest_router.py | 140 ------------------- tests/test_forest_routing.py | 51 ++----- 4 files changed, 15 insertions(+), 186 deletions(-) delete mode 100644 src/brad/routing/tree_based/forest_router.py diff --git a/src/brad/admin/drop_schema.py b/src/brad/admin/drop_schema.py index 9f0d7b71..a5a306db 100644 --- a/src/brad/admin/drop_schema.py +++ b/src/brad/admin/drop_schema.py @@ -8,7 +8,7 @@ from brad.front_end.engine_connections import EngineConnections from brad.provisioning.directory import Directory from brad.routing.policy import RoutingPolicy -from brad.routing.tree_based.forest_router import ForestRouter +from brad.routing.tree_based.forest_policy import ForestPolicy logger = logging.getLogger(__name__) @@ -56,10 +56,10 @@ def drop_schema(args): # 4. Drop any serialized routers. assets = AssetManager(config) - ForestRouter.static_drop_model_sync( + ForestPolicy.static_drop_model_sync( args.schema_name, RoutingPolicy.ForestTableSelectivity, assets ) - ForestRouter.static_drop_model_sync( + ForestPolicy.static_drop_model_sync( args.schema_name, RoutingPolicy.ForestTablePresence, assets ) diff --git a/src/brad/admin/train_router.py b/src/brad/admin/train_router.py index b05c1eec..29053b0c 100644 --- a/src/brad/admin/train_router.py +++ b/src/brad/admin/train_router.py @@ -9,7 +9,7 @@ from brad.data_stats.estimator import Estimator from brad.data_stats.postgres_estimator import PostgresEstimator from brad.routing.policy import RoutingPolicy -from brad.routing.tree_based.forest_router import ForestRouter +from brad.routing.tree_based.forest_policy import ForestPolicy from brad.routing.tree_based.trainer import ForestTrainer from brad.blueprint.manager import BlueprintManager @@ -163,7 +163,7 @@ def train_router(args): response = input("Do you want to persist this model? (y/n): ").lower() if response == "y": assets = AssetManager(config) - ForestRouter.static_persist_sync(model, schema_name, assets) + ForestPolicy.static_persist_sync(model, schema_name, assets) logger.info("Model persisted successfully.") break elif response == "n": diff --git a/src/brad/routing/tree_based/forest_router.py b/src/brad/routing/tree_based/forest_router.py deleted file mode 100644 index aae48877..00000000 --- a/src/brad/routing/tree_based/forest_router.py +++ /dev/null @@ -1,140 +0,0 @@ -import asyncio -from typing import Optional, Dict - -from .model_wrap import ModelWrap -from brad.asset_manager import AssetManager -from brad.blueprint import Blueprint -from brad.config.engine import Engine, EngineBitmapValues -from brad.data_stats.estimator import Estimator -from brad.query_rep import QueryRep -from brad.routing.policy import RoutingPolicy -from brad.routing.router import Router -from brad.blueprint.manager import BlueprintManager - - -class ForestRouter(Router): - @classmethod - def for_server( - cls, - policy: RoutingPolicy, - schema_name: str, - assets: AssetManager, - blueprint_mgr: BlueprintManager, - ) -> "ForestRouter": - return cls(policy, schema_name, assets=assets, blueprint_mgr=blueprint_mgr) - - @classmethod - def for_planner( - cls, - policy: RoutingPolicy, - schema_name: str, - model: ModelWrap, - table_bitmap: Dict[str, int], - ) -> "ForestRouter": - return cls( - policy, schema_name, model=model, table_placement_bitmap=table_bitmap - ) - - def __init__( - self, - policy: RoutingPolicy, - schema_name: str, - # One of `assets` and `model` most not be `None`. - assets: Optional[AssetManager] = None, - model: Optional[ModelWrap] = None, - # One of `blueprint_mgr`, `blueprint`, and `table_placement_bitmap` must not be `None`. - blueprint_mgr: Optional[BlueprintManager] = None, - blueprint: Optional[Blueprint] = None, - table_placement_bitmap: Optional[Dict[str, int]] = None, - ) -> None: - self._policy = policy - self._schema_name = schema_name - self._model = model - self._assets = assets - - self._blueprint_mgr = blueprint_mgr - self._blueprint = blueprint - self._table_placement_bitmap = table_placement_bitmap - self._estimator: Optional[Estimator] = None - - async def run_setup(self, estimator: Optional[Estimator] = None) -> None: - self._estimator = estimator - - # Load the model. - if self._model is None: - assert self._assets is not None - serialized_model = await self._assets.load( - _SERIALIZED_KEY.format( - schema_name=self._schema_name, policy=self._policy.value - ) - ) - self._model = ModelWrap.from_pickle_bytes(serialized_model) - - # Load the table placement if it was not provided. - if self._table_placement_bitmap is None: - if self._blueprint is None: - assert self._blueprint_mgr is not None - self._blueprint = self._blueprint_mgr.get_blueprint() - - self._table_placement_bitmap = self._blueprint.table_locations_bitmap() - - def update_blueprint(self, blueprint: Blueprint) -> None: - self._blueprint = blueprint - self._table_placement_bitmap = blueprint.table_locations_bitmap() - - async def engine_for(self, query: QueryRep) -> Engine: - # Compute valid locations. - assert self._table_placement_bitmap is not None - valid_locations, only_location = self._run_location_routing( - query, self._table_placement_bitmap - ) - if only_location is not None: - return only_location - - # Multiple locations possible. Use the model to figure out which location to use. - assert self._model is not None - preferred_locations = await self._model.engine_for(query, self._estimator) - - for loc in preferred_locations: - if (EngineBitmapValues[loc] & valid_locations) != 0: - return loc - - # This should be unreachable. The model must rank all engines, and we - # know >= 2 engines can support this query. - raise AssertionError - - def engine_for_sync(self, query: QueryRep) -> Engine: - return asyncio.run(self.engine_for(query)) - - def persist_sync(self) -> None: - assert self._assets is not None - assert self._model is not None - self.static_persist_sync(self._model, self._schema_name, self._assets) - - @staticmethod - def static_persist_sync( - model: ModelWrap, schema_name: str, assets: AssetManager - ) -> None: - key = _SERIALIZED_KEY.format( - schema_name=schema_name, policy=model.policy().value - ) - serialized = model.to_pickle() - assets.persist_sync(key, serialized) - - @staticmethod - def static_load_model_sync( - schema_name: str, policy: RoutingPolicy, assets: AssetManager - ) -> ModelWrap: - key = _SERIALIZED_KEY.format(schema_name=schema_name, policy=policy.value) - serialized = assets.load_sync(key) - return ModelWrap.from_pickle_bytes(serialized) - - @staticmethod - def static_drop_model_sync( - schema_name: str, policy: RoutingPolicy, assets: AssetManager - ) -> None: - key = _SERIALIZED_KEY.format(schema_name=schema_name, policy=policy.value) - assets.delete_sync(key) - - -_SERIALIZED_KEY = "{schema_name}/{policy}-router.pickle" diff --git a/tests/test_forest_routing.py b/tests/test_forest_routing.py index 4e03cabb..095dc810 100644 --- a/tests/test_forest_routing.py +++ b/tests/test_forest_routing.py @@ -2,8 +2,8 @@ from sklearn.ensemble import RandomForestClassifier -from brad.config.engine import Engine, EngineBitmapValues -from brad.routing.tree_based.forest_router import ForestRouter +from brad.config.engine import Engine +from brad.routing.tree_based.forest_policy import ForestPolicy from brad.routing.tree_based.model_wrap import ModelWrap from brad.routing.policy import RoutingPolicy from brad.query_rep import QueryRep @@ -18,54 +18,23 @@ def get_dummy_router(): return ModelWrap(RoutingPolicy.ForestTablePresence, ["test1", "test2"], model) -def test_location_constraints(): - model = get_dummy_router() - bitmap = { - "test1": EngineBitmapValues[Engine.Aurora], - "test2": EngineBitmapValues[Engine.Aurora] - | EngineBitmapValues[Engine.Redshift], - } - router = ForestRouter.for_planner( - RoutingPolicy.ForestTablePresence, "test_schema", model, bitmap - ) - - query1 = QueryRep("SELECT * FROM test1") - loc = router.engine_for_sync(query1) - assert loc == Engine.Aurora - - query2 = QueryRep("SELECT * FROM test1, test2") - loc = router.engine_for_sync(query2) - assert loc == Engine.Aurora - - def test_model_codepath_partial(): model = get_dummy_router() - bitmap = { - "test1": EngineBitmapValues[Engine.Aurora] - | EngineBitmapValues[Engine.Redshift], - "test2": EngineBitmapValues[Engine.Aurora] - | EngineBitmapValues[Engine.Redshift], - } - router = ForestRouter.for_planner( - RoutingPolicy.ForestTablePresence, "test_schema", model, bitmap - ) + router = ForestPolicy.from_loaded_model(RoutingPolicy.ForestTablePresence, model) query = QueryRep("SELECT * FROM test1, test2") loc = router.engine_for_sync(query) - assert loc == Engine.Aurora or loc == Engine.Redshift + assert ( + loc[0] == Engine.Aurora or loc[0] == Engine.Redshift or loc[0] == Engine.Athena + ) def test_model_codepath_all(): model = get_dummy_router() - bitmap = { - "test1": Engine.bitmap_all(), - "test2": EngineBitmapValues[Engine.Aurora] - | EngineBitmapValues[Engine.Redshift], - } - router = ForestRouter.for_planner( - RoutingPolicy.ForestTablePresence, "test_schema", model, bitmap - ) + router = ForestPolicy.from_loaded_model(RoutingPolicy.ForestTablePresence, model) query = QueryRep("SELECT * FROM test1") loc = router.engine_for_sync(query) - assert loc == Engine.Aurora or loc == Engine.Redshift or loc == Engine.Athena + assert ( + loc[0] == Engine.Aurora or loc[0] == Engine.Redshift or loc[0] == Engine.Athena + ) From 859d655c72bb9366f687c110e971ed66c2c74e47 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sat, 4 Nov 2023 18:28:24 -0400 Subject: [PATCH 07/11] Add ability to manually set the routing policy, remove unused parts of the rule routing policy --- src/brad/admin/modify_blueprint.py | 30 +++++++++++++++++- src/brad/front_end/front_end.py | 2 +- src/brad/planner/enumeration/blueprint.py | 13 +++++++- src/brad/planner/neighborhood/neighborhood.py | 2 +- .../planner/neighborhood/scaling_scorer.py | 2 +- src/brad/planner/router_provider.py | 2 +- src/brad/routing/rule_based.py | 31 +++---------------- 7 files changed, 49 insertions(+), 33 deletions(-) diff --git a/src/brad/admin/modify_blueprint.py b/src/brad/admin/modify_blueprint.py index 72bec1bb..d8369879 100644 --- a/src/brad/admin/modify_blueprint.py +++ b/src/brad/admin/modify_blueprint.py @@ -16,6 +16,11 @@ from brad.daemon.transition_orchestrator import TransitionOrchestrator from brad.front_end.engine_connections import EngineConnections from brad.planner.enumeration.blueprint import EnumeratedBlueprint +from brad.routing.abstract_policy import AbstractRoutingPolicy, FullRoutingPolicy +from brad.routing.always_one import AlwaysOneRouter +from brad.routing.policy import RoutingPolicy +from brad.routing.tree_based.forest_policy import ForestPolicy +from brad.routing.rule_based import RuleBased logger = logging.getLogger(__name__) @@ -68,6 +73,11 @@ def register_admin_action(subparser) -> None: action="store_true", help="Updates the blueprint's table placement and places tables on all engines.", ) + parser.add_argument( + "--set-routing-policy", + type=str, + help="Sets the serialized routing policy to a preconfigured default: {always_redshift, df_selectivity, rule_based}", + ) parser.add_argument( "--add-indexes", action="store_true", @@ -184,7 +194,7 @@ async def run_transition( # This method is called by `brad.exec.admin.main`. -def modify_blueprint(args): +def modify_blueprint(args) -> None: # 1. Load the config. config = ConfigFile.load(args.config_file) @@ -253,6 +263,24 @@ def modify_blueprint(args): new_placement[tbl] = Engine.from_bitmap(Engine.bitmap_all()) enum_blueprint.set_table_locations(new_placement) + if args.set_routing_policy is not None: + if args.set_routing_policy == "always_redshift": + definite_policy: AbstractRoutingPolicy = AlwaysOneRouter(Engine.Redshift) + elif args.set_routing_policy == "df_selectivity": + definite_policy = asyncio.run( + ForestPolicy.from_assets( + args.schema_name, RoutingPolicy.ForestTableSelectivity, assets + ) + ) + elif args.set_routing_policy == "rule_based": + definite_policy = RuleBased() + else: + raise RuntimeError( + f"Unknown routing policy preset: {args.set_routing_policy}" + ) + full_policy = FullRoutingPolicy([], definite_policy) + enum_blueprint.set_routing_policy(full_policy) + # 3. Write the changes back. modified_blueprint = enum_blueprint.to_blueprint() if blueprint == modified_blueprint: diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index f0dd199c..09b3b222 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -214,7 +214,7 @@ async def _set_up_router(self, estimator: Optional[Estimator]) -> None: elif self._routing_policy_override == RoutingPolicy.RuleBased: # TODO: If we need metrics, re-create the monitor here. It's # easier to not have it created. - definite_policy = RuleBased(blueprint=blueprint, monitor=None) + definite_policy = RuleBased() elif ( self._routing_policy_override == RoutingPolicy.ForestTablePresence or self._routing_policy_override == RoutingPolicy.ForestTableSelectivity diff --git a/src/brad/planner/enumeration/blueprint.py b/src/brad/planner/enumeration/blueprint.py index 29b10e9a..8fff27dd 100644 --- a/src/brad/planner/enumeration/blueprint.py +++ b/src/brad/planner/enumeration/blueprint.py @@ -3,6 +3,7 @@ from brad.blueprint.blueprint import Blueprint from brad.blueprint.provisioning import Provisioning from brad.config.engine import Engine +from brad.routing.abstract_policy import FullRoutingPolicy class EnumeratedBlueprint(Blueprint): @@ -27,6 +28,7 @@ def __init__(self, base_blueprint: Blueprint) -> None: self._current_aurora_provisioning = base_blueprint.aurora_provisioning() self._current_redshift_provisioning = base_blueprint.redshift_provisioning() self._current_table_locations_bitmap: Optional[Dict[str, int]] = None + self._current_routing_policy = base_blueprint.get_routing_policy() def set_table_locations( self, locations: Dict[str, List[Engine]] @@ -43,6 +45,12 @@ def set_redshift_provisioning(self, prov: Provisioning) -> "EnumeratedBlueprint" self._current_redshift_provisioning = prov return self + def set_routing_policy( + self, routing_policy: FullRoutingPolicy + ) -> "EnumeratedBlueprint": + self._current_routing_policy = routing_policy + return self + def to_blueprint(self) -> Blueprint: """ Makes a copy of this object as a `Blueprint`. @@ -57,7 +65,7 @@ def to_blueprint(self) -> Blueprint: }, aurora_provisioning=self._current_aurora_provisioning.clone(), redshift_provisioning=self._current_redshift_provisioning.clone(), - full_routing_policy=self.get_routing_policy(), + full_routing_policy=self._current_routing_policy, ) # Overridden getters. @@ -84,3 +92,6 @@ def get_table_locations(self, table_name: str) -> List[Engine]: return self._current_locations[table_name] except KeyError as ex: raise ValueError from ex + + def get_routing_policy(self) -> FullRoutingPolicy: + return self._current_routing_policy diff --git a/src/brad/planner/neighborhood/neighborhood.py b/src/brad/planner/neighborhood/neighborhood.py index 07ab41b3..12a2322a 100644 --- a/src/brad/planner/neighborhood/neighborhood.py +++ b/src/brad/planner/neighborhood/neighborhood.py @@ -194,7 +194,7 @@ def _check_if_metrics_warrant_replanning(self) -> bool: def _estimate_current_data_accessed( self, engines: EngineConnections, current_workload: Workload ) -> Dict[Engine, int]: - current_router = RuleBased(blueprint=self._current_blueprint) + current_router = RuleBased() total_accessed_mb: Dict[Engine, int] = {} total_accessed_mb[Engine.Aurora] = 0 diff --git a/src/brad/planner/neighborhood/scaling_scorer.py b/src/brad/planner/neighborhood/scaling_scorer.py index eb4360da..9b7b7b5b 100644 --- a/src/brad/planner/neighborhood/scaling_scorer.py +++ b/src/brad/planner/neighborhood/scaling_scorer.py @@ -57,7 +57,7 @@ def score(self, ctx: ScoringContext) -> Score: def _simulate_next_workload(self, ctx: ScoringContext) -> None: # NOTE: The routing policy should be included in the blueprint. We # currently hardcode it here for engineering convenience. - router = RuleBased(blueprint=ctx.next_blueprint) + router = RuleBased() # See where each analytical query gets routed. for q in ctx.next_workload.analytical_queries(): diff --git a/src/brad/planner/router_provider.py b/src/brad/planner/router_provider.py index 87dc19bd..1cd7407c 100644 --- a/src/brad/planner/router_provider.py +++ b/src/brad/planner/router_provider.py @@ -49,7 +49,7 @@ async def get_router(self, table_bitmap: Dict[str, int]) -> Router: ) elif self._routing_policy == RoutingPolicy.RuleBased: - definite_policy = RuleBased(table_placement_bitmap=table_bitmap) + definite_policy = RuleBased() else: raise RuntimeError( diff --git a/src/brad/routing/rule_based.py b/src/brad/routing/rule_based.py index 0c401901..7cab4e2a 100644 --- a/src/brad/routing/rule_based.py +++ b/src/brad/routing/rule_based.py @@ -1,13 +1,12 @@ import os.path import json import logging -from typing import List, Optional, Mapping, MutableMapping, Any, Dict +from typing import List, Optional, Mapping, MutableMapping, Any from importlib.resources import files, as_file import brad.routing from brad.blueprint import Blueprint from brad.config.engine import Engine -from brad.blueprint.manager import BlueprintManager from brad.daemon.monitor import Monitor from brad.routing.abstract_policy import AbstractRoutingPolicy from brad.query_rep import QueryRep @@ -57,18 +56,9 @@ def __init__(self) -> None: class RuleBased(AbstractRoutingPolicy): def __init__( self, - # One of `blueprint_mgr` and `blueprint` must not be `None`. - blueprint_mgr: Optional[BlueprintManager] = None, - blueprint: Optional[Blueprint] = None, - table_placement_bitmap: Optional[Dict[str, int]] = None, monitor: Optional[Monitor] = None, catalog: Optional[MutableMapping[str, MutableMapping[str, Any]]] = None, - use_decision_tree: bool = False, - deterministic: bool = True, ): - self._blueprint_mgr = blueprint_mgr - self._blueprint = blueprint - self._table_placement_bitmap = table_placement_bitmap self._monitor = monitor # catalog contains all tables' number of rows and columns self._catalog = catalog @@ -78,21 +68,14 @@ def __init__( if os.path.exists(file): with open(file, "r", encoding="utf8") as f: self._catalog = json.load(f) - # use decision tree instead of rules - self._use_decision_tree = use_decision_tree - # deterministic routing guarantees the same decision for the same query and should be used online - # non-determinism will be used for offline training data exploration (not implemented) - self._deterministic = deterministic self._params = RuleBasedParams() def name(self) -> str: return "RuleBased" - def update_blueprint(self, blueprint: Blueprint) -> None: - self._blueprint = blueprint - self._table_placement_bitmap = blueprint.table_locations_bitmap() - - async def recollect_catalog(self, sessions: SessionManager) -> None: + async def recollect_catalog( + self, sessions: SessionManager, blueprint: Blueprint + ) -> None: # recollect catalog stats; happens every maintenance window if self._catalog is None: self._catalog = dict() @@ -105,12 +88,6 @@ async def recollect_catalog(self, sessions: SessionManager) -> None: connection = session.engines.get_connection(Engine.Aurora) cursor = await connection.cursor() - if self._blueprint is not None: - blueprint = self._blueprint - else: - assert self._blueprint_mgr is not None - blueprint = self._blueprint_mgr.get_blueprint() - indexes_sql = ( "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' " "ORDER BY tablename, indexname;" From af5411647fc62a55a53c873de6c1289e71a7d3b7 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sat, 4 Nov 2023 18:33:35 -0400 Subject: [PATCH 08/11] Actually use the serialized policy --- src/brad/routing/router.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/brad/routing/router.py b/src/brad/routing/router.py index 8b6d10d2..f94fc998 100644 --- a/src/brad/routing/router.py +++ b/src/brad/routing/router.py @@ -65,8 +65,7 @@ def update_blueprint(self, blueprint: "Blueprint") -> None: """ self._table_placement_bitmap = blueprint.table_locations_bitmap() if self._use_future_blueprint_policies: - # TODO: Deserialize from the blueprint. - pass + self._full_policy = blueprint.get_routing_policy() async def engine_for(self, query: QueryRep) -> Engine: """ From a4e5ed39e9d615d5da877182f3dc0e134ea1065a Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sun, 5 Nov 2023 00:23:49 -0400 Subject: [PATCH 09/11] Better transition support --- src/brad/blueprint/blueprint.py | 25 ++++++++++++++++++++++--- src/brad/blueprint/serde.py | 13 +++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/brad/blueprint/blueprint.py b/src/brad/blueprint/blueprint.py index b5a57902..5c1dd06d 100644 --- a/src/brad/blueprint/blueprint.py +++ b/src/brad/blueprint/blueprint.py @@ -82,6 +82,7 @@ def __eq__(self, other: object) -> bool: and self.table_locations() == other.table_locations() and self.aurora_provisioning() == other.aurora_provisioning() and self.redshift_provisioning() == other.redshift_provisioning() + # TODO: Do we want to check for routing policy equality? ) def _compute_base_tables(self) -> Set[str]: @@ -113,15 +114,33 @@ def visit_table(table: Table) -> None: def __repr__(self) -> str: # Useful for debugging purposes. - aurora = "Aurora: " + str(self.aurora_provisioning()) - redshift = "Redshift: " + str(self.redshift_provisioning()) + aurora = "Aurora: " + str(self.aurora_provisioning()) + redshift = "Redshift: " + str(self.redshift_provisioning()) tables = "\n ".join( map( lambda name_loc: "".join([name_loc[0], ": ", str(name_loc[1])]), self.table_locations().items(), ) ) - return "\n ".join(["Blueprint:", tables, aurora, redshift]) + routing_policy = self.get_routing_policy() + indefinite_policies = ( + f"Indefinite routing policies: {len(routing_policy.indefinite_policies)}" + ) + definite_policy = ( + f"Definite routing policy: {routing_policy.definite_policy.name()}" + ) + return "\n ".join( + [ + "Blueprint:", + tables, + "", + aurora, + redshift, + "", + indefinite_policies, + definite_policy, + ] + ) def as_dict(self) -> Dict[str, Any]: """ diff --git a/src/brad/blueprint/serde.py b/src/brad/blueprint/serde.py index 60ebf002..2f768e5c 100644 --- a/src/brad/blueprint/serde.py +++ b/src/brad/blueprint/serde.py @@ -1,4 +1,5 @@ import pickle +import logging from typing import Tuple, List, Dict from brad.blueprint import Blueprint @@ -6,9 +7,12 @@ from brad.blueprint.table import Column, Table from brad.config.engine import Engine from brad.routing.abstract_policy import FullRoutingPolicy +from brad.routing.round_robin import RoundRobin import brad.proto_gen.blueprint_pb2 as b +logger = logging.getLogger(__name__) + # We define the data blueprint serialization/deserialization functions # separately from the blueprint classes to avoid mixing protobuf code (an # implementation detail) with the blueprint classes. @@ -151,6 +155,15 @@ def _indexed_columns_from_proto( def _policy_from_proto(policy: b.RoutingPolicy) -> FullRoutingPolicy: + if len(policy.policy) == 0: + logger.warning( + "Did not find a routing policy in the serialized blueprint. " + "This likely means you are running with an older blueprint. " + "Falling back to round robin routing. If you want to use a " + "different policy, use `brad admin modify_blueprint` to set " + "a policy." + ) + return FullRoutingPolicy(indefinite_policies=[], definite_policy=RoundRobin()) # We just use Python pickle serialization. In the future, this should be # something more robust. return pickle.loads(policy.policy) From f6bed9b7e903d8981cbc1632310930bb8d8721ce Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sun, 5 Nov 2023 00:34:20 -0400 Subject: [PATCH 10/11] Define routing policy equality --- src/brad/blueprint/blueprint.py | 6 +++--- src/brad/routing/abstract_policy.py | 7 +++++++ src/brad/routing/always_one.py | 3 +++ src/brad/routing/round_robin.py | 3 +++ src/brad/routing/rule_based.py | 3 +++ src/brad/routing/tree_based/forest_policy.py | 5 +++++ src/brad/routing/tree_based/model_wrap.py | 14 ++++++++++++++ 7 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/brad/blueprint/blueprint.py b/src/brad/blueprint/blueprint.py index 5c1dd06d..1304ea7b 100644 --- a/src/brad/blueprint/blueprint.py +++ b/src/brad/blueprint/blueprint.py @@ -82,7 +82,7 @@ def __eq__(self, other: object) -> bool: and self.table_locations() == other.table_locations() and self.aurora_provisioning() == other.aurora_provisioning() and self.redshift_provisioning() == other.redshift_provisioning() - # TODO: Do we want to check for routing policy equality? + and self.get_routing_policy() == other.get_routing_policy() ) def _compute_base_tables(self) -> Set[str]: @@ -133,10 +133,10 @@ def __repr__(self) -> str: [ "Blueprint:", tables, - "", + "---", aurora, redshift, - "", + "---", indefinite_policies, definite_policy, ] diff --git a/src/brad/routing/abstract_policy.py b/src/brad/routing/abstract_policy.py index 66979abe..c0b2306a 100644 --- a/src/brad/routing/abstract_policy.py +++ b/src/brad/routing/abstract_policy.py @@ -72,3 +72,10 @@ async def run_setup(self, estimator: Optional[Estimator] = None) -> None: for policy in self.indefinite_policies: await policy.run_setup(estimator) await self.definite_policy.run_setup(estimator) + + def __eq__(self, other: object): + if not isinstance(other, FullRoutingPolicy): + return False + return (self.indefinite_policies == other.indefinite_policies) and ( + self.definite_policy == other.definite_policy + ) diff --git a/src/brad/routing/always_one.py b/src/brad/routing/always_one.py index 8fba5653..cc51243b 100644 --- a/src/brad/routing/always_one.py +++ b/src/brad/routing/always_one.py @@ -21,3 +21,6 @@ def name(self) -> str: def engine_for_sync(self, _query: QueryRep) -> List[Engine]: return self._always_route_to + + def __eq__(self, other: object) -> bool: + return isinstance(other, AlwaysOneRouter) and self._engine == other._engine diff --git a/src/brad/routing/round_robin.py b/src/brad/routing/round_robin.py index 11ae7e13..8ba7d304 100644 --- a/src/brad/routing/round_robin.py +++ b/src/brad/routing/round_robin.py @@ -22,3 +22,6 @@ def engine_for_sync(self, _query: QueryRep) -> List[Engine]: self._ordering[1] = self._ordering[2] self._ordering[2] = tmp return self._ordering + + def __eq__(self, other: object) -> bool: + return isinstance(other, RoundRobin) diff --git a/src/brad/routing/rule_based.py b/src/brad/routing/rule_based.py index 7cab4e2a..f0bf50ef 100644 --- a/src/brad/routing/rule_based.py +++ b/src/brad/routing/rule_based.py @@ -73,6 +73,9 @@ def __init__( def name(self) -> str: return "RuleBased" + def __eq__(self, other: object) -> bool: + return isinstance(other, RuleBased) + async def recollect_catalog( self, sessions: SessionManager, blueprint: Blueprint ) -> None: diff --git a/src/brad/routing/tree_based/forest_policy.py b/src/brad/routing/tree_based/forest_policy.py index 049f59b9..cfeab413 100644 --- a/src/brad/routing/tree_based/forest_policy.py +++ b/src/brad/routing/tree_based/forest_policy.py @@ -43,6 +43,11 @@ def __setstate__(self, d: Dict[Any, Any]) -> None: def name(self) -> str: return f"ForestPolicy({self._policy.name})" + def __eq__(self, other: object) -> bool: + if not isinstance(other, ForestPolicy): + return False + return self._policy == other._policy and self._model == other.model + async def run_setup(self, estimator: Optional[Estimator] = None) -> None: self._estimator = estimator diff --git a/src/brad/routing/tree_based/model_wrap.py b/src/brad/routing/tree_based/model_wrap.py index 2ff068fa..42a496ad 100644 --- a/src/brad/routing/tree_based/model_wrap.py +++ b/src/brad/routing/tree_based/model_wrap.py @@ -49,6 +49,20 @@ async def engine_for( low_to_high = np.argsort(preds) return [ENGINE_LABELS[label] for label in reversed(low_to_high)] + def __eq__(self, other: object) -> bool: + if not isinstance(other, ModelWrap): + return False + if self._policy != other._policy or self._table_order != other._table_order: + return False + + if id(self._model) == id(other._model): + return True + + # Not very ideal, but this will check for identical copies. + serialized = pickle.dumps(self._model) + other_serialized = pickle.dumps(other._model) + return serialized == other_serialized + def to_pickle(self) -> bytes: # TODO: Pickling might not be the best option. return pickle.dumps(self) From 9c121e38c2fb63b9c17dd5fce8779567e78131a5 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Sun, 5 Nov 2023 01:03:18 -0400 Subject: [PATCH 11/11] Fix type error --- src/brad/routing/tree_based/forest_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brad/routing/tree_based/forest_policy.py b/src/brad/routing/tree_based/forest_policy.py index cfeab413..6b90bcf7 100644 --- a/src/brad/routing/tree_based/forest_policy.py +++ b/src/brad/routing/tree_based/forest_policy.py @@ -46,7 +46,7 @@ def name(self) -> str: def __eq__(self, other: object) -> bool: if not isinstance(other, ForestPolicy): return False - return self._policy == other._policy and self._model == other.model + return self._policy == other._policy and self._model == other._model async def run_setup(self, estimator: Optional[Estimator] = None) -> None: self._estimator = estimator