Skip to content

Commit 617b1ff

Browse files
committed
refactor: add parse_user_date_or_week, remove unneeded cast
1 parent b773b3f commit 617b1ff

File tree

1 file changed

+41
-6
lines changed

1 file changed

+41
-6
lines changed

epidatpy/_parse.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import date, datetime
2-
from typing import Callable, Optional, Sequence, Set, Union, cast
2+
from typing import Callable, Literal, Optional, Sequence, Set, Union
33

44
from epiweeks import Week
55

@@ -14,23 +14,58 @@ def parse_api_date(value: Union[str, int, float, None]) -> Optional[date]:
1414
def parse_api_week(value: Union[str, int, float, None]) -> Optional[date]:
1515
if value is None:
1616
return None
17-
return cast(date, Week.fromstring(str(value)).startdate())
17+
return Week.fromstring(str(value)).startdate()
1818

1919

2020
def parse_api_date_or_week(value: Union[str, int, float, None]) -> Optional[date]:
2121
if value is None:
2222
return None
2323
v = str(value)
2424
if len(v) == 6:
25-
d = cast(date, Week.fromstring(v).startdate())
25+
d = Week.fromstring(v).startdate()
2626
else:
2727
d = datetime.strptime(v, "%Y%m%d").date()
2828
return d
2929

3030

31-
def fields_to_predicate(
32-
fields: Optional[Sequence[str]] = None,
33-
) -> Callable[[str], bool]:
31+
def parse_user_date_or_week(
32+
value: Union[str, int, date, Week], out_type: Literal["day", "week", None] = None
33+
) -> Union[date, Week]:
34+
if isinstance(value, Week):
35+
if out_type == "day":
36+
return value.startdate()
37+
return value
38+
39+
if isinstance(value, date):
40+
if out_type == "week":
41+
return Week.fromdate(value)
42+
return value
43+
44+
value = str(value)
45+
if out_type == "week":
46+
if len(value) == 6:
47+
return Week.fromstring(value)
48+
if len(value) == 8:
49+
return Week.fromdate(datetime.strptime(value, "%Y%m%d").date())
50+
if len(value) == 10:
51+
return Week.fromdate(datetime.strptime(value, "%Y-%m-%d").date())
52+
if out_type == "day":
53+
if len(value) == 8:
54+
return datetime.strptime(value, "%Y%m%d").date()
55+
if len(value) == 10:
56+
return datetime.strptime(value, "%Y-%m-%d").date()
57+
if out_type is None:
58+
if len(value) == 6:
59+
return Week.fromstring(value)
60+
if len(value) == 8:
61+
return datetime.strptime(value, "%Y%m%d").date()
62+
if len(value) == 10:
63+
return datetime.strptime(value, "%Y-%m-%d").date()
64+
65+
raise ValueError(f"Cannot parse date or week from {value}")
66+
67+
68+
def fields_to_predicate(fields: Optional[Sequence[str]] = None) -> Callable[[str], bool]:
3469
if not fields:
3570
return lambda _: True
3671
to_include: Set[str] = set()

0 commit comments

Comments
 (0)