Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make connection initialization more robust #366

Merged
merged 4 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion proto/brad.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ message StartSessionRequest {
}

message StartSessionResponse {
SessionId id = 1;
oneof result {
SessionId id = 1;
StartSessionError error = 2;
}
}

message StartSessionError {
string error_msg = 1;
}

message RunQueryRequest {
Expand Down
3 changes: 2 additions & 1 deletion src/brad/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ def is_connected(self) -> bool:

class ConnectionFailed(Exception):
"""
Used when
Used when an existing connection fails for any reason, or we failed to
establish a connection to an underlying engine.
"""
21 changes: 19 additions & 2 deletions src/brad/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brad.blueprint.manager import BlueprintManager
from brad.config.engine import Engine
from brad.config.file import ConfigFile
from brad.connection.connection import ConnectionFailed
from brad.daemon.monitor import Monitor
from brad.daemon.messages import (
ShutdownFrontEnd,
Expand Down Expand Up @@ -262,8 +263,24 @@ async def _run_teardown(self):
self._estimator = None

async def start_session(self) -> SessionId:
session_id, _ = await self._sessions.create_new_session()
return session_id
rand_backoff = None
while True:
try:
session_id, _ = await self._sessions.create_new_session()
return session_id
except ConnectionFailed:
if rand_backoff is None:
rand_backoff = RandomizedExponentialBackoff(
max_retries=10, base_delay_s=0.5, max_delay_s=10.0
)
time_to_wait = rand_backoff.wait_time_s()
if time_to_wait is None:
logger.exception(
"Failed to start a new session due to a repeated "
"connection failure (10 retries)."
)
raise
await asyncio.sleep(time_to_wait)

async def end_session(self, session_id: SessionId) -> None:
await self._sessions.end_session(session_id)
Expand Down
10 changes: 8 additions & 2 deletions src/brad/front_end/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import brad.proto_gen.brad_pb2_grpc as rpc
from brad.config.engine import Engine
from brad.config.session import SessionId
from brad.connection.connection import ConnectionFailed
from brad.front_end.brad_interface import BradInterface
from brad.front_end.errors import QueryError

Expand All @@ -24,8 +25,13 @@ def __init__(self, brad: BradInterface):
async def StartSession(
self, _request: b.StartSessionRequest, _context
) -> b.StartSessionResponse:
new_session_id = await self._brad.start_session()
return b.StartSessionResponse(id=b.SessionId(id_value=new_session_id.value()))
try:
new_session_id = await self._brad.start_session()
return b.StartSessionResponse(
id=b.SessionId(id_value=new_session_id.value())
)
except ConnectionFailed as ex:
return b.StartSessionResponse(error=b.StartSessionError(error_msg=repr(ex)))

async def RunQuery(
self, request: b.RunQueryRequest, _context
Expand Down
14 changes: 13 additions & 1 deletion src/brad/grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,19 @@ def close(self) -> None:
def start_session(self) -> SessionId:
assert self._stub is not None
result = self._stub.StartSession(b.StartSessionRequest())
return SessionId(result.id.id_value)
msg_kind = result.WhichOneof("result")
if msg_kind is None:
raise BradClientError(
message="BRAD RPC error: Unspecified start session result."
)
elif msg_kind == "id":
return SessionId(result.id.id_value)
elif msg_kind == "error":
raise BradClientError(message=result.error.error_msg)
else:
raise BradClientError(
message="BRAD RPC error: Unknown start session result."
)

def end_session(self, session_id: SessionId) -> None:
assert self._stub is not None
Expand Down
40 changes: 20 additions & 20 deletions src/brad/proto_gen/blueprint_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

84 changes: 44 additions & 40 deletions src/brad/proto_gen/blueprint_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,68 +4,55 @@ from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union

ATHENA: Engine
AURORA: Engine
DESCRIPTOR: _descriptor.FileDescriptor
REDSHIFT: Engine

class Engine(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = [] # type: ignore
UNKNOWN: _ClassVar[Engine]
AURORA: _ClassVar[Engine]
REDSHIFT: _ClassVar[Engine]
ATHENA: _ClassVar[Engine]
UNKNOWN: Engine
AURORA: Engine
REDSHIFT: Engine
ATHENA: Engine

class Blueprint(_message.Message):
__slots__ = ["aurora", "policy", "redshift", "schema_name", "tables"]
AURORA_FIELD_NUMBER: _ClassVar[int]
POLICY_FIELD_NUMBER: _ClassVar[int]
REDSHIFT_FIELD_NUMBER: _ClassVar[int]
__slots__ = ["schema_name", "tables", "aurora", "redshift", "policy"]
SCHEMA_NAME_FIELD_NUMBER: _ClassVar[int]
TABLES_FIELD_NUMBER: _ClassVar[int]
aurora: Provisioning
policy: RoutingPolicy
redshift: Provisioning
AURORA_FIELD_NUMBER: _ClassVar[int]
REDSHIFT_FIELD_NUMBER: _ClassVar[int]
POLICY_FIELD_NUMBER: _ClassVar[int]
schema_name: str
tables: _containers.RepeatedCompositeFieldContainer[Table]
aurora: Provisioning
redshift: Provisioning
policy: RoutingPolicy
def __init__(self, schema_name: _Optional[str] = ..., tables: _Optional[_Iterable[_Union[Table, _Mapping]]] = ..., aurora: _Optional[_Union[Provisioning, _Mapping]] = ..., redshift: _Optional[_Union[Provisioning, _Mapping]] = ..., policy: _Optional[_Union[RoutingPolicy, _Mapping]] = ...) -> None: ...

class Index(_message.Message):
__slots__ = ["column_name"]
COLUMN_NAME_FIELD_NUMBER: _ClassVar[int]
column_name: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, column_name: _Optional[_Iterable[str]] = ...) -> None: ...

class Provisioning(_message.Message):
__slots__ = ["instance_type", "num_nodes"]
INSTANCE_TYPE_FIELD_NUMBER: _ClassVar[int]
NUM_NODES_FIELD_NUMBER: _ClassVar[int]
instance_type: str
num_nodes: int
def __init__(self, instance_type: _Optional[str] = ..., num_nodes: _Optional[int] = ...) -> None: ...

class RoutingPolicy(_message.Message):
__slots__ = ["policy"]
POLICY_FIELD_NUMBER: _ClassVar[int]
policy: bytes
def __init__(self, policy: _Optional[bytes] = ...) -> None: ...

class Table(_message.Message):
__slots__ = ["columns", "dependencies", "indexes", "locations", "table_name"]
__slots__ = ["table_name", "columns", "locations", "dependencies", "indexes"]
TABLE_NAME_FIELD_NUMBER: _ClassVar[int]
COLUMNS_FIELD_NUMBER: _ClassVar[int]
LOCATIONS_FIELD_NUMBER: _ClassVar[int]
DEPENDENCIES_FIELD_NUMBER: _ClassVar[int]
INDEXES_FIELD_NUMBER: _ClassVar[int]
LOCATIONS_FIELD_NUMBER: _ClassVar[int]
TABLE_NAME_FIELD_NUMBER: _ClassVar[int]
table_name: str
columns: _containers.RepeatedCompositeFieldContainer[TableColumn]
locations: _containers.RepeatedScalarFieldContainer[Engine]
dependencies: TableDependency
indexes: _containers.RepeatedCompositeFieldContainer[Index]
locations: _containers.RepeatedScalarFieldContainer[Engine]
table_name: str
def __init__(self, table_name: _Optional[str] = ..., columns: _Optional[_Iterable[_Union[TableColumn, _Mapping]]] = ..., locations: _Optional[_Iterable[_Union[Engine, str]]] = ..., dependencies: _Optional[_Union[TableDependency, _Mapping]] = ..., indexes: _Optional[_Iterable[_Union[Index, _Mapping]]] = ...) -> None: ...

class TableColumn(_message.Message):
__slots__ = ["data_type", "is_primary", "name"]
__slots__ = ["name", "data_type", "is_primary"]
NAME_FIELD_NUMBER: _ClassVar[int]
DATA_TYPE_FIELD_NUMBER: _ClassVar[int]
IS_PRIMARY_FIELD_NUMBER: _ClassVar[int]
NAME_FIELD_NUMBER: _ClassVar[int]
name: str
data_type: str
is_primary: bool
name: str
def __init__(self, name: _Optional[str] = ..., data_type: _Optional[str] = ..., is_primary: bool = ...) -> None: ...

class TableDependency(_message.Message):
Expand All @@ -76,5 +63,22 @@ class TableDependency(_message.Message):
transform: str
def __init__(self, source_table_names: _Optional[_Iterable[str]] = ..., transform: _Optional[str] = ...) -> None: ...

class Engine(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): # type: ignore
__slots__ = [] # type: ignore
class Provisioning(_message.Message):
__slots__ = ["instance_type", "num_nodes"]
INSTANCE_TYPE_FIELD_NUMBER: _ClassVar[int]
NUM_NODES_FIELD_NUMBER: _ClassVar[int]
instance_type: str
num_nodes: int
def __init__(self, instance_type: _Optional[str] = ..., num_nodes: _Optional[int] = ...) -> None: ...

class RoutingPolicy(_message.Message):
__slots__ = ["policy"]
POLICY_FIELD_NUMBER: _ClassVar[int]
policy: bytes
def __init__(self, policy: _Optional[bytes] = ...) -> None: ...

class Index(_message.Message):
__slots__ = ["column_name"]
COLUMN_NAME_FIELD_NUMBER: _ClassVar[int]
column_name: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, column_name: _Optional[_Iterable[str]] = ...) -> None: ...
Loading
Loading