@@ -100,7 +100,10 @@ def _openai_format_prompt(self, prompt, extra_content):
100
100
if extra_content is None :
101
101
return prompt
102
102
else :
103
- return [{"type" : "text" , "text" : prompt }, extra_content ]
103
+ if type (extra_content ) == list :
104
+ return [{"type" : "text" , "text" : prompt }] + extra_content
105
+ else :
106
+ return [{"type" : "text" , "text" : prompt }, extra_content ]
104
107
105
108
def _embed_cache (func ):
106
109
def inner (* args , ** kwargs ):
@@ -117,19 +120,61 @@ def inner(*args, **kwargs):
117
120
def _prompt_cache (func ):
118
121
def inner (* args , ** kwargs ):
119
122
self = args [0 ]
120
- if self .cache and "_" .join (args [1 :]) in self ._prompt_cache_dict :
121
- return self ._prompt_cache_dict ["_" .join (args [1 :])]
123
+ if self .cache and "_" .join ([ str ( x ) for x in args [1 :] ]) in self ._prompt_cache_dict :
124
+ return self ._prompt_cache_dict ["_" .join ([ str ( x ) for x in args [1 :] ])]
122
125
else :
123
126
r = func (* args , ** kwargs )
124
127
if self .cache :
125
- self ._prompt_cache_dict ["_" .join (args [1 :])] = r
128
+ self ._prompt_cache_dict ["_" .join ([ str ( x ) for x in args [1 :] ])] = r
126
129
return r
127
130
return inner
128
131
129
132
def prompt_image (self , text , image ):
130
133
prompt_image = self ._convert_image (image )
131
134
return self ._prompt_image_decorated (text , prompt_image )
132
135
136
+ def _prepare_video_frame (self , frame , size , i , b64 ):
137
+ height , width , channels = frame .shape
138
+ if height > width and height > size :
139
+ ratio = width / height
140
+ frame = cv2 .resize (frame , (round (size * ratio ), size ))
141
+ elif width > height and width > size :
142
+ ratio = height / width
143
+ frame = cv2 .resize (frame , (size , round (size * ratio )))
144
+ #cv2.imwrite("tmp/img" + str(i) + ".png", frame)
145
+ _ , buffer = cv2 .imencode (".png" , frame )
146
+ if b64 :
147
+ return base64 .b64encode (buffer ).decode ("utf-8" )
148
+ else :
149
+ return buffer
150
+
151
+ @_prompt_cache
152
+ def prompt_video (self , text , video , size = 1000 , n_frames = 5 ):
153
+ frames = self ._process_video (video ,size ,n_frames ,b64 = True )
154
+ return self ._prompt_video (text , frames )
155
+
156
+
157
+ def _process_video (self , video , size , n_frames , b64 = True ):
158
+ video = cv2 .VideoCapture (video )
159
+ frames = []
160
+ while video .isOpened ():
161
+ success , frame = video .read ()
162
+ if not success :
163
+ break
164
+ frames .append (frame )
165
+
166
+ video .release ()
167
+ if (n_frames ) >= len (frames ):
168
+ intervals = 1
169
+ elif n_frames <= 0 :
170
+ intervals = 1
171
+ else :
172
+ intervals = round (len (frames )/ n_frames )
173
+ frames = frames [0 ::intervals ]
174
+ frames = [self ._prepare_video_frame (x , size , i , b64 ) for i , x in enumerate (frames )]
175
+ return frames
176
+
177
+
133
178
def _convert_image (self , image ):
134
179
prompt_image = None
135
180
if type (image ) == str :
@@ -170,6 +215,9 @@ def prompt(self, text):
170
215
def _prompt (self , text ):
171
216
raise NotImplementedException ("LLM does not support prompting" )
172
217
218
+ def _prompt_video (self , text , frames ):
219
+ raise NotImplementedException ("LLM does not support prompting" )
220
+
173
221
@_embed_cache
174
222
def embed (self , text ):
175
223
return self ._embed (text )
@@ -249,6 +297,10 @@ def set_system_prompt(self, text):
249
297
def _prompt_image (self , text , prompt_image ):
250
298
return self ._prompt (text , extra_content = {"type" :"image_url" , "image_url" : {"url" : prompt_image } })
251
299
300
+ def _prompt_video (self , text , frames ):
301
+ extra_content = [{"type" :"image_url" , "image_url" : {"url" : "data:image/png;base64," + prompt_image } } for prompt_image in frames ]
302
+ return self ._prompt (text , extra_content = extra_content )
303
+
252
304
def _prompt (self , prompt , temperature = 1 , extra_content = None ):
253
305
prompt = self ._openai_format_prompt (prompt , extra_content )
254
306
chat_completion = self .client .chat .completions .create (
@@ -296,6 +348,11 @@ def __init__(self, api_key, model="gemini-1.5-flash", task_type="retrieval_docum
296
348
self .model_name = model
297
349
self .task_type = task_type
298
350
351
+ def _prompt_video (self , text , frames ):
352
+ img_prompts = [{'mime_type' :'image/png' , 'data' : img } for img in frames ]
353
+ response = self .model .generate_content (img_prompts + [text ])
354
+ return response .text
355
+
299
356
def set_system_prompt (self , text ):
300
357
self .model = genai .GenerativeModel (self .model_name , system_instruction = text )
301
358
@@ -381,6 +438,10 @@ def _prompt_image(self, text, prompt_image):
381
438
def set_system_prompt (self , text ):
382
439
self .system = text
383
440
441
+ def _prompt_video (self , text , frames ):
442
+ extra_content = [{"type" :"image" , "source" : {"data" : prompt_image , "type" :"base64" , "media_type" :"image/png" }} for prompt_image in frames ]
443
+ return self ._prompt (text , extra_content = extra_content )
444
+
384
445
def _prompt (self , prompt , temperature = 1 , extra_content = None ):
385
446
prompt = self ._openai_format_prompt (prompt , extra_content )
386
447
if self .system :
0 commit comments