-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create BRAD client classes for various connections, run and visualize…
… benchmark experiments (#510) Co-authored-by: Sophie Zhang <[email protected]>
- Loading branch information
Showing
4 changed files
with
254 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import argparse | ||
import ast | ||
from collections import defaultdict | ||
from typing import Any | ||
import numpy as np | ||
import pandas as pd | ||
import time | ||
|
||
from brad.grpc_client import BradGrpcClient | ||
from brad.flight_sql_client_odbc import BradFlightSqlClientOdbc | ||
from brad.sqlite_client import BradSqliteClient | ||
|
||
|
||
def adjusted_data(data: list[float], drop_count=1) -> list[float]: | ||
# Drop top and bottom `k` values from `data` | ||
return sorted(data)[drop_count : len(data) - drop_count] | ||
|
||
|
||
def run_client( | ||
client: BradGrpcClient | BradFlightSqlClientOdbc | BradSqliteClient, | ||
trials: int, | ||
repetitions: int, | ||
query: str, | ||
) -> tuple[np.floating[Any], np.floating[Any]]: | ||
average_latencies = [] | ||
for _ in range(trials): | ||
start = time.time() | ||
for _ in range(repetitions): | ||
if isinstance(client, BradGrpcClient): | ||
client.run_query_json(query) | ||
else: | ||
client.run_query(query) | ||
end = time.time() | ||
|
||
total = end - start | ||
latency = total / repetitions | ||
|
||
average_latencies.append(latency) | ||
|
||
adjusted_average_latencies = adjusted_data(average_latencies) | ||
return np.mean(adjusted_average_latencies), np.std(adjusted_average_latencies) | ||
|
||
|
||
def build_dataframe() -> pd.DataFrame: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--repetitions", type=int, default=1000) | ||
parser.add_argument("--trials", type=int, default=10) | ||
parser.add_argument("--query", type=str, default="SELECT 1") | ||
args = parser.parse_args() | ||
|
||
queries = [ | ||
"SELECT 1;", | ||
"SELECT * FROM person_info;", | ||
"SELECT * FROM person_info WHERE id = 123456;", | ||
] | ||
|
||
data = defaultdict(list) | ||
|
||
data["Connection Type"].append("gRPC") | ||
with BradGrpcClient(host="localhost", port=6583) as client: | ||
for query in queries: | ||
lat_avg, lat_std_dev = run_client( | ||
client, args.trials, args.repetitions, query | ||
) | ||
data[query].append(str((lat_avg, lat_std_dev))) | ||
|
||
data["Connection Type"].append("Flight SQL ODBC") | ||
with BradFlightSqlClientOdbc(host="localhost", port=31337) as client: | ||
for query in queries: | ||
lat_avg, lat_std_dev = run_client( | ||
client, args.trials, args.repetitions, query | ||
) | ||
data[query].append(str((lat_avg, lat_std_dev))) | ||
|
||
data["Connection Type"].append("SQLite") | ||
with BradSqliteClient(database="/tmp/sophiez_brad_stub_db.sqlite") as client: | ||
for query in queries: | ||
lat_avg, lat_std_dev = run_client( | ||
client, args.trials, args.repetitions, query | ||
) | ||
data[query].append(str((lat_avg, lat_std_dev))) | ||
|
||
return pd.DataFrame.from_dict(data) | ||
|
||
|
||
def print_to_csv(dataframe: pd.DataFrame, filename: str) -> None: | ||
dataframe.to_csv(filename, index=False) | ||
|
||
|
||
def plot_from_csv(filename: str) -> None: | ||
dataframe = pd.read_csv(filename) | ||
|
||
lat_avgs = {} | ||
lat_std_devs = {} | ||
for query in dataframe.columns[1:]: | ||
query_records = dataframe[query].tolist() | ||
lat_avgs[query] = [ | ||
ast.literal_eval(lat_statistic)[0] for lat_statistic in query_records | ||
] | ||
lat_std_devs[query] = [ | ||
ast.literal_eval(lat_statistic)[1] for lat_statistic in query_records | ||
] | ||
|
||
std_dev_yerr = [] | ||
for query, std_devs in lat_std_devs.items(): | ||
std_dev_yerr.append([std_devs, std_devs]) | ||
|
||
df = pd.DataFrame(lat_avgs, index=dataframe["Connection Type"].tolist()) | ||
ax = df.plot.bar(yerr=std_dev_yerr, rot=0) | ||
ax.legend(loc="upper center", bbox_to_anchor=(0.5, 1.25)) | ||
ax.set_xlabel("Connection Type") | ||
ax.set_ylabel("Average Latency (s)") | ||
|
||
fig = ax.get_figure() | ||
fig.savefig("measurement_comparisons_plot.png", bbox_inches="tight") | ||
|
||
|
||
def main() -> None: | ||
dataframe = build_dataframe() | ||
csv_filename = "measurement_comparisons.csv" | ||
print_to_csv(dataframe, csv_filename) | ||
plot_from_csv(csv_filename) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
28 changes: 28 additions & 0 deletions
28
experiments/13-connect-overhead/measure_noop_flight_sql_odbc.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import argparse | ||
import time | ||
|
||
from brad.flight_sql_client_odbc import BradFlightSqlClientOdbc | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--repetitions", type=int, default=1000) | ||
parser.add_argument("--host", type=str, default="localhost") | ||
parser.add_argument("--port", type=int, default=31337) | ||
args = parser.parse_args() | ||
|
||
with BradFlightSqlClientOdbc(args.host, args.port) as client: | ||
start = time.time() | ||
for _ in range(args.repetitions): | ||
client.run_query("BRAD_NOOP") | ||
end = time.time() | ||
|
||
total = end - start | ||
avg_lat = total / args.repetitions | ||
|
||
print("reps,total_time_s,avg_lat_s") | ||
print("{},{},{}".format(args.repetitions, total, avg_lat)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import pyodbc | ||
from typing import Generator, Optional, Self, Tuple, List, Any | ||
|
||
|
||
class BradFlightSqlClientOdbc: | ||
""" | ||
A client that communicates with BRAD via Arrow Flight SQL ODBC driver. | ||
Usage: | ||
``` | ||
with BradFlightSqlClientOdbc(host, port) as client: | ||
for row in client.run_query("SELECT 1"): | ||
print(row) | ||
``` | ||
""" | ||
|
||
RowList = List[Tuple[Any, ...]] | ||
|
||
def __init__(self, host="localhost", port=31337) -> None: | ||
self._host = host | ||
self._port = port | ||
self._connection: Optional[pyodbc.Connection] = None | ||
self._cursor: Optional[pyodbc.Cursor] = None | ||
|
||
def __enter__(self) -> Self: | ||
self.connect() | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback) -> None: | ||
self.close() | ||
|
||
def connect(self) -> None: | ||
self._connection = pyodbc.connect( | ||
"DRIVER={Arrow Flight SQL ODBC Driver};USEENCRYPTION=false;" | ||
+ f"HOST={self._host};" | ||
+ f"PORT={self._port}", | ||
autocommit=True, | ||
) | ||
self._cursor = self._connection.cursor() | ||
|
||
def close(self) -> None: | ||
assert self._cursor | ||
assert self._connection | ||
self._cursor.close() | ||
self._connection.close() | ||
|
||
def run_query_generator(self, query: str) -> Generator[Tuple[Any, ...], None, None]: | ||
assert self._cursor | ||
for row in self._cursor.execute(query): | ||
yield row | ||
|
||
def run_query(self, query: str) -> RowList: | ||
assert self._cursor | ||
return self._cursor.execute(query) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import sqlite3 | ||
from typing import Generator, Optional, Self, Tuple, Any | ||
|
||
|
||
class BradSqliteClient: | ||
""" | ||
A client that communicates with BRAD directly against SQLite database. | ||
Usage: | ||
``` | ||
with BradSqliteClient(database) as client: | ||
for row in client.run_query(session_id, "SELECT 1"): | ||
print(row) | ||
``` | ||
""" | ||
|
||
def __init__(self, database: str) -> None: | ||
self._database = database | ||
self._connection: Optional[sqlite3.Connection] = None | ||
self._cursor: Optional[sqlite3.Cursor] = None | ||
|
||
def __enter__(self) -> Self: | ||
self.connect() | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback) -> None: | ||
self.close() | ||
|
||
def connect(self) -> None: | ||
self._connection = sqlite3.connect(self._database) | ||
self._cursor = self._connection.cursor() | ||
|
||
def close(self) -> None: | ||
assert self._cursor | ||
assert self._connection | ||
self._cursor.close() | ||
self._connection.close() | ||
|
||
def run_query_generator(self, query: str) -> Generator[Tuple[Any, ...], None, None]: | ||
assert self._cursor | ||
for row in self._cursor.execute(query): | ||
yield row | ||
|
||
def run_query(self, query: str) -> sqlite3.Cursor: | ||
assert self._cursor | ||
return self._cursor.execute(query) |