diff --git a/CHANGELOG.md b/CHANGELOG.md index a6864353..c520d08f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ * bug: check for None in Chipper bounding box reduction * chore: removes `install-detectron2` from the `Makefile` +# enhancement: allow table_agent to load from external parameter `UNSTRUCTURED_TABLE_AGENT_MODEL_PATH` ## 0.7.24 diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index 2df6e157..81fb8360 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -1,4 +1,5 @@ import os +from unittest import mock import numpy as np import pytest @@ -14,16 +15,39 @@ skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"} +def test_load_agent_with_env_variable(monkeypatch): + with mock.patch( + "unstructured_inference.models.tables.tables_agent", + autospec=True, + ) as mock_tables_agent: + mock_tables_agent.initialize.return_value = None + monkeypatch.setenv("UNSTRUCTURED_TABLE_AGENT_MODEL_PATH", "fake_model_path") + + tables.load_agent() + + mock_tables_agent.initialize.assert_called_once_with("fake_model_path") + + +def test_load_agent_without_env_variable(monkeypatch): + with mock.patch( + "unstructured_inference.models.tables.tables_agent", + autospec=True, + ) as mock_tables_agent: + mock_tables_agent.initialize.return_value = None + + tables.load_agent() + + mock_tables_agent.initialize.assert_called_once_with( + "microsoft/table-transformer-structure-recognition", + ) + + @pytest.fixture() def table_transformer(): tables.load_agent() return tables.tables_agent -def test_load_agent(table_transformer): - assert hasattr(table_transformer, "model") - - @pytest.fixture() def example_image(): return Image.open("./sample-docs/table-multi-row-column-cells.png").convert("RGB") @@ -567,7 +591,7 @@ def test_load_table_model_raises_when_not_available(model_path): @pytest.mark.parametrize( - "bbox1, bbox2, expected_result", + ("bbox1", "bbox2", "expected_result"), [ ((0, 0, 5, 5), (2, 2, 7, 7), 0.36), ((0, 0, 0, 0), (6, 6, 10, 10), 0), @@ -916,7 +940,9 @@ def test_table_prediction_output_format( ) if output_format: result = table_transformer.run_prediction( - example_image, result_format=output_format, ocr_tokens=mocked_ocr_tokens + example_image, + result_format=output_format, + ocr_tokens=mocked_ocr_tokens, ) else: result = table_transformer.run_prediction(example_image, ocr_tokens=mocked_ocr_tokens) @@ -1154,10 +1180,30 @@ def test_cells_to_html(): # | rows |sub cell 1|sub cell 2| # +----------+----------+----------+ cells = [ - {"row_nums": [0, 1], "column_nums": [0], "cell text": "two row", "column header": False}, - {"row_nums": [0], "column_nums": [1, 2], "cell text": "two cols", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "sub cell 1", "column header": False}, - {"row_nums": [1], "column_nums": [2], "cell text": "sub cell 2", "column header": False}, + { + "row_nums": [0, 1], + "column_nums": [0], + "cell text": "two row", + "column header": False, + }, + { + "row_nums": [0], + "column_nums": [1, 2], + "cell text": "two cols", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "sub cell 1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "sub cell 2", + "column header": False, + }, ] expected = ( '
two row | two ' diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index ca7e75a4..7ea04dd2 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -1,5 +1,6 @@ # https://github.com/microsoft/table-transformer/blob/main/src/inference.py # https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Using_Table_Transformer_for_table_detection_and_table_structure_recognition.ipynb +import os import xml.etree.ElementTree as ET from collections import defaultdict from pathlib import Path @@ -121,7 +122,11 @@ def load_agent(): if not hasattr(tables_agent, "model"): logger.info("Loading the Table agent ...") - tables_agent.initialize("microsoft/table-transformer-structure-recognition") + table_agent_model_path = os.environ.get("UNSTRUCTURED_TABLE_AGENT_MODEL_PATH") + if table_agent_model_path is not None: + tables_agent.initialize(table_agent_model_path) + else: + tables_agent.initialize("microsoft/table-transformer-structure-recognition") return |