Skip to content

Commit 464eca8

Browse files
committed
Server-side compute: stubs and flags
1 parent 5520961 commit 464eca8

File tree

8 files changed

+444
-24
lines changed

8 files changed

+444
-24
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
Flask==1.1.2
22
SQLAlchemy==1.3.22
3+
more_itertools==8.4.0
34
mysqlclient==2.0.2
45
python-dotenv==0.15.0
56
orjson==3.4.7
@@ -8,3 +9,4 @@ scipy==1.6.2
89
tenacity==7.0.0
910
newrelic
1011
epiweeks==2.1.2
12+
delphi_utils

src/server/_params.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from dataclasses import dataclass
2+
from itertools import groupby, chain
13
from math import inf
24
import re
3-
from dataclasses import dataclass
45
from typing import List, Optional, Sequence, Tuple, Union
56

67
from flask import request
78

8-
99
from ._exceptions import ValidationFailedException
1010
from .utils import days_in_range, weeks_in_range, guess_time_value_is_day
1111

@@ -93,8 +93,17 @@ def count(self) -> float:
9393
return len(self.signal)
9494

9595

96+
def _combine_source_signal_pairs(source_signal_pairs: List[SourceSignalPair]) -> List[SourceSignalPair]:
97+
"""Combine SourceSignalPairs with the same source into a single SourceSignalPair object.
98+
99+
Example:
100+
[SourceSignalPair("src", ["sig1", "sig2"]), SourceSignalPair("src", ["sig2", "sig3"])] will be merged
101+
into [SourceSignalPair("src", ["sig1", "sig2", "sig3])].
102+
"""
103+
return [SourceSignalPair("src", ["sig1", "sig2", "sig3"])]
104+
96105
def parse_source_signal_arg(key: str = "signal") -> List[SourceSignalPair]:
97-
return [SourceSignalPair(source, signals) for [source, signals] in _parse_common_multi_arg(key)]
106+
return _combine_source_signal_pairs([SourceSignalPair(source, signals) for [source, signals] in _parse_common_multi_arg(key)])
98107

99108

100109
def parse_single_source_signal_arg(key: str) -> SourceSignalPair:

src/server/endpoints/covidcast.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import List, Optional, Union, Tuple, Dict, Any
22
from itertools import groupby
33
from datetime import date, timedelta
4+
from bisect import bisect_right
45
from epiweeks import Week
56
from flask import Blueprint, request
67
from flask.json import loads, jsonify
7-
from bisect import bisect_right
8+
from more_itertools import peekable
89
from sqlalchemy import text
910
from pandas import read_csv, to_datetime
1011

@@ -36,11 +37,13 @@
3637
from .._pandas import as_pandas, print_pandas
3738
from .covidcast_utils import compute_trend, compute_trends, compute_correlations, compute_trend_value, CovidcastMetaEntry
3839
from ..utils import shift_time_value, date_to_time_value, time_value_to_iso, time_value_to_date, shift_week_value, week_value_to_week, guess_time_value_is_day, week_to_time_value
39-
from .covidcast_utils.model import TimeType, count_signal_time_types, data_sources, create_source_signal_alias_mapper
40+
from .covidcast_utils.model import TimeType, count_signal_time_types, data_sources, create_source_signal_alias_mapper, get_basename_signals, get_pad_length, pad_time_pairs, pad_time_window
41+
from .covidcast_utils.smooth_diff import PadFillValue, SmootherKernelValue
4042

4143
# first argument is the endpoint name
4244
bp = Blueprint("covidcast", __name__)
4345
alias = None
46+
JIT_COMPUTE = True
4447

4548

4649
def parse_source_signal_pairs() -> List[SourceSignalPair]:
@@ -123,36 +126,53 @@ def guess_index_to_use(time: List[TimePair], geo: List[GeoPair], issues: Optiona
123126
return None
124127

125128

129+
# TODO: Write an actual smoother arg parser.
130+
def parse_transform_args():
131+
return {"smoother_kernel": SmootherKernelValue.average, "smoother_window_length": 7, "pad_fill_value": None, "nans_fill_value": None}
132+
133+
134+
def parse_jit_bypass():
135+
jit_bypass = request.values.get("jit_bypass")
136+
if jit_bypass is not None:
137+
return bool(jit_bypass)
138+
else:
139+
return False
140+
141+
126142
@bp.route("/", methods=("GET", "POST"))
127143
def handle():
128144
source_signal_pairs = parse_source_signal_pairs()
129145
source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs)
130146
time_pairs = parse_time_pairs()
131147
geo_pairs = parse_geo_pairs()
148+
transform_args = parse_transform_args()
149+
jit_bypass = parse_jit_bypass()
132150

