Skip to content

Commit

Permalink
Use overloading and type[] for typing instead of custom enum
Browse files Browse the repository at this point in the history
  • Loading branch information
lucc committed Jan 3, 2025
1 parent 8d61c33 commit 8686e3b
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 56 deletions.
53 changes: 23 additions & 30 deletions khard/contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@

from . import address_book # pylint: disable=unused-import # for type checking
from . import helpers
from .helpers.typing import (Date, ObjectType, PostAddress, StrList,
convert_to_vcard, list_to_string, string_to_date,
string_to_list)
from .helpers.typing import (Date, PostAddress, StrList, convert_to_vcard,
list_to_string, string_to_date, string_to_list)
from .query import AnyQuery, Query


Expand Down Expand Up @@ -271,7 +270,7 @@ def version(self, value: str) -> None:
# for version 4 but also makes sense for all other versions.
self._delete_vcard_object("VERSION")
version = self.vcard.add("version")
version.value = convert_to_vcard("version", value, ObjectType.str)
version.value = convert_to_vcard("version", value, str)

@property
def uid(self) -> str:
Expand All @@ -283,7 +282,7 @@ def uid(self, value: str) -> None:
# for version 4 but also makes sense for all other versions.
self._delete_vcard_object("UID")
uid = self.vcard.add('uid')
uid.value = convert_to_vcard("uid", value, ObjectType.str)
uid.value = convert_to_vcard("uid", value, str)

def _update_revision(self) -> None:
"""Generate a new REV field for the vCard, replace any existing
Expand Down Expand Up @@ -413,7 +412,7 @@ def _get_new_group(self, group_type: str = "") -> str:
def _add_labelled_property(
self, property: str, value: StrList, label: Optional[str] = None,
name_groups: bool = False,
allowed_object_type: ObjectType = ObjectType.str) -> None:
allowed_object_type: Union[None, type[str], type[list]] = str) -> None:
"""Add an object to the VCARD. If a label is given it will be added to
a group with an ABLABEL.
Expand Down Expand Up @@ -501,7 +500,7 @@ def formatted_name(self, value: str) -> None:
"""
self._delete_vcard_object("FN")
if value:
final = convert_to_vcard("FN", value, ObjectType.str)
final = convert_to_vcard("FN", value, str)
elif self._get_first_names() or self._get_last_names():
# autofill the FN field from the N field
names = [self._get_name_prefixes(), self._get_first_names(),
Expand Down Expand Up @@ -593,12 +592,11 @@ def _add_name(self, prefix: StrList, first_name: StrList,
"""
name_obj = self.vcard.add('n')
name_obj.value = vobject.vcard.Name(
prefix=convert_to_vcard("name prefix", prefix, ObjectType.both),
given=convert_to_vcard("first name", first_name, ObjectType.both),
additional=convert_to_vcard("additional name", additional_name,
ObjectType.both),
family=convert_to_vcard("last name", last_name, ObjectType.both),
suffix=convert_to_vcard("name suffix", suffix, ObjectType.both))
prefix=convert_to_vcard("name prefix", prefix, None),
given=convert_to_vcard("first name", first_name, None),
additional=convert_to_vcard("additional name", additional_name, None),
family=convert_to_vcard("last name", last_name, None),
suffix=convert_to_vcard("name suffix", suffix, None))

@property
def organisations(self) -> list[Union[list[str], dict[str, list[str]]]]:
Expand All @@ -613,8 +611,7 @@ def _add_organisation(self, organisation: StrList, label: Optional[str] = None)
:param organisation: the value to add
:param label: an optional label to add
"""
self._add_labelled_property("org", organisation, label, True,
ObjectType.list)
self._add_labelled_property("org", organisation, label, True, list)
# check if fn attribute is already present
if not self.vcard.getChildValue("fn") and self.organisations:
# if not, set fn to organisation name
Expand Down Expand Up @@ -680,8 +677,7 @@ def _add_category(self, categories: list[str]) -> None:
:param categories:
"""
categories_obj = self.vcard.add('categories')
categories_obj.value = convert_to_vcard("category", categories,
ObjectType.list)
categories_obj.value = convert_to_vcard("category", categories, list)

