Skip to content

Commit

Permalink
feat: add agentic od tools (#344)
Browse files Browse the repository at this point in the history
* add agentic od tools

* Update vision_agent/tools/tools.py

Co-authored-by: Hernan Payrumani <[email protected]>

---------

Co-authored-by: Hernan Payrumani <[email protected]>
  • Loading branch information
camiloaz and hrnn authored Jan 15, 2025
1 parent 8359052 commit 7e42110
Show file tree
Hide file tree
Showing 4 changed files with 325 additions and 0 deletions.
77 changes: 77 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
vit_image_classification,
vit_nsfw_classification,
custom_object_detection,
agentic_object_detection,
agentic_sam2_instance_segmentation,
agentic_sam2_video_tracking,
)

FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da"
Expand Down Expand Up @@ -108,6 +111,80 @@ def test_owlv2_sam2_video_tracking_fine_tune_id():
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_agentic_object_detection():
img = ski.data.coins()
result = agentic_object_detection(
prompt="coin",
image=img,
)
assert 24 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_agentic_sam2_instance_segmentation():
img = ski.data.coins()
result = agentic_sam2_instance_segmentation(
prompt="coin",
image=img,
)
assert 24 <= len(result) <= 26
assert "mask" in result[0]
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_agentic_object_detection_empty():
result = agentic_object_detection(
prompt="coin",
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result == []


def test_agentic_fine_tune_id():
img = ski.data.coins()
result = agentic_object_detection(
prompt="coin",
image=img,
fine_tune_id=FINE_TUNE_ID,
)
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 13 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_agentic_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = agentic_sam2_video_tracking(
prompt="coin",
frames=frames,
)

assert len(result) == 10
assert 24 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_agentic_sam2_video_tracking_fine_tune_id():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
# this calls a fine-tuned florence2 model which is going to be worse at this task
result = agentic_sam2_video_tracking(
prompt="coin",
frames=frames,
fine_tune_id=FINE_TUNE_ID,
)

assert len(result) == 10
assert 12 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_florence2_object_detection():
img = ski.data.coins()
result = florence2_object_detection(
Expand Down
3 changes: 3 additions & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
vit_image_classification,
vit_nsfw_classification,
custom_object_detection,
agentic_object_detection,
agentic_sam2_instance_segmentation,
agentic_sam2_video_tracking,
)

__new_tools__ = [
Expand Down
244 changes: 244 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,14 @@ def _apply_object_detection( # inner method to avoid circular importing issues.
)
function_name = "florence2_object_detection"

elif od_model == ODModels.AGENTIC:
segment_results = agentic_object_detection(
prompt=prompt,
image=segment_frames[frame_number],
fine_tune_id=fine_tune_id,
)
function_name = "agentic_object_detection"

elif od_model == ODModels.CUSTOM:
segment_results = custom_object_detection(
deployment_id=fine_tune_id,
Expand Down Expand Up @@ -2140,6 +2148,242 @@ def siglip_classification(image: np.ndarray, labels: List[str]) -> Dict[str, Any
return response


# agentic od tools


def _agentic_object_detection(
prompt: str,
image: np.ndarray,
image_size: Tuple[int, ...],
image_bytes: Optional[bytes] = None,
fine_tune_id: Optional[str] = None,
) -> Dict[str, Any]:
if image_bytes is None:
image_bytes = numpy_to_bytes(image)

files = [("image", image_bytes)]
payload = {
"prompts": [s.strip() for s in prompt.split(",")],
"model": "agentic",
}
metadata = {"function_name": "agentic_object_detection"}

if fine_tune_id is not None:
landing_api = LandingPublicAPI()
status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
if status is not JobStatus.SUCCEEDED:
raise FineTuneModelIsNotReady(
f"Fine-tuned model {fine_tune_id} is not ready yet"
)

# we can only execute fine-tuned models with florence2
payload = {
"prompts": payload["prompts"],
"jobId": fine_tune_id,
"model": "florence2",
}

detections = send_task_inference_request(
payload,
"text-to-object-detection",
files=files,
metadata=metadata,
)

# get the first frame
bboxes = detections[0]
bboxes_formatted = [
{
"label": bbox["label"],
"bbox": normalize_bbox(bbox["bounding_box"], image_size),
"score": bbox["score"],
}
for bbox in bboxes
]
display_data = [
{
"label": bbox["label"],
"bbox": bbox["bounding_box"],
"score": bbox["score"],
}
for bbox in bboxes
]
return {
"files": files,
"return_data": bboxes_formatted,
"display_data": display_data,
}


def agentic_object_detection(
prompt: str,
image: np.ndarray,
fine_tune_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""'agentic_object_detection' is a tool that can detect and count multiple objects
given a text prompt such as category names or referring expressions on images. The
categories in text prompt are separated by commas. It returns a list of bounding
boxes with normalized coordinates, label names and associated probability scores.
Parameters:
prompt (str): The prompt to ground to the image.
image (np.ndarray): The image to ground the prompt to.
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
fine-tuned model ID here to use it.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
bounding box of the detected objects with normalized coordinates between 0
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
top-left and xmax and ymax are the coordinates of the bottom-right of the
bounding box.
Example
-------
>>> agentic_object_detection("car", image)
[
{'score': 0.99, 'label': 'car', 'bbox': [0.1, 0.11, 0.35, 0.4]},
{'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5},
]
"""

image_size = image.shape[:2]
if image_size[0] < 1 or image_size[1] < 1:
return []

ret = _agentic_object_detection(
prompt, image, image_size, fine_tune_id=fine_tune_id
)

_display_tool_trace(
agentic_object_detection.__name__,
{"prompts": prompt},
ret["display_data"],
ret["files"],
)
return ret["return_data"] # type: ignore


def agentic_sam2_instance_segmentation(
prompt: str, image: np.ndarray
) -> List[Dict[str, Any]]:
"""'agentic_sam2_instance_segmentation' is a tool that can detect and count multiple
instances of objects given a text prompt such as category names or referring
expressions on images. The categories in text prompt are separated by commas. It
returns a list of bounding boxes with normalized coordinates, label names, masks
and associated probability scores.
Parameters:
prompt (str): The object that needs to be counted.
image (np.ndarray): The image that contains multiple instances of the object.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
and xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.
Example
-------
>>> agentic_sam2_instance_segmentation("flower", image)
[
{
'score': 0.49,
'label': 'flower',
'bbox': [0.1, 0.11, 0.35, 0.4],
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
},
]
"""

od_ret = _agentic_object_detection(prompt, image, image.shape[:2])
seg_ret = _sam2(
image, od_ret["return_data"], image.shape[:2], image_bytes=od_ret["files"][0][1]
)

_display_tool_trace(
agentic_sam2_instance_segmentation.__name__,
{
"prompts": prompt,
},
seg_ret["display_data"],
seg_ret["files"],
)

return seg_ret["return_data"] # type: ignore


def agentic_sam2_video_tracking(
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
fine_tune_id: Optional[str] = None,
) -> List[List[Dict[str, Any]]]:
"""'agentic_sam2_video_tracking' is a tool that can track and segment multiple
objects in a video given a text prompt such as category names or referring
expressions. The categories in the text prompt are separated by commas. It returns
a list of bounding boxes, label names, masks and associated probability scores and
is useful for tracking and counting without duplicating counts.
Parameters:
prompt (str): The prompt to ground to the image.
frames (List[np.ndarray]): The list of frames to ground the prompt to.
chunk_length (Optional[int]): The number of frames to re-run agentic object detection to
to find new objects.
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
fine-tuned model ID here to use it.
Returns:
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
label, segmentation mask and bounding boxes. The outer list represents each
frame and the inner list is the entities per frame. The detected objects
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
and ymin are the coordinates of the top-left and xmax and ymax are the
coordinates of the bottom-right of the bounding box. The mask is binary 2D
numpy array where 1 indicates the object and 0 indicates the background.
The label names are prefixed with their ID represent the total count.
Example
-------
>>> agentic_sam2_video_tracking("dinosaur", frames)
[
[
{
'label': '0: dinosaur',
'bbox': [0.1, 0.11, 0.35, 0.4],
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
},
],
...
]
"""

ret = od_sam2_video_tracking(
ODModels.AGENTIC,
prompt=prompt,
frames=frames,
chunk_length=chunk_length,
fine_tune_id=fine_tune_id,
)
_display_tool_trace(
agentic_sam2_video_tracking.__name__,
{},
ret["display_data"],
ret["files"],
)
return ret["return_data"] # type: ignore


def minimum_distance(
det1: Dict[str, Any], det2: Dict[str, Any], image_size: Tuple[int, int]
) -> float:
Expand Down
1 change: 1 addition & 0 deletions vision_agent/utils/video_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ODModels(str, Enum):
COUNTGD = "countgd"
FLORENCE2 = "florence2"
OWLV2 = "owlv2"
AGENTIC = "agentic"
CUSTOM = "custom"


Expand Down

0 comments on commit 7e42110

Please sign in to comment.