133151
as_of = extract_date("as_of")
134152
issues = extract_dates("issues")
135153
lag = extract_integer("lag")
154+
is_time_type_week = any(time_pair.time_type == "week" for time_pair in time_pairs)
155+
is_time_value_true = any(isinstance(time_pair.time_values, bool) for time_pair in time_pairs)
156+
use_server_side_compute = not any((issues, lag, is_time_type_week, is_time_value_true)) and JIT_COMPUTE and not jit_bypass
157+
if use_server_side_compute:
158+
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
159+
source_signal_pairs, row_transform_generator = get_basename_signals(source_signal_pairs)
160+
time_pairs = pad_time_pairs(time_pairs, pad_length)
136161

137162
# build query
138163
q = QueryBuilder("covidcast", "t")
139164

140-
fields_string = ["geo_value", "signal"]
165+
fields_string = ["geo_type", "geo_value", "source", "signal", "time_type"]
141166
fields_int = ["time_value", "direction", "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size"]
142167
fields_float = ["value", "stderr", "sample_size"]
143168
is_compatibility = is_compatibility_mode()
144-
if is_compatibility:
145-
q.set_order("signal", "time_value", "geo_value", "issue")
146-
else:
147-
# transfer also the new detail columns
148-
fields_string.extend(["source", "geo_type", "time_type"])
149-
q.set_order("source", "signal", "time_type", "time_value", "geo_type", "geo_value", "issue")
169+
170+
q.set_order("geo_type", "geo_value", "source", "signal", "time_type", "time_value", "issue")
150171
q.set_fields(fields_string, fields_int, fields_float)
151172

152173
# basic query info
153174
# data type of each field
154175
# build the source, signal, time, and location (type and id) filters
155-
156176
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
157177
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
158178
q.where_time_pairs("time_type", "time_value", time_pairs)
@@ -161,14 +181,40 @@ def handle():
161181

162182
_handle_lag_issues_as_of(q, issues, lag, as_of)
163183

164-
def transform_row(row, proxy):
184+
p = create_printer()
185+
186+
def alias_row(row):
187+
if is_compatibility:
188+
# old api returned fewer fields
189+
remove_fields = ["geo_type", "source", "time_type"]
190+
for field in remove_fields:
191+
if field in row:
192+
del row[field]
165193
if is_compatibility or not alias_mapper or "source" not in row:
166194
return row
167-
row["source"] = alias_mapper(row["source"], proxy["signal"])
195+
row["source"] = alias_mapper(row["source"], row["signal"])
168196
return row
169197

170-
# send query
171-
return execute_query(str(q), q.params, fields_string, fields_int, fields_float, transform=transform_row)
198+
if use_server_side_compute:
199+
def gen_transform(rows):
200+
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
201+
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=time_pairs, transform_args=transform_args)
202+
for row in transformed_rows:
203+
yield alias_row(row)
204+
else:
205+
def gen_transform(rows):
206+
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
207+
for row in parsed_rows:
208+
yield alias_row(row)
209+
210+
# execute first query
211+
try:
212+
r = run_query(p, (str(q), q.params))
213+
except Exception as e:
214+
raise DatabaseErrorException(str(e))
215+
216+
# now use a generator for sending the rows and execute all the other queries
217+
return p(filter_fields(gen_transform(r)))
172218

173219

174220
def _verify_argument_time_type_matches(is_day_argument: bool, count_daily_signal: int, count_weekly_signal: int) -> None:

src/server/endpoints/covidcast_utils/model.py

Lines changed: 171 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
from dataclasses import asdict, dataclass, field
2-
from typing import Callable, Optional, Dict, List, Set, Tuple
32
from enum import Enum
3+
from functools import partial
4+
from itertools import groupby, repeat, tee
5+
from typing import Callable, Generator, Optional, Dict, List, Set, Tuple, Iterable, Union
46
from pathlib import Path
57
import re
8+
from more_itertools import interleave_longest, peekable
69
import pandas as pd
710
import numpy as np
811

9-
from ..._params import SourceSignalPair
12+
from delphi_utils.nancodes import Nans
13+
from ..._params import SourceSignalPair, TimePair
14+
from .smooth_diff import generate_smooth_rows, generate_row_diffs
15+
from ...utils import time_value_range, shift_time_value
16+
17+
18+
IDENTITY: Callable = lambda rows, **kwargs: rows
1019

1120

1221
class HighValuesAre(str, Enum):
@@ -21,6 +30,7 @@ class SignalFormat(str, Enum):
2130
fraction = "fraction"
2231
raw_count = "raw_count"
2332
raw = "raw"
33+
count = "count"
2434

2535

2636
class SignalCategory(str, Enum):
@@ -201,7 +211,7 @@ def _load_data_sources():
201211

202212

