Skip to content

Commit 9f4e022

Browse files
committed
Image Payload changes
1 parent 8b233e4 commit 9f4e022

File tree

4 files changed

+470
-58
lines changed

4 files changed

+470
-58
lines changed

forte/common/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
# The index storing entry type in the internal entry data of DataStore.
1212
ENTRY_TYPE_INDEX = 3
1313

14+
# The index storing the payload ID in internal entry data of DataStore
15+
PAYLOAD_INDEX = 0
16+
1417
# The index storing entry type (specific to Link and Group type). It is saved
1518
# in the `tid_idx_dict` in DataStore.
1619
ENTRY_DICT_TYPE_INDEX = 0

forte/data/ontology/top.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
PARENT_TID_INDEX,
4444
CHILD_TID_INDEX,
4545
MEMBER_TID_INDEX,
46+
PAYLOAD_INDEX,
4647
)
4748

4849
__all__ = [
@@ -72,8 +73,15 @@
7273
make sure it available across the ontology system:
7374
1. Create a new top level class that inherits from `Entry` or `MultiEntry`
7475
2. Add the new class to `SinglePackEntries` or `MultiPackEntries`
75-
3. Register a new method in `DataStore`: `add_<new_entry>_raw()`
76-
4. Insert a new conditional branch in `EntryConverter.save_entry_object()`
76+
3. Insert a new conditional branch in `EntryConverter.save_entry_object()`
77+
4. Decide two main attributes which will qualify as your `attribute_data`
78+
parameters. These parameters will be passes in your branch of
79+
`EntryConverter.save_entry_object()`. If there are no such parameters,
80+
you can pass None
81+
5. add `getter` and `setter` functions to update `attribute_data` parameters
82+
if you have any
83+
6. If additional attributes are required, make the class a `dataclass` and set
84+
`dataclass` attributes.
7785
"""
7886

7987

@@ -879,7 +887,9 @@ def image(self):
879887
"Cannot get image because image annotation is not "
880888
"attached to any data pack."
881889
)
882-
return self.pack.get_image_array(self._image_payload_idx)
890+
return self.pack.get_payload_data_at(
891+
Modality.Image, self._image_payload_idx
892+
)
883893

884894
@property
885895
def max_x(self):
@@ -1052,12 +1062,35 @@ def __init__(self, pack: PackType, image_payload_idx: int = 0):
10521062
else:
10531063
self._image_payload_idx = image_payload_idx
10541064

1065+
@property
1066+
def image_payload_idx(self):
1067+
r"""Getter function of ``image_payload_idx``. The function will first try to
1068+
retrieve the image_payload_idx index from ``DataStore`` in ``self.pack``. If
1069+
this attempt fails, it will directly return the value in ``_image_payload_idx``.
1070+
"""
1071+
try:
1072+
self._image_payload_idx = self.pack.get_entry_raw(self.tid)[
1073+
PAYLOAD_INDEX
1074+
]
1075+
except KeyError:
1076+
pass
1077+
return self._image_payload_idx
1078+
1079+
@image_payload_idx.setter
1080+
def image_payload_idx(self, val: int):
1081+
r"""Setter function of ``image_payload_idx``. The update will also be populated
1082+
into ``DataStore`` in ``self.pack``.
1083+
"""
1084+
self._image_payload_idx = val
1085+
self.pack.get_entry_raw(self.tid)[PAYLOAD_INDEX] = val
1086+
10551087
def compute_iou(self, other) -> int:
10561088
intersection = np.sum(np.logical_and(self.image, other.image))
10571089
union = np.sum(np.logical_or(self.image, other.image))
10581090
return intersection / union
10591091

10601092

1093+
@dataclass
10611094
class Box(Region):
10621095
"""
10631096
A box class with a center position and a box configuration.
@@ -1078,13 +1111,18 @@ class Box(Region):
10781111
width: the width of the box, the unit is one image array entry.
10791112
"""
10801113

1114+
_cy: int
1115+
_cx: int
1116+
_height: int
1117+
_width: int
1118+
10811119
def __init__(
10821120
self,
10831121
pack: PackType,
1084-
cy: int,
1085-
cx: int,
1086-
height: int,
1087-
width: int,
1122+
cy: int = 0,
1123+
cx: int = 0,
1124+
height: int = 1,
1125+
width: int = 1,
10881126
image_payload_idx: int = 0,
10891127
):
10901128
# assume Box is associated with Grids
@@ -1180,6 +1218,7 @@ def compute_iou(self, other):
11801218
return intersection / union
11811219

11821220

1221+
@dataclass
11831222
class BoundingBox(Box):
11841223
"""
11851224
A bounding box class that associates with image payload and grids and
@@ -1208,15 +1247,17 @@ class BoundingBox(Box):
12081247
12091248
"""
12101249

1250+
_grid_id: int
1251+
12111252
def __init__(
12121253
self,
12131254
pack: PackType,
1214-
height: int,
1215-
width: int,
1216-
grid_height: int,
1217-
grid_width: int,
1218-
grid_cell_h_idx: int,
1219-
grid_cell_w_idx: int,
1255+
height: int = 1,
1256+
width: int = 1,
1257+
grid_height: int = 1,
1258+
grid_width: int = 1,
1259+
grid_cell_h_idx: int = 0,
1260+
grid_cell_w_idx: int = 0,
12201261
image_payload_idx: int = 0,
12211262
):
12221263
self.grids = Grids(pack, grid_height, grid_width, image_payload_idx)
@@ -1228,6 +1269,8 @@ def __init__(
12281269
image_payload_idx,
12291270
)
12301271

1272+
self._grid_id = self.grids.tid
1273+
12311274

12321275
class Payload(Entry):
12331276
"""

0 commit comments

Comments
 (0)