Skip to content

Commit a0d6dfe

Browse files
authored
Merge pull request #571 from cmu-delphi/sgratz/missing_columns_adapter
dynamically check if the missing_* columns exist
2 parents 72fc2bd + d14e42a commit a0d6dfe

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

src/server/_db.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
from sqlalchemy import MetaData, create_engine
1+
from typing import Dict, List
2+
from sqlalchemy import MetaData, create_engine, inspect
23
from sqlalchemy.engine import Engine
4+
from sqlalchemy.engine.reflection import Inspector
35

46
from ._config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_ENGINE_OPTIONS
57

@@ -12,3 +14,13 @@
1214
mysql_charset="utf8mb4",
1315
# mariadb_charset="utf8",
1416
)
17+
18+
19+
def sql_table_has_columns(table: str, columns: List[str]) -> bool:
20+
"""
21+
checks whether the given table has all the given columns defined
22+
"""
23+
inspector: Inspector = inspect(engine)
24+
table_columns: List[Dict] = inspector.get_columns(table)
25+
table_column_names = set(str(d.get("name", "")).lower() for d in table_columns)
26+
return all(c.lower() in table_column_names for c in columns)

src/server/endpoints/covidcast.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
require_all,
3232
require_any,
3333
)
34+
from .._db import sql_table_has_columns
3435
from .._pandas import as_pandas
3536
from .covidcast_utils import compute_trend, compute_trends, compute_correlations, compute_trend_value, CovidcastMetaEntry, AllSignalsMap
3637
from ..utils import shift_time_value, date_to_time_value, time_value_to_iso
@@ -134,7 +135,12 @@ def handle():
134135
q = QueryBuilder("covidcast", "t")
135136

136137
fields_string = ["geo_value", "signal"]
137-
fields_int = ["time_value", "direction", "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size"]
138+
fields_int = ["time_value", "direction", "issue", "lag"]
139+
140+
missing_fields = ["missing_value", "missing_stderr", "missing_sample_size"]
141+
if sql_table_has_columns("covidcast", missing_fields):
142+
fields_int.extend(missing_fields)
143+
138144
fields_float = ["value", "stderr", "sample_size"]
139145
if is_compatibility_mode():
140146
q.set_order("signal", "time_value", "geo_value", "issue")
@@ -144,6 +150,7 @@ def handle():
144150
q.set_order("source", "signal", "time_type", "time_value", "geo_type", "geo_value", "issue")
145151
q.set_fields(fields_string, fields_int, fields_float)
146152

153+
147154
# basic query info
148155
# data type of each field
149156
# build the source, signal, time, and location (type and id) filters

0 commit comments

Comments
 (0)