|
1 | 1 | from datetime import date
|
2 | 2 | from typing import (
|
| 3 | + Any, |
| 4 | + Dict, |
3 | 5 | Final,
|
4 | 6 | List,
|
5 | 7 | Mapping,
|
|
9 | 11 | cast,
|
10 | 12 | )
|
11 | 13 |
|
12 |
| -from pandas import DataFrame |
| 14 | +from pandas import CategoricalDtype, DataFrame, Series |
13 | 15 | from requests import Response, Session
|
14 | 16 | from requests.auth import HTTPBasicAuth
|
15 | 17 | from tenacity import retry, stop_after_attempt
|
|
21 | 23 | from ._model import (
|
22 | 24 | AEpiDataCall,
|
23 | 25 | EpidataFieldInfo,
|
| 26 | + EpidataFieldType, |
24 | 27 | EpiDataFormatType,
|
25 | 28 | EpiDataResponse,
|
26 | 29 | EpiRange,
|
27 | 30 | EpiRangeParam,
|
28 | 31 | OnlySupportsClassicFormatException,
|
29 | 32 | add_endpoint_to_url,
|
30 | 33 | )
|
| 34 | +from ._parse import fields_to_predicate |
31 | 35 |
|
32 | 36 | # Make the linter happy about the unused variables
|
33 | 37 | __all__ = ["Epidata", "EpiDataCall", "EpiDataContext", "EpiRange", "CovidcastEpidata"]
|
@@ -140,8 +144,36 @@ def df(
|
140 | 144 | if self.only_supports_classic:
|
141 | 145 | raise OnlySupportsClassicFormatException()
|
142 | 146 | self._verify_parameters()
|
143 |
| - r = self.json(fields, disable_date_parsing=disable_date_parsing) |
144 |
| - return self._as_df(r, fields, disable_date_parsing=disable_date_parsing) |
| 147 | + rows = self.json(fields, disable_date_parsing=disable_date_parsing) |
| 148 | + pred = fields_to_predicate(fields) |
| 149 | + columns: List[str] = [info.name for info in self.meta if pred(info.name)] |
| 150 | + df = DataFrame(rows, columns=columns or None) |
| 151 | + |
| 152 | + data_types: Dict[str, Any] = {} |
| 153 | + for info in self.meta: |
| 154 | + if not pred(info.name) or df[info.name].isnull().all(): |
| 155 | + continue |
| 156 | + if info.type == EpidataFieldType.bool: |
| 157 | + data_types[info.name] = bool |
| 158 | + elif info.type == EpidataFieldType.categorical: |
| 159 | + data_types[info.name] = CategoricalDtype( |
| 160 | + categories=Series(info.categories) if info.categories else None, ordered=True |
| 161 | + ) |
| 162 | + elif info.type == EpidataFieldType.int: |
| 163 | + data_types[info.name] = "Int64" |
| 164 | + elif info.type in ( |
| 165 | + EpidataFieldType.date, |
| 166 | + EpidataFieldType.epiweek, |
| 167 | + EpidataFieldType.date_or_epiweek, |
| 168 | + ): |
| 169 | + data_types[info.name] = "Int64" if disable_date_parsing else "datetime64[ns]" |
| 170 | + elif info.type == EpidataFieldType.float: |
| 171 | + data_types[info.name] = "Float64" |
| 172 | + else: |
| 173 | + data_types[info.name] = "string" |
| 174 | + if data_types: |
| 175 | + df = df.astype(data_types) |
| 176 | + return df |
145 | 177 |
|
146 | 178 |
|
147 | 179 | class EpiDataContext(AEpiDataEndpoints[EpiDataCall]):
|
|
0 commit comments