Skip to content

Commit fcfe08c

Browse files
committed
Simplify code, upload files individually if #files <= 3
1 parent 322f6a7 commit fcfe08c

File tree

1 file changed

+33
-73
lines changed

1 file changed

+33
-73
lines changed

llmlib/llmlib/gemini/gemini_code.py

Lines changed: 33 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import json
99
from logging import getLogger
1010
from pathlib import Path
11-
import tempfile
1211
from typing import Any
1312
from google.cloud import storage
1413
from google.cloud.storage import transfer_manager
@@ -62,19 +61,6 @@ class GeminiModels(StrEnum):
6261
available_models = list(GeminiModels)
6362

6463

65-
@dataclass
66-
class SingleTurnRequest:
67-
media_files: list[Path]
68-
model_name: GeminiModels = GeminiModels.gemini_15_pro
69-
prompt: str = "Describe this video in detail."
70-
max_output_tokens: int = 1000
71-
safety_filter_threshold: HarmBlockThreshold = HarmBlockThreshold.BLOCK_NONE
72-
delete_files_after_use: bool = True
73-
74-
def fetch_media_description(self) -> str:
75-
return _execute_single_turn_req(self)
76-
77-
7864
@dataclass
7965
class MultiTurnRequest:
8066
messages: list[Message]
@@ -88,20 +74,6 @@ def fetch_media_description(self) -> str:
8874
return _execute_multi_turn_req(self)
8975

9076

91-
@notify_bugsnag
92-
def _execute_single_turn_req(req: SingleTurnRequest) -> str:
93-
# Prepare Inputs. Upload media files
94-
blobs = upload_files(files=req.media_files)
95-
contents = [*blobs_to_parts(blobs), req.prompt]
96-
# Call Gemini
97-
client = create_client()
98-
response: GenerateContentResponse = _call_gemini(client, req, contents)
99-
# Cleanup
100-
if req.delete_files_after_use:
101-
delete_blobs(blobs)
102-
return response.text
103-
104-
10577
@notify_bugsnag
10678
def _execute_multi_turn_req(req: MultiTurnRequest) -> str:
10779
# Validation: Only the first message can have file(s)
@@ -202,7 +174,7 @@ def convert_to_gemini_format(msg: Message) -> tuple[Content, list[storage.Blob]]
202174

203175
def _call_gemini(
204176
client: genai.Client,
205-
req: SingleTurnRequest | MultiTurnRequest,
177+
req: MultiTurnRequest,
206178
contents: list[Part],
207179
cached_content: CachedContent | None = None,
208180
) -> GenerateContentResponse:
@@ -276,16 +248,24 @@ def mime_type(file_name: str) -> str:
276248
def upload_files(files: list[Path]) -> list[storage.Blob]:
277249
if len(files) == 0:
278250
return []
279-
logger.info("Uploading %d file(s)", len(files))
280-
bucket = _bucket(name=Buckets.temp)
251+
if len(files) <= 3:
252+
blobs = [_upload_single_file(file, Buckets.temp) for file in files]
253+
else:
254+
blobs = _upload_batchof_files(files, bucket_name=Buckets.temp)
255+
return blobs
256+
257+
258+
def _upload_batchof_files(files: list[Path], bucket_name: str) -> list[storage.Blob]:
259+
logger.info("Uploading %d file(s) in batch to bucket '%s'", len(files), bucket_name)
260+
bucket = _bucket(name=bucket_name)
281261
files_str = [str(f) for f in files]
282262
blobs = [bucket.blob(file.name) for file in files]
283263
transfer_manager.upload_many(
284264
file_blob_pairs=zip(files_str, blobs),
285265
skip_if_exists=True,
286266
raise_exception=True,
287267
)
288-
logger.info("Completed file(s) upload")
268+
logger.info("Completed batch upload of %d file(s)", len(files))
289269
return blobs
290270

291271

@@ -294,12 +274,13 @@ def _bucket(name: str) -> storage.Bucket:
294274
return client.bucket(name)
295275

296276

297-
def upload_single_file(file: Path, bucket: str, blob_name: str) -> storage.Blob:
298-
logger.info("Uploading file '%s' to bucket '%s' as '%s'", file, bucket, blob_name)
299-
bucket: storage.Bucket = _bucket(name=bucket)
300-
blob = bucket.blob(blob_name)
277+
def _upload_single_file(file: Path, bucket_name: str) -> storage.Blob:
278+
bucket: storage.Bucket = _bucket(name=bucket_name)
279+
blob = bucket.blob(file.name)
301280
if blob.exists():
302-
logger.info("Blob '%s' already exists. Overwriting it...", blob_name)
281+
logger.info("Blob '%s' already exists. Skipping upload...", blob.name)
282+
return blob
283+
logger.info("Uploading file '%s' to bucket '%s'", file.name, bucket_name)
303284
blob.upload_from_filename(str(file))
304285
return blob
305286

@@ -342,48 +323,27 @@ class GeminiAPI(LLM):
342323
model_ids = available_models
343324

344325
def complete_msgs(self, msgs: list[Message]) -> str:
345-
if len(msgs) == 1:
346-
msg = msgs[0]
347-
paths = filepaths(msg)
348-
req = SingleTurnRequest(
349-
model_name=self.model_id,
350-
media_files=paths,
351-
prompt=msg.msg,
352-
max_output_tokens=self.max_output_tokens,
353-
delete_files_after_use=self.delete_files_after_use,
354-
)
355-
else:
356-
delete_files_after_use = self.delete_files_after_use
357-
if self.use_context_caching:
358-
delete_files_after_use = False
359-
360-
req = MultiTurnRequest(
361-
model_name=self.model_id,
362-
messages=msgs,
363-
use_context_caching=self.use_context_caching,
364-
max_output_tokens=self.max_output_tokens,
365-
delete_files_after_use=delete_files_after_use,
366-
)
326+
delete_files_after_use = self.delete_files_after_use
327+
if self.use_context_caching:
328+
delete_files_after_use = False
329+
330+
req = MultiTurnRequest(
331+
model_name=self.model_id,
332+
messages=msgs,
333+
use_context_caching=self.use_context_caching,
334+
max_output_tokens=self.max_output_tokens,
335+
delete_files_after_use=delete_files_after_use,
336+
)
367337
return req.fetch_media_description()
368338

369339
@singledispatchmethod
370-
def video_prompt(self, video, prompt: str) -> str:
371-
raise NotImplementedError(f"Unsupported video type: {type(video)}")
372-
373-
@video_prompt.register
374-
def _(self, video: Path, prompt: str) -> str:
375-
req = SingleTurnRequest(
376-
model_name=self.model_id, media_files=[video], prompt=prompt
340+
def video_prompt(self, video: Path | BytesIO, prompt: str) -> str:
341+
req = MultiTurnRequest(
342+
model_name=self.model_id,
343+
messages=[Message(role="user", msg=prompt, video=video)],
377344
)
378345
return req.fetch_media_description()
379346

380-
@video_prompt.register
381-
def _(self, video: BytesIO, prompt: str) -> str:
382-
path = tempfile.mktemp(suffix=".mp4")
383-
with open(path, "wb") as f:
384-
f.write(video.getvalue())
385-
return self.video_prompt(Path(path), prompt)
386-
387347
@classmethod
388348
def get_warnings(cls) -> list[str]:
389349
return [

0 commit comments

Comments
 (0)