33"""
44
55from dataclasses import dataclass
6- from datetime import datetime
76from functools import singledispatchmethod
87from io import BytesIO
8+ import json
99from logging import getLogger
1010from pathlib import Path
1111import tempfile
@@ -56,7 +56,7 @@ class GeminiModels(StrEnum):
5656
5757 gemini_15_pro = "gemini-1.5-pro"
5858 gemini_15_flash = "gemini-1.5-flash-002"
59- gemini_20_flash = "gemini-2.0-flash"
59+ gemini_20_flash = "gemini-2.0-flash-001 "
6060 gemini_20_flash_lite = "gemini-2.0-flash-lite-001"
6161
6262
@@ -115,23 +115,29 @@ def _execute_multi_turn_req(req: MultiTurnRequest) -> str:
115115 raise ValueError ("Only the first message can have file(s)" )
116116
117117 # Prepare Inputs. Use context caching for media
118- paths = filepaths (msg = req .messages [0 ])
119- use_caching = req .use_context_caching and is_long_enough_to_cache (paths )
118+ client = create_client ()
119+ contents = [convert_to_gemini_format (msg ) for msg in req .messages ]
120+
121+ files : list [Path ] = filepaths (msg = req .messages [0 ])
122+ use_caching = req .use_context_caching and is_long_enough_to_cache (files )
120123 if use_caching :
121- cached_content , blobs = cache_content (req .model_name , tuple (paths ))
122- else :
124+ # Assume caching was done before
125+ cached_content , success = get_cached_content (client , req .model_name , files )
126+ blobs = []
127+ if not success :
128+ cached_content , blobs = cache_content (client , req .model_name , files )
129+ else : # Add files to the content
130+ blobs = upload_files (files = files )
131+ contents = [* blobs_to_parts (blobs ), * contents ]
123132 cached_content = None
124- blobs = upload_files (files = paths )
125- contents = [convert_to_gemini_format (msg ) for msg in req .messages ]
126133
127134 # Call Gemini
128- client = create_client ()
129135 response : GenerateContentResponse = _call_gemini (
130136 client , req , contents , cached_content
131137 )
132138
133139 # Cleanup
134- if req .delete_files_after_use and not use_caching :
140+ if req .delete_files_after_use :
135141 delete_blobs (blobs )
136142 return response .text
137143
@@ -161,26 +167,37 @@ def video_duration_in_sec(filename: Path) -> float:
161167 return duration
162168
163169
164- @ttl_cache (ttl = 60 * 60 )
165170def cache_content (
166- model_id : str , paths : list [Path ]
171+ client : genai . Client , model_id : str , paths : list [Path ], ttl : str = f" { 60 * 20 } s"
167172) -> tuple [CachedContent , list [storage .Blob ]]:
168173 """Caches the content on Google as describe here: https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-create"""
169174 logger .info ("Caching content for paths: %s" , paths )
170- client = create_client ()
171175 blobs = upload_files (files = paths )
172176 parts = blobs_to_parts (blobs )
173177 content = Content (role = "user" , parts = parts )
174- cached_content = client .caches .create (
175- model = model_id ,
176- config = CreateCachedContentConfig (
177- contents = [content ],
178- display_name = "multiturn cache for req at %s" % datetime .now (),
179- ),
178+ config = CreateCachedContentConfig (
179+ contents = [content ], display_name = cache_id (model_id , paths ), ttl = ttl
180180 )
181+ cached_content = client .caches .create (model = model_id , config = config )
181182 return cached_content , blobs
182183
183184
185+ def cache_id (model_id : str , paths : list [Path ]) -> str :
186+ return json .dumps (dict (model = model_id , paths = str (paths )))
187+
188+
189+ def get_cached_content (
190+ client : genai .Client , model_id : str , paths : list [Path ]
191+ ) -> tuple [CachedContent , bool ]:
192+ for cache in client .caches .list ():
193+ if cache .display_name == cache_id (model_id , paths ):
194+ logger .info (
195+ "Found cached content for model_id='%s' and paths='%s'" , model_id , paths
196+ )
197+ return cache , True
198+ return None , False
199+
200+
184201def convert_to_gemini_format (msg : Message ) -> tuple [Content , list [storage .Blob ]]:
185202 role_map = dict (user = "user" , assistant = "model" )
186203 role = role_map [msg .role ]
@@ -324,6 +341,7 @@ class GeminiAPI(LLM):
324341 model_id : str = GeminiModels .gemini_20_flash_lite
325342 max_output_tokens : int = 1000
326343 use_context_caching : bool = False
344+ delete_files_after_use : bool = True
327345
328346 requires_gpu_exclusively = False
329347 model_ids = available_models
@@ -333,13 +351,23 @@ def complete_msgs(self, msgs: list[Message]) -> str:
333351 msg = msgs [0 ]
334352 paths = filepaths (msg )
335353 req = SingleTurnRequest (
336- model_name = self .model_id , media_files = paths , prompt = msg .msg
354+ model_name = self .model_id ,
355+ media_files = paths ,
356+ prompt = msg .msg ,
357+ max_output_tokens = self .max_output_tokens ,
358+ delete_files_after_use = self .delete_files_after_use ,
337359 )
338360 else :
361+ delete_files_after_use = self .delete_files_after_use
362+ if self .use_context_caching :
363+ delete_files_after_use = False
364+
339365 req = MultiTurnRequest (
340366 model_name = self .model_id ,
341367 messages = msgs ,
342368 use_context_caching = self .use_context_caching ,
369+ max_output_tokens = self .max_output_tokens ,
370+ delete_files_after_use = delete_files_after_use ,
343371 )
344372 return req .fetch_media_description ()
345373
0 commit comments