88import json
99from logging import getLogger
1010from pathlib import Path
11- import tempfile
1211from typing import Any
1312from google .cloud import storage
1413from google .cloud .storage import transfer_manager
@@ -62,19 +61,6 @@ class GeminiModels(StrEnum):
6261available_models = list (GeminiModels )
6362
6463
65- @dataclass
66- class SingleTurnRequest :
67- media_files : list [Path ]
68- model_name : GeminiModels = GeminiModels .gemini_15_pro
69- prompt : str = "Describe this video in detail."
70- max_output_tokens : int = 1000
71- safety_filter_threshold : HarmBlockThreshold = HarmBlockThreshold .BLOCK_NONE
72- delete_files_after_use : bool = True
73-
74- def fetch_media_description (self ) -> str :
75- return _execute_single_turn_req (self )
76-
77-
7864@dataclass
7965class MultiTurnRequest :
8066 messages : list [Message ]
@@ -88,20 +74,6 @@ def fetch_media_description(self) -> str:
8874 return _execute_multi_turn_req (self )
8975
9076
91- @notify_bugsnag
92- def _execute_single_turn_req (req : SingleTurnRequest ) -> str :
93- # Prepare Inputs. Upload media files
94- blobs = upload_files (files = req .media_files )
95- contents = [* blobs_to_parts (blobs ), req .prompt ]
96- # Call Gemini
97- client = create_client ()
98- response : GenerateContentResponse = _call_gemini (client , req , contents )
99- # Cleanup
100- if req .delete_files_after_use :
101- delete_blobs (blobs )
102- return response .text
103-
104-
10577@notify_bugsnag
10678def _execute_multi_turn_req (req : MultiTurnRequest ) -> str :
10779 # Validation: Only the first message can have file(s)
@@ -202,7 +174,7 @@ def convert_to_gemini_format(msg: Message) -> tuple[Content, list[storage.Blob]]
202174
203175def _call_gemini (
204176 client : genai .Client ,
205- req : SingleTurnRequest | MultiTurnRequest ,
177+ req : MultiTurnRequest ,
206178 contents : list [Part ],
207179 cached_content : CachedContent | None = None ,
208180) -> GenerateContentResponse :
@@ -276,16 +248,24 @@ def mime_type(file_name: str) -> str:
276248def upload_files (files : list [Path ]) -> list [storage .Blob ]:
277249 if len (files ) == 0 :
278250 return []
279- logger .info ("Uploading %d file(s)" , len (files ))
280- bucket = _bucket (name = Buckets .temp )
251+ if len (files ) <= 3 :
252+ blobs = [_upload_single_file (file , Buckets .temp ) for file in files ]
253+ else :
254+ blobs = _upload_batchof_files (files , bucket_name = Buckets .temp )
255+ return blobs
256+
257+
258+ def _upload_batchof_files (files : list [Path ], bucket_name : str ) -> list [storage .Blob ]:
259+ logger .info ("Uploading %d file(s) in batch to bucket '%s'" , len (files ), bucket_name )
260+ bucket = _bucket (name = bucket_name )
281261 files_str = [str (f ) for f in files ]
282262 blobs = [bucket .blob (file .name ) for file in files ]
283263 transfer_manager .upload_many (
284264 file_blob_pairs = zip (files_str , blobs ),
285265 skip_if_exists = True ,
286266 raise_exception = True ,
287267 )
288- logger .info ("Completed file(s) upload" )
268+ logger .info ("Completed batch upload of %d file(s)" , len ( files ) )
289269 return blobs
290270
291271
@@ -294,12 +274,13 @@ def _bucket(name: str) -> storage.Bucket:
294274 return client .bucket (name )
295275
296276
297- def upload_single_file (file : Path , bucket : str , blob_name : str ) -> storage .Blob :
298- logger .info ("Uploading file '%s' to bucket '%s' as '%s'" , file , bucket , blob_name )
299- bucket : storage .Bucket = _bucket (name = bucket )
300- blob = bucket .blob (blob_name )
277+ def _upload_single_file (file : Path , bucket_name : str ) -> storage .Blob :
278+ bucket : storage .Bucket = _bucket (name = bucket_name )
279+ blob = bucket .blob (file .name )
301280 if blob .exists ():
302- logger .info ("Blob '%s' already exists. Overwriting it..." , blob_name )
281+ logger .info ("Blob '%s' already exists. Skipping upload..." , blob .name )
282+ return blob
283+ logger .info ("Uploading file '%s' to bucket '%s'" , file .name , bucket_name )
303284 blob .upload_from_filename (str (file ))
304285 return blob
305286
@@ -342,48 +323,27 @@ class GeminiAPI(LLM):
342323 model_ids = available_models
343324
344325 def complete_msgs (self , msgs : list [Message ]) -> str :
345- if len (msgs ) == 1 :
346- msg = msgs [0 ]
347- paths = filepaths (msg )
348- req = SingleTurnRequest (
349- model_name = self .model_id ,
350- media_files = paths ,
351- prompt = msg .msg ,
352- max_output_tokens = self .max_output_tokens ,
353- delete_files_after_use = self .delete_files_after_use ,
354- )
355- else :
356- delete_files_after_use = self .delete_files_after_use
357- if self .use_context_caching :
358- delete_files_after_use = False
359-
360- req = MultiTurnRequest (
361- model_name = self .model_id ,
362- messages = msgs ,
363- use_context_caching = self .use_context_caching ,
364- max_output_tokens = self .max_output_tokens ,
365- delete_files_after_use = delete_files_after_use ,
366- )
326+ delete_files_after_use = self .delete_files_after_use
327+ if self .use_context_caching :
328+ delete_files_after_use = False
329+
330+ req = MultiTurnRequest (
331+ model_name = self .model_id ,
332+ messages = msgs ,
333+ use_context_caching = self .use_context_caching ,
334+ max_output_tokens = self .max_output_tokens ,
335+ delete_files_after_use = delete_files_after_use ,
336+ )
367337 return req .fetch_media_description ()
368338
369339 @singledispatchmethod
370- def video_prompt (self , video , prompt : str ) -> str :
371- raise NotImplementedError (f"Unsupported video type: { type (video )} " )
372-
373- @video_prompt .register
374- def _ (self , video : Path , prompt : str ) -> str :
375- req = SingleTurnRequest (
376- model_name = self .model_id , media_files = [video ], prompt = prompt
340+ def video_prompt (self , video : Path | BytesIO , prompt : str ) -> str :
341+ req = MultiTurnRequest (
342+ model_name = self .model_id ,
343+ messages = [Message (role = "user" , msg = prompt , video = video )],
377344 )
378345 return req .fetch_media_description ()
379346
380- @video_prompt .register
381- def _ (self , video : BytesIO , prompt : str ) -> str :
382- path = tempfile .mktemp (suffix = ".mp4" )
383- with open (path , "wb" ) as f :
384- f .write (video .getvalue ())
385- return self .video_prompt (Path (path ), prompt )
386-
387347 @classmethod
388348 def get_warnings (cls ) -> list [str ]:
389349 return [
0 commit comments