Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store and use the routing policy in the blueprint #350

Merged
merged 11 commits into from
Nov 5, 2023
6 changes: 3 additions & 3 deletions src/brad/admin/drop_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from brad.front_end.engine_connections import EngineConnections
from brad.provisioning.directory import Directory
from brad.routing.policy import RoutingPolicy
from brad.routing.tree_based.forest_router import ForestRouter
from brad.routing.tree_based.forest_policy import ForestPolicy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,10 +56,10 @@ def drop_schema(args):

# 4. Drop any serialized routers.
assets = AssetManager(config)
ForestRouter.static_drop_model_sync(
ForestPolicy.static_drop_model_sync(
args.schema_name, RoutingPolicy.ForestTableSelectivity, assets
)
ForestRouter.static_drop_model_sync(
ForestPolicy.static_drop_model_sync(
args.schema_name, RoutingPolicy.ForestTablePresence, assets
)

Expand Down
30 changes: 29 additions & 1 deletion src/brad/admin/modify_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from brad.daemon.transition_orchestrator import TransitionOrchestrator
from brad.front_end.engine_connections import EngineConnections
from brad.planner.enumeration.blueprint import EnumeratedBlueprint
from brad.routing.abstract_policy import AbstractRoutingPolicy, FullRoutingPolicy
from brad.routing.always_one import AlwaysOneRouter
from brad.routing.policy import RoutingPolicy
from brad.routing.tree_based.forest_policy import ForestPolicy
from brad.routing.rule_based import RuleBased

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,6 +73,11 @@ def register_admin_action(subparser) -> None:
action="store_true",
help="Updates the blueprint's table placement and places tables on all engines.",
)
parser.add_argument(
"--set-routing-policy",
type=str,
help="Sets the serialized routing policy to a preconfigured default: {always_redshift, df_selectivity, rule_based}",
)
parser.add_argument(
"--add-indexes",
action="store_true",
Expand Down Expand Up @@ -184,7 +194,7 @@ async def run_transition(


# This method is called by `brad.exec.admin.main`.
def modify_blueprint(args):
def modify_blueprint(args) -> None:
# 1. Load the config.
config = ConfigFile.load(args.config_file)

Expand Down Expand Up @@ -253,6 +263,24 @@ def modify_blueprint(args):
new_placement[tbl] = Engine.from_bitmap(Engine.bitmap_all())
enum_blueprint.set_table_locations(new_placement)

if args.set_routing_policy is not None:
if args.set_routing_policy == "always_redshift":
definite_policy: AbstractRoutingPolicy = AlwaysOneRouter(Engine.Redshift)
elif args.set_routing_policy == "df_selectivity":
definite_policy = asyncio.run(
ForestPolicy.from_assets(
args.schema_name, RoutingPolicy.ForestTableSelectivity, assets
)
)
elif args.set_routing_policy == "rule_based":
definite_policy = RuleBased()
else:
raise RuntimeError(
f"Unknown routing policy preset: {args.set_routing_policy}"
)
full_policy = FullRoutingPolicy([], definite_policy)
enum_blueprint.set_routing_policy(full_policy)

# 3. Write the changes back.
modified_blueprint = enum_blueprint.to_blueprint()
if blueprint == modified_blueprint:
Expand Down
4 changes: 2 additions & 2 deletions src/brad/admin/train_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from brad.data_stats.estimator import Estimator
from brad.data_stats.postgres_estimator import PostgresEstimator
from brad.routing.policy import RoutingPolicy
from brad.routing.tree_based.forest_router import ForestRouter
from brad.routing.tree_based.forest_policy import ForestPolicy
from brad.routing.tree_based.trainer import ForestTrainer
from brad.blueprint.manager import BlueprintManager

Expand Down Expand Up @@ -163,7 +163,7 @@ def train_router(args):
response = input("Do you want to persist this model? (y/n): ").lower()
if response == "y":
assets = AssetManager(config)
ForestRouter.static_persist_sync(model, schema_name, assets)
ForestPolicy.static_persist_sync(model, schema_name, assets)
logger.info("Model persisted successfully.")
break
elif response == "n":
Expand Down
42 changes: 28 additions & 14 deletions src/brad/blueprint/blueprint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Callable, Dict, List, Set, Optional, Tuple, Any
from typing import Dict, List, Set, Tuple, Any

from brad.blueprint.provisioning import Provisioning
from brad.blueprint.table import Table
from brad.config.engine import Engine
from brad.routing.router import Router

RouterProvider = Callable[[], Router]
from brad.routing.abstract_policy import FullRoutingPolicy


class Blueprint:
Expand All @@ -16,14 +14,14 @@ def __init__(
table_locations: Dict[str, List[Engine]],
aurora_provisioning: Provisioning,
redshift_provisioning: Provisioning,
router_provider: Optional[RouterProvider],
full_routing_policy: FullRoutingPolicy,
):
self._schema_name = schema_name
self._table_schemas = table_schemas
self._table_locations = table_locations
self._aurora_provisioning = aurora_provisioning
self._redshift_provisioning = redshift_provisioning
self._router_provider = router_provider
self._full_routing_policy = full_routing_policy

# Derived properties used for the class' methods.
self._tables_by_name = {tbl.name: tbl for tbl in self._table_schemas}
Expand Down Expand Up @@ -57,11 +55,8 @@ def aurora_provisioning(self) -> Provisioning:
def redshift_provisioning(self) -> Provisioning:
return self._redshift_provisioning

def get_router(self) -> Optional[Router]:
return self._router_provider() if self._router_provider is not None else None

def router_provider(self) -> Optional[RouterProvider]:
return self._router_provider
def get_routing_policy(self) -> FullRoutingPolicy:
return self._full_routing_policy

def base_table_names(self) -> Set[str]:
return self._base_table_names
Expand All @@ -87,6 +82,7 @@ def __eq__(self, other: object) -> bool:
and self.table_locations() == other.table_locations()
and self.aurora_provisioning() == other.aurora_provisioning()
and self.redshift_provisioning() == other.redshift_provisioning()
and self.get_routing_policy() == other.get_routing_policy()
)

def _compute_base_tables(self) -> Set[str]:
Expand Down Expand Up @@ -118,15 +114,33 @@ def visit_table(table: Table) -> None:

def __repr__(self) -> str:
# Useful for debugging purposes.
aurora = "Aurora: " + str(self.aurora_provisioning())
redshift = "Redshift: " + str(self.redshift_provisioning())
aurora = "Aurora: " + str(self.aurora_provisioning())
redshift = "Redshift: " + str(self.redshift_provisioning())
tables = "\n ".join(
map(
lambda name_loc: "".join([name_loc[0], ": ", str(name_loc[1])]),
self.table_locations().items(),
)
)
return "\n ".join(["Blueprint:", tables, aurora, redshift])
routing_policy = self.get_routing_policy()
indefinite_policies = (
f"Indefinite routing policies: {len(routing_policy.indefinite_policies)}"
)
definite_policy = (
f"Definite routing policy: {routing_policy.definite_policy.name()}"
)
return "\n ".join(
[
"Blueprint:",
tables,
"---",
aurora,
redshift,
"---",
indefinite_policies,
definite_policy,
]
)

def as_dict(self) -> Dict[str, Any]:
"""
Expand Down
32 changes: 30 additions & 2 deletions src/brad/blueprint/serde.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import pickle
import logging
from typing import Tuple, List, Dict

from brad.blueprint import Blueprint
from brad.blueprint.provisioning import Provisioning
from brad.blueprint.table import Column, Table
from brad.config.engine import Engine
from brad.routing.abstract_policy import FullRoutingPolicy
from brad.routing.round_robin import RoundRobin

import brad.proto_gen.blueprint_pb2 as b

logger = logging.getLogger(__name__)

# We define the data blueprint serialization/deserialization functions
# separately from the blueprint classes to avoid mixing protobuf code (an
# implementation detail) with the blueprint classes.
Expand All @@ -24,7 +30,7 @@ def deserialize_blueprint(raw_data: bytes) -> Blueprint:
table_locations=dict(map(_table_locations_from_proto, proto.tables)),
aurora_provisioning=_provisioning_from_proto(proto.aurora),
redshift_provisioning=_provisioning_from_proto(proto.redshift),
router_provider=None,
full_routing_policy=_policy_from_proto(proto.policy),
)


Expand All @@ -34,7 +40,7 @@ def serialize_blueprint(blueprint: Blueprint) -> bytes:
tables=map(_tables_with_locations_to_proto, blueprint.tables_with_locations()),
aurora=_provisioning_to_proto(blueprint.aurora_provisioning()),
redshift=_provisioning_to_proto(blueprint.redshift_provisioning()),
policy=None,
policy=_policy_to_proto(blueprint.get_routing_policy()),
)
return proto.SerializeToString()

Expand Down Expand Up @@ -88,6 +94,13 @@ def _indexed_columns_to_proto(indexed_columns: Tuple[Column, ...]) -> b.Index:
return b.Index(column_name=map(lambda col: col.name, indexed_columns))


def _policy_to_proto(full_routing_policy: FullRoutingPolicy) -> b.RoutingPolicy:
# We just use Python pickle serialization. In the future, this should be
# something more robust.
bytes_str = pickle.dumps(full_routing_policy)
return b.RoutingPolicy(policy=bytes_str)


# Deserialization


Expand Down Expand Up @@ -139,3 +152,18 @@ def _indexed_columns_from_proto(
for col_name in indexed_columns.column_name:
col_list.append(col_map[col_name])
return tuple(col_list)


def _policy_from_proto(policy: b.RoutingPolicy) -> FullRoutingPolicy:
if len(policy.policy) == 0:
logger.warning(
"Did not find a routing policy in the serialized blueprint. "
"This likely means you are running with an older blueprint. "
"Falling back to round robin routing. If you want to use a "
"different policy, use `brad admin modify_blueprint` to set "
"a policy."
)
return FullRoutingPolicy(indefinite_policies=[], definite_policy=RoundRobin())
# We just use Python pickle serialization. In the future, this should be
# something more robust.
return pickle.loads(policy.policy)
Loading
Loading