203213
data_sources, data_sources_df = _load_data_sources()
204-
data_source_by_id = {d.source: d for d in data_sources}
214+
data_sources_by_id = {d.source: d for d in data_sources}
205215

206216

207217
def _load_data_signals(sources: List[DataSource]):
@@ -230,7 +240,7 @@ def _load_data_signals(sources: List[DataSource]):
230240
data_signals_by_key = {d.key: d for d in data_signals}
231241
# also add the resolved signal version to the signal lookup
232242
for d in data_signals:
233-
source = data_source_by_id.get(d.source)
243+
source = data_sources_by_id.get(d.source)
234244
if source and source.uses_db_alias:
235245
data_signals_by_key[(source.db_source, d.signal)] = d
236246

@@ -265,7 +275,7 @@ def create_source_signal_alias_mapper(source_signals: List[SourceSignalPair]) ->
265275
alias_to_data_sources: Dict[str, List[DataSource]] = {}
266276
transformed_pairs: List[SourceSignalPair] = []
267277
for pair in source_signals:
268-
source = data_source_by_id.get(pair.source)
278+
source = data_sources_by_id.get(pair.source)
269279
if not source or not source.uses_db_alias:
270280
transformed_pairs.append(pair)
271281
continue
@@ -299,3 +309,159 @@ def map_row(source: str, signal: str) -> str:
299309
return signal_source.source
300310

