Skip to content

Commit 312eeb4

Browse files
committed
Make compatible with Python 3.10 on Lightning AI workspace
1 parent 1efc64a commit 312eeb4

File tree

7 files changed

+27
-19
lines changed

7 files changed

+27
-19
lines changed

llmlib/llmlib/base_llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from io import BytesIO
22
from pathlib import Path
3-
from typing import Literal, Self
3+
from typing import Literal
4+
from typing_extensions import Self
45
from PIL import Image
56

67

llmlib/llmlib/gemini/gemini_code.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from dataclasses import dataclass
6+
from enum import Enum
67
from functools import singledispatchmethod
78
from io import BytesIO
89
import json
@@ -23,7 +24,6 @@
2324
)
2425
import cv2
2526
from google import genai
26-
from enum import StrEnum
2727
from ..base_llm import LLM, Message, validate_only_first_message_has_files
2828
from ..error_handling import notify_bugsnag
2929

@@ -43,7 +43,7 @@ def storage_uri(bucket: str, blob_name: str) -> str:
4343
return "gs://%s/%s" % (bucket, blob_name)
4444

4545

46-
class GeminiModels(StrEnum):
46+
class GeminiModels(str, Enum):
4747
"""
4848
The 3 trailing digits indicate the stable version
4949
https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versions#stable-version

llmlib/llmlib/huggingface_inference.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from pathlib import Path
66
from dataclasses import dataclass
77
import PIL
8-
from enum import StrEnum
9-
8+
from enum import Enum
109
import openai
1110
from .base_llm import LLM, Message, validate_only_first_message_has_files
1211
import cv2
@@ -21,8 +20,12 @@ def get_image_as_base64(image_bytes: bytes):
2120
return base64.b64encode(image_bytes).decode("utf-8")
2221

2322

24-
def convert_message_to_hf_format(message: Message, max_n_frames_per_video: int) -> dict:
25-
"""Convert a Message to HuggingFace chat format."""
23+
def convert_message_to_openai_format(message: Message, max_n_frames_per_video: int) -> dict:
24+
"""
25+
Convert a Message to OpenAI chat format.
26+
Images become base64 encoded strings.
27+
Videos are processed like a list of images, each of which becomes a base64 encoded string.
28+
"""
2629
content = []
2730

2831
# Add text content if present
@@ -54,8 +57,9 @@ def convert_message_to_hf_format(message: Message, max_n_frames_per_video: int)
5457

5558

5659
def video_to_imgs(video_path: Path, max_n_frames: int) -> list[PIL.Image.Image]:
57-
assert isinstance(video_path, Path), video_path
5860
"""From https://github.com/agustoslu/simple-inference-benchmark/blob/5cec55787d34af65f0d11efc429c3d4de92f051a/utils.py#L79"""
61+
assert isinstance(video_path, Path), video_path
62+
assert video_path.exists(), video_path
5963
cap = cv2.VideoCapture(str(video_path))
6064
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
6165
fps = int(cap.get(cv2.CAP_PROP_FPS))
@@ -129,7 +133,7 @@ def extract_bytes(img: PIL.Image.Image | str | Path) -> bytes:
129133
raise ValueError(f"Unsupported image type: {type(img)}")
130134

131135

132-
class HuggingFaceVLMs(StrEnum):
136+
class HuggingFaceVLMs(str, Enum):
133137
gemma_3_27b_it = "google/gemma-3-27b-it"
134138

135139

@@ -171,7 +175,7 @@ def complete_msgs(self, msgs: list[Message]) -> str:
171175
"""Complete a conversation with the model."""
172176
validate_only_first_message_has_files(msgs)
173177
hf_messages = [
174-
convert_message_to_hf_format(
178+
convert_message_to_openai_format(
175179
msg, max_n_frames_per_video=self.max_n_frames_per_video
176180
)
177181
for msg in msgs

llmlib/llmlib/openai/openai_completion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
_default_model = "gpt-4o-mini"
88

9-
client = OpenAI() # must be outside of the class to avoid pickling issues
9+
10+
def create_client() -> OpenAI:
11+
return OpenAI() # must be outside of the class to avoid pickling issues
1012

1113

1214
class OpenAIModel(LLM):
@@ -43,6 +45,7 @@ def complete(model: str, prompt: str) -> str:
4345

4446

4547
def complete_msgs(model: str, messages: list[dict]) -> str:
48+
client = create_client()
4649
completion: ChatCompletion = client.chat.completions.create(
4750
model=model, temperature=0.0, messages=messages
4851
)

llmlib/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ authors = [
66
{name = "Tomas Ruiz", email = "[email protected]"}
77
]
88
readme = "README.md"
9-
requires-python = ">=3.11"
9+
requires-python = ">=3.10"
1010
dependencies = [
1111
"bugsnag>=4.7.1",
1212
"decorator>=5.1.1",

tests/test_huggingface_vlm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import cv2
2-
from llmlib.huggingface_inference import convert_message_to_hf_format
2+
from llmlib.huggingface_inference import convert_message_to_openai_format
33
import pytest
44
from llmlib.huggingface_inference import HuggingFaceVLM, HuggingFaceVLMs
55
from .helpers import (
@@ -66,14 +66,14 @@ def test_huggingface_vlm_multi_turn_with_6min_video(gemma3):
6666

6767

6868
@pytest.mark.skipif(condition=is_ci(), reason="Files are not available on CI")
69-
def test_convert_to_huggingface_format():
69+
def test_convert_to_openai_format():
7070
img_msg1 = pyramid_message(load_img=True)
7171
img_msg2 = pyramid_message(load_img=False)
7272
max_n_frames_per_video = 200
73-
b64_enc1 = convert_message_to_hf_format(img_msg1, max_n_frames_per_video)[
73+
b64_enc1 = convert_message_to_openai_format(img_msg1, max_n_frames_per_video)[
7474
"content"
7575
][1]["image_url"]["url"]
76-
b64_enc2 = convert_message_to_hf_format(img_msg2, max_n_frames_per_video)[
76+
b64_enc2 = convert_message_to_openai_format(img_msg2, max_n_frames_per_video)[
7777
"content"
7878
][1]["image_url"]["url"]
7979
# assert b64_enc1 == b64_enc2
@@ -84,5 +84,5 @@ def test_convert_to_huggingface_format():
8484
cv2.imwrite(file_for_test("generated_pyramid_2.jpeg"), array2)
8585

8686
msg = video_message()
87-
hf_msg = convert_message_to_hf_format(msg, max_n_frames_per_video)
87+
hf_msg = convert_message_to_openai_format(msg, max_n_frames_per_video)
8888
assert len(hf_msg["content"]) > 10

tests/test_model_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55

66
@dataclass
7-
class TestLLM(LLM):
7+
class LLMForTest(LLM):
88
model_id: str
99
model_ids = ["id1", "id2"]
1010

1111

1212
def test_model_entries_from_mult_ids():
13-
e1, e2 = model_entries_from_mult_ids(TestLLM)
13+
e1, e2 = model_entries_from_mult_ids(LLMForTest)
1414
assert e1.model_id == "id1"
1515
assert e2.model_id == "id2"
1616

0 commit comments

Comments
 (0)