Skip to content

Commit 21bc67f

Browse files
authored
rfctr: improve element typing (#2247)
In preparation for work on generalized chunking including `chunk_by_character()` and overlap, get `elements` module and tests passing strict type-checking.
1 parent 76efcf4 commit 21bc67f

File tree

5 files changed

+116
-70
lines changed

5 files changed

+116
-70
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## 0.11.4-dev7
1+
## 0.11.4-dev8
22

33
### Enhancements
44

test_unstructured/documents/test_elements.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Element,
2626
ElementMetadata,
2727
NoID,
28+
Points,
2829
RegexMetadata,
2930
Text,
3031
)
@@ -37,9 +38,13 @@ def test_text_id():
3738

3839
def test_text_uuid():
3940
text_element = Text(text="hello there!", element_id=UUID())
40-
assert len(text_element.id) == 36
41-
assert text_element.id.count("-") == 4
42-
# Test that the element is JSON serializable. This shold run without an error
41+
42+
id = text_element.id
43+
44+
assert isinstance(id, str)
45+
assert len(id) == 36
46+
assert id.count("-") == 4
47+
# -- Test that the element is JSON serializable. This shold run without an error --
4348
json.dumps(text_element.to_dict())
4449

4550

@@ -71,9 +76,13 @@ def test_text_element_apply_multiple_cleaners():
7176

7277

7378
def test_apply_raises_if_func_does_not_produce_string():
79+
def bad_cleaner(s: str):
80+
return 1
81+
7482
text_element = Text(text="[1] A Textbook on Crocodile Habitats")
75-
with pytest.raises(ValueError):
76-
text_element.apply(lambda s: 1)
83+
84+
with pytest.raises(ValueError, match="Cleaner produced a non-string output."):
85+
text_element.apply(bad_cleaner) # pyright: ignore[reportGeneralTypeIssues]
7786

7887

7988
@pytest.mark.parametrize(
@@ -106,22 +115,27 @@ def test_apply_raises_if_func_does_not_produce_string():
106115
],
107116
)
108117
def test_convert_coordinates_to_new_system(
109-
coordinates,
110-
orientation1,
111-
orientation2,
112-
expected_coords,
118+
coordinates: Points,
119+
orientation1: Orientation,
120+
orientation2: Orientation,
121+
expected_coords: Points,
113122
):
114123
coord1 = CoordinateSystem(100, 200)
115124
coord1.orientation = orientation1
116125
coord2 = CoordinateSystem(1000, 2000)
117126
coord2.orientation = orientation2
118127
element = Element(coordinates=coordinates, coordinate_system=coord1)
128+
119129
new_coords = element.convert_coordinates_to_new_system(coord2)
120-
for new_coord, expected_coord in zip(new_coords, expected_coords):
121-
new_coord == pytest.approx(expected_coord)
130+
131+
assert new_coords is not None
132+
for new_coord, expected in zip(new_coords, expected_coords):
133+
assert new_coord == pytest.approx(expected) # pyright: ignore[reportUnknownMemberType]
122134
element.convert_coordinates_to_new_system(coord2, in_place=True)
123-
for new_coord, expected_coord in zip(element.metadata.coordinates.points, expected_coords):
124-
assert new_coord == pytest.approx(expected_coord)
135+
assert element.metadata.coordinates is not None
136+
assert element.metadata.coordinates.points is not None
137+
for new_coord, expected in zip(element.metadata.coordinates.points, expected_coords):
138+
assert new_coord == pytest.approx(expected) # pyright: ignore[reportUnknownMemberType]
125139
assert element.metadata.coordinates.system == coord2
126140

127141

unstructured/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.11.4-dev7" # pragma: no cover
1+
__version__ = "0.11.4-dev8" # pragma: no cover

unstructured/cleaners/translate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _validate_language_code(language_code: str):
2121
)
2222

2323

