Skip to content

Commit

Permalink
Store and use the routing policy in the blueprint (#350)
Browse files Browse the repository at this point in the history
* Implement unified routing policy abstraction

* WIP: Update the routing pipeline to allow modifications during planning

* WIP: Overhaul router usage

* Fix tests

* Fix other test

* Remove dependence on legacy ForestRouter

* Add ability to manually set the routing policy, remove unused parts of the rule routing policy

* Actually use the serialized policy

* Better transition support

* Define routing policy equality

* Fix type error
  • Loading branch information
geoffxy authored Nov 5, 2023
1 parent 68eb4c5 commit b5594b3
Show file tree
Hide file tree
Showing 29 changed files with 550 additions and 390 deletions.
6 changes: 3 additions & 3 deletions src/brad/admin/drop_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from brad.front_end.engine_connections import EngineConnections
from brad.provisioning.directory import Directory
from brad.routing.policy import RoutingPolicy
from brad.routing.tree_based.forest_router import ForestRouter
from brad.routing.tree_based.forest_policy import ForestPolicy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,10 +56,10 @@ def drop_schema(args):

# 4. Drop any serialized routers.
assets = AssetManager(config)
ForestRouter.static_drop_model_sync(
ForestPolicy.static_drop_model_sync(
args.schema_name, RoutingPolicy.ForestTableSelectivity, assets
)
ForestRouter.static_drop_model_sync(
ForestPolicy.static_drop_model_sync(
args.schema_name, RoutingPolicy.ForestTablePresence, assets
)

Expand Down
30 changes: 29 additions & 1 deletion src/brad/admin/modify_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from brad.daemon.transition_orchestrator import TransitionOrchestrator
from brad.front_end.engine_connections import EngineConnections
from brad.planner.enumeration.blueprint import EnumeratedBlueprint
from brad.routing.abstract_policy import AbstractRoutingPolicy, FullRoutingPolicy
from brad.routing.always_one import AlwaysOneRouter
from brad.routing.policy import RoutingPolicy
from brad.routing.tree_based.forest_policy import ForestPolicy
from brad.routing.rule_based import RuleBased

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,6 +73,11 @@ def register_admin_action(subparser) -> None:
action="store_true",
help="Updates the blueprint's table placement and places tables on all engines.",
)
parser.add_argument(
"--set-routing-policy",
type=str,
help="Sets the serialized routing policy to a preconfigured default: {always_redshift, df_selectivity, rule_based}",
)
parser.add_argument(
"--add-indexes",
action="store_true",
Expand Down Expand Up @@ -184,7 +194,7 @@ async def run_transition(


# This method is called by `brad.exec.admin.main`.
def modify_blueprint(args):
def modify_blueprint(args) -> None:
# 1. Load the config.
config = ConfigFile.load(args.config_file)

Expand Down Expand Up @@ -253,6 +263,24 @@ def modify_blueprint(args):
new_placement[tbl] = Engine.from_bitmap(Engine.bitmap_all())
enum_blueprint.set_table_locations(new_placement)

if args.set_routing_policy is not None:
if args.set_routing_policy == "always_redshift":
definite_policy: AbstractRoutingPolicy = AlwaysOneRouter(Engine.Redshift)
elif args.set_routing_policy == "df_selectivity":
definite_policy = asyncio.run(
ForestPolicy.from_assets(
args.schema_name, RoutingPolicy.ForestTableSelectivity, assets
)
)
elif args.set_routing_policy == "rule_based":
definite_policy = RuleBased()
else:
raise RuntimeError(
f"Unknown routing policy preset: {args.set_routing_policy}"
)
full_policy = FullRoutingPolicy([], definite_policy)
enum_blueprint.set_routing_policy(full_policy)

# 3. Write the changes back.
modified_blueprint = enum_blueprint.to_blueprint()
if blueprint == modified_blueprint:
Expand Down
4 changes: 2 additions & 2 deletions src/brad/admin/train_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from brad.data_stats.estimator import Estimator
from brad.data_stats.postgres_estimator import PostgresEstimator
from brad.routing.policy import RoutingPolicy
from brad.routing.tree_based.forest_router import ForestRouter
from brad.routing.tree_based.forest_policy import ForestPolicy
from brad.routing.tree_based.trainer import ForestTrainer
from brad.blueprint.manager import BlueprintManager

Expand Down Expand Up @@ -163,7 +163,7 @@ def train_router(args):
response = input("Do you want to persist this model? (y/n): ").lower()
if response == "y":
assets = AssetManager(config)
ForestRouter.static_persist_sync(model, schema_name, assets)
ForestPolicy.static_persist_sync(model, schema_name, assets)
logger.info("Model persisted successfully.")
break
elif response == "n":
Expand Down
42 changes: 28 additions & 14 deletions src/brad/blueprint/blueprint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Callable, Dict, List, Set, Optional, Tuple, Any
from typing import Dict, List, Set, Tuple, Any

from brad.blueprint.provisioning import Provisioning
from brad.blueprint.table import Table
from brad.config.engine import Engine
from brad.routing.router import Router

RouterProvider = Callable[[], Router]
from brad.routing.abstract_policy import FullRoutingPolicy


class Blueprint:
Expand All @@ -16,14 +14,14 @@ def __init__(
table_locations: Dict[str, List[Engine]],
aurora_provisioning: Provisioning,
redshift_provisioning: Provisioning,
router_provider: Optional[RouterProvider],
full_routing_policy: FullRoutingPolicy,
):
self._schema_name = schema_name
self._table_schemas = table_schemas
self._table_locations = table_locations
self._aurora_provisioning = aurora_provisioning
self._redshift_provisioning = redshift_provisioning
self._router_provider = router_provider
self._full_routing_policy = full_routing_policy

# Derived properties used for the class' methods.
self._tables_by_name = {tbl.name: tbl for tbl in self._table_schemas}
Expand Down Expand Up @@ -57,11 +55,8 @@ def aurora_provisioning(self) -> Provisioning:
def redshift_provisioning(self) -> Provisioning:
return self._redshift_provisioning

def get_router(self) -> Optional[Router]:
return self._router_provider() if self._router_provider is not None else None

def router_provider(self) -> Optional[RouterProvider]:
return self._router_provider
def get_routing_policy(self) -> FullRoutingPolicy:
return self._full_routing_policy

def base_table_names(self) -> Set[str]:
return self._base_table_names
Expand All @@ -87,6 +82,7 @@ def __eq__(self, other: object) -> bool:
and self.table_locations() == other.table_locations()
and self.aurora_provisioning() == other.aurora_provisioning()
and self.redshift_provisioning() == other.redshift_provisioning()
and self.get_routing_policy() == other.get_routing_policy()
)

def _compute_base_tables(self) -> Set[str]:
Expand Down Expand Up @@ -118,15 +114,33 @@ def visit_table(table: Table) -> None:

def __repr__(self) -> str:
# Useful for debugging purposes.
aurora = "Aurora: " + str(self.aurora_provisioning())
redshift = "Redshift: " + str(self.redshift_provisioning())
aurora = "Aurora: " + str(self.aurora_provisioning())
redshift = "Redshift: " + str(self.redshift_provisioning())
tables = "\n ".join(
map(
lambda name_loc: "".join([name_loc[0], ": ", str(name_loc[1])]),
self.table_locations().items(),
)
)
return "\n ".join(["Blueprint:", tables, aurora, redshift])
routing_policy = self.get_routing_policy()
indefinite_policies = (
f"Indefinite routing policies: {len(routing_policy.indefinite_policies)}"
)
definite_policy = (
f"Definite routing policy: {routing_policy.definite_policy.name()}"
)
return "\n ".join(
[
"Blueprint:",
tables,
"---",
aurora,
redshift,
"---",
indefinite_policies,
definite_policy,
]
)

def as_dict(self) -> Dict[str, Any]:
"""
Expand Down
32 changes: 30 additions & 2 deletions src/brad/blueprint/serde.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import pickle
import logging
from typing import Tuple, List, Dict

from brad.blueprint import Blueprint
from brad.blueprint.provisioning import Provisioning
from brad.blueprint.table import Column, Table
from brad.config.engine import Engine
from brad.routing.abstract_policy import FullRoutingPolicy
from brad.routing.round_robin import RoundRobin

import brad.proto_gen.blueprint_pb2 as b

logger = logging.getLogger(__name__)

# We define the data blueprint serialization/deserialization functions
# separately from the blueprint classes to avoid mixing protobuf code (an
# implementation detail) with the blueprint classes.
Expand All @@ -24,7 +30,7 @@ def deserialize_blueprint(raw_data: bytes) -> Blueprint:
table_locations=dict(map(_table_locations_from_proto, proto.tables)),
aurora_provisioning=_provisioning_from_proto(proto.aurora),
redshift_provisioning=_provisioning_from_proto(proto.redshift),
router_provider=None,
full_routing_policy=_policy_from_proto(proto.policy),
)


Expand All @@ -34,7 +40,7 @@ def serialize_blueprint(blueprint: Blueprint) -> bytes:
tables=map(_tables_with_locations_to_proto, blueprint.tables_with_locations()),
aurora=_provisioning_to_proto(blueprint.aurora_provisioning()),
redshift=_provisioning_to_proto(blueprint.redshift_provisioning()),
policy=None,
policy=_policy_to_proto(blueprint.get_routing_policy()),
)
return proto.SerializeToString()

Expand Down Expand Up @@ -88,6 +94,13 @@ def _indexed_columns_to_proto(indexed_columns: Tuple[Column, ...]) -> b.Index:
return b.Index(column_name=map(lambda col: col.name, indexed_columns))


def _policy_to_proto(full_routing_policy: FullRoutingPolicy) -> b.RoutingPolicy:
# We just use Python pickle serialization. In the future, this should be
# something more robust.
bytes_str = pickle.dumps(full_routing_policy)
return b.RoutingPolicy(policy=bytes_str)


# Deserialization


Expand Down Expand Up @@ -139,3 +152,18 @@ def _indexed_columns_from_proto(
for col_name in indexed_columns.column_name:
col_list.append(col_map[col_name])
return tuple(col_list)


def _policy_from_proto(policy: b.RoutingPolicy) -> FullRoutingPolicy:
if len(policy.policy) == 0:
logger.warning(
"Did not find a routing policy in the serialized blueprint. "
"This likely means you are running with an older blueprint. "
"Falling back to round robin routing. If you want to use a "
"different policy, use `brad admin modify_blueprint` to set "
"a policy."
)
return FullRoutingPolicy(indefinite_policies=[], definite_policy=RoundRobin())
# We just use Python pickle serialization. In the future, this should be
# something more robust.
return pickle.loads(policy.policy)
Loading

0 comments on commit b5594b3

Please sign in to comment.