Skip to content

Commit c21b940

Browse files
authored
API server code health pass - reduce use of request global (#1069)
1 parent 54722a4 commit c21b940

39 files changed

+123
-122
lines changed

src/server/_pandas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, Any, Optional
22
import pandas as pd
33

4+
from flask import request
45
from sqlalchemy import text
56
from sqlalchemy.engine.base import Engine
67

@@ -20,7 +21,7 @@ def as_pandas(query: str, params: Dict[str, Any], db_engine: Engine = engine, pa
2021

2122

2223
def print_pandas(df: pd.DataFrame):
23-
p = create_printer()
24+
p = create_printer(request.values.get("format"))
2425

2526
def gen():
2627
for row in df.to_dict(orient="records"):

src/server/_params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def parse_source_signal_sets() -> List[SourceSignalSet]:
446446
ds = request.values.get("data_source")
447447
if ds:
448448
# old version
449-
require_any("signal", "signals", empty=True)
449+
require_any(request, "signal", "signals", empty=True)
450450
signals = extract_strings(("signals", "signal"))
451451
if len(signals) == 1 and signals[0] == "*":
452452
return [SourceSignalSet(ds, True)]
@@ -462,7 +462,7 @@ def parse_geo_sets() -> List[GeoSet]:
462462
geo_type = request.values.get("geo_type")
463463
if geo_type:
464464
# old version
465-
require_any("geo_value", "geo_values", empty=True)
465+
require_any(request, "geo_value", "geo_values", empty=True)
466466
geo_values = extract_strings(("geo_values", "geo_value"))
467467
if len(geo_values) == 1 and geo_values[0] == "*":
468468
return [GeoSet(geo_type, True)]
@@ -478,7 +478,7 @@ def parse_time_set() -> TimeSet:
478478
time_type = request.values.get("time_type")
479479
if time_type:
480480
# old version
481-
require_all("time_type", "time_values")
481+
require_all(request, "time_type", "time_values")
482482
time_values = extract_dates("time_values")
483483
if time_values == ["*"]:
484484
return TimeSet(time_type, True)

src/server/_printer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from io import StringIO
33
from typing import Any, Dict, Iterable, List, Optional, Union
44

5-
from flask import Response, jsonify, request, stream_with_context
5+
from flask import Response, jsonify, stream_with_context
66
from flask.json import dumps
77
import orjson
88

@@ -11,12 +11,10 @@
1111
from .utils.logger import get_structured_logger
1212

1313

14-
def print_non_standard(data):
14+
def print_non_standard(format: str, data):
1515
"""
1616
prints a non standard JSON message
1717
"""
18-
19-
format = request.values.get("format", "classic")
2018
if format == "json":
2119
return jsonify(data)
2220

@@ -250,8 +248,9 @@ def _end(self):
250248
return b""
251249

252250

253-
def create_printer() -> APrinter:
254-
format: str = request.values.get("format", "classic")
251+
def create_printer(format: str) -> APrinter:
252+
if format is None:
253+
return ClassicPrinter()
255254
if format == "tree":
256255
return ClassicTreePrinter("signal")
257256
if format.startswith("tree-"):

src/server/_query.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
cast
1313
)
1414

15+
from flask import request
1516
from sqlalchemy import text
1617
from sqlalchemy.engine import Row
1718

@@ -259,7 +260,7 @@ def execute_queries(
259260
execute the given queries and return the response to send them
260261
"""
261262

262-
p = create_printer()
263+
p = create_printer(request.values.get("format"))
263264

264265
fields_to_send = set(extract_strings("fields") or [])
265266
if fields_to_send:

src/server/_validate.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from typing import List, Optional, Sequence, Tuple, Union
1+
from typing import Optional
22

3-
from flask import request
3+
from flask import Request
44

55
from ._exceptions import UnAuthenticatedException, ValidationFailedException
6-
from .utils import IntRange, TimeValues
76

87

9-
def resolve_auth_token() -> Optional[str]:
8+
def resolve_auth_token(request: Request) -> Optional[str]:
109
# auth request param
1110
if "auth" in request.values:
1211
return request.values["auth"]
@@ -20,8 +19,8 @@ def resolve_auth_token() -> Optional[str]:
2019
return None
2120

2221

23-
def check_auth_token(token: str, optional=False) -> bool:
24-
value = resolve_auth_token()
22+
def check_auth_token(request: Request, token: str, optional=False) -> bool:
23+
value = resolve_auth_token(request)
2524

2625
if value is None:
2726
if optional:
@@ -35,7 +34,7 @@ def check_auth_token(token: str, optional=False) -> bool:
3534
return valid_token
3635

3736

38-
def require_all(*values: str) -> bool:
37+
def require_all(request: Request, *values: str) -> bool:
3938
"""
4039
returns true if all fields are present in the request otherwise raises an exception
4140
:returns bool
@@ -46,7 +45,7 @@ def require_all(*values: str) -> bool:
4645
return True
4746

4847

49-
def require_any(*values: str, empty=False) -> bool:
48+
def require_any(request: Request, *values: str, empty=False) -> bool:
5049
"""
5150
returns true if any fields are present in the request otherwise raises an exception
5251
:returns bool

src/server/endpoints/afhsb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, List
22

3-
from flask import Blueprint
3+
from flask import Blueprint, request
44

55
from .._config import AUTH
66
from .._params import extract_integers, extract_strings
@@ -54,8 +54,8 @@ def _split_flu_types(flu_types: List[str]):
5454

5555
@bp.route("/", methods=("GET", "POST"))
5656
def handle():
57-
check_auth_token(AUTH["afhsb"])
58-
require_all("locations", "epiweeks", "flu_types")
57+
check_auth_token(request, AUTH["afhsb"])
58+
require_all(request, "locations", "epiweeks", "flu_types")
5959

6060
locations = extract_strings("locations")
6161
epiweeks = extract_integers("epiweeks")

src/server/endpoints/cdc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from flask import Blueprint
1+
from flask import Blueprint, request
22

33
from .._config import AUTH, NATION_REGION, REGION_TO_STATE
44
from .._params import extract_strings, extract_integers
@@ -12,8 +12,8 @@
1212

1313
@bp.route("/", methods=("GET", "POST"))
1414
def handle():
15-
check_auth_token(AUTH["cdc"])
16-
require_all("locations", "epiweeks")
15+
check_auth_token(request, AUTH["cdc"])
16+
require_all(request, "locations", "epiweeks")
1717

1818
# parse the request
1919
locations = extract_strings("locations")

src/server/endpoints/covid_hosp_facility.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from flask import Blueprint
1+
from flask import Blueprint, request
22

33
from .._params import extract_integers, extract_strings
44
from .._query import execute_query, QueryBuilder
@@ -10,7 +10,7 @@
1010

1111
@bp.route("/", methods=("GET", "POST"))
1212
def handle():
13-
require_all("hospital_pks", "collection_weeks")
13+
require_all(request, "hospital_pks", "collection_weeks")
1414
hospital_pks = extract_strings("hospital_pks")
1515
collection_weeks = extract_integers("collection_weeks")
1616
publication_dates = extract_integers("publication_dates")

src/server/endpoints/covid_hosp_facility_lookup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from flask import Blueprint
1+
from flask import Blueprint, request
22

33
from .._params import extract_strings
44
from .._query import execute_query, QueryBuilder
@@ -10,7 +10,7 @@
1010

1111
@bp.route("/", methods=("GET", "POST"))
1212
def handle():
13-
require_any("state", "ccn", "city", "zip", "fips_code")
13+
require_any(request, "state", "ccn", "city", "zip", "fips_code")
1414
state = extract_strings("state")
1515
ccn = extract_strings("ccn")
1616
city = extract_strings("city")

src/server/endpoints/covid_hosp_state_timeseries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from flask import Blueprint
1+
from flask import Blueprint, request
22

33
from .._params import extract_integers, extract_strings, extract_date
44
from .._query import execute_query, QueryBuilder
@@ -11,7 +11,7 @@
1111

1212
@bp.route("/", methods=("GET", "POST"))
1313
def handle():
14-
require_all("states", "dates")
14+
require_all(request, "states", "dates")
1515
states = extract_strings("states")
1616
dates = extract_integers("dates")
1717
issues = extract_integers("issues")

0 commit comments

Comments
 (0)