From e35ec91700a339bcf29779fdb9e523c028d143e3 Mon Sep 17 00:00:00 2001 From: Geoffrey Yu Date: Wed, 15 Nov 2023 15:38:01 -0500 Subject: [PATCH] Fix lint/type/format errors --- src/brad/front_end/front_end.py | 2 -- src/brad/routing/router.py | 4 +++- tests/test_forest_routing.py | 7 +++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index a58bfc0d..c8e9abce 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -29,8 +29,6 @@ NewBlueprint, NewBlueprintAck, ) -from brad.data_stats.estimator import Estimator -from brad.data_stats.postgres_estimator import PostgresEstimator from brad.front_end.brad_interface import BradInterface from brad.front_end.errors import QueryError from brad.front_end.grpc import BradGrpc diff --git a/src/brad/routing/router.py b/src/brad/routing/router.py index cf75969d..2318cd7c 100644 --- a/src/brad/routing/router.py +++ b/src/brad/routing/router.py @@ -55,7 +55,9 @@ def log_policy(self) -> None: logger.info(" - %s", p.name()) logger.info(" Definite policy: %s", self._full_policy.definite_policy.name()) - async def run_setup_for_standalone(self, estimator: Optional[Estimator] = None) -> None: + async def run_setup_for_standalone( + self, estimator: Optional[Estimator] = None + ) -> None: """ Should be called before using the router "standalone" contexts (i.e., outside the front end). This is used to set up any dynamic state that is diff --git a/tests/test_forest_routing.py b/tests/test_forest_routing.py index 095dc810..2ab4752f 100644 --- a/tests/test_forest_routing.py +++ b/tests/test_forest_routing.py @@ -3,6 +3,7 @@ from sklearn.ensemble import RandomForestClassifier from brad.config.engine import Engine +from brad.routing.context import RoutingContext from brad.routing.tree_based.forest_policy import ForestPolicy from brad.routing.tree_based.model_wrap import ModelWrap from brad.routing.policy import RoutingPolicy @@ -21,9 +22,10 @@ def get_dummy_router(): def test_model_codepath_partial(): model = get_dummy_router() router = ForestPolicy.from_loaded_model(RoutingPolicy.ForestTablePresence, model) + ctx = RoutingContext() query = QueryRep("SELECT * FROM test1, test2") - loc = router.engine_for_sync(query) + loc = router.engine_for_sync(query, ctx) assert ( loc[0] == Engine.Aurora or loc[0] == Engine.Redshift or loc[0] == Engine.Athena ) @@ -32,9 +34,10 @@ def test_model_codepath_partial(): def test_model_codepath_all(): model = get_dummy_router() router = ForestPolicy.from_loaded_model(RoutingPolicy.ForestTablePresence, model) + ctx = RoutingContext() query = QueryRep("SELECT * FROM test1") - loc = router.engine_for_sync(query) + loc = router.engine_for_sync(query, ctx) assert ( loc[0] == Engine.Aurora or loc[0] == Engine.Redshift or loc[0] == Engine.Athena )