Skip to content

Commit

Permalink
Add query sequence runner (#386)
Browse files Browse the repository at this point in the history
* WIP

* continue on  ad-hoc runner

* delete out time scale stuff

* Run until the end with an abort option

* Add exponential distributed wait times

* Fixes

* Daemon startup fix for new engines

* Initialization fix for missing tables

---------

Co-authored-by: ferdiko <[email protected]>
  • Loading branch information
geoffxy and ferdiko authored Nov 23, 2023
1 parent 7ce4617 commit 31a1d62
Show file tree
Hide file tree
Showing 3 changed files with 368 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/brad/daemon/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,10 @@ async def _run_setup(self) -> None:
for fe in self._front_ends:
fe.process.start()

if self._config.routing_policy == RoutingPolicy.ForestTableSelectivity:
if (
self._config.routing_policy == RoutingPolicy.ForestTableSelectivity
or self._config.routing_policy == RoutingPolicy.Default
):
logger.info("Setting up the cardinality estimator...")
estimator = await PostgresEstimator.connect(self._schema_name, self._config)
await estimator.analyze(
Expand Down
5 changes: 5 additions & 0 deletions src/brad/data_stats/postgres_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .estimator import Estimator

from brad.blueprint import Blueprint
from brad.config.engine import Engine
from brad.config.file import ConfigFile
from brad.config.strings import (
base_table_name_from_source,
Expand Down Expand Up @@ -195,6 +196,10 @@ async def _get_table_sizes_direct(self) -> Dict[str, int]:
table_counts = {}

for table in self._blueprint.tables():
locations = self._blueprint.get_table_locations(table.name)
if Engine.Aurora not in locations:
logger.warning("Not fetching size of %s.", table.name)
continue
query = f"SELECT COUNT(*) FROM {table.name}"
logger.debug("PostgresEstimator running: %s", query)
await self._cursor.execute(query)
Expand Down
359 changes: 359 additions & 0 deletions workloads/IMDB_extended/run_query_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
import argparse
import multiprocessing as mp
import time
import os
import pathlib
import random
import sys
import signal
import pytz
import logging
from typing import List
from datetime import datetime, timedelta

from workload_utils.connect import connect_to_db
from brad.config.engine import Engine
from brad.grpc_client import BradClientError
from brad.utils.rand_exponential_backoff import RandomizedExponentialBackoff


logger = logging.getLogger(__name__)

STARTUP_FAILED = "startup_failed"


def runner(
runner_idx: int,
start_queue: mp.Queue,
control_semaphore: mp.Semaphore, # type: ignore
args,
queries: List[str],
) -> None:
# Check args.
assert args.num_clients > runner_idx

def noop(_signal, _frame):
pass

signal.signal(signal.SIGINT, noop)

# For printing out results.
if "COND_OUT" in os.environ:
# pylint: disable-next=import-error
import conductor.lib as cond

out_dir = cond.get_output_path()
else:
out_dir = pathlib.Path(".")

if args.engine is not None:
engine = Engine.from_str(args.engine)
else:
engine = None

try:
database = connect_to_db(
args,
runner_idx,
direct_engine=engine,
# Ensure we disable the result cache if we are running directly on
# Redshift.
disable_direct_redshift_result_cache=True,
)
except BradClientError as ex:
print(f"[Seq runner {runner_idx}] Failed to connect to BRAD:", str(ex))
start_queue.put_nowait(STARTUP_FAILED)
return

# Query indexes the runner should execute.
runner_qidx = [i for i in range(len(queries)) if i % args.num_clients == runner_idx]

exec_count = 0
file = open(
out_dir / "seq_queries_{}.csv".format(runner_idx),
"w",
encoding="UTF-8",
)

try:
print(
"timestamp,query_idx,run_time_s,engine",
file=file,
flush=True,
)

prng = random.Random(args.seed ^ runner_idx)
rand_backoff = None

logger.info(
"[Ad hoc Runner %d] Queries to run: %s",
runner_idx,
queries,
)

# Signal that we're ready to start and wait for the controller.
print(
f"Seq Runner {runner_idx} is ready to start running.",
flush=True,
file=sys.stderr,
)
start_queue.put_nowait("")
control_semaphore.acquire() # type: ignore

last_run_time_s = None

for qidx in runner_qidx:
# Note that `False` means to not block.
should_exit_early = control_semaphore.acquire(False) # type: ignore
if should_exit_early:
print(
f"Seq Runner {runner_idx} is exiting early.",
file=sys.stderr,
flush=True,
)
break

# Wait for some time before issuing, if requested.
if args.avg_gap_s is not None:
wait_for_s = prng.gauss(args.avg_gap_s, args.avg_gap_std_s)
elif args.arrivals_per_s is not None:
wait_for_s = prng.expovariate(args.arrivals_per_s)
if last_run_time_s is not None:
wait_for_s -= last_run_time_s
else:
wait_for_s = 0.0

if wait_for_s > 0.0:
time.sleep(wait_for_s)

logger.debug("Executing qidx: %d", qidx)
query = queries[qidx]

try:
# Get time stamp for logging.
now = datetime.now().astimezone(pytz.utc)

# Execute query.
start = time.time()
_, engine = database.execute_sync_with_engine(query)
end = time.time()

# Log.
engine_log = "xxx"
if engine is not None:
engine_log = engine.value
run_time_s = end - start
print(
"{},{},{},{}".format(
now,
qidx,
run_time_s,
engine_log,
),
file=file,
flush=True,
)
last_run_time_s = run_time_s

if exec_count % 20 == 0:
# To avoid data loss if this script crashes.
os.fsync(file.fileno())

exec_count += 1
if rand_backoff is not None:
print(
f"[Seq Runner {runner_idx}] Continued after transient errors.",
flush=True,
file=sys.stderr,
)
rand_backoff = None

except BradClientError as ex:
if ex.is_transient():
# This is too verbose during a transition.
# print(
# "Transient query error:",
# ex.message(),
# flush=True,
# file=sys.stderr,
# )

if rand_backoff is None:
rand_backoff = RandomizedExponentialBackoff(
max_retries=100,
base_delay_s=1.0,
max_delay_s=timedelta(minutes=1).total_seconds(),
)
print(
f"[Seq Runner {runner_idx}] Backing off due to transient errors.",
flush=True,
file=sys.stderr,
)

# Delay retrying in the case of a transient error (this
# happens during blueprint transitions).
wait_s = rand_backoff.wait_time_s()
if wait_s is None:
print(
f"[Seq Runner {runner_idx}] Aborting benchmark. Too many transient errors.",
flush=True,
file=sys.stderr,
)
break
time.sleep(wait_s)

else:
print(
"Unexpected seq query error:",
ex.message(),
flush=True,
file=sys.stderr,
)

finally:
os.fsync(file.fileno())
file.close()
database.close_sync()
print(f"Seq runner {runner_idx} has exited.", flush=True, file=sys.stderr)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--brad-host", type=str, default="localhost")
parser.add_argument("--brad-port", type=int, default=6583)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num-front-ends", type=int, default=1)
parser.add_argument(
"--cstr-var",
type=str,
help="Set to connect via ODBC instead of the BRAD client (for use with other baselines).",
)
parser.add_argument(
"--query-sequence-file",
type=str,
required=True,
help="Path to a query sequence.",
)
# Use these to slice the file, if needed.
parser.add_argument("--query-sequence-offset", type=int, default=0)
parser.add_argument("--query-sequence-length", type=int)
parser.add_argument("--num-clients", type=int, default=1)
parser.add_argument("--client-offset", type=int, default=0)
parser.add_argument("--avg-gap-s", type=float)
parser.add_argument("--avg-gap-std-s", type=float, default=0.5)
# Set this to use an exponential distribution for the gap times.\
# This value is per-client.
parser.add_argument("--arrivals-per-s", type=float)
parser.add_argument(
"--brad-direct",
action="store_true",
help="Set to connect directly to Aurora via BRAD's config.",
)
parser.add_argument(
"--config-file",
type=str,
help="The BRAD config file (if --brad-direct is used).",
)
parser.add_argument(
"--schema-name",
type=str,
help="The schema name to use, if connecting directly.",
)
parser.add_argument(
"--engine", type=str, help="The engine to use, if connecting directly."
)
args = parser.parse_args()

with open(args.query_sequence_file, "r", encoding="UTF-8") as file:
query_seq = [line.strip() for line in file]

# Truncate according to requested offset and sequence length.
offset = args.query_sequence_offset
seq_len = args.query_sequence_length
query_seq = query_seq[offset:]
if seq_len is not None:
query_seq = query_seq[:seq_len]

# Our control protocol is as follows.
# - Runner processes write to their `start_queue` when they have finished
# setting up and are ready to start running. They then wait on the control
# semaphore.
# - The control process blocks and waits on each `start_queue` to ensure
# runners can start together (if needed).
# - The control process signals the control semaphore twice. Once to tell a
# runner to start, once to tell it to stop.
# - If there is an error, a runner is free to exit as long as they have
# written to `start_queue`.
mgr = mp.Manager()
start_queue = [mgr.Queue() for _ in range(args.num_clients)]
# N.B. `value = 0` since we use this for synchronization, not mutual exclusion.
# pylint: disable-next=no-member
control_semaphore = [mgr.Semaphore(value=0) for _ in range(args.num_clients)]

processes = []
for idx in range(args.num_clients):
p = mp.Process(
target=runner,
args=(
idx,
start_queue[idx],
control_semaphore[idx],
args,
query_seq,
),
)
p.start()
processes.append(p)

print("Seq: Waiting for startup...", flush=True)
one_startup_failed = False
for i in range(args.num_clients):
msg = start_queue[i].get()
if msg == STARTUP_FAILED:
one_startup_failed = True

if one_startup_failed:
print("At least one ad-hoc runner failed to start up. Aborting the experiment.")
for i in range(args.num_clients):
# Ideally we should be able to release twice atomically.
control_semaphore[i].release()
control_semaphore[i].release()
for p in processes:
p.join()
print("Abort complete.")
return

print(
"Telling all {} ad-hoc clients to start.".format(args.num_clients), flush=True
)
for i in range(args.num_clients):
control_semaphore[i].release()

# Wait until requested to stop.
print(
"Seq queries running until completion. Hit Ctrl-C to stop early.",
flush=True,
file=sys.stderr,
)

def signal_handler(_signal, _frame):
for i in range(args.num_clients):
control_semaphore[i].release()
control_semaphore[i].release()

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

print("Waiting for the seq clients to complete...", flush=True, file=sys.stderr)
for p in processes:
p.join()

print("Done query sequence!", flush=True, file=sys.stderr)


if __name__ == "__main__":
# On Unix platforms, the default way to start a process is by forking, which
# is not ideal (we do not want to duplicate this process' file
# descriptors!).
mp.set_start_method("spawn")
main()

0 comments on commit 31a1d62

Please sign in to comment.