Skip to content

Commit 7eae500

Browse files
committed
Improve tests on cached multiturn conversation and fix bugs
1 parent 43f6ecd commit 7eae500

File tree

4 files changed

+91
-30
lines changed

4 files changed

+91
-30
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ models/
77
# C extensions
88
*.so
99

10+
# MacOS
11+
.DS_Store
12+
1013
# Distribution / packaging
1114
.Python
1215
build/

llmlib/llmlib/gemini/gemini_code.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
"""
44

55
from dataclasses import dataclass
6-
from datetime import datetime
76
from functools import singledispatchmethod
87
from io import BytesIO
8+
import json
99
from logging import getLogger
1010
from pathlib import Path
1111
import tempfile
@@ -56,7 +56,7 @@ class GeminiModels(StrEnum):
5656

5757
gemini_15_pro = "gemini-1.5-pro"
5858
gemini_15_flash = "gemini-1.5-flash-002"
59-
gemini_20_flash = "gemini-2.0-flash"
59+
gemini_20_flash = "gemini-2.0-flash-001"
6060
gemini_20_flash_lite = "gemini-2.0-flash-lite-001"
6161

6262

@@ -115,23 +115,29 @@ def _execute_multi_turn_req(req: MultiTurnRequest) -> str:
115115
raise ValueError("Only the first message can have file(s)")
116116

117117
# Prepare Inputs. Use context caching for media
118-
paths = filepaths(msg=req.messages[0])
119-
use_caching = req.use_context_caching and is_long_enough_to_cache(paths)
118+
client = create_client()
119+
contents = [convert_to_gemini_format(msg) for msg in req.messages]
120+
121+
files: list[Path] = filepaths(msg=req.messages[0])
122+
use_caching = req.use_context_caching and is_long_enough_to_cache(files)
120123
if use_caching:
121-
cached_content, blobs = cache_content(req.model_name, tuple(paths))
122-
else:
124+
# Assume caching was done before
125+
cached_content, success = get_cached_content(client, req.model_name, files)
126+
blobs = []
127+
if not success:
128+
cached_content, blobs = cache_content(client, req.model_name, files)
129+
else: # Add files to the content
130+
blobs = upload_files(files=files)
131+
contents = [*blobs_to_parts(blobs), *contents]
123132
cached_content = None
124-
blobs = upload_files(files=paths)
125-
contents = [convert_to_gemini_format(msg) for msg in req.messages]
126133

127134
# Call Gemini
128-
client = create_client()
129135
response: GenerateContentResponse = _call_gemini(
130136
client, req, contents, cached_content
131137
)
132138

133139
# Cleanup
134-
if req.delete_files_after_use and not use_caching:
140+
if req.delete_files_after_use:
135141
delete_blobs(blobs)
136142
return response.text
137143

@@ -161,26 +167,37 @@ def video_duration_in_sec(filename: Path) -> float:
161167
return duration
162168

163169

164-
@ttl_cache(ttl=60 * 60)
165170
def cache_content(
166-
model_id: str, paths: list[Path]
171+
client: genai.Client, model_id: str, paths: list[Path], ttl: str = f"{60 * 20}s"
167172
) -> tuple[CachedContent, list[storage.Blob]]:
168173
"""Caches the content on Google as describe here: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create"""
169174
logger.info("Caching content for paths: %s", paths)
170-
client = create_client()
171175
blobs = upload_files(files=paths)
172176
parts = blobs_to_parts(blobs)
173177
content = Content(role="user", parts=parts)
174-
cached_content = client.caches.create(
175-
model=model_id,
176-
config=CreateCachedContentConfig(
177-
contents=[content],
178-
display_name="multiturn cache for req at %s" % datetime.now(),
179-
),
178+
config = CreateCachedContentConfig(
179+
contents=[content], display_name=cache_id(model_id, paths), ttl=ttl
180180
)
181+
cached_content = client.caches.create(model=model_id, config=config)
181182
return cached_content, blobs
182183

183184

185+
def cache_id(model_id: str, paths: list[Path]) -> str:
186+
return json.dumps(dict(model=model_id, paths=str(paths)))
187+
188+
189+
def get_cached_content(
190+
client: genai.Client, model_id: str, paths: list[Path]
191+
) -> tuple[CachedContent, bool]:
192+
for cache in client.caches.list():
193+
if cache.display_name == cache_id(model_id, paths):
194+
logger.info(
195+
"Found cached content for model_id='%s' and paths='%s'", model_id, paths
196+
)
197+
return cache, True
198+
return None, False
199+
200+
184201
def convert_to_gemini_format(msg: Message) -> tuple[Content, list[storage.Blob]]:
185202
role_map = dict(user="user", assistant="model")
186203
role = role_map[msg.role]
@@ -324,6 +341,7 @@ class GeminiAPI(LLM):
324341
model_id: str = GeminiModels.gemini_20_flash_lite
325342
max_output_tokens: int = 1000
326343
use_context_caching: bool = False
344+
delete_files_after_use: bool = True
327345

328346
requires_gpu_exclusively = False
329347
model_ids = available_models
@@ -333,13 +351,23 @@ def complete_msgs(self, msgs: list[Message]) -> str:
333351
msg = msgs[0]
334352
paths = filepaths(msg)
335353
req = SingleTurnRequest(
336-
model_name=self.model_id, media_files=paths, prompt=msg.msg
354+
model_name=self.model_id,
355+
media_files=paths,
356+
prompt=msg.msg,
357+
max_output_tokens=self.max_output_tokens,
358+
delete_files_after_use=self.delete_files_after_use,
337359
)
338360
else:
361+
delete_files_after_use = self.delete_files_after_use
362+
if self.use_context_caching:
363+
delete_files_after_use = False
364+
339365
req = MultiTurnRequest(
340366
model_name=self.model_id,
341367
messages=msgs,
342368
use_context_caching=self.use_context_caching,
369+
max_output_tokens=self.max_output_tokens,
370+
delete_files_after_use=delete_files_after_use,
343371
)
344372
return req.fetch_media_description()
345373

tests/helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,26 +133,26 @@ def assert_model_supports_multiturn_with_6min_video(model: LLM):
133133
video = file_for_test("tasting travel - rome italy.mp4")
134134
convo = [Message(role="user", msg="What country are they visiting?", video=video)]
135135
answer1 = model.complete_msgs(convo)
136-
assert "italy" in answer1.lower()
136+
assert "italy" in answer1.lower(), answer1
137137

138138
convo.append(Message(role="assistant", msg=answer1))
139139
convo.append(Message(role="user", msg="What food do they eat?"))
140140
answer2 = model.complete_msgs(convo)
141-
assert "lasagna" in answer2.lower()
141+
assert "lasagna" in answer2.lower(), answer2
142142

143143
convo.append(Message(role="assistant", msg=answer2))
144144
convo.append(
145145
Message(role="user", msg="What character appears in the middle of the video?")
146146
)
147147
answer3 = model.complete_msgs(convo)
148-
assert "jesus" in answer3.lower()
148+
assert "jesus" in answer3.lower(), answer3
149149

150150

151151
def assert_model_supports_multiturn_with_picture(model: LLM):
152152
q1_msg = pyramid_message()
153153
a1_txt = model.complete_msgs([q1_msg])
154-
assert "pyramid" in a1_txt.lower()
154+
assert "pyramid" in a1_txt.lower(), a1_txt
155155
a1_msg = Message(role="assistant", msg=a1_txt)
156156
q2_msg = Message(role="user", msg="What country is the picture in?")
157157
a2_txt = model.complete_msgs([q1_msg, a1_msg, q2_msg])
158-
assert "egypt" in a2_txt.lower()
158+
assert "egypt" in a2_txt.lower(), a2_txt

tests/test_gemini.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
from llmlib.gemini.gemini_code import GeminiAPI, GeminiModels
1+
from pathlib import Path
2+
from llmlib.gemini.gemini_code import (
3+
GeminiAPI,
4+
GeminiModels,
5+
cache_content,
6+
create_client,
7+
get_cached_content,
8+
)
9+
from google.genai.types import CachedContent
210
import pytest
311

412
from tests.helpers import (
@@ -8,6 +16,7 @@
816
assert_model_supports_multiturn,
917
assert_model_supports_multiturn_with_6min_video,
1018
assert_model_supports_multiturn_with_picture,
19+
file_for_test,
1120
is_ci,
1221
)
1322

@@ -21,31 +30,52 @@ def test_gemini_vision_using_interface():
2130

2231

2332
@pytest.mark.skipif(condition=is_ci(), reason="Avoid costs")
24-
def test_multiturn_conversation():
33+
def test_multiturn_textonly_conversation():
2534
model = GeminiAPI(model_id=GeminiModels.gemini_20_flash_lite, max_output_tokens=50)
2635
assert_model_supports_multiturn(model)
2736

2837

2938
@pytest.mark.skipif(condition=is_ci(), reason="Avoid costs")
3039
@pytest.mark.parametrize("use_context_caching", [False, True])
31-
def test_multiturn_conversation_with_file(use_context_caching: bool):
40+
def test_multiturn_conversation_with_img(use_context_caching: bool):
3241
model = GeminiAPI(
3342
model_id=GeminiModels.gemini_20_flash_lite,
3443
max_output_tokens=50,
3544
use_context_caching=use_context_caching,
45+
delete_files_after_use=False,
3646
)
3747
assert_model_supports_multiturn_with_picture(model)
3848

3949

4050
@pytest.mark.skipif(condition=is_ci(), reason="Avoid costs")
41-
def test_multiturn_conversation_with_file_and_context_caching():
51+
@pytest.mark.parametrize("use_context_caching", [False, True])
52+
def test_multiturn_conversation_with_6min_video_and_context_caching(
53+
use_context_caching: bool,
54+
):
4255
"""
4356
Context caching is supported only for Gemini 1.5 Pro and Flash
4457
https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-overview#supported_models
4558
"""
4659
model = GeminiAPI(
4760
model_id=GeminiModels.gemini_15_flash,
4861
max_output_tokens=50,
49-
use_context_caching=True,
62+
use_context_caching=use_context_caching,
63+
delete_files_after_use=False,
5064
)
5165
assert_model_supports_multiturn_with_6min_video(model)
66+
67+
68+
def test_get_cached_content():
69+
"""We can cache content and reuse the cache later"""
70+
path: Path = file_for_test("tasting travel - rome italy.mp4")
71+
client = create_client()
72+
model_id = GeminiModels.gemini_15_flash
73+
_, success = get_cached_content(client, model_id=model_id, paths=[path])
74+
assert not success
75+
76+
cache_content(client, model_id=model_id, paths=[path], ttl="60s")
77+
cached_content, success = get_cached_content(
78+
client, model_id=model_id, paths=[path]
79+
)
80+
assert success
81+
assert isinstance(cached_content, CachedContent)

0 commit comments

Comments
 (0)