Skip to content

Commit 03bc328

Browse files
committed
refactor: center DataFrame return in EpidataCall
* remove json, csv, iter formats * remove format_type option, always request classic * consolidate DataFrame code * parse types only if classic, otherwise let Pandas do it
1 parent 2a6c3f7 commit 03bc328

File tree

3 files changed

+47
-119
lines changed

3 files changed

+47
-119
lines changed

epidatpy/_model.py

+2-58
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from datetime import date
33
from enum import Enum
44
from typing import (
5-
Any,
6-
Dict,
75
Final,
86
List,
97
Mapping,
@@ -18,10 +16,8 @@
1816
from urllib.parse import urlencode
1917

2018
from epiweeks import Week
21-
from pandas import CategoricalDtype, DataFrame, Series
2219

2320
from ._parse import (
24-
fields_to_predicate,
2521
parse_api_date,
2622
parse_api_date_or_week,
2723
parse_api_week,
@@ -90,17 +86,6 @@ def __str__(self) -> str:
9086
return f"{format_date(self.start)}-{format_date(self.end)}"
9187

9288

93-
class EpiDataFormatType(str, Enum):
94-
"""
95-
possible formatting options for API calls
96-
"""
97-
98-
json = "json"
99-
classic = "classic"
100-
csv = "csv"
101-
jsonl = "jsonl"
102-
103-
10489
class InvalidArgumentException(Exception):
10590
"""
10691
exception for an invalid argument
@@ -180,41 +165,36 @@ def _verify_parameters(self) -> None:
180165

181166
def _formatted_parameters(
182167
self,
183-
format_type: Optional[EpiDataFormatType] = None,
184168
fields: Optional[Sequence[str]] = None,
185169
) -> Mapping[str, str]:
186170
"""
187171
format this call into a [URL, Params] tuple
188172
"""
189173
all_params = dict(self._params)
190-
if format_type and format_type != EpiDataFormatType.classic:
191-
all_params["format"] = format_type
192174
if fields:
193175
all_params["fields"] = fields
194176
return {k: format_list(v) for k, v in all_params.items() if v is not None}
195177

196178
def request_arguments(
197179
self,
198-
format_type: Optional[EpiDataFormatType] = None,
199180
fields: Optional[Sequence[str]] = None,
200181
) -> Tuple[str, Mapping[str, str]]:
201182
"""
202183
format this call into a [URL, Params] tuple
203184
"""
204-
formatted_params = self._formatted_parameters(format_type, fields)
185+
formatted_params = self._formatted_parameters(fields)
205186
full_url = add_endpoint_to_url(self._base_url, self._endpoint)
206187
return full_url, formatted_params
207188

208189
def request_url(
209190
self,
210-
format_type: Optional[EpiDataFormatType] = None,
211191
fields: Optional[Sequence[str]] = None,
212192
) -> str:
213193
"""
214194
format this call into a full HTTP request url with encoded parameters
215195
"""
216196
self._verify_parameters()
217-
u, p = self.request_arguments(format_type, fields)
197+
u, p = self.request_arguments(fields)
218198
query = urlencode(p)
219199
if query:
220200
return f"{u}?{query}"
@@ -253,39 +233,3 @@ def _parse_row(
253233
if not self.meta:
254234
return row
255235
return {k: self._parse_value(k, v, disable_date_parsing) for k, v in row.items()}
256-
257-
def _as_df(
258-
self,
259-
rows: Sequence[Mapping[str, Union[str, float, int, date, None]]],
260-
fields: Optional[Sequence[str]] = None,
261-
disable_date_parsing: Optional[bool] = False,
262-
) -> DataFrame:
263-
pred = fields_to_predicate(fields)
264-
columns: List[str] = [info.name for info in self.meta if pred(info.name)]
265-
df = DataFrame(rows, columns=columns or None)
266-
267-
data_types: Dict[str, Any] = {}
268-
for info in self.meta:
269-
if not pred(info.name) or df[info.name].isnull().all():
270-
continue
271-
if info.type == EpidataFieldType.bool:
272-
data_types[info.name] = bool
273-
elif info.type == EpidataFieldType.categorical:
274-
data_types[info.name] = CategoricalDtype(
275-
categories=Series(info.categories) if info.categories else None, ordered=True
276-
)
277-
elif info.type == EpidataFieldType.int:
278-
data_types[info.name] = int
279-
elif info.type in (
280-
EpidataFieldType.date,
281-
EpidataFieldType.epiweek,
282-
EpidataFieldType.date_or_epiweek,
283-
):
284-
data_types[info.name] = int if disable_date_parsing else "datetime64[ns]"
285-
elif info.type == EpidataFieldType.float:
286-
data_types[info.name] = float
287-
else:
288-
data_types[info.name] = str
289-
if data_types:
290-
df = df.astype(data_types)
291-
return df

epidatpy/request.py

+45-55
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from datetime import date
21
from typing import (
2+
Any,
3+
Dict,
34
Final,
45
List,
56
Mapping,
@@ -9,7 +10,7 @@
910
cast,
1011
)
1112

12-
from pandas import DataFrame
13+
from pandas import CategoricalDtype, DataFrame, Series
1314
from requests import Response, Session
1415
from requests.auth import HTTPBasicAuth
1516
from tenacity import retry, stop_after_attempt
@@ -21,13 +22,14 @@
2122
from ._model import (
2223
AEpiDataCall,
2324
EpidataFieldInfo,
24-
EpiDataFormatType,
25+
EpidataFieldType,
2526
EpiDataResponse,
2627
EpiRange,
2728
EpiRangeParam,
2829
OnlySupportsClassicFormatException,
2930
add_endpoint_to_url,
3031
)
32+
from ._parse import fields_to_predicate
3133

3234
# Make the linter happy about the unused variables
3335
__all__ = ["Epidata", "EpiDataCall", "EpiDataContext", "EpiRange", "CovidcastEpidata"]
@@ -83,23 +85,25 @@ def with_session(self, session: Session) -> "EpiDataCall":
8385

8486
def _call(
8587
self,
86-
format_type: Optional[EpiDataFormatType] = None,
8788
fields: Optional[Sequence[str]] = None,
8889
stream: bool = False,
8990
) -> Response:
90-
url, params = self.request_arguments(format_type, fields)
91+
url, params = self.request_arguments(fields)
9192
return _request_with_retry(url, params, self._session, stream)
9293

9394
def classic(
9495
self,
9596
fields: Optional[Sequence[str]] = None,
9697
disable_date_parsing: Optional[bool] = False,
98+
disable_type_parsing: Optional[bool] = False,
9799
) -> EpiDataResponse:
98100
"""Request and parse epidata in CLASSIC message format."""
99101
self._verify_parameters()
100102
try:
101-
response = self._call(None, fields)
103+
response = self._call(fields)
102104
r = cast(EpiDataResponse, response.json())
105+
if disable_type_parsing:
106+
return r
103107
epidata = r.get("epidata")
104108
if epidata and isinstance(epidata, list) and len(epidata) > 0 and isinstance(epidata[0], dict):
105109
r["epidata"] = [self._parse_row(row, disable_date_parsing=disable_date_parsing) for row in epidata]
@@ -111,25 +115,11 @@ def __call__(
111115
self,
112116
fields: Optional[Sequence[str]] = None,
113117
disable_date_parsing: Optional[bool] = False,
114-
) -> EpiDataResponse:
115-
"""Request and parse epidata in CLASSIC message format."""
116-
return self.classic(fields, disable_date_parsing=disable_date_parsing)
117-
118-
def json(
119-
self,
120-
fields: Optional[Sequence[str]] = None,
121-
disable_date_parsing: Optional[bool] = False,
122-
) -> List[Mapping[str, Union[str, int, float, date, None]]]:
123-
"""Request and parse epidata in JSON format"""
118+
) -> Union[EpiDataResponse, DataFrame]:
119+
"""Request and parse epidata in df message format."""
124120
if self.only_supports_classic:
125-
raise OnlySupportsClassicFormatException()
126-
self._verify_parameters()
127-
response = self._call(EpiDataFormatType.json, fields)
128-
response.raise_for_status()
129-
return [
130-
self._parse_row(row, disable_date_parsing=disable_date_parsing)
131-
for row in cast(List[Mapping[str, Union[str, int, float, None]]], response.json())
132-
]
121+
return self.classic(fields, disable_date_parsing=disable_date_parsing, disable_type_parsing=False)
122+
return self.df(fields, disable_date_parsing=disable_date_parsing)
133123

134124
def df(
135125
self,
@@ -140,37 +130,37 @@ def df(
140130
if self.only_supports_classic:
141131
raise OnlySupportsClassicFormatException()
142132
self._verify_parameters()
143-
r = self.json(fields, disable_date_parsing=disable_date_parsing)
144-
return self._as_df(r, fields, disable_date_parsing=disable_date_parsing)
145-
146-
def csv(self, fields: Optional[Iterable[str]] = None) -> str:
147-
"""Request and parse epidata in CSV format"""
148-
if self.only_supports_classic:
149-
raise OnlySupportsClassicFormatException()
150-
self._verify_parameters()
151-
response = self._call(EpiDataFormatType.csv, fields)
152-
response.raise_for_status()
153-
return response.text
154-
155-
def iter(
156-
self,
157-
fields: Optional[Iterable[str]] = None,
158-
disable_date_parsing: Optional[bool] = False,
159-
) -> Generator[Mapping[str, Union[str, int, float, date, None]], None, Response]:
160-
"""Request and streams epidata rows"""
161-
if self.only_supports_classic:
162-
raise OnlySupportsClassicFormatException()
163-
self._verify_parameters()
164-
response = self._call(EpiDataFormatType.jsonl, fields, stream=True)
165-
response.raise_for_status()
166-
for line in response.iter_lines():
167-
yield self._parse_row(loads(line), disable_date_parsing=disable_date_parsing)
168-
return response
169-
170-
def __iter__(
171-
self,
172-
) -> Generator[Mapping[str, Union[str, int, float, date, None]], None, Response]:
173-
return self.iter()
133+
json = self.classic(fields, disable_type_parsing=True)
134+
rows = json.get("epidata", [])
135+
pred = fields_to_predicate(fields)
136+
columns: List[str] = [info.name for info in self.meta if pred(info.name)]
137+
df = DataFrame(rows, columns=columns or None)
138+
139+
data_types: Dict[str, Any] = {}
140+
for info in self.meta:
141+
if not pred(info.name) or df[info.name].isnull().all():
142+
continue
143+
if info.type == EpidataFieldType.bool:
144+
data_types[info.name] = bool
145+
elif info.type == EpidataFieldType.categorical:
146+
data_types[info.name] = CategoricalDtype(
147+
categories=Series(info.categories) if info.categories else None, ordered=True
148+
)
149+
elif info.type == EpidataFieldType.int:
150+
data_types[info.name] = "Int64"
151+
elif info.type in (
152+
EpidataFieldType.date,
153+
EpidataFieldType.epiweek,
154+
EpidataFieldType.date_or_epiweek,
155+
):
156+
data_types[info.name] = "Int64" if disable_date_parsing else "datetime64[ns]"
157+
elif info.type == EpidataFieldType.float:
158+
data_types[info.name] = "Float64"
159+
else:
160+
data_types[info.name] = "string"
161+
if data_types:
162+
df = df.astype(data_types)
163+
return df
174164

175165

176166
class EpiDataContext(AEpiDataEndpoints[EpiDataCall]):

smoke_test.py

-6
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
classic = apicall.classic()
1414
print(classic)
1515

16-
data = apicall.json()
17-
print(data[0])
18-
1916
df = apicall.df()
2017
print(df.columns)
2118
print(df.dtypes)
@@ -53,9 +50,6 @@
5350
classic = apicall.classic()
5451
print(classic)
5552

56-
data = apicall.json()
57-
print(data[0])
58-
5953
df = apicall.df()
6054
print(df.columns)
6155
print(df.dtypes)

0 commit comments

Comments
 (0)