@property
def phone_numbers(self) -> dict[str, list[str]]:
Expand Down Expand Up @@ -725,13 +721,12 @@ def _add_phone_number(self, type: str, number: str) -> None:
phone_obj = self.vcard.add('tel')
if self.version == "4.0":
phone_obj.value = "tel:{}".format(
convert_to_vcard("phone number", number, ObjectType.str))
convert_to_vcard("phone number", number, str))
phone_obj.params['VALUE'] = ["uri"]
if pref > 0:
phone_obj.params['PREF'] = str(pref)
else:
phone_obj.value = convert_to_vcard("phone number", number,
ObjectType.str)
phone_obj.value = convert_to_vcard("phone number", number, str)
if pref > 0:
standard_types.append("pref")
if standard_types:
Expand Down Expand Up @@ -777,8 +772,7 @@ def add_email(self, type: str, address: str) -> None:
"than one custom label: " +
list_to_string(custom_types, ", "))
email_obj = self.vcard.add('email')
email_obj.value = convert_to_vcard("email address", address,
ObjectType.str)
email_obj.value = convert_to_vcard("email address", address, str)
if self.version == "4.0":
if pref > 0:
email_obj.params['PREF'] = str(pref)
Expand Down Expand Up @@ -877,14 +871,13 @@ def _add_post_address(self, type: str, box: StrList, extended: StrList,
"label: " + list_to_string(custom_types, ", "))
adr_obj = self.vcard.add('adr')
adr_obj.value = vobject.vcard.Address(
box=convert_to_vcard("box address field", box, ObjectType.both),
extended=convert_to_vcard("extended address field", extended,
ObjectType.both),
street=convert_to_vcard("street", street, ObjectType.both),
code=convert_to_vcard("post code", code, ObjectType.both),
city=convert_to_vcard("city", city, ObjectType.both),
region=convert_to_vcard("region", region, ObjectType.both),
country=convert_to_vcard("country", country, ObjectType.both))
box=convert_to_vcard("box address field", box, None),
extended=convert_to_vcard("extended address field", extended, None),
street=convert_to_vcard("street", street, None),
code=convert_to_vcard("post code", code, None),
city=convert_to_vcard("city", city, None),
region=convert_to_vcard("region", region, None),
country=convert_to_vcard("country", country, None))
if self.version == "4.0":
if pref > 0:
adr_obj.params['PREF'] = str(pref)
Expand Down
26 changes: 12 additions & 14 deletions khard/helpers/typing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
"""Helper code for type annotations and runtime type conversion."""

from datetime import datetime
from enum import Enum
from typing import Union


class ObjectType(Enum):
str = 1
list = 2
both = 3
from typing import Union, overload


# some type aliases
Expand All @@ -17,8 +10,13 @@ class ObjectType(Enum):
PostAddress = dict[str, str]


def convert_to_vcard(name: str, value: StrList, constraint: ObjectType
) -> StrList:
@overload
def convert_to_vcard(name: str, value: StrList, constraint: type[str]) -> str: ...
@overload
def convert_to_vcard(name: str, value: StrList, constraint: type[list]) -> list[str]: ...
@overload
def convert_to_vcard(name: str, value: StrList, constraint: None) -> StrList: ...
def convert_to_vcard(name: str, value: StrList, constraint: Union[None, type[str], type[list]]) -> StrList:
"""converts user input into vCard compatible data structures
:param name: object name, only required for error messages
Expand All @@ -27,19 +25,19 @@ def convert_to_vcard(name: str, value: StrList, constraint: ObjectType
:returns: cleaned user input, ready for vCard or a ValueError
"""
if isinstance(value, str):
if constraint == ObjectType.list:
if constraint is list:
return [value.strip()]
return value.strip()
if isinstance(value, list):
if constraint == ObjectType.str:
if constraint is str:
raise ValueError(f"{name} must contain a string.")
if not all(isinstance(entry, str) for entry in value):
raise ValueError(f"{name} must not contain a nested list")
# filter out empty list items and strip leading and trailing space
return [x.strip() for x in value if x.strip()]
if constraint == ObjectType.str:
if constraint is str:
raise ValueError(f"{name} must be a string.")
if constraint == ObjectType.list:
if constraint is list:
raise ValueError(f"{name} must be a list with strings.")
raise ValueError(f"{name} must be a string or a list with strings.")