24-
def translate_text(text, source_lang: Optional[str] = None, target_lang: str = "en") -> str:
24+
def translate_text(text: str, source_lang: Optional[str] = None, target_lang: str = "en") -> str:
2525
"""Translates the foreign language text. If the source language is not specified, the
2626
function will attempt to detect it using langdetect.
2727

unstructured/documents/elements.py

Lines changed: 86 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import re
1313
import uuid
1414
from types import MappingProxyType
15-
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple, Union
15+
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Sequence, Tuple, Union, cast
1616

17-
from typing_extensions import ParamSpec, TypedDict
17+
from typing_extensions import ParamSpec, TypeAlias, TypedDict
1818

1919
from unstructured.documents.coordinates import (
2020
TYPE_TO_COORDINATE_SYSTEM_MAP,
@@ -24,6 +24,9 @@
2424
from unstructured.partition.utils.constants import UNSTRUCTURED_INCLUDE_DEBUG_METADATA
2525
from unstructured.utils import lazyproperty
2626

27+
Point: TypeAlias = Tuple[float, float]
28+
Points: TypeAlias = Tuple[Point, ...]
29+
2730

2831
class NoID(abc.ABC):
2932
"""Class to indicate that an element do not have an ID."""
@@ -61,10 +64,10 @@ def from_dict(cls, input_dict: Dict[str, Any]):
6164
class CoordinatesMetadata:
6265
"""Metadata fields that pertain to the coordinates of the element."""
6366

64-
points: Tuple[Tuple[float, float], ...]
65-
system: CoordinateSystem
67+
points: Optional[Points]
68+
system: Optional[CoordinateSystem]
6669

67-
def __init__(self, points, system):
70+
def __init__(self, points: Optional[Points], system: Optional[CoordinateSystem]):
6871
# Both `points` and `system` must be present; one is not meaningful without the other.
6972
if (points is None and system is not None) or (points is not None and system is None):
7073
raise ValueError(
@@ -94,30 +97,38 @@ def to_dict(self):
9497
@classmethod
9598
def from_dict(cls, input_dict: Dict[str, Any]):
9699
# `input_dict` may contain a tuple of tuples or a list of lists
97-
def convert_to_tuple_of_tuples(sequence_of_sequences):
98-
subsequences = []
100+
def convert_to_points(sequence_of_sequences: Sequence[Sequence[float]]) -> Points:
101+
points: List[Point] = []
99102
for seq in sequence_of_sequences:
100103
if isinstance(seq, list):
101-
subsequences.append(tuple(seq))
104+
points.append(cast(Point, tuple(seq)))
102105
elif isinstance(seq, tuple):
103-
subsequences.append(seq)
104-
return tuple(subsequences)
105-
106-
input_points = input_dict.get("points", None)
107-
points = convert_to_tuple_of_tuples(input_points) if input_points is not None else None
108-
width = input_dict.get("layout_width", None)
109-
height = input_dict.get("layout_height", None)
110-
system = None
111-
if input_dict.get("system", None) == "RelativeCoordinateSystem":
112-
system = RelativeCoordinateSystem()
113-
elif (
114-
width is not None
115-
and height is not None
116-
and input_dict.get("system", None) in TYPE_TO_COORDINATE_SYSTEM_MAP
117-
):
118-
system = TYPE_TO_COORDINATE_SYSTEM_MAP[input_dict["system"]](width, height)
119-
constructor_args = {"points": points, "system": system}
120-
return cls(**constructor_args)
106+
points.append(cast(Point, seq))
107+
return tuple(points)
108+
109+
# -- parse points --
110+
input_points = input_dict.get("points")
111+
points = convert_to_points(input_points) if input_points is not None else None
112+
113+
# -- parse system --
114+
system_name = input_dict.get("system")
115+
width = input_dict.get("layout_width")
116+
height = input_dict.get("layout_height")
117+
system = (
118+
None
119+
if system_name is None
120+
else RelativeCoordinateSystem()
121+
if system_name == "RelativeCoordinateSystem"
122+
else TYPE_TO_COORDINATE_SYSTEM_MAP[system_name](width, height)
123+
if (
124+
width is not None
125+
and height is not None
126+
and system_name in TYPE_TO_COORDINATE_SYSTEM_MAP
127+
)
128+
else None
129+
)
130+
131+
return cls(points=points, system=system)
121132

122133

123134
class RegexMetadata(TypedDict):
@@ -637,14 +648,19 @@ def to_dict(self) -> Dict[str, Any]:
637648
}
638649

639650
def convert_coordinates_to_new_system(
640-
self,
641-
new_system: CoordinateSystem,
642-
in_place=True,
643-
) -> Optional[Tuple[Tuple[Union[int, float], Union[int, float]], ...]]:
644-
"""Converts the element location coordinates to a new coordinate system. If inplace is true,
645-
changes the coordinates in place and updates the coordinate system."""
646-
if self.metadata.coordinates is None:
651+
self, new_system: CoordinateSystem, in_place: bool = True
652+
) -> Optional[Points]:
653+
"""Converts the element location coordinates to a new coordinate system.
654+
655+
If inplace is true, changes the coordinates in place and updates the coordinate system.
656+
"""
657+
if (
658+
self.metadata.coordinates is None
659+
or self.metadata.coordinates.system is None
660+
or self.metadata.coordinates.points is None
661+
):
647662
return None
663+
648664
new_coordinates = tuple(
649665
self.metadata.coordinates.system.convert_coordinates_to_new_system(
650666
new_system=new_system,
@@ -653,15 +669,19 @@ def convert_coordinates_to_new_system(
653669
)
654670
for x, y in self.metadata.coordinates.points
655671
)
672+
656673
if in_place:
657674
self.metadata.coordinates.points = new_coordinates
658675
self.metadata.coordinates.system = new_system
676+
659677
return new_coordinates
660678

661679

662680
class CheckBox(Element):
663-
"""A checkbox with an attribute indicating whether its checked or not. Primarily used
664-
in documents that are forms"""
681+
"""A checkbox with an attribute indicating whether its checked or not.
682+
683+
Primarily used in documents that are forms.
684+
"""
665685

666686
def __init__(
667687
self,
@@ -682,12 +702,18 @@ def __init__(
682702
)
683703
self.checked: bool = checked
684704

685-
def __eq__(self, other):
686-
return (self.checked == other.checked) and (
687-
self.metadata.coordinates == other.metadata.coordinates
705+
def __eq__(self, other: object) -> bool:
706+
if not isinstance(other, CheckBox):
707+
return False
708+
return all(
709+
(
710+
self.checked == other.checked,
711+
self.metadata.coordinates == other.metadata.coordinates,
712+
)
688713
)
689714

690-
def to_dict(self) -> dict:
715+
def to_dict(self) -> Dict[str, Any]:
716+
"""Serialize to JSON-compatible (str keys) dict."""
691717
out = super().to_dict()
692718
out["type"] = "CheckBox"
693719
out["checked"] = self.checked
@@ -729,20 +755,23 @@ def __init__(
729755
detection_origin=detection_origin,
730756
)
731757

732-
def __str__(self):
733-
return self.text
734-
735-
def __eq__(self, other):
758+
def __eq__(self, other: object):
759+
if not isinstance(other, Text):
760+
return False
736761
return all(
737-
[
738-
(self.text == other.text),
739-
(self.metadata.coordinates == other.metadata.coordinates),
740-
(self.category == other.category),
741-
(self.embeddings == other.embeddings),
742-
],
762+
(
763+
self.text == other.text,
764+
self.metadata.coordinates == other.metadata.coordinates,
765+
self.category == other.category,
766+
self.embeddings == other.embeddings,
767+
),
743768
)
744769

745-
def to_dict(self) -> dict:
770+
def __str__(self):
771+
return self.text
772+
773+
def to_dict(self) -> Dict[str, Any]:
774+
"""Serialize to JSON-compatible (str keys) dict."""
746775
out = super().to_dict()
747776
out["element_id"] = self.id
748777
out["type"] = self.category
@@ -751,14 +780,17 @@ def to_dict(self) -> dict:
751780
out["embeddings"] = self.embeddings
752781
return out
753782

754-
def apply(self, *cleaners: Callable):
755-
"""Applies a cleaning brick to the text element. The function that's passed in
756-
should take a string as input and produce a string as output."""
783+
def apply(self, *cleaners: Callable[[str], str]):
784+
"""Applies a cleaning brick to the text element.
785+
786+
The function that's passed in should take a string as input and produce a string as
787+
output.
788+
"""
757789
cleaned_text = self.text
758790
for cleaner in cleaners:
759791
cleaned_text = cleaner(cleaned_text)
760792

761-
if not isinstance(cleaned_text, str):
793+
if not isinstance(cleaned_text, str): # pyright: ignore[reportUnnecessaryIsInstance]
762794
raise ValueError("Cleaner produced a non-string output.")
763795

764796
self.text = cleaned_text

0 commit comments

Comments
 (0)