Skip to content

Commit dcbcc07

Browse files
committed
Server-side compute: utilities and general mechanisms
1 parent 464eca8 commit dcbcc07

File tree

10 files changed

+779
-43
lines changed

10 files changed

+779
-43
lines changed

integrations/server/test_covidcast.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,8 @@ def test_date_formats(self):
902902
response = response.json()
903903

904904
# assert that the right data came back
905-
self.assertEqual(len(response['epidata']), 4)
905+
# any county with data in the time range will be padded to make a consistent range
906+
self.assertEqual(len(response['epidata']), 3 * 2)
906907

907908
# make the request
908909
response = requests.get(BASE_URL, params={
@@ -918,7 +919,8 @@ def test_date_formats(self):
918919
response = response.json()
919920

920921
# assert that the right data came back
921-
self.assertEqual(len(response['epidata']), 4)
922+
# any county with data in the time range will be padded to make a consistent range
923+
self.assertEqual(len(response['epidata']), 3 * 2)
922924

923925
# make the request
924926
response = requests.get(BASE_URL, params={
@@ -934,7 +936,8 @@ def test_date_formats(self):
934936
response = response.json()
935937

936938
# assert that the right data came back
937-
self.assertEqual(len(response['epidata']), 6)
939+
# any county with data in the time range will be padded to make a consistent range
940+
self.assertEqual(len(response['epidata']), 4 * 3)
938941

939942
# make the request
940943
response = requests.get(BASE_URL, params={
@@ -950,4 +953,5 @@ def test_date_formats(self):
950953
response = response.json()
951954

952955
# assert that the right data came back
953-
self.assertEqual(len(response['epidata']), 6)
956+
# any county with data in the time range will be padded to make a consistent range
957+
self.assertEqual(len(response['epidata']), 4 * 3)

integrations/server/test_covidcast_endpoints.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Integration tests for the custom `covidcast/*` endpoints."""
22

33
# standard library
4-
from typing import Iterable, Dict, Any
4+
from copy import copy
5+
from itertools import accumulate, chain
6+
from typing import Iterable, Dict, Any, List, Sequence
57
import unittest
68
from io import StringIO
79

@@ -10,15 +12,19 @@
1012

1113
# third party
1214
import mysql.connector
15+
from more_itertools import interleave_longest, windowed
1316
import requests
1417
import pandas as pd
18+
import numpy as np
1519
from delphi_utils import Nans
1620

1721
from delphi.epidata.acquisition.covidcast.covidcast_meta_cache_updater import main as update_cache
22+
from delphi.epidata.server.endpoints.covidcast_utils.model import DataSignal, DataSource
1823

1924

2025
# use the local instance of the Epidata API
2126
BASE_URL = "http://delphi_web_epidata/epidata/covidcast"
27+
BASE_URL_OLD = "http://delphi_web_epidata/epidata/api.php"
2228

2329

2430
@dataclass
@@ -128,27 +134,37 @@ def _insert_rows(self, rows: Iterable[CovidcastRow]):
128134
self.cur.execute(
129135
f"""
130136
INSERT INTO
131-
`covidcast` (`id`, `source`, `signal`, `time_type`, `geo_type`,
132-
`time_value`, `geo_value`, `value_updated_timestamp`,
133-
`value`, `stderr`, `sample_size`, `direction_updated_timestamp`,
137+
`covidcast` (`id`, `source`, `signal`, `time_type`, `geo_type`,
138+
`time_value`, `geo_value`, `value_updated_timestamp`,
139+
`value`, `stderr`, `sample_size`, `direction_updated_timestamp`,
134140
`direction`, `issue`, `lag`, `is_latest_issue`, `is_wip`,`missing_value`,
135-
`missing_stderr`,`missing_sample_size`)
141+
`missing_stderr`,`missing_sample_size`)
136142
VALUES
137143
{sql}
138144
"""
139145
)
140146
self.cnx.commit()
141147
return rows
142148

143-
def _fetch(self, endpoint="/", **params):
149+
def _fetch(self, endpoint="/", is_compatibility=False, **params):
144150
# make the request
145-
response = requests.get(
146-
f"{BASE_URL}{endpoint}",
147-
params=params,
148-
)
151+
if is_compatibility:
152+
url = BASE_URL_OLD
153+
params.setdefault("endpoint", "covidcast")
154+
if params.get("source"):
155+
params.setdefault("data_source", params.get("source"))
156+
else:
157+
url = f"{BASE_URL}{endpoint}"
158+
response = requests.get(url, params=params)
149159
response.raise_for_status()
150160
return response.json()
151161

162+
def _diff_rows(self, rows: Sequence[float]):
163+
return [float(x - y) if x is not None and y is not None else None for x, y in zip(rows[1:], rows[:-1])]
164+
165+
def _smooth_rows(self, rows: Sequence[float]):
166+
return [sum(e)/len(e) if None not in e else None for e in windowed(rows, 7)]
167+
152168
def test_basic(self):
153169
"""Request a signal the / endpoint."""
154170

@@ -164,6 +180,17 @@ def test_basic(self):
164180
out = self._fetch("/", signal=first.signal_pair, geo=first.geo_pair, time="day:*")
165181
self.assertEqual(len(out["epidata"]), len(rows))
166182

183+
with self.subTest("unknown signal"):
184+
rows = [CovidcastRow(source="jhu-csse", signal="confirmed_unknown", time_value=20200401 + i, value=i) for i in range(10)]
185+
first = rows[0]
186+
self._insert_rows(rows)
187+
188+
out = self._fetch("/", signal="jhu-csse:confirmed_unknown", geo=first.geo_pair, time="day:*")
189+
self.assertEqual(len(out["epidata"]), len(rows))
190+
out_values = [row["value"] for row in out["epidata"]]
191+
expected_values = [float(row.value) for row in rows]
192+
self.assertEqual(out_values, expected_values)
193+
167194
def test_trend(self):
168195
"""Request a signal the /trend endpoint."""
169196

src/server/_params.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,16 @@ def _combine_source_signal_pairs(source_signal_pairs: List[SourceSignalPair]) ->
100100
[SourceSignalPair("src", ["sig1", "sig2"]), SourceSignalPair("src", ["sig2", "sig3"])] will be merged
101101
into [SourceSignalPair("src", ["sig1", "sig2", "sig3])].
102102
"""
103-
return [SourceSignalPair("src", ["sig1", "sig2", "sig3"])]
103+
source_signal_pairs_grouped = groupby(sorted(source_signal_pairs, key=lambda x: x.source), lambda x: x.source)
104+
source_signal_pairs_combined = []
105+
for source, group in source_signal_pairs_grouped:
106+
group = list(group)
107+
if any(x.signal == True for x in group):
108+
source_signal_pairs_combined.append(SourceSignalPair(source, True))
109+
continue
110+
combined_signals = sorted(list(set(chain(*[x.signal for x in group]))))
111+
source_signal_pairs_combined.append(SourceSignalPair(source, combined_signals))
112+
return source_signal_pairs_combined
104113

105114
def parse_source_signal_arg(key: str = "signal") -> List[SourceSignalPair]:
106115
return _combine_source_signal_pairs([SourceSignalPair(source, signals) for [source, signals] in _parse_common_multi_arg(key)])

src/server/endpoints/covidcast_utils/model.py

Lines changed: 124 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717

1818
IDENTITY: Callable = lambda rows, **kwargs: rows
19+
DIFF: Callable = lambda rows, **kwargs: generate_row_diffs(rows, **kwargs)
20+
SMOOTH: Callable = lambda rows, **kwargs: generate_smooth_rows(rows, **kwargs)
21+
DIFF_SMOOTH: Callable = lambda rows, **kwargs: generate_smooth_rows(generate_row_diffs(rows, **kwargs), **kwargs)
1922

2023

2124
class HighValuesAre(str, Enum):
@@ -245,7 +248,6 @@ def _load_data_signals(sources: List[DataSource]):
245248
data_signals_by_key[(source.db_source, d.signal)] = d
246249

247250

248-
249251
def get_related_signals(signal: DataSignal) -> List[DataSignal]:
250252
return [s for s in data_signals if s != signal and s.signal_basename == signal.signal_basename]
251253

@@ -316,8 +318,15 @@ def _resolve_all_signals(source_signals: Union[SourceSignalPair, List[SourceSign
316318
317319
Example: SourceSignalPair("jhu-csse", signal=True) would return SourceSignalPair("jhu-csse", [<list of all JHU signals>]).
318320
"""
319-
return [SourceSignalPair("src", ["sig1", "sig2"])]
320-
321+
if isinstance(source_signals, SourceSignalPair):
322+
if source_signals.signal == True:
323+
source = data_sources_by_id.get(source_signals.source)
324+
if source:
325+
return SourceSignalPair(source.source, [s.signal for s in source.signals])
326+
return source_signals
327+
if isinstance(source_signals, list):
328+
return [_resolve_all_signals(pair, data_sources_by_id) for pair in source_signals]
329+
raise TypeError("source_signals is not Union[SourceSignalPair, List[SourceSignalPair]].")
321330

322331

323332
def _reindex_iterable(iterable: Iterable[Dict], time_pairs: List[TimePair], fill_value: Optional[int] = None) -> Iterable:
@@ -331,17 +340,47 @@ def _reindex_iterable(iterable: Iterable[Dict], time_pairs: List[TimePair], fill
331340
if time_pairs is None:
332341
return iterable
333342

343+
_iterable = peekable(iterable)
344+
first_item = _iterable.peek()
334345
day_range_index = get_day_range(time_pairs)
335346
for day in day_range_index.time_values:
336-
if day in next_value(iterable):
337-
return next_value(iterable)
347+
index_item = first_item.copy()
348+
index_item.update({
349+
"time_value": day,
350+
"value": fill_value,
351+
"stderr": None,
352+
"sample_size": None,
353+
"missing_value": Nans.NOT_MISSING if fill_value is not None else Nans.NOT_APPLICABLE,
354+
"missing_stderr": Nans.NOT_APPLICABLE,
355+
"missing_sample_size": Nans.NOT_APPLICABLE
356+
})
357+
new_item = _iterable.peek(default=index_item)
358+
if day == new_item.get("time_value"):
359+
yield next(_iterable, index_item)
338360
else:
339-
yield updated_default_value
361+
yield index_item
340362

341363

342364
def _get_base_signal_transform(signal: Union[DataSignal, Tuple[str, str]], data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> Callable:
343365
"""Given a DataSignal, return the transformation that needs to be applied to its base signal to derive the signal."""
344-
return IDENTITY
366+
if isinstance(signal, DataSignal):
367+
parent_signal = data_signals_by_key.get((signal.source, signal.signal_basename))
368+
if signal.format not in [SignalFormat.raw, SignalFormat.raw_count, SignalFormat.count] or not signal.compute_from_base or not parent_signal:
369+
return IDENTITY
370+
if signal.is_cumulative and signal.is_smoothed:
371+
return SMOOTH
372+
if not signal.is_cumulative and not signal.is_smoothed:
373+
return DIFF if parent_signal.is_cumulative else IDENTITY
374+
if not signal.is_cumulative and signal.is_smoothed:
375+
return DIFF_SMOOTH if parent_signal.is_cumulative else SMOOTH
376+
return IDENTITY
377+
if isinstance(signal, tuple):
378+
signal = data_signals_by_key.get(signal)
379+
if signal:
380+
return _get_base_signal_transform(signal, data_signals_by_key)
381+
return IDENTITY
382+
383+
raise TypeError("signal must be either Tuple[str, str] or DataSignal.")
345384

346385

347386
def get_transform_types(source_signal_pairs: List[SourceSignalPair], data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> Set[Callable]:
@@ -352,7 +391,17 @@ def get_transform_types(source_signal_pairs: List[SourceSignalPair], data_source
352391
353392
Used to pad the user DB query with extra days.
354393
"""
355-
return set([IDENTITY])
394+
source_signal_pairs = _resolve_all_signals(source_signal_pairs, data_sources_by_id)
395+
396+
transform_types = set()
397+
for source_signal_pair in source_signal_pairs:
398+
source_name = source_signal_pair.source
399+
signal_names = source_signal_pair.signal
400+
if isinstance(signal_names, bool):
401+
continue
402+
transform_types |= {_get_base_signal_transform((source_name, signal_name), data_signals_by_key=data_signals_by_key) for signal_name in signal_names}
403+
404+
return transform_types
356405

357406

358407
def get_pad_length(source_signal_pairs: List[SourceSignalPair], smoother_window_length: int, data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key):
@@ -364,10 +413,14 @@ def get_pad_length(source_signal_pairs: List[SourceSignalPair], smoother_window_
364413
Used to pad the user DB query with extra days.
365414
"""
366415
transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=data_sources_by_id, data_signals_by_key=data_signals_by_key)
416+
pad_length = [0]
417+
if DIFF_SMOOTH in transform_types:
418+
pad_length.append(smoother_window_length)
367419
if SMOOTH in transform_types:
368-
return 7
369-
else:
370-
return 0
420+
pad_length.append(smoother_window_length - 1)
421+
if DIFF in transform_types:
422+
pad_length.append(1)
423+
return max(pad_length)
371424

372425

373426
def pad_time_pairs(time_pairs: List[TimePair], pad_length: int) -> List[TimePair]:
@@ -378,7 +431,15 @@ def pad_time_pairs(time_pairs: List[TimePair], pad_length: int) -> List[TimePair
378431
379432
Used to pad the user DB query with extra days.
380433
"""
381-
return [TimePair("day", [(20210401, 20210407)])]
434+
if pad_length < 0:
435+
raise ValueError("pad_length should non-negative.")
436+
if pad_length == 0:
437+
return time_pairs.copy()
438+
min_time = min(time_value if isinstance(time_value, int) else time_value[0] for time_pair in time_pairs if not isinstance(time_pair.time_values, bool) for time_value in time_pair.time_values)
439+
padded_time = TimePair("day", [(shift_time_value(min_time, -1 * pad_length), min_time)])
440+
new_time_pairs = time_pairs.copy()
441+
new_time_pairs.append(padded_time)
442+
return new_time_pairs
382443

383444

384445
def pad_time_window(time_window: Tuple[int, int], pad_length: int) -> Tuple[int, int]:
@@ -389,7 +450,12 @@ def pad_time_window(time_window: Tuple[int, int], pad_length: int) -> Tuple[int,
389450
390451
Used to pad the user DB query with extra days.
391452
"""
392-
return (20210401, 20210407)
453+
if pad_length < 0:
454+
raise ValueError("pad_length should non-negative.")
455+
if pad_length == 0:
456+
return time_window
457+
min_time, max_time = time_window
458+
return (shift_time_value(min_time, -1 * pad_length), max_time)
393459

394460

395461
def get_day_range(time_pairs: Union[TimePair, List[TimePair]]) -> TimePair:
@@ -401,11 +467,23 @@ def get_day_range(time_pairs: Union[TimePair, List[TimePair]]) -> TimePair:
401467
402468
Used to produce a contiguous time index for time series operations.
403469
"""
404-
return TimePair("day", [20210407, 20210408, 20210409, 20210410])
470+
if isinstance(time_pairs, TimePair):
471+
time_pair = time_pairs
472+
time_values = sorted(list(set().union(*(set(time_value_range(time_value)) if isinstance(time_value, tuple) else {time_value} for time_value in time_pair.time_values))))
473+
if True in time_values:
474+
raise ValueError("TimePair.time_value should not be a bool when calling get_day_range.")
475+
return TimePair(time_pair.time_type, time_values)
476+
elif isinstance(time_pairs, list):
477+
if not all(time_pair.time_type == "day" for time_pair in time_pairs):
478+
raise ValueError("get_day_range only supports day time_type pairs.")
479+
time_values = sorted(list(set().union(*(get_day_range(time_pair).time_values for time_pair in time_pairs))))
480+
return TimePair("day", time_values)
481+
else:
482+
raise ValueError("get_day_range received an unsupported type as input.")
405483

406484

407485
def _generate_transformed_rows(
408-
parsed_rows: Iterable[Dict], time_pairs: Optional[List[TimePair]] = None, transform_dict: Optional[Dict[str, List[Tuple[str, str]]]]=None, transform_args: Optional[Dict] = None, group_keyfunc: Optional[Callable] = None, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key,
486+
parsed_rows: Iterable[Dict], time_pairs: Optional[List[TimePair]] = None, transform_dict: Optional[Dict[Tuple[str, str], List[Tuple[str, str]]]] = None, transform_args: Optional[Dict] = None, group_keyfunc: Optional[Callable] = None, data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key,
409487
) -> Iterable[Dict]:
410488
"""Applies time-series transformations to streamed rows from a database.
411489
@@ -441,13 +519,19 @@ def _generate_transformed_rows(
441519
for key, group in groupby(parsed_rows, group_keyfunc):
442520
_, _, source_name, signal_name = key
443521
# Extract the list of derived signals.
522+
derived_signals: List[Tuple[str, str]] = transform_dict.get((source_name, signal_name), [(source_name, signal_name)])
444523
# Create a list of source-signal pairs along with the transformation required for the signal.
524+
source_signal_pairs_and_group_transforms: List[Tuple[Tuple[str, str], Callable]] = [((derived_source, derived_signal), _get_base_signal_transform((derived_source, derived_signal), data_signals_by_key)) for (derived_source, derived_signal) in derived_signals]
445525
# Put the current time series on a contiguous time index.
526+
group_continguous_time = _reindex_iterable(group, time_pairs, fill_value=transform_args.get("pad_fill_value")) if time_pairs else group
446527
# Create copies of the iterable, with smart memory usage.
528+
group_iter_copies: Iterable[Iterable[Dict]] = tee(group_continguous_time, len(source_signal_pairs_and_group_transforms))
447529
# Create a list of transformed group iterables, remembering their derived name as needed.
530+
transformed_group_rows: Iterable[Iterable[Dict]] = (zip(transform(rows, **transform_args), repeat(key)) for (key, transform), rows in zip(source_signal_pairs_and_group_transforms, group_iter_copies))
448531
# Traverse through the transformed iterables in an interleaved fashion, which makes sure that only a small window
449532
# of the original iterable (group) is stored in memory.
450-
for row in transform_group(group):
533+
for row, (_, derived_signal) in interleave_longest(*transformed_group_rows):
534+
row["signal"] = derived_signal
451535
yield row
452536
except Exception as e:
453537
print(f"Tranformation encountered error of type {type(e)}, with message {e}. Yielding None and stopping.")
@@ -462,6 +546,28 @@ def get_basename_signals(source_signal_pairs: List[SourceSignalPair], data_sourc
462546
SourceSignalPair("src", signal=["sig_base", "sig_smoothed"]) would return SourceSignalPair("src", signal=["sig_base"]) and a transformation function
463547
that will take the returned database query for "sig_base" and return both the base time series and the smoothed time series.
464548
"""
465-
transform_dict = {("src", "sig_base"): [("src", "sig_base"), ("src", "sig_smoothed")]}
549+
source_signal_pairs = _resolve_all_signals(source_signal_pairs, data_sources_by_id)
550+
base_signal_pairs: List[SourceSignalPair] = []
551+
transform_dict: Dict[Tuple[str, str], List[Tuple[str, str]]] = dict()
552+
553+
for pair in source_signal_pairs:
554+
if isinstance(pair.signal, bool):
555+
base_signal_pairs.append(pair)
556+
continue
557+
558+
source_name = pair.source
559+
signal_names = pair.signal
560+
signals = []
561+
for signal_name in signal_names:
562+
signal = data_signals_by_key.get((source_name, signal_name))
563+
if not signal or not signal.compute_from_base:
564+
signals.append(signal_name)
565+
transform_dict.setdefault((source_name, signal_name), []).append((source_name, signal_name))
566+
else:
567+
signals.append(signal.signal_basename)
568+
transform_dict.setdefault((source_name, signal.signal_basename), []).append((source_name, signal_name))
569+
base_signal_pairs.append(SourceSignalPair(pair.source, signals))
570+
466571
row_transform_generator = partial(_generate_transformed_rows, transform_dict=transform_dict, data_signals_by_key=data_signals_by_key)
467-
return SourceSignalPair("src", signal=["sig_base"]), row_transform_generator
572+
573+
return base_signal_pairs, row_transform_generator

0 commit comments

Comments
 (0)