Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 42 additions & 32 deletions docext/benchmark/metrics/grits.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from difflib import SequenceMatcher

import numpy as np
from fitz import Rect
from shapely.geometry import box



def compute_fscore(num_true_positives, num_true, num_positives):
Expand Down Expand Up @@ -235,17 +236,14 @@ def iou(bbox1, bbox2):
"""
Compute the intersection-over-union of two bounding boxes.
"""
intersection = Rect(x0=bbox1[0], y0=bbox1[1], x1=bbox1[2], y1=bbox1[3]).intersect(
Rect(x0=bbox2[0], y0=bbox2[1], x1=bbox2[2], y1=bbox2[3])
)
union = Rect(x0=bbox1[0], y0=bbox1[1], x1=bbox1[2], y1=bbox1[3]).include_rect(
Rect(x0=bbox2[0], y0=bbox2[1], x1=bbox2[2], y1=bbox2[3])
)

union_area = union.get_area()
if union_area > 0:
return intersection.get_area() / union.get_area()

box1 = box(bbox1[0], bbox1[1], bbox1[2], bbox1[3])
box2 = box(bbox2[0], bbox2[1], bbox2[2], bbox2[3])

intersection = box1.intersection(box2).area
union = box1.union(box2).area

if union > 0:
return intersection / union
return 0


Expand Down Expand Up @@ -309,24 +307,24 @@ def get_spanning_cell_rows_and_columns(spanning_cells, rows, columns):
row_matches = set()
column_matches = set()
for row_num, row in enumerate(rows):
bbox1 = [
bbox1 = box(
spanning_cell["bbox"][0],
row["bbox"][1],
spanning_cell["bbox"][2],
row["bbox"][3],
]
bbox2 = Rect(spanning_cell["bbox"]).intersect(bbox1)
if bbox2.get_area() / Rect(bbox1).get_area() >= 0.5:
row["bbox"][3]
)
bbox2 = box(*spanning_cell["bbox"]).intersection(bbox1)
if bbox2.area / bbox1.area >= 0.5:
row_matches.add(row_num)
for column_num, column in enumerate(columns):
bbox1 = [
bbox1 = box(
column["bbox"][0],
spanning_cell["bbox"][1],
column["bbox"][2],
spanning_cell["bbox"][3],
]
bbox2 = Rect(spanning_cell["bbox"]).intersect(bbox1)
if bbox2.get_area() / Rect(bbox1).get_area() >= 0.5:
spanning_cell["bbox"][3]
)
bbox2 = box(*spanning_cell["bbox"]).intersection(bbox1)
if bbox2.area / bbox1.area >= 0.5:
column_matches.add(column_num)
already_taken = False
this_matches = []
Expand All @@ -341,13 +339,23 @@ def get_spanning_cell_rows_and_columns(spanning_cells, rows, columns):
matches_by_spanning_cell.append(this_matches)
row_nums = [elem[0] for elem in this_matches]
column_nums = [elem[1] for elem in this_matches]
row_rect = Rect()
for row_num in row_nums:
row_rect.include_rect(rows[row_num]["bbox"])
column_rect = Rect()
for column_num in column_nums:
column_rect.include_rect(columns[column_num]["bbox"])
spanning_cell["bbox"] = list(row_rect.intersect(column_rect))

# Create union of all row boxes
row_boxes = [box(*rows[row_num]["bbox"]) for row_num in row_nums]
row_union = row_boxes[0]
for b in row_boxes[1:]:
row_union = row_union.union(b)

# Create union of all column boxes
column_boxes = [box(*columns[column_num]["bbox"]) for column_num in column_nums]
column_union = column_boxes[0]
for b in column_boxes[1:]:
column_union = column_union.union(b)

# Get intersection of row and column unions
intersection = row_union.intersection(column_union)
spanning_cell["bbox"] = [intersection.bounds[0], intersection.bounds[1],
intersection.bounds[2], intersection.bounds[3]]
else:
matches_by_spanning_cell.append([])

Expand Down Expand Up @@ -380,15 +388,17 @@ def output_to_dilatedbbox_grid(bboxes, labels, scores):
for row_num, row in enumerate(rows):
column_grid = []
for column_num, column in enumerate(columns):
bbox = Rect(row["bbox"]).intersect(column["bbox"])
column_grid.append(list(bbox))
bbox = box(row["bbox"][0], row["bbox"][1], row["bbox"][2], row["bbox"][3]).intersection(
box(column["bbox"][0], column["bbox"][1], column["bbox"][2], column["bbox"][3])
)
column_grid.append(list(bbox.bounds))
cell_grid.append(column_grid)
matches_by_spanning_cell = get_spanning_cell_rows_and_columns(
spanning_cells, rows, columns
)
for matches, spanning_cell in zip(matches_by_spanning_cell, spanning_cells):
for match in matches:
cell_grid[match[0]][match[1]] = spanning_cell["bbox"]
cell_grid[match[0]][match[1]] = list(spanning_cell["bbox"])

return cell_grid

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ loguru
mdpd
numpy
pandas
PyMuPDF
python-dotenv
python-levenshtein==0.27.1
requests
setuptools
shapely
tabulate
tenacity
types-requests
Expand Down