Expand Down
17 changes: 8 additions & 9 deletions test/test_helpers_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import unittest

from khard.helpers.typing import (
ObjectType,
convert_to_vcard,
list_to_string,
string_to_date,
Expand All @@ -14,42 +13,42 @@
class ConvertToVcard(unittest.TestCase):
def test_returns_strings(self):
value = "some text"
actual = convert_to_vcard("test", value, ObjectType.str)
actual = convert_to_vcard("test", value, str)
self.assertEqual(value, actual)

def test_returns_lists(self):
value = ["some", "text"]
actual = convert_to_vcard("test", value, ObjectType.list)
actual = convert_to_vcard("test", value, list)
self.assertListEqual(value, actual)

def test_fail_if_not_string(self):
value = ["some", "text"]
with self.assertRaises(ValueError):
convert_to_vcard("test", value, ObjectType.str)
convert_to_vcard("test", value, str)

def test_upgrades_string_to_list(self):
value = "some text"
actual = convert_to_vcard("test", value, ObjectType.list)
actual = convert_to_vcard("test", value, list)
self.assertListEqual([value], actual)

def test_fails_if_string_lists_are_not_homogeneous(self):
value = ["some", ["nested", "list"]]
with self.assertRaises(ValueError):
convert_to_vcard("test", value, ObjectType.list)
convert_to_vcard("test", value, list)

def test_empty_list_items_are_filtered(self):
value = ["some", "", "text", "", "more text"]
actual = convert_to_vcard("test", value, ObjectType.list)
actual = convert_to_vcard("test", value, list)
self.assertListEqual(["some", "text", "more text"], actual)

def test_strings_are_stripped(self):
value = " some text "
actual = convert_to_vcard("test", value, ObjectType.str)
actual = convert_to_vcard("test", value, str)
self.assertEqual("some text", actual)

def test_strings_in_lists_are_stripped(self):
value = [" some ", " text "]
actual = convert_to_vcard("test", value, ObjectType.list)
actual = convert_to_vcard("test", value, list)
self.assertListEqual(["some", "text"], actual)


Expand Down
5 changes: 2 additions & 3 deletions test/test_vcard_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import vobject

from khard.contacts import VCardWrapper
from khard.helpers.typing import ObjectType

from .helpers import vCard, TestVCardWrapper

Expand Down Expand Up @@ -524,7 +523,7 @@ def test_add_several_strings(self):
def test_add_a_list_of_strings(self):
with self.assertTitle([["foo","bar"]]) as wrapper:
wrapper._add_labelled_property("title", ["foo", "bar"],
allowed_object_type=ObjectType.list)
allowed_object_type=list)

def test_add_string_with_label(self):
with self.assertTitle([{"foo": "bar"}]) as wrapper:
Expand All @@ -543,7 +542,7 @@ def test_add_strings_with_different_label(self):
def test_add_a_list_with_label(self):
with self.assertTitle([{"foo": ["bar", "baz"]}]) as wrapper:
wrapper._add_labelled_property("title", ["bar", "baz"], "foo",
allowed_object_type=ObjectType.list)
allowed_object_type=list)


class GetFirst(unittest.TestCase):
Expand Down

0 comments on commit 8686e3b

Please sign in to comment.