Skip to content

Commit e9e62a1

Browse files
author
Mika
committed
video prompting
1 parent e5fda53 commit e9e62a1

File tree

2 files changed

+70
-7
lines changed

2 files changed

+70
-7
lines changed

test_uralicnlp.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,14 @@
154154
#uralicApi.import_dictionary_to_db("sms")
155155

156156
#llm = get_llm("chatgpt", open_read(os.path.expanduser("~/.openaiapikey")).read().strip())
157-
#llm = get_llm("gemini", open_read(os.path.expanduser("~/.geminiapikey")).read().strip())
157+
llm = get_llm("gemini", open_read(os.path.expanduser("~/.geminiapikey")).read().strip())
158158
#llm = get_llm("mistral", open_read(os.path.expanduser("~/.mistralapikey")).read().strip())
159159

160160
#llm = get_llm("perplexity", open_read(os.path.expanduser("~/.perplexityapikey")).read().strip())
161161
#llm = get_llm("claude", open_read(os.path.expanduser("~/.claudeapikey")).read().strip())
162162

163+
print(llm.prompt_video("What is happening on this video?", "/Users/mikahama/Downloads/6830385-uhd_4096_2160_25fps.mp4"))
164+
163165
#print(llm.prompt_image("What is this image about?", "/Users/mikahama/Desktop/teams.jpg"))
164166

165167
#print(llm.prompt("I forgot where I put my hat..."))
@@ -203,7 +205,7 @@
203205
#t = TartuTranslator()
204206
#print(t.translate("Hello, how are you?", "eng", "fin"))
205207

206-
llm = get_llm("chatgpt", open_read(os.path.expanduser("~/.openaiapikey")).read().strip(), model="omni-moderation-latest")
207-
print(llm.moderate("those faggots punched idiots and fucked each other."))
208+
#llm = get_llm("chatgpt", open_read(os.path.expanduser("~/.openaiapikey")).read().strip(), model="omni-moderation-latest")
209+
#print(llm.moderate("those faggots punched idiots and fucked each other."))
208210

209211

uralicNLP/llm.py

+65-4
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ def _openai_format_prompt(self, prompt, extra_content):
100100
if extra_content is None:
101101
return prompt
102102
else:
103-
return [{"type": "text", "text": prompt}, extra_content]
103+
if type(extra_content) == list:
104+
return [{"type": "text", "text": prompt}] + extra_content
105+
else:
106+
return [{"type": "text", "text": prompt}, extra_content]
104107

105108
def _embed_cache(func):
106109
def inner(*args, **kwargs):
@@ -117,19 +120,61 @@ def inner(*args, **kwargs):
117120
def _prompt_cache(func):
118121
def inner(*args, **kwargs):
119122
self = args[0]
120-
if self.cache and "_".join(args[1:]) in self._prompt_cache_dict:
121-
return self._prompt_cache_dict["_".join(args[1:])]
123+
if self.cache and "_".join([str(x) for x in args[1:]]) in self._prompt_cache_dict:
124+
return self._prompt_cache_dict["_".join([str(x) for x in args[1:]])]
122125
else:
123126
r = func(*args, **kwargs)
124127
if self.cache:
125-
self._prompt_cache_dict["_".join(args[1:])] = r
128+
self._prompt_cache_dict["_".join([str(x) for x in args[1:]])] = r
126129
return r
127130
return inner
128131

129132
def prompt_image(self, text, image):
130133
prompt_image = self._convert_image(image)
131134
return self._prompt_image_decorated(text, prompt_image)
132135

136+
def _prepare_video_frame(self, frame, size, i, b64):
137+
height, width, channels = frame.shape
138+
if height > width and height > size:
139+
ratio = width/height
140+
frame = cv2.resize(frame, (round(size*ratio), size))
141+
elif width > height and width > size:
142+
ratio = height/width
143+
frame = cv2.resize(frame, (size, round(size*ratio)))
144+
#cv2.imwrite("tmp/img" + str(i) + ".png", frame)
145+
_, buffer = cv2.imencode(".png", frame)
146+
if b64:
147+
return base64.b64encode(buffer).decode("utf-8")
148+
else:
149+
return buffer
150+
151+
@_prompt_cache
152+
def prompt_video(self, text, video, size=1000, n_frames=5):
153+
frames = self._process_video(video,size,n_frames,b64=True)
154+
return self._prompt_video(text, frames)
155+
156+
157+
def _process_video(self, video, size, n_frames, b64=True):
158+
video = cv2.VideoCapture(video)
159+
frames = []
160+
while video.isOpened():
161+
success, frame = video.read()
162+
if not success:
163+
break
164+
frames.append(frame)
165+
166+
video.release()
167+
if (n_frames) >= len(frames):
168+
intervals = 1
169+
elif n_frames <= 0:
170+
intervals = 1
171+
else:
172+
intervals = round(len(frames)/n_frames)
173+
frames = frames[0::intervals]
174+
frames = [self._prepare_video_frame(x, size, i, b64) for i, x in enumerate(frames)]
175+
return frames
176+
177+
133178
def _convert_image(self, image):
134179
prompt_image = None
135180
if type(image) == str:
@@ -170,6 +215,9 @@ def prompt(self, text):
170215
def _prompt(self, text):
171216
raise NotImplementedException("LLM does not support prompting")
172217

218+
def _prompt_video(self, text, frames):
219+
raise NotImplementedException("LLM does not support prompting")
220+
173221
@_embed_cache
174222
def embed(self, text):
175223
return self._embed(text)
@@ -249,6 +297,10 @@ def set_system_prompt(self, text):
249297
def _prompt_image(self, text, prompt_image):
250298
return self._prompt(text, extra_content = {"type":"image_url", "image_url": {"url": prompt_image} })
251299

300+
def _prompt_video(self, text, frames):
301+
extra_content = [{"type":"image_url", "image_url": {"url": "data:image/png;base64,"+ prompt_image} } for prompt_image in frames]
302+
return self._prompt(text, extra_content=extra_content)
303+
252304
def _prompt(self, prompt, temperature=1, extra_content=None):
253305
prompt = self._openai_format_prompt(prompt, extra_content)
254306
chat_completion = self.client.chat.completions.create(
@@ -296,6 +348,11 @@ def __init__(self, api_key, model="gemini-1.5-flash", task_type="retrieval_docum
296348
self.model_name = model
297349
self.task_type = task_type
298350

351+
def _prompt_video(self, text, frames):
352+
img_prompts = [{'mime_type':'image/png', 'data': img} for img in frames]
353+
response = self.model.generate_content(img_prompts + [text])
354+
return response.text
355+
299356
def set_system_prompt(self, text):
300357
self.model = genai.GenerativeModel(self.model_name, system_instruction=text)
301358

@@ -381,6 +438,10 @@ def _prompt_image(self, text, prompt_image):
381438
def set_system_prompt(self, text):
382439
self.system = text
383440

441+
def _prompt_video(self, text, frames):
442+
extra_content = [{"type":"image", "source": {"data": prompt_image, "type":"base64", "media_type":"image/png" }} for prompt_image in frames]
443+
return self._prompt(text, extra_content=extra_content)
444+
384445
def _prompt(self, prompt, temperature=1, extra_content=None):
385446
prompt = self._openai_format_prompt(prompt, extra_content)
386447
if self.system:

0 commit comments

Comments
 (0)