Skip to content

Commit

Permalink
Fix lint/type/format errors
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffxy committed Nov 15, 2023
1 parent 77c2f29 commit e35ec91
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
2 changes: 0 additions & 2 deletions src/brad/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
NewBlueprint,
NewBlueprintAck,
)
from brad.data_stats.estimator import Estimator
from brad.data_stats.postgres_estimator import PostgresEstimator
from brad.front_end.brad_interface import BradInterface
from brad.front_end.errors import QueryError
from brad.front_end.grpc import BradGrpc
Expand Down
4 changes: 3 additions & 1 deletion src/brad/routing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def log_policy(self) -> None:
logger.info(" - %s", p.name())
logger.info(" Definite policy: %s", self._full_policy.definite_policy.name())

async def run_setup_for_standalone(self, estimator: Optional[Estimator] = None) -> None:
async def run_setup_for_standalone(
self, estimator: Optional[Estimator] = None
) -> None:
"""
Should be called before using the router "standalone" contexts (i.e.,
outside the front end). This is used to set up any dynamic state that is
Expand Down
7 changes: 5 additions & 2 deletions tests/test_forest_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sklearn.ensemble import RandomForestClassifier

from brad.config.engine import Engine
from brad.routing.context import RoutingContext
from brad.routing.tree_based.forest_policy import ForestPolicy
from brad.routing.tree_based.model_wrap import ModelWrap
from brad.routing.policy import RoutingPolicy
Expand All @@ -21,9 +22,10 @@ def get_dummy_router():
def test_model_codepath_partial():
model = get_dummy_router()
router = ForestPolicy.from_loaded_model(RoutingPolicy.ForestTablePresence, model)
ctx = RoutingContext()

query = QueryRep("SELECT * FROM test1, test2")
loc = router.engine_for_sync(query)
loc = router.engine_for_sync(query, ctx)
assert (
loc[0] == Engine.Aurora or loc[0] == Engine.Redshift or loc[0] == Engine.Athena
)
Expand All @@ -32,9 +34,10 @@ def test_model_codepath_partial():
def test_model_codepath_all():
model = get_dummy_router()
router = ForestPolicy.from_loaded_model(RoutingPolicy.ForestTablePresence, model)
ctx = RoutingContext()

query = QueryRep("SELECT * FROM test1")
loc = router.engine_for_sync(query)
loc = router.engine_for_sync(query, ctx)
assert (
loc[0] == Engine.Aurora or loc[0] == Engine.Redshift or loc[0] == Engine.Athena
)

0 comments on commit e35ec91

Please sign in to comment.