Skip to content

Commit c0d7f23

Browse files
committed
feat: dynamically check if the missing_* columsn exist
1 parent 72fc2bd commit c0d7f23

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

src/server/_db.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
1-
from sqlalchemy import MetaData, create_engine
1+
from typing import Dict, List
2+
from sqlalchemy import MetaData, create_engine, inspect
23
from sqlalchemy.engine import Engine
4+
from sqlalchemy.engine.reflection import Inspector
35

46
from ._config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_ENGINE_OPTIONS
57

68
engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS)
79
metadata = MetaData(bind=engine)
10+
inspector: Inspector = inspect(engine)
811

912
TABLE_OPTIONS = dict(
1013
mysql_engine="InnoDB",
1114
# mariadb_engine="InnoDB",
1215
mysql_charset="utf8mb4",
1316
# mariadb_charset="utf8",
1417
)
18+
19+
20+
def sql_table_has_columns(table: str, columns: List[str]) -> bool:
21+
"""
22+
checks whether the given table has all the given columns defined
23+
"""
24+
table_columns: List[Dict] = inspector.get_columns(table)
25+
table_column_names = set(str(d.get("name", "")).lower() for d in table_columns)
26+
return all(c.lower() in table_column_names for c in columns)

src/server/endpoints/covidcast.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
require_all,
3232
require_any,
3333
)
34+
from .._db import sql_table_has_columns
3435
from .._pandas import as_pandas
3536
from .covidcast_utils import compute_trend, compute_trends, compute_correlations, compute_trend_value, CovidcastMetaEntry, AllSignalsMap
3637
from ..utils import shift_time_value, date_to_time_value, time_value_to_iso
@@ -134,7 +135,7 @@ def handle():
134135
q = QueryBuilder("covidcast", "t")
135136

136137
fields_string = ["geo_value", "signal"]
137-
fields_int = ["time_value", "direction", "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size"]
138+
fields_int = ["time_value", "direction", "issue", "lag"]
138139
fields_float = ["value", "stderr", "sample_size"]
139140
if is_compatibility_mode():
140141
q.set_order("signal", "time_value", "geo_value", "issue")
@@ -144,6 +145,15 @@ def handle():
144145
q.set_order("source", "signal", "time_type", "time_value", "geo_type", "geo_value", "issue")
145146
q.set_fields(fields_string, fields_int, fields_float)
146147

148+
missing_fields = ["missing_value", "missing_stderr", "missing_sample_size"]
149+
fields_int.extend(missing_fields)
150+
if sql_table_has_columns("covidcast", missing_fields):
151+
# real fields
152+
q.fields.extend([f"{q.alias}.{field}" for field in missing_fields])
153+
else:
154+
# fake fields
155+
q.fields.extend([f"0 as {field}" for field in missing_fields])
156+
147157
# basic query info
148158
# data type of each field
149159
# build the source, signal, time, and location (type and id) filters

0 commit comments

Comments
 (0)