301311
return transformed_pairs, map_row
312+
313+
314+
def _resolve_all_signals(source_signals: Union[SourceSignalPair, List[SourceSignalPair]], data_sources_by_id: Dict[str, DataSource]) -> Union[SourceSignalPair, List[SourceSignalPair]]:
315+
"""Expand a request for all signals to an explicit list of signal names.
316+
317+
Example: SourceSignalPair("jhu-csse", signal=True) would return SourceSignalPair("jhu-csse", [<list of all JHU signals>]).
318+
"""
319+
return [SourceSignalPair("src", ["sig1", "sig2"])]
320+
321+
322+
323+
def _reindex_iterable(iterable: Iterable[Dict], time_pairs: List[TimePair], fill_value: Optional[int] = None) -> Iterable:
324+
"""Produces an iterable that fills in gaps in the time window of another iterable.
325+
326+
Used to produce an iterable with a contiguous time index for time series operations.
327+
328+
We iterate over contiguous range of days made from time_pairs. If `iterable`, which is assumed to be sorted by its "time_value" key,
329+
is missing a time_value in the range, a dummy row entry is returned with the correct date and the value fields set appropriately.
330+
"""
331+
if time_pairs is None:
332+
return iterable
333+
334+
day_range_index = get_day_range(time_pairs)
335+
for day in day_range_index.time_values:
336+
if day in next_value(iterable):
337+
return next_value(iterable)
338+
else:
339+
yield updated_default_value
340+
341+
342+
def _get_base_signal_transform(signal: Union[DataSignal, Tuple[str, str]], data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> Callable:
343+
"""Given a DataSignal, return the transformation that needs to be applied to its base signal to derive the signal."""
344+
return IDENTITY
345+
346+
347+
def get_transform_types(source_signal_pairs: List[SourceSignalPair], data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> Set[Callable]:
348+
"""Return a collection of the unique transforms required for transforming a given source-signal pair list.
349+
350+
Example:
351+
SourceSignalPair("src", ["sig", "sig_smoothed", "sig_diff"]) would return {IDENTITY, SMOOTH, DIFF}.
352+
353+
Used to pad the user DB query with extra days.
354+
"""
355+
return set([IDENTITY])
356+
357+
358+
def get_pad_length(source_signal_pairs: List[SourceSignalPair], smoother_window_length: int, data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key):
359+
"""Returns the size of the extra date padding needed, depending on the transformations the source-signal pair list requires.
360+
361+
Example:
362+
If smoothing is required, we fetch an extra 6 days. If both diffing and smoothing is required on the same signal, then we fetch 7 extra days.
363+
364+
Used to pad the user DB query with extra days.
365+
"""
366+
transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=data_sources_by_id, data_signals_by_key=data_signals_by_key)
367+
if SMOOTH in transform_types:
368+
return 7
369+
else:
370+
return 0
371+
372+
373+
def pad_time_pairs(time_pairs: List[TimePair], pad_length: int) -> List[TimePair]:
374+
"""Pads a list of TimePairs with another TimePair that extends the smallest time value by the pad_length, if needed.
375+
376+
Example:
377+
[TimePair("day", [20210407])] with pad_length 6 would return [TimePair("day", [20210407]), TimePair("day", [(20210401, 20210407)])].
378+
379+
Used to pad the user DB query with extra days.
380+
"""
381+
return [TimePair("day", [(20210401, 20210407)])]
382+
383+
384+
def pad_time_window(time_window: Tuple[int, int], pad_length: int) -> Tuple[int, int]:
385+
"""Extend a time window on the left by pad_length.
386+
387+
Example:
388+
(20210407, 20210413) with pad_length 6 would return (20210401, 20210413).
389+
390+
Used to pad the user DB query with extra days.
391+
"""
392+
return (20210401, 20210407)
393+
394+
395+
def get_day_range(time_pairs: Union[TimePair, List[TimePair]]) -> TimePair:
396+
"""Combine a list of TimePairs into a single contiguous, explicit TimePair object.
397+
398+
Example:
399+
[TimePair("day", [20210407, 20210408]), TimePair("day", [(20210408, 20210410)])] would return
400+
TimePair("day", [20210407, 20210408, 20210409, 20210410]).
401+
402+
Used to produce a contiguous time index for time series operations.
403+
"""
404+
return TimePair("day", [20210407, 20210408, 20210409, 20210410])
405+
406+
407+
def _generate_transformed_rows(
408+
parsed_rows: Iterable[Dict], time_pairs: Optional[List[TimePair]] = None, transform_dict: Optional[Dict[str, List[Tuple[str, str]]]]=None, transform_args: Optional[Dict] = None, group_keyfunc: Optional[Callable] = None, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key,
409+
) -> Iterable[Dict]:
410+
"""Applies time-series transformations to streamed rows from a database.
411+
412+
Parameters:
413+
parsed_rows: Iterable[Dict]
414+
Streamed rows from the database.
415+
time_pairs: Optional[List[TimePair]], default None
416+
A list of TimePairs, which can be used to create a continguous time index for time-series operations.
417+
The min and max dates in the TimePairs list is used.
418+
transform_dict: Optional[Dict[Tuple[str, str], List[Tuple[str, str]]]], default None
419+
A dictionary mapping base sources to a list of their derived signals that the user wishes to query.
420+
For example, transform_dict may be {("jhu-csse", "confirmed_cumulative_num): [("jhu-csse", "confirmed_incidence_num"), ("jhu-csse", "confirmed_7dav_incidence_num")]}.
421+
transform_args: Optional[Dict], default None
422+
A dictionary of keyword arguments for the transformer functions.
423+
group_keyfunc: Optional[Callable], default None
424+
The groupby function to use to order the streamed rows. Note that Python groupby does not do any sorting, so
425+
parsed_rows are assumed to be sorted in accord with this groupby.
426+
data_signals_by_key: Dict[Tuple[str, str], DataSignal], default data_signals_by_key
427+
The dictionary of DataSignals which is used to find the base signal transforms
428+
429+
Yields:
430+
transformed rows: Dict
431+
The transformed rows returned in an interleaved fashion. Non-transformed rows have the IDENTITY operation applied.
432+
"""
433+
if not transform_args:
434+
transform_args = dict()
435+
if not transform_dict:
436+
transform_dict = dict()
437+
if not group_keyfunc:
438+
group_keyfunc = lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])
439+
440+
try:
441+
for key, group in groupby(parsed_rows, group_keyfunc):
442+
_, _, source_name, signal_name = key
443+
# Extract the list of derived signals.
444+
# Create a list of source-signal pairs along with the transformation required for the signal.
445+
# Put the current time series on a contiguous time index.
446+
# Create copies of the iterable, with smart memory usage.
447+
# Create a list of transformed group iterables, remembering their derived name as needed.
448+
# Traverse through the transformed iterables in an interleaved fashion, which makes sure that only a small window
449+
# of the original iterable (group) is stored in memory.
450+
for row in transform_group(group):
451+
yield row
452+
except Exception as e:
453+
print(f"Tranformation encountered error of type {type(e)}, with message {e}. Yielding None and stopping.")
454+
yield None
455+
456+
457+
def get_basename_signals(source_signal_pairs: List[SourceSignalPair], data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> Tuple[List[SourceSignalPair], Generator]:
458+
"""From a list of SourceSignalPairs, return the base signals required to derive them and a transformation function to take a stream
459+
of the base signals and return the transformed signals.
460+
461+
Example:
462+
SourceSignalPair("src", signal=["sig_base", "sig_smoothed"]) would return SourceSignalPair("src", signal=["sig_base"]) and a transformation function
463+
that will take the returned database query for "sig_base" and return both the base time series and the smoothed time series.
464+
"""
465+
transform_dict = {("src", "sig_base"): [("src", "sig_base"), ("src", "sig_smoothed")]}
466+
row_transform_generator = partial(_generate_transformed_rows, transform_dict=transform_dict, data_signals_by_key=data_signals_by_key)
467+
return SourceSignalPair("src", signal=["sig_base"]), row_transform_generator

0 commit comments

Comments
 (0)