|
| 1 | +from dataclasses import dataclass |
| 2 | +from pathlib import Path |
| 3 | +from llmlib.base_llm import LLM, validate_only_first_message_has_files |
| 4 | +import torch |
| 5 | +from llmlib.huggingface_inference import Message, is_img, is_video, video_to_imgs |
| 6 | +from transformers import AutoProcessor, Gemma3ForConditionalGeneration |
| 7 | + |
| 8 | + |
| 9 | +@dataclass |
| 10 | +class Gemma3Local(LLM): |
| 11 | + model_id: str |
| 12 | + max_n_frames_per_video: int = 100 |
| 13 | + max_new_tokens: int = 500 |
| 14 | + |
| 15 | + model_ids = [ |
| 16 | + "google/gemma-3-4b-it", |
| 17 | + "google/gemma-3-27b-it", |
| 18 | + ] |
| 19 | + |
| 20 | + def __post_init__(self): |
| 21 | + self.model = Gemma3ForConditionalGeneration.from_pretrained( |
| 22 | + self.model_id, |
| 23 | + device_map="auto", |
| 24 | + torch_dtype=torch.bfloat16, |
| 25 | + ).eval() |
| 26 | + self.processor = AutoProcessor.from_pretrained(self.model_id) |
| 27 | + |
| 28 | + def complete_msgs(self, msgs: list[Message]) -> str: |
| 29 | + """Complete a conversation with the model.""" |
| 30 | + validate_only_first_message_has_files(msgs) |
| 31 | + |
| 32 | + messages: list[dict] = [ |
| 33 | + convert_msg_to_gemma3_format(msg, self.max_n_frames_per_video) |
| 34 | + for msg in msgs |
| 35 | + ] |
| 36 | + |
| 37 | + inputs = self.processor.apply_chat_template( |
| 38 | + messages, |
| 39 | + add_generation_prompt=True, |
| 40 | + tokenize=True, |
| 41 | + return_dict=True, |
| 42 | + return_tensors="pt", |
| 43 | + ).to(self.model.device) |
| 44 | + |
| 45 | + with torch.inference_mode(): |
| 46 | + outputs = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens) |
| 47 | + |
| 48 | + input_len = len(inputs["input_ids"][0]) |
| 49 | + response: str = self.processor.decode( |
| 50 | + outputs[0][input_len:], skip_special_tokens=True |
| 51 | + ) |
| 52 | + return response |
| 53 | + |
| 54 | + |
| 55 | +def convert_msg_to_gemma3_format(msg: Message, max_n_frames_per_video: int) -> dict: |
| 56 | + dict_msg = {"role": msg.role, "content": []} |
| 57 | + if msg.img is not None: |
| 58 | + image = msg.img |
| 59 | + if isinstance(image, Path): |
| 60 | + image = str(image) |
| 61 | + dict_msg["content"].append({"type": "image", "image": image}) |
| 62 | + if msg.video is not None: |
| 63 | + imgs: list = video_to_imgs(msg.video, max_n_frames_per_video) |
| 64 | + for img in imgs: |
| 65 | + dict_msg["content"].append({"type": "image", "image": img}) |
| 66 | + if msg.files is not None: |
| 67 | + for filepath in msg.files: |
| 68 | + if is_img(filepath): |
| 69 | + dict_msg["content"].append({"type": "image", "image": str(filepath)}) |
| 70 | + elif is_video(filepath): |
| 71 | + imgs: list = video_to_imgs(filepath, max_n_frames_per_video) |
| 72 | + for img in imgs: |
| 73 | + dict_msg["content"].append({"type": "image", "image": img}) |
| 74 | + if msg.msg: |
| 75 | + dict_msg["content"].append({"type": "text", "text": msg.msg}) |
| 76 | + return dict_msg |
0 commit comments