diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b85b710bf4a1..943496db6cc3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -525,6 +525,8 @@ title: glm4 - local: model_doc/glm4_moe title: glm4_moe + - local: model_doc/glm_image + title: GlmImage - local: model_doc/openai-gpt title: GPT - local: model_doc/gpt_neo diff --git a/docs/source/en/model_doc/glm46v.md b/docs/source/en/model_doc/glm46v.md index ab62530a438a..bc5cbdc4ee43 100644 --- a/docs/source/en/model_doc/glm46v.md +++ b/docs/source/en/model_doc/glm46v.md @@ -1,22 +1,55 @@ - -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-15.* +*This model was released on 2025-12-09 and added to Hugging Face Transformers on 2025-11-15.* # GLM-4.6V +## Overview + +The GLM-V model was proposed in [GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning](https://huggingface.co/papers/2507.01006v6). + +The abstract from the paper is the following: + +> *We present GLM-4.1V-Thinking, GLM-4.5V, and GLM-4.6V, a family of vision-language models (VLMs) designed to advance +general-purpose multimodal understanding and reasoning. In this report, we share our key findings in the development of +the reasoning-centric training framework. We first develop a capable vision foundation model with significant potential +through large-scale pre-training, which arguably sets the upper bound for the final performance. We then propose +Reinforcement Learning with Curriculum Sampling (RLCS) to unlock the full potential of the model, leading to +comprehensive capability enhancement across a diverse range of tasks, including STEM problem solving, video +understanding, content recognition, coding, grounding, GUI-based agents, and long document interpretation. In a +comprehensive evaluation across 42 public benchmarks, GLM-4.5V achieves state-of-the-art performance on nearly all tasks +among open-source models of similar size, and demonstrates competitive or even superior results compared to +closed-source models such as Gemini-2.5-Flash on challenging tasks including Coding and GUI Agents. Meanwhile, the +smaller GLM-4.1V-9B-Thinking remains highly competitive-achieving superior results to the much larger Qwen2.5-VL-72B on +29 benchmarks. We open-source both GLM-4.1V-9B-Thinking and GLM-4.5V. We further introduce the GLM-4.6V series, +open-source multimodal models with native tool use and a 128K context window. A brief overview is available at this +https URL. Code, models and more information are released at https://github.com/zai-org/GLM-V* + +## Support Model + +This Model Processor support these model of zai-org: + ++ [GLM-4.6V-Flash](https://huggingface.co/zai-org/GLM-4.6V-Flash) ++ [GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V) + +This model was contributed by [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) and [Yuxuan Zhang](https://huggingface.co/ZHANGYUXUAN-zR). + ## Glm46VConfig [[autodoc]] Glm46VConfig diff --git a/docs/source/en/model_doc/glm4v.md b/docs/source/en/model_doc/glm4v.md index 38e43ab5c5c8..206287f9d576 100644 --- a/docs/source/en/model_doc/glm4v.md +++ b/docs/source/en/model_doc/glm4v.md @@ -1,49 +1,61 @@ - *This model was released on 2025-07-01 and added to Hugging Face Transformers on 2025-06-25.* -
-
-PyTorch -FlashAttention -SDPA
-
- -# GLM-4.1V +# GLM-V ## Overview -**GLM-4.1V-9B-Thinking** is a bilingual vision-language model optimized for reasoning, built on GLM-4-9B. It introduces -a "thinking paradigm" with reinforcement learning, achieving state-of-the-art results among 10B-class models and -rivaling 72B-scale models. It supports 64k context, 4K resolution, and arbitrary aspect ratios, with an open-source base -model for further research. You can check our paper [here](https://huggingface.co/papers/2507.01006). and below is a abstract. - -*We present GLM-4.1V-Thinking, a vision-language model (VLM) designed to advance general-purpose multimodal understanding -and reasoning. In this report, we share our key findings in the development of the reasoning-centric training framework. -We first develop a capable vision foundation model with significant potential through large-scale pre-training, which -arguably sets the upper bound for the final performance. We then propose Reinforcement Learning with Curriculum -Sampling (RLCS) to unlock the full potential of the model, leading to comprehensive capability enhancement across a -diverse range of tasks, including STEM problem solving, video understanding, content recognition, coding, grounding, -GUI-based agents, and long document understanding. We open-source GLM-4.1V-9B-Thinking, which achieves state-of-the-art -performance among models of comparable size. In a comprehensive evaluation across 28 public benchmarks, our model -outperforms Qwen2.5-VL-7B on nearly all tasks and achieves comparable or even superior performance on 18 benchmarks -relative to the significantly larger Qwen2.5-VL-72B. Notably, GLM-4.1V-9B-Thinking also demonstrates competitive or -superior performance compared to closed-source models such as GPT-4o on challenging tasks including long document -understanding and STEM reasoning, further underscoring its strong capabilities. Code, models and more information -are released at https://github.com/THUDM/GLM-4.1V-Thinking.* +The GLM-V model was proposed in [GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning](https://huggingface.co/papers/2507.01006v6). + +The abstract from the paper is the following: + +> *We present GLM-4.1V-Thinking, GLM-4.5V, and GLM-4.6V, a family of vision-language models (VLMs) designed to advance +general-purpose multimodal understanding and reasoning. In this report, we share our key findings in the development of +the reasoning-centric training framework. We first develop a capable vision foundation model with significant potential +through large-scale pre-training, which arguably sets the upper bound for the final performance. We then propose +Reinforcement Learning with Curriculum Sampling (RLCS) to unlock the full potential of the model, leading to +comprehensive capability enhancement across a diverse range of tasks, including STEM problem solving, video +understanding, content recognition, coding, grounding, GUI-based agents, and long document interpretation. In a +comprehensive evaluation across 42 public benchmarks, GLM-4.5V achieves state-of-the-art performance on nearly all tasks +among open-source models of similar size, and demonstrates competitive or even superior results compared to +closed-source models such as Gemini-2.5-Flash on challenging tasks including Coding and GUI Agents. Meanwhile, the +smaller GLM-4.1V-9B-Thinking remains highly competitive-achieving superior results to the much larger Qwen2.5-VL-72B on +29 benchmarks. We open-source both GLM-4.1V-9B-Thinking and GLM-4.5V. We further introduce the GLM-4.6V series, +open-source multimodal models with native tool use and a 128K context window. A brief overview is available at this +https URL. Code, models and more information are released at https://github.com/zai-org/GLM-V* + +## Support Model + +This Model type support these model of zai-org: + ++ [GLM-4.1V-9B-Base](https://huggingface.co/zai-org/GLM-4.1V-9B-Base) ++ [GLM-4.1V-9B-Thinking](https://huggingface.co/zai-org/GLM-4.1V-9B-Thinking) ++ [GLM-4.6V-Flash](https://huggingface.co/zai-org/GLM-4.6V-Flash) ++ [AutoGLM-Phone-9B](https://huggingface.co/zai-org/AutoGLM-Phone-9B) ++ [AutoGLM-Phone-9B-Multilingual](https://huggingface.co/zai-org/AutoGLM-Phone-9B-Multilingual) ++ [Glyph](https://huggingface.co/zai-org/Glyph) ++ [WebVIA-Agent](https://huggingface.co/zai-org/WebVIA-Agent) ++ [UI2Code_N](https://huggingface.co/zai-org/UI2Code_N) + +This model was contributed by [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) +and [Yuxuan Zhang](https://huggingface.co/ZHANGYUXUAN-zR). ## Usage @@ -55,6 +67,7 @@ The example below demonstrates how to generate text based on an image with [`Pip ```py import torch from transformers import pipeline + pipe = pipeline( task="image-text-to-text", model="THUDM/GLM-4.1V-9B-Thinking", @@ -69,11 +82,11 @@ messages = [ "type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", }, - { "type": "text", "text": "Describe this image."}, + {"type": "text", "text": "Describe this image."}, ] } ] -pipe(text=messages,max_new_tokens=20, return_full_text=False) +pipe(text=messages, max_new_tokens=20, return_full_text=False) ``` @@ -92,15 +105,15 @@ model = Glm4vForConditionalGeneration.from_pretrained( processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking") messages = [ { - "role":"user", - "content":[ + "role": "user", + "content": [ { - "type":"image", + "type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" }, { - "type":"text", - "text":"Describe this image." + "type": "text", + "text": "Describe this image." } ] } @@ -117,10 +130,10 @@ inputs = processor.apply_chat_template( generated_ids = model.generate(**inputs, max_new_tokens=128) generated_ids_trimmed = [ - out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print(output_text) ``` @@ -160,9 +173,10 @@ messages = [ ], } ] -inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", padding=True).to(model.device) +inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, + return_tensors="pt", padding=True).to(model.device) generated_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=1.0) -output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True) +output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) print(output_text) ``` @@ -181,17 +195,17 @@ print(output_text) ## Glm4vImageProcessor [[autodoc]] Glm4vImageProcessor - - preprocess +- preprocess ## Glm4vVideoProcessor [[autodoc]] Glm4vVideoProcessor - - preprocess +- preprocess ## Glm4vImageProcessorFast [[autodoc]] Glm4vImageProcessorFast - - preprocess +- preprocess ## Glm4vProcessor @@ -201,19 +215,19 @@ print(output_text) ## Glm4vVisionModel [[autodoc]] Glm4vVisionModel - - forward +- forward ## Glm4vTextModel [[autodoc]] Glm4vTextModel - - forward +- forward ## Glm4vModel [[autodoc]] Glm4vModel - - forward +- forward ## Glm4vForConditionalGeneration [[autodoc]] Glm4vForConditionalGeneration - - forward +- forward diff --git a/docs/source/en/model_doc/glm4v_moe.md b/docs/source/en/model_doc/glm4v_moe.md index ffb6a3d85cb2..a67906ea001e 100644 --- a/docs/source/en/model_doc/glm4v_moe.md +++ b/docs/source/en/model_doc/glm4v_moe.md @@ -1,48 +1,54 @@ - -*This model was released on 2025-07-28 and added to Hugging Face Transformers on 2025-08-08.* +⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer. -
-
-PyTorch -FlashAttention -SDPA
-
+--> +*This model was released on 2025-08-12 and added to Hugging Face Transformers on 2025-08-08.* -# Glm4vMoeMoe +# Glm4vMoe ## Overview -Vision-language models (VLMs) have become a key cornerstone of intelligent systems. As real-world AI tasks grow increasingly complex, VLMs urgently need to enhance reasoning capabilities beyond basic multimodal perception — improving accuracy, comprehensiveness, and intelligence — to enable complex problem solving, long-context understanding, and multimodal agents. +The GLM-V model was proposed in [GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning](https://huggingface.co/papers/2507.01006v6). -Through our open-source work, we aim to explore the technological frontier together with the community while empowering more developers to create exciting and innovative applications. +The abstract from the paper is the following: -[GLM-4.5V](https://huggingface.co/papers/2508.06471) ([Github repo](https://github.com/zai-org/GLM-V)) is based on ZhipuAI’s next-generation flagship text foundation model GLM-4.5-Air (106B parameters, 12B active). It continues the technical approach of [GLM-4.1V-Thinking](https://huggingface.co/papers/2507.01006), achieving SOTA performance among models of the same scale on 42 public vision-language benchmarks. It covers common tasks such as image, video, and document understanding, as well as GUI agent operations. +> *We present GLM-4.1V-Thinking, GLM-4.5V, and GLM-4.6V, a family of vision-language models (VLMs) designed to advance +general-purpose multimodal understanding and reasoning. In this report, we share our key findings in the development of +the reasoning-centric training framework. We first develop a capable vision foundation model with significant potential +through large-scale pre-training, which arguably sets the upper bound for the final performance. We then propose +Reinforcement Learning with Curriculum Sampling (RLCS) to unlock the full potential of the model, leading to +comprehensive capability enhancement across a diverse range of tasks, including STEM problem solving, video +understanding, content recognition, coding, grounding, GUI-based agents, and long document interpretation. In a +comprehensive evaluation across 42 public benchmarks, GLM-4.5V achieves state-of-the-art performance on nearly all tasks +among open-source models of similar size, and demonstrates competitive or even superior results compared to +closed-source models such as Gemini-2.5-Flash on challenging tasks including Coding and GUI Agents. Meanwhile, the +smaller GLM-4.1V-9B-Thinking remains highly competitive-achieving superior results to the much larger Qwen2.5-VL-72B on +29 benchmarks. We open-source both GLM-4.1V-9B-Thinking and GLM-4.5V. We further introduce the GLM-4.6V series, +open-source multimodal models with native tool use and a 128K context window. A brief overview is available at this +https URL. Code, models and more information are released at https://github.com/zai-org/GLM-V* -![bench_45](https://raw.githubusercontent.com/zai-org/GLM-V/refs/heads/main/resources/bench_45v.jpeg) +## Support Model -Beyond benchmark performance, GLM-4.5V focuses on real-world usability. Through efficient hybrid training, it can handle diverse types of visual content, enabling full-spectrum vision reasoning, including: +This Model type support these model of zai-org: -- **Image reasoning** (scene understanding, complex multi-image analysis, spatial recognition) -- **Video understanding** (long video segmentation and event recognition) -- **GUI tasks** (screen reading, icon recognition, desktop operation assistance) -- **Complex chart & long document parsing** (research report analysis, information extraction) -- **Grounding** (precise visual element localization) ++ [GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V) ++ [GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V) -The model also introduces a **Thinking Mode** switch, allowing users to balance between quick responses and deep reasoning. This switch works the same as in the `GLM-4.5` language model. +This model was contributed by [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) and [Yuxuan Zhang](https://huggingface.co/ZHANGYUXUAN-zR). ## Glm4vMoeConfig diff --git a/docs/source/en/model_doc/glm_image.md b/docs/source/en/model_doc/glm_image.md new file mode 100644 index 000000000000..4b4ff609d2dd --- /dev/null +++ b/docs/source/en/model_doc/glm_image.md @@ -0,0 +1,206 @@ + +*This model was released on 2026-01-10 and added to Hugging Face Transformers on 2026-01-10.* + +# GlmImage + +## Overview + +GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios. + +Model architecture: a hybrid autoregressive + diffusion decoder design、 + ++ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. ++ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images. + +Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality. + ++ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness. ++ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering. + +GLM-Image supports both text-to-image and image-to-image generation within a single model + ++ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios. ++ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects. + ++ `GlmImageForConditionalGeneration` is the AR part of GLM-Image model, and for full image generation pipeline, please refer to [here](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/glm_image). + +This model was contributed by [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) and [Yuxuan Zhang](https://huggingface.co/ZHANGYUXUAN-zR). + +## Usage examples + +Using GLM-Image with image input to generate vision token for DIT using. + +```python +from transformers import GlmImageForConditionalGeneration, AutoProcessor +import torch + +model = GlmImageForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path="zai-org/GLM-Image/vision_language_encoder", + dtype=torch.bfloat16, + device_map="cuda:0" +) +processor = AutoProcessor.from_pretrained( + pretrained_model_name_or_path="zai-org/GLM-Image/processor", + use_fast=True +) + +# Case1 T2I +prompt = "现代美食杂志风格的甜点制作教程图,主题为覆盆子慕斯蛋糕。整体布局干净明亮,分为四个主要区域:顶部左侧是黑色粗体标题“覆盆子慕斯蛋糕制作指南”,右侧搭配光线柔和的成品蛋糕特写照片,蛋糕呈淡粉色,表面点缀新鲜覆盆子与薄荷叶;左下方为配料清单区域,标题“配料”使用简洁字体,下方列有“面粉 150g”“鸡蛋 3个”“细砂糖 120g”“覆盆子果泥 200g”“明胶片 10g”“淡奶油 300ml”“新鲜覆盆子”等配料,每种配料旁配有简约线图标(如面粉袋、鸡蛋、糖罐等);右下方是四个等大的步骤方框,每个方框内含高清微距实拍图及对应操作说明,从上到下依次为:步骤1展示打蛋器打发白色泡沫(对应说明“打发蛋白至干性发泡”),步骤2展示红白相间的混合物被刮刀翻拌(对应说明“轻柔翻拌果泥与面糊”),步骤3展示粉色液体被倒入圆形模具(对应说明“倒入模具并冷藏4小时”),步骤4展示成品蛋糕表面装饰覆盆子与薄荷叶(对应说明“用覆盆子和薄荷装饰”);底部边缘设浅棕色信息条,左侧图标分别代表“准备时间:30分钟”“烹饪时间:20分钟”“份量:8人份”。整体色调以奶油白、淡粉色为主,背景带轻微纸质纹理,图文排版紧凑有序,信息层级分明。" +target_h, target_w = 1152, 768 +use_reference_images = False +reference_image_paths = None + +# ## Case2 +# prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator." +# cond_0 = "cond.jpg" +# target_h, target_w = 1152, 768 +# use_reference_images = True +# reference_image_paths = [cond_0] + +## Case3 +# prompt = "Make the man in the first figure and the child from the second image bow at the same time in a respectful KTV." +# cond_0 = "cond_0.jpg" +# cond_1 = "cond_1.jpg" +# target_h, target_w = 1152, 768 +# use_reference_images = True +# reference_image_paths = [cond_0, cond_1] + + +def build_messages(prompt, use_reference_images, reference_image_paths): + content = [] + if use_reference_images: + for img_path in reference_image_paths: + content.append({"type": "image", "url": img_path}) + content.append({"type": "text", "text": prompt}) + return [{"role": "user", "content": content}] + + +def compute_generation_params(image_grid_thw, use_reference_images): + grid_sizes = [] + for i in range(image_grid_thw.shape[0]): + t, h, w = image_grid_thw[i].tolist() + grid_sizes.append(int(h * w)) + + target_output_length = grid_sizes[0] + + if use_reference_images: + max_new_tokens = grid_sizes[-1] + 1 + output_start_offset = 0 + output_length = grid_sizes[-1] + else: + total_tokens = sum(grid_sizes) + max_new_tokens = total_tokens + 1 + output_start_offset = sum(grid_sizes[1:]) + output_length = target_output_length + + return max_new_tokens, output_start_offset, output_length + + +messages = build_messages(prompt, use_reference_images, reference_image_paths if use_reference_images else None) + +inputs = processor.apply_chat_template( + messages, + target_h=target_h, + target_w=target_w, + tokenize=True, + return_dict=True, + return_tensors="pt", +).to(model.device) + +image_grid_thw = inputs.get('image_grid_thw') +print(f"image_grid_thw: {image_grid_thw}") + +max_new_tokens, output_start_offset, output_length = compute_generation_params( + image_grid_thw, use_reference_images +) + +print(f"use_reference_images: {use_reference_images}") +print(f"max_new_tokens: {max_new_tokens}") +print(f"output_start_offset: {output_start_offset}") +print(f"output_length: {output_length}") + +outputs = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True +) + +input_length = inputs["input_ids"].shape[-1] +output_tokens = outputs[0][input_length:][output_start_offset:output_start_offset + output_length] +print(f"Input length: {input_length}") +print(f"Total generated tokens: {outputs[0].shape[-1] - input_length}") +print(f"Extracted output tokens shape: {output_tokens.shape}") +print(f"Output tokens: {output_tokens}") +``` + +## GlmImageConfig + +[[autodoc]] GlmImageConfig + +## GlmImageVisionConfig + +[[autodoc]] GlmImageVisionConfig + +## GlmImageTextConfig + +[[autodoc]] GlmImageTextConfig + +## GlmImageVQVAEConfig + +[[autodoc]] GlmImageVQVAEConfig + +## GlmImageImageProcessor + +[[autodoc]] GlmImageImageProcessor + - preprocess + +## GlmImageImageProcessorFast + +[[autodoc]] GlmImageImageProcessorFast + - preprocess + +## GlmImageProcessor + +[[autodoc]] GlmImageProcessor + +## GlmImageVisionModel + +[[autodoc]] GlmImageVisionModel + - forward + +## GlmImageTextModel + +[[autodoc]] GlmImageTextModel + - forward + +## GlmImageVQVAE + +[[autodoc]] GlmImageVQVAE + - forward + +## GlmImageModel + +[[autodoc]] GlmImageModel + - forward + +## GlmImageForConditionalGeneration + +[[autodoc]] GlmImageForConditionalGeneration + - forward diff --git a/docs/source/en/model_doc/glmasr.md b/docs/source/en/model_doc/glmasr.md index a895b778bedb..d04be02a5f33 100644 --- a/docs/source/en/model_doc/glmasr.md +++ b/docs/source/en/model_doc/glmasr.md @@ -16,7 +16,7 @@ limitations under the License. ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer. --> -*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-24.* +*This model was released on 2025-12-08 and added to Hugging Face Transformers on 2025-12-24.* # GlmAsr diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index edc2cfb32f94..0cd0b26035ba 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -155,6 +155,7 @@ from .glm4v import * from .glm4v_moe import * from .glm46v import * + from .glm_image import * from .glmasr import * from .glpn import * from .got_ocr2 import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 51ac332febbb..e85836dbe426 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -180,6 +180,10 @@ ("glm4v_moe_vision", "Glm4vMoeVisionConfig"), ("glm4v_text", "Glm4vTextConfig"), ("glm4v_vision", "Glm4vVisionConfig"), + ("glm_image", "GlmImageConfig"), + ("glm_image_text", "GlmImageTextConfig"), + ("glm_image_vision", "GlmImageVisionConfig"), + ("glm_image_vqmodel", "GlmImageVQVAEConfig"), ("glmasr", "GlmAsrConfig"), ("glmasr_encoder", "GlmAsrEncoderConfig"), ("glpn", "GLPNConfig"), @@ -638,6 +642,10 @@ ("glm4v_moe_vision", "Glm4vMoeVisionModel"), ("glm4v_text", "GLM4V"), ("glm4v_vision", "Glm4vVisionModel"), + ("glm_image", "GlmImage"), + ("glm_image_text", "GlmImageText"), + ("glm_image_vision", "GlmImageVisionModel"), + ("glm_image_vqmodel", "GlmImageVQVAE"), ("glmasr", "GLM-ASR"), ("glmasr_encoder", "GLM-ASR Encoder"), ("glpn", "GLPN"), @@ -984,6 +992,9 @@ ("glm4v_moe_vision", "glm4v_moe"), ("glm4v_text", "glm4v"), ("glm4v_moe_text", "glm4v_moe"), + ("glm_image_vision", "glm_image"), + ("glm_image_vqmodel", "glm_image"), + ("glm_image_text", "glm_image"), ("glmasr_encoder", "glmasr"), ("grounding-dino", "grounding_dino"), ("mm-grounding-dino", "mm_grounding_dino"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 5379386cc1df..8c5663dff3fa 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -109,6 +109,7 @@ ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("glm46v", ("Glm46VImageProcessor", "Glm46VImageProcessorFast")), ("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")), + ("glm_image", ("GlmImageImageProcessor", "GlmImageImageProcessorFast")), ("glpn", ("GLPNImageProcessor", "GLPNImageProcessorFast")), ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")), ("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 583fbc8c1d3f..13eabcdc3405 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -183,6 +183,10 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("glm4v_moe_vision", "Glm4vMoeVisionModel"), ("glm4v_text", "Glm4vTextModel"), ("glm4v_vision", "Glm4vVisionModel"), + ("glm_image", "GlmImageModel"), + ("glm_image_text", "GlmImageTextModel"), + ("glm_image_vision", "GlmImageVisionModel"), + ("glm_image_vqmodel", "GlmImageVQVAE"), ("glmasr", "GlmAsrForConditionalGeneration"), ("glmasr_encoder", "GlmAsrEncoder"), ("glpn", "GLPNModel"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 8ec4f0254d24..58cef02556b3 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -78,6 +78,7 @@ ("glm46v", "Glm46VProcessor"), ("glm4v", "Glm4vProcessor"), ("glm4v_moe", "Glm4vProcessor"), + ("glm_image", "Glm4vProcessor"), ("glmasr", "GlmAsrProcessor"), ("got_ocr2", "GotOcr2Processor"), ("granite_speech", "GraniteSpeechProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index ac7503053ba5..1f2d68cab65b 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -136,6 +136,7 @@ ("glm4_moe_lite", "TokenizersBackend" if is_tokenizers_available() else None), ("glm4v", "TokenizersBackend" if is_tokenizers_available() else None), ("glm4v_moe", "TokenizersBackend" if is_tokenizers_available() else None), + ("glm_image", "TokenizersBackend" if is_tokenizers_available() else None), ("glmasr", "TokenizersBackend" if is_tokenizers_available() else None), ("got_ocr2", "TokenizersBackend" if is_tokenizers_available() else None), ("gpt-sw3", "GPTSw3Tokenizer" if is_sentencepiece_available() else None), diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index f3eb66c734d0..7388fdca11f0 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -143,6 +143,7 @@ def __init__(self, config: Glm4vVisionConfig): self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.interpolated_method = "bicubic" def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor: """ @@ -161,57 +162,45 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc # Get position embedding parameters pos_embed_weight = self.position_embedding.weight hidden_size = pos_embed_weight.shape[1] - total_seq = h_coords.shape[0] device = pos_embed_weight.device - # Move coordinates to correct device - h_coords, w_coords = h_coords.to(device), w_coords.to(device) - - # Handle empty sequence case - if total_seq == 0: - adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype) - else: - # Convert inputs to tensors if needed - if isinstance(lengths, list): - lengths = torch.tensor(lengths, device=device, dtype=torch.long) - if not isinstance(image_shapes, torch.Tensor): - image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long) - - # Prepare 2D position embedding - orig_size_sq = pos_embed_weight.shape[0] - orig_size = int(orig_size_sq**0.5) - pos_embed_2d = ( - pos_embed_weight.view(orig_size, orig_size, hidden_size) - .permute(2, 0, 1) - .unsqueeze(0) - .to(device=device, dtype=torch.float32) - ) + # Convert inputs to tensors if needed + if isinstance(lengths, list): + lengths = torch.tensor(lengths, device=device, dtype=torch.long) + + # Prepare 2D position embedding + orig_size_sq = pos_embed_weight.shape[0] + orig_size = int(orig_size_sq**0.5) + pos_embed_2d = ( + pos_embed_weight.view(orig_size, orig_size, hidden_size) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=device, dtype=torch.float32) + ) - # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + # Calculate target dimensions for each patch + target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) - # Normalize coordinates to [-1, 1] range for grid_sample - h_coords = h_coords.to(device=device, dtype=torch.float32) - w_coords = w_coords.to(device=device, dtype=torch.float32) - norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 - norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 + # Normalize coordinates to [-1, 1] range for grid_sample + norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 + norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 - # Create sampling grid - grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) + # Create sampling grid + grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) - # Perform bicubic interpolation - interpolated_embed_fp32 = F.grid_sample( - pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border" - ) + # Perform bicubic interpolation + interpolated_embed_fp32 = F.grid_sample( + pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border" + ) - # Reshape and convert back to original dtype - adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) - adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device) + # Reshape and convert back to original dtype + adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) + adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device) # Add adapted position encoding to embeddings embeddings = embeddings + adapted_pos_embed @@ -405,6 +394,7 @@ def __init__(self, config: Glm4vTextConfig, device=None): self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12]) @staticmethod def compute_default_rope_parameters( @@ -441,7 +431,7 @@ def compute_default_rope_parameters( @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - # In contrast to other models, GLM4V different position ids for the grids + # In contrast to other models, GLM-V has different position ids for the grids # So we expand the inv_freq to shape (3, ...) inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) @@ -449,12 +439,19 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_mrope(freqs, self.mrope_section) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def apply_mrope(self, freqs, mrope_section): + section = mrope_section + chunks = freqs.split(section, dim=-1) + result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + return result + def rotate_half_llm(x): """Rotates half the hidden dims of the input.""" @@ -463,25 +460,16 @@ def rotate_half_llm(x): return torch.stack((-x2, x1), dim=-1).flatten(-2) -def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). - - Explanation: - Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding - sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. - Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. - For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, - height and width) of text embedding is always the same, so the text embedding rotary position embedding has no - difference with modern LLMs. +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -492,13 +480,8 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Interleave them instead of usual shape cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) @@ -516,7 +499,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim # Concatenate back to full shape q_embed = torch.cat([q_embed, q_pass], dim=-1) k_embed = torch.cat([k_embed, k_pass], dim=-1) - return q_embed, k_embed @@ -566,9 +548,7 @@ def forward( value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama - query_states, key_states, cos, sin, self.rope_parameters["mrope_section"] - ) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models @@ -788,7 +768,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) """ hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -803,7 +782,13 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) + hidden_states = self.embeddings( + hidden_states, + seqlens, + grid_thw, + image_type_ids[:, 0].to(hidden_states.device), + image_type_ids[:, 1].to(hidden_states.device), + ) for blk in self.blocks: hidden_states = blk( @@ -915,7 +900,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, - position_ids=position_ids, + position_ids=text_position_ids, past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 61edcdc7a390..77946b1310f9 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -408,6 +408,7 @@ def __init__(self, config: Glm4vVisionConfig): self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.interpolated_method = "bicubic" def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor: """ @@ -426,57 +427,45 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc # Get position embedding parameters pos_embed_weight = self.position_embedding.weight hidden_size = pos_embed_weight.shape[1] - total_seq = h_coords.shape[0] device = pos_embed_weight.device - # Move coordinates to correct device - h_coords, w_coords = h_coords.to(device), w_coords.to(device) - - # Handle empty sequence case - if total_seq == 0: - adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype) - else: - # Convert inputs to tensors if needed - if isinstance(lengths, list): - lengths = torch.tensor(lengths, device=device, dtype=torch.long) - if not isinstance(image_shapes, torch.Tensor): - image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long) - - # Prepare 2D position embedding - orig_size_sq = pos_embed_weight.shape[0] - orig_size = int(orig_size_sq**0.5) - pos_embed_2d = ( - pos_embed_weight.view(orig_size, orig_size, hidden_size) - .permute(2, 0, 1) - .unsqueeze(0) - .to(device=device, dtype=torch.float32) - ) + # Convert inputs to tensors if needed + if isinstance(lengths, list): + lengths = torch.tensor(lengths, device=device, dtype=torch.long) + + # Prepare 2D position embedding + orig_size_sq = pos_embed_weight.shape[0] + orig_size = int(orig_size_sq**0.5) + pos_embed_2d = ( + pos_embed_weight.view(orig_size, orig_size, hidden_size) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=device, dtype=torch.float32) + ) - # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + # Calculate target dimensions for each patch + target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) - # Normalize coordinates to [-1, 1] range for grid_sample - h_coords = h_coords.to(device=device, dtype=torch.float32) - w_coords = w_coords.to(device=device, dtype=torch.float32) - norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 - norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 + # Normalize coordinates to [-1, 1] range for grid_sample + norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 + norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 - # Create sampling grid - grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) + # Create sampling grid + grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) - # Perform bicubic interpolation - interpolated_embed_fp32 = F.grid_sample( - pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border" - ) + # Perform bicubic interpolation + interpolated_embed_fp32 = F.grid_sample( + pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border" + ) - # Reshape and convert back to original dtype - adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) - adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device) + # Reshape and convert back to original dtype + adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) + adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device) # Add adapted position encoding to embeddings embeddings = embeddings + adapted_pos_embed @@ -501,9 +490,12 @@ def __init__(self, config) -> None: class Glm4vTextRotaryEmbedding(Glm4RotaryEmbedding): - # Ignore copy + def __init__(self, config: Glm4vTextConfig, device=None): + super().__init__() + self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12]) + def forward(self, x, position_ids): - # In contrast to other models, GLM4V different position ids for the grids + # In contrast to other models, GLM-V has different position ids for the grids # So we expand the inv_freq to shape (3, ...) inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) @@ -511,12 +503,19 @@ def forward(self, x, position_ids): device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_mrope(freqs, self.mrope_section) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def apply_mrope(self, freqs, mrope_section): + section = mrope_section + chunks = freqs.split(section, dim=-1) + result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + return result + def rotate_half_llm(x): """Rotates half the hidden dims of the input.""" @@ -525,25 +524,16 @@ def rotate_half_llm(x): return torch.stack((-x2, x1), dim=-1).flatten(-2) -def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). - - Explanation: - Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding - sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. - Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. - For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, - height and width) of text embedding is always the same, so the text embedding rotary position embedding has no - difference with modern LLMs. +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -554,13 +544,8 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) # Interleave them instead of usual shape cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) @@ -578,7 +563,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim # Concatenate back to full shape q_embed = torch.cat([q_embed, q_pass], dim=-1) k_embed = torch.cat([k_embed, k_pass], dim=-1) - return q_embed, k_embed @@ -628,9 +612,7 @@ def forward( value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama - query_states, key_states, cos, sin, self.rope_parameters["mrope_section"] - ) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models @@ -805,7 +787,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) """ hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -820,7 +801,13 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) + hidden_states = self.embeddings( + hidden_states, + seqlens, + grid_thw, + image_type_ids[:, 0].to(hidden_states.device), + image_type_ids[:, 1].to(hidden_states.device), + ) for blk in self.blocks: hidden_states = blk( @@ -922,7 +909,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, - position_ids=position_ids, + position_ids=text_position_ids, past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py b/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py index 29f0ed747f2e..5772ed2957ee 100644 --- a/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py @@ -21,99 +21,6 @@ from ...modeling_rope_utils import RopeParameters -class Glm4vMoeVisionConfig(PreTrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Glm4vMoeVisionModel`]. It is used to instantiate an Glm4vMoeVisionModel - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield - a similar configuration to that of - GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking). - - Args: - depth (`int`, *optional*, defaults to 24): - Number of layers (depth) in the model. - hidden_size (`int`, *optional*, defaults to 1536): - Dimensionality of the encoder layers and the pooler layer. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"selu"` and `"gelu_new"` are supported. - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to add a bias to the queries, keys and values. - attention_dropout (`float`, *optional*, defaults to 0.0): - Dropout probability for attention weights. - num_heads (``, *optional*, defaults to 12): - in_channels (``, *optional*, defaults to 3): - image_size (`int` or `list[int]`, *optional*, defaults to 336): - The size (resolution) of each image. - patch_size (`int`, *optional*, defaults to 14): - The size (resolution) of each patch. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the rms normalization layers. - spatial_merge_size (`int`, *optional*, defaults to 2): - The size used for merging spatial dimensions. - temporal_patch_size (`int`, *optional*, defaults to 2): - The size used for patches along the temporal dimension. - out_hidden_size (`int`, *optional*, defaults to 4096): - The output hidden size of the vision model. - intermediate_size (`int`, *optional*, defaults to 13696): - Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - Example: - - ```python - >>> from transformers import Glm4vMoeVisionConfig, Glm4vMoeVisionModel - - >>> # Initializing a Glm4vMoeVisionConfig GLM-4.1V-9B style configuration - >>> configuration = Glm4vMoeVisionConfig() - - >>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration - >>> model = Glm4vMoeVisionModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "glm4v_moe_vision" - base_config_key = "vision_config" - - def __init__( - self, - depth=24, - hidden_size=1536, - hidden_act="silu", - attention_bias=False, - attention_dropout=0.0, - num_heads=12, - in_channels=3, - image_size=336, - patch_size=14, - rms_norm_eps=1e-05, - spatial_merge_size=2, - temporal_patch_size=2, - out_hidden_size=4096, - intermediate_size=13696, - initializer_range=0.02, - **kwargs, - ): - super().__init__(**kwargs) - - self.depth = depth - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.num_heads = num_heads - self.in_channels = in_channels - self.image_size = image_size - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size - self.out_hidden_size = out_hidden_size - self.intermediate_size = intermediate_size - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - - class Glm4vMoeTextConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a @@ -282,6 +189,99 @@ def __init__( ) +class Glm4vMoeVisionConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Glm4vMoeVisionModel`]. It is used to instantiate an Glm4vMoeVisionModel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield + a similar configuration to that of + GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking). + + Args: + depth (`int`, *optional*, defaults to 24): + Number of layers (depth) in the model. + hidden_size (`int`, *optional*, defaults to 1536): + Dimensionality of the encoder layers and the pooler layer. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries, keys and values. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for attention weights. + num_heads (``, *optional*, defaults to 12): + in_channels (``, *optional*, defaults to 3): + image_size (`int` or `list[int]`, *optional*, defaults to 336): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + spatial_merge_size (`int`, *optional*, defaults to 2): + The size used for merging spatial dimensions. + temporal_patch_size (`int`, *optional*, defaults to 2): + The size used for patches along the temporal dimension. + out_hidden_size (`int`, *optional*, defaults to 4096): + The output hidden size of the vision model. + intermediate_size (`int`, *optional*, defaults to 13696): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + Example: + + ```python + >>> from transformers import Glm4vMoeVisionConfig, Glm4vMoeVisionModel + + >>> # Initializing a Glm4vMoeVisionConfig GLM-4.1V-9B style configuration + >>> configuration = Glm4vMoeVisionConfig() + + >>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration + >>> model = Glm4vMoeVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm4v_moe_vision" + base_config_key = "vision_config" + + def __init__( + self, + depth=24, + hidden_size=1536, + hidden_act="silu", + attention_bias=False, + attention_dropout=0.0, + num_heads=12, + in_channels=3, + image_size=336, + patch_size=14, + rms_norm_eps=1e-05, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=4096, + intermediate_size=13696, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_heads = num_heads + self.in_channels = in_channels + self.image_size = image_size + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.intermediate_size = intermediate_size + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + class Glm4vMoeConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a @@ -360,4 +360,4 @@ def __init__( super().__init__(**kwargs) -__all__ = ["Glm4vMoeConfig", "Glm4vMoeTextConfig", "Glm4vMoeVisionConfig"] +__all__ = ["Glm4vMoeConfig", "Glm4vMoeVisionConfig", "Glm4vMoeTextConfig"] diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 3ed571198353..ef95cf26e191 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -50,120 +50,6 @@ from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig -@use_kernel_forward_from_hub("RMSNorm") -class Glm4vMoeRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Glm4vMoeRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -@dataclass -@auto_docstring( - custom_intro=""" - Base class for Llava outputs, with hidden states and attentions. - """ -) -class Glm4vMoeModelOutputWithPast(ModelOutput): - r""" - past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. - """ - - last_hidden_state: torch.FloatTensor | None = None - past_key_values: Cache | None = None - hidden_states: tuple[torch.FloatTensor] | None = None - attentions: tuple[torch.FloatTensor] | None = None - rope_deltas: torch.LongTensor | None = None - - -class Glm4vMoeTextRotaryEmbedding(nn.Module): - inv_freq: torch.Tensor # fix linting for `register_buffer` - - def __init__(self, config: Glm4vMoeTextConfig, device=None, layer_type=None): - super().__init__() - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - - self.rope_type = self.config.rope_parameters["rope_type"] - rope_init_fn: Callable = self.compute_default_rope_parameters - if self.rope_type != "default": - rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = rope_init_fn(self.config, device) - - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) - - @staticmethod - def compute_default_rope_parameters( - config: Glm4vMoeTextConfig | None = None, - device: Optional["torch.device"] = None, - seq_len: int | None = None, - ) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PreTrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - base = config.rope_parameters["rope_theta"] - partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) - ) - return inv_freq, attention_factor - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - # In contrast to other models, GLM4V_MOE different position ids for the grids - # So we expand the inv_freq to shape (3, ...) - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) - position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -232,10 +118,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - # Interleave them instead of usual shape - cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1) - sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1) - # Keep half or full tensor for later concatenation rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] @@ -251,59 +133,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). - - Explanation: - Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding - sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. - Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. - For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, - height and width) of text embedding is always the same, so the text embedding rotary position embedding has no - difference with modern LLMs. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - - # Keep half or full tensor for later concatenation - rotary_dim = cos.shape[-1] - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - - # Apply rotary embeddings on the first half or full tensor - q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) - k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) - - # Concatenate back to full shape - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) - - return q_embed, k_embed - - @use_kernelized_func(apply_rotary_pos_emb) class Glm4vMoeTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -351,9 +180,7 @@ def forward( value_states = value_states.transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama - query_states, key_states, cos, sin, self.rope_parameters["mrope_section"] - ) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache @@ -665,6 +492,27 @@ def forward(self, seqlen: int) -> torch.Tensor: return freqs +@use_kernel_forward_from_hub("RMSNorm") +class Glm4vMoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Glm4vMoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + class Glm4vMoeisionMlp(nn.Module): def __init__(self, config, bias: bool = False): super().__init__() @@ -727,6 +575,7 @@ def __init__(self, config: Glm4vMoeVisionConfig): self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.interpolated_method = "bicubic" def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor: """ @@ -745,57 +594,45 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc # Get position embedding parameters pos_embed_weight = self.position_embedding.weight hidden_size = pos_embed_weight.shape[1] - total_seq = h_coords.shape[0] device = pos_embed_weight.device - # Move coordinates to correct device - h_coords, w_coords = h_coords.to(device), w_coords.to(device) - - # Handle empty sequence case - if total_seq == 0: - adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype) - else: - # Convert inputs to tensors if needed - if isinstance(lengths, list): - lengths = torch.tensor(lengths, device=device, dtype=torch.long) - if not isinstance(image_shapes, torch.Tensor): - image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long) - - # Prepare 2D position embedding - orig_size_sq = pos_embed_weight.shape[0] - orig_size = int(orig_size_sq**0.5) - pos_embed_2d = ( - pos_embed_weight.view(orig_size, orig_size, hidden_size) - .permute(2, 0, 1) - .unsqueeze(0) - .to(device=device, dtype=torch.float32) - ) + # Convert inputs to tensors if needed + if isinstance(lengths, list): + lengths = torch.tensor(lengths, device=device, dtype=torch.long) + + # Prepare 2D position embedding + orig_size_sq = pos_embed_weight.shape[0] + orig_size = int(orig_size_sq**0.5) + pos_embed_2d = ( + pos_embed_weight.view(orig_size, orig_size, hidden_size) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=device, dtype=torch.float32) + ) - # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + # Calculate target dimensions for each patch + target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) - # Normalize coordinates to [-1, 1] range for grid_sample - h_coords = h_coords.to(device=device, dtype=torch.float32) - w_coords = w_coords.to(device=device, dtype=torch.float32) - norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 - norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 + # Normalize coordinates to [-1, 1] range for grid_sample + norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 + norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 - # Create sampling grid - grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) + # Create sampling grid + grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) - # Perform bicubic interpolation - interpolated_embed_fp32 = F.grid_sample( - pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border" - ) + # Perform bicubic interpolation + interpolated_embed_fp32 = F.grid_sample( + pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border" + ) - # Reshape and convert back to original dtype - adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) - adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device) + # Reshape and convert back to original dtype + adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) + adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device) # Add adapted position encoding to embeddings embeddings = embeddings + adapted_pos_embed @@ -1002,7 +839,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) """ hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -1017,7 +853,13 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) + hidden_states = self.embeddings( + hidden_states, + seqlens, + grid_thw, + image_type_ids[:, 0].to(hidden_states.device), + image_type_ids[:, 1].to(hidden_states.device), + ) for blk in self.blocks: hidden_states = blk( @@ -1038,6 +880,83 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) return hidden_states +class Glm4vMoeTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Glm4vMoeTextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12]) + + @staticmethod + def compute_default_rope_parameters( + config: Glm4vMoeTextConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, GLM-V has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def apply_mrope(self, freqs, mrope_section): + section = mrope_section + chunks = freqs.split(section, dim=-1) + result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + return result + + @auto_docstring class Glm4vMoeTextModel(Glm4vMoePreTrainedModel): config: Glm4vMoeTextConfig @@ -1147,6 +1066,30 @@ def forward( ) +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class Glm4vMoeModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + + @auto_docstring class Glm4vMoeModel(Glm4vMoePreTrainedModel): base_model_prefix = "model" diff --git a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py index 15a15c0fda9a..5e81dc84c6c9 100644 --- a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable -from typing import Optional import torch import torch.nn as nn @@ -36,22 +35,19 @@ Glm4MoeMLP, Glm4MoeMoE, Glm4MoePreTrainedModel, - Glm4MoeRMSNorm, Glm4MoeTopkRouter, eager_attention_forward, ) -from ..glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig +from ..glm4v.configuration_glm4v import Glm4vConfig from ..glm4v.modeling_glm4v import ( Glm4vForConditionalGeneration, Glm4vTextModel, - Glm4vTextRotaryEmbedding, Glm4vVisionModel, Glm4vVisionRotaryEmbedding, - rotate_half, ) +from ..gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb from ..qwen3_vl_moe.modeling_qwen3_vl_moe import ( Qwen3VLMoeCausalLMOutputWithPast, - Qwen3VLMoeModelOutputWithPast, load_balancing_loss_func, ) @@ -59,14 +55,6 @@ logger = logging.get_logger(__name__) -class Glm4vMoeVisionConfig(Glm4vVisionConfig): - pass - - -class Glm4vMoeRMSNorm(Glm4MoeRMSNorm): - pass - - class Glm4vMoeTextConfig(Glm4MoeConfig, RotaryEmbeddingConfigMixin): r""" This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a @@ -289,100 +277,6 @@ def __init__( super().__init__() -class Glm4vMoeModelOutputWithPast(Qwen3VLMoeModelOutputWithPast): - pass - - -def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). - - Explanation: - Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding - sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For - vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. - Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. - For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, - height and width) of text embedding is always the same, so the text embedding rotary position embedding has no - difference with modern LLMs. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - mrope_section = mrope_section * 2 - cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( - unsqueeze_dim - ) - - # Keep half or full tensor for later concatenation - rotary_dim = cos.shape[-1] - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - - # Apply rotary embeddings on the first half or full tensor - q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) - k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) - - # Concatenate back to full shape - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) - - return q_embed, k_embed - - -class Glm4vMoeTextRotaryEmbedding(Glm4vTextRotaryEmbedding): - def __init__(self, config: Glm4vMoeTextConfig, device=None, layer_type=None): - super().__init__(config, device=device, layer_type=layer_type) - - @staticmethod - def compute_default_rope_parameters( - config: Glm4vMoeTextConfig | None = None, - device: Optional["torch.device"] = None, - seq_len: int | None = None, - ) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PreTrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - base = config.rope_parameters["rope_theta"] - partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) - ) - return inv_freq, attention_factor - - class Glm4vMoeTextAttention(Glm4Attention): def __init__(self, config: Glm4vMoeTextConfig, layer_idx: int | None = None): super().__init__(config, layer_idx) @@ -409,9 +303,7 @@ def forward( value_states = value_states.transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama - query_states, key_states, cos, sin, self.rope_parameters["mrope_section"] - ) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache @@ -656,8 +548,8 @@ def forward( __all__ = [ "Glm4vMoeConfig", + "Glm4vMoeVisionConfig", # noqa: F822 "Glm4vMoeTextConfig", - "Glm4vMoeVisionConfig", "Glm4vMoeForConditionalGeneration", "Glm4vMoeModel", # noqa: F822 "Glm4vMoePreTrainedModel", diff --git a/src/transformers/models/glm_image/__init__.py b/src/transformers/models/glm_image/__init__.py new file mode 100644 index 000000000000..7c5108a29576 --- /dev/null +++ b/src/transformers/models/glm_image/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_glm_image import * + from .image_processing_glm_image import * + from .image_processing_glm_image_fast import * + from .modeling_glm_image import * + from .processing_glm_image import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/glm_image/configuration_glm_image.py b/src/transformers/models/glm_image/configuration_glm_image.py new file mode 100644 index 000000000000..403b5676dace --- /dev/null +++ b/src/transformers/models/glm_image/configuration_glm_image.py @@ -0,0 +1,352 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm_image.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters + + +class GlmImageVQVAEConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GlmImageVQModel`]. It is used to instantiate a + `GlmImageVQModel` according to the specified arguments, defining the model architecture. + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. Instantiating a + configuration with the defaults will yield a similar configuration to the VQModel of the + [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image) architecture. + + Args: + embed_dim (`int`, *optional*, defaults to 2048): + Dimensionality of each embedding vector. + num_embeddings (`int`, *optional*, defaults to 16384): + Number of codebook embeddings. + latent_channels (`int`, *optional*, defaults to 1536): + Number of channels for the latent space. + in_channels (`int`, *optional*, defaults to 3): + Number of input channels. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + """ + + model_type = "glm_image_vqmodel" + base_config_key = "vq_config" + + def __init__( + self, + embed_dim: int = 2048, + num_embeddings: int = 16384, + latent_channels: int = 1536, + in_channels: int = 3, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.num_embeddings = num_embeddings + self.latent_channels = latent_channels + self.in_channels = in_channels + self.initializer_range = initializer_range + + +class GlmImageVisionConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GlmImageVisionModel`]. It is used to instantiate an GlmImageVisionModel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield + a similar configuration to that of + GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image). + + Args: + depth (`int`, *optional*, defaults to 40): + Number of layers (depth) in the model. + hidden_size (`int`, *optional*, defaults to 1536): + Dimensionality of the encoder layers and the pooler layer. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for attention weights. + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer architecture. + in_channels (`int`, *optional*, defaults to 3): + Number of input channels. + image_size (`int` or `list[int]`, *optional*, defaults to 2048): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + spatial_merge_size (`int`, *optional*, defaults to 1): + The size used for merging spatial dimensions. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + """ + + model_type = "glm_image_vision" + base_config_key = "vision_config" + + def __init__( + self, + depth=40, + hidden_size=1536, + hidden_act="gelu", + attention_bias=True, + attention_dropout=0.0, + num_heads=16, + in_channels=3, + image_size=2048, + patch_size=16, + layer_norm_eps=1e-06, + spatial_merge_size=1, + intermediate_size=6144, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_heads = num_heads + self.in_channels = in_channels + self.image_size = image_size + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.intermediate_size = intermediate_size + self.initializer_range = initializer_range + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + + +class GlmImageTextConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GlmImageTextModel`]. It is used to instantiate a + GLM-Image model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of + GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 168064): + Vocabulary size of the GlmImage model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GlmImageModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 13696): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + vision_vocab_size (`int`, *optional*, defaults to 16512): + Vision vocabulary size of the GlmImage model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`GlmImageVisionModel`] + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + + ```python + >>> from transformers import GlmImageTextModel, GlmImageConfig + + >>> # Initializing a GlmImageConfig style configuration + >>> configuration = GlmImageConfig() + + >>> # Initializing a model from the GlmImageConfig style configuration + >>> model = GlmImageTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm_image_text" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `GlmImage` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation + "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: int | None = 168064, + hidden_size: int | None = 4096, + intermediate_size: int | None = 13696, + num_hidden_layers: int | None = 40, + num_attention_heads: int | None = 32, + num_key_value_heads: int | None = 2, + hidden_act: str | None = "silu", + max_position_embeddings: int | None = 32768, + initializer_range: float | None = 0.02, + rms_norm_eps: int | None = 1e-05, + use_cache: bool | None = True, + tie_word_embeddings: bool | None = False, + attention_dropout: float | None = 0.0, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + vision_vocab_size: int | None = 16512, + attention_bias: bool | None = True, + **kwargs, + ): + self.vocab_size = vocab_size + self.vision_vocab_size = vision_vocab_size + self.attention_bias = attention_bias + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_dropout = attention_dropout + self.rope_parameters = rope_parameters + + super().__init__( + tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs + ) + + +class GlmImageConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GlmImageModel`]. It is used to instantiate a + GLM-Image model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of + GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image) architecture. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `GlmImageTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `GlmImageVisionConfig`): + The config object or dictionary of the vision backbone. + vq_config (`Union[Dict, GlmImageVQVAEConfig]`, *optional*): + GlmImageVQVAEConfig instance containing the configuration for the VQ-VAE model. + image_token_id (`int`, *optional*, defaults to 167855): + The image token index to encode the image prompt. + image_start_token_id (`int`, *optional*, defaults to 16384): + The image start token index to encode the start of image. + image_end_token_id (`int`, *optional*, defaults to 16385): + The image end token index to encode the end of image. + + ```python + >>> from transformers import Glm4vForConditionalGeneration, Glm4vConfig + + >>> # Initializing a GLM-Image style configuration + >>> configuration = Glm4vConfig() + + >>> # Initializing a model from the GLM-Image style configuration + >>> model = Glm4vForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm_image" + sub_configs = { + "vision_config": GlmImageVisionConfig, + "text_config": GlmImageTextConfig, + "vq_config": GlmImageVQVAEConfig, + } + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + vq_config=None, + image_token_id=167855, + image_start_token_id=16384, + image_end_token_id=16385, + **kwargs, + ): + if isinstance(vision_config, dict): + vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + vision_config = self.sub_configs["vision_config"](**kwargs) + + if isinstance(vq_config, dict): + vq_config = self.sub_configs["vq_config"](**vq_config) + elif vq_config is None: + vq_config = self.sub_configs["vq_config"](**kwargs) + + if isinstance(text_config, dict): + text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + text_config = self.sub_configs["text_config"](**kwargs) + + self.image_token_id = image_token_id + self.image_start_token_id = image_start_token_id + self.image_end_token_id = image_end_token_id + self.text_config = text_config + self.vision_config = vision_config + self.vq_config = vq_config + super().__init__(**kwargs) + + +__all__ = ["GlmImageVQVAEConfig", "GlmImageVisionConfig", "GlmImageTextConfig", "GlmImageConfig"] diff --git a/src/transformers/models/glm_image/image_processing_glm_image.py b/src/transformers/models/glm_image/image_processing_glm_image.py new file mode 100644 index 000000000000..acf934b180a8 --- /dev/null +++ b/src/transformers/models/glm_image/image_processing_glm_image.py @@ -0,0 +1,503 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm_image.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import BaseImageProcessor +from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_flat_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...processing_utils import ImagesKwargs +from ...utils import TensorType, logging +from ...video_utils import VideoInput + + +logger = logging.get_logger(__name__) + + +class GlmImageImageProcessorKwargs(ImagesKwargs, total=False): + r""" + min_pixels (`int`, *optional*, defaults to `56 * 56`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to 14): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """ + + min_pixels: int + max_pixels: int + patch_size: int + temporal_patch_size: int + merge_size: int + + +def smart_resize( + height: int, + width: int, + factor: int = 16, + min_pixels: int = 512 * 512, + max_pixels: int = 2048 * 2048, +) -> tuple[int, int]: + if height < factor or width < factor: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 4: + raise ValueError( + f"absolute aspect ratio must be smaller than 4, got {max(height, width) / min(height, width)}" + ) + + shortest_edge = int(round(math.sqrt(min_pixels))) + longest_edge = int(round(math.sqrt(max_pixels))) + min_side = min(height, width) + max_side = max(height, width) + + scale = 1.0 + + if min_side < shortest_edge: + scale = shortest_edge / min_side + + if max_side * scale > longest_edge: + scale = longest_edge / max_side + + height = height // 2 + width = width // 2 + + h_bar = max(factor, int(round(height * scale / factor)) * factor) + w_bar = max(factor, int(round(width * scale / factor)) * factor) + + if max(h_bar, w_bar) > longest_edge: + beta = max(h_bar, w_bar) / longest_edge + h_bar = max(factor, int(math.floor((h_bar / beta) / factor)) * factor) + w_bar = max(factor, int(math.floor((w_bar / beta) / factor)) * factor) + + return h_bar, w_bar + + +class GlmImageImageProcessor(BaseImageProcessor): + r""" + Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + size (`dict[str, int]`, *optional*, defaults to `{"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}`): + Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `list[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + min_pixels (`int`, *optional*, defaults to `56 * 56`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to 14): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to 2): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to 2): + The merge size of the vision encoder to llm encoder. + """ + + model_input_names = ["pixel_values", "image_grid_thw"] + valid_kwargs = GlmImageImageProcessorKwargs + + def __init__( + self, + do_resize: bool = True, + size: dict[str, int] | None = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: int | float = 1 / 255, + do_normalize: bool = True, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + do_convert_rgb: bool = True, + min_pixels: int | None = None, + max_pixels: int | None = None, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if size is not None: + if "shortest_edge" not in size or "longest_edge" not in size: + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + else: + size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280} + # backward compatibility: override size with min_pixels and max_pixels if they are provided + if min_pixels is not None: + size["shortest_edge"] = min_pixels + if max_pixels is not None: + size["longest_edge"] = max_pixels + self.min_pixels = size["shortest_edge"] + self.max_pixels = size["longest_edge"] + self.size = size + + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: ImageInput | VideoInput, + do_resize: bool | None = None, + size: dict[str, int] | None = None, + resample: PILImageResampling | None = None, + do_rescale: bool | None = None, + rescale_factor: float | None = None, + do_normalize: bool | None = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + patch_size: int | None = None, + temporal_patch_size: int | None = None, + merge_size: int | None = None, + do_convert_rgb: bool | None = None, + data_format: ChannelDimension | None = ChannelDimension.FIRST, + input_data_format: str | ChannelDimension | None = None, + ): + """ + Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + vision_info (`list[Dict]`, *optional*): + Optional list of dictionaries containing additional information about vision inputs. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to `self.merge_size`): + The merge size of the vision encoder to llm encoder. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_flat_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = height, width + processed_images = [] + for image in images: + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=patch_size * merge_size, + min_pixels=size["shortest_edge"], + max_pixels=size["longest_edge"], + ) + image = resize( + image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + patches = np.array(processed_images) + if data_format == ChannelDimension.LAST: + patches = patches.transpose(0, 3, 1, 2) + if patches.shape[0] % temporal_patch_size != 0: + repeats = np.repeat( + patches[-1][np.newaxis], temporal_patch_size - (patches.shape[0] % temporal_patch_size), axis=0 + ) + patches = np.concatenate([patches, repeats], axis=0) + channel = patches.shape[1] + grid_t = patches.shape[0] // temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + patches = patches.reshape( + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size + ) + + return flatten_patches, (grid_t, grid_h, grid_w) + + def preprocess( + self, + images: ImageInput, + do_resize: bool | None = None, + size: dict[str, int] | None = None, + min_pixels: int | None = None, + max_pixels: int | None = None, + resample: PILImageResampling | None = None, + do_rescale: bool | None = None, + rescale_factor: float | None = None, + do_normalize: bool | None = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + patch_size: int | None = None, + temporal_patch_size: int | None = None, + merge_size: int | None = None, + do_convert_rgb: bool | None = None, + return_tensors: str | TensorType | None = None, + data_format: ChannelDimension | None = ChannelDimension.FIRST, + input_data_format: str | ChannelDimension | None = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + min_pixels (`int`, *optional*, defaults to `self.min_pixels`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `self.max_pixels`): + The max pixels of the image to resize the image. + patch_size (`int`, *optional*, defaults to `self.patch_size`): + The spatial patch size of the vision encoder. + temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`): + The temporal patch size of the vision encoder. + merge_size (`int`, *optional*, defaults to `self.merge_size`): + The merge size of the vision encoder to llm encoder. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ + min_pixels = min_pixels if min_pixels is not None else self.min_pixels + max_pixels = max_pixels if max_pixels is not None else self.max_pixels + + if size is not None: + if "shortest_edge" not in size or "longest_edge" not in size: + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + min_pixels = size["shortest_edge"] + elif min_pixels is not None and max_pixels is not None: + # backward compatibility: override size with min_pixels and max_pixels if they are provided + size = {"shortest_edge": min_pixels, "longest_edge": max_pixels} + else: + size = {**self.size} + + do_resize = do_resize if do_resize is not None else self.do_resize + + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + patch_size = patch_size if patch_size is not None else self.patch_size + temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size + merge_size = merge_size if merge_size is not None else self.merge_size + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + if images is not None: + images = self.fetch_images(images) + images = make_flat_list_of_images(images) + + if images is not None and not valid_images(images): + raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor") + + validate_preprocess_arguments( + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + data = {} + pixel_values, vision_grid_thws = [], [] + for image in images: + patches, image_grid_thw = self._preprocess( + image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + merge_size=merge_size, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values.extend(patches) + vision_grid_thws.append(image_grid_thw) + pixel_values = np.array(pixel_values) + vision_grid_thws = np.array(vision_grid_thws) + data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}) + + return BatchFeature(data=data, tensor_type=return_tensors) + + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number of image patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of image patches per image. + """ + min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"] + max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"] + patch_size = images_kwargs.get("patch_size", self.patch_size) + merge_size = images_kwargs.get("merge_size", self.merge_size) + + factor = patch_size * merge_size + resized_height, resized_width = smart_resize( + height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels + ) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + return grid_h * grid_w + + +__all__ = ["GlmImageImageProcessor"] diff --git a/src/transformers/models/glm_image/image_processing_glm_image_fast.py b/src/transformers/models/glm_image/image_processing_glm_image_fast.py new file mode 100644 index 000000000000..7ca16ce810b7 --- /dev/null +++ b/src/transformers/models/glm_image/image_processing_glm_image_fast.py @@ -0,0 +1,296 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm_image.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union + +import torch.nn.functional as F + +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import TensorType, auto_docstring, is_torch_available +from .image_processing_glm_image import GlmImageImageProcessorKwargs + + +if is_torch_available(): + import torch + + +def smart_resize( + height: int, + width: int, + factor: int = 16, + min_pixels: int = 512 * 512, + max_pixels: int = 2048 * 2048, +) -> tuple[int, int]: + if height < factor or width < factor: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 4: + raise ValueError( + f"absolute aspect ratio must be smaller than 4, got {max(height, width) / min(height, width)}" + ) + + shortest_edge = int(round(math.sqrt(min_pixels))) + longest_edge = int(round(math.sqrt(max_pixels))) + min_side = min(height, width) + max_side = max(height, width) + + scale = 1.0 + + if min_side < shortest_edge: + scale = shortest_edge / min_side + + if max_side * scale > longest_edge: + scale = longest_edge / max_side + + height = height // 2 + width = width // 2 + + h_bar = max(factor, int(round(height * scale / factor)) * factor) + w_bar = max(factor, int(round(width * scale / factor)) * factor) + + if max(h_bar, w_bar) > longest_edge: + beta = max(h_bar, w_bar) / longest_edge + h_bar = max(factor, int(math.floor((h_bar / beta) / factor)) * factor) + w_bar = max(factor, int(math.floor((w_bar / beta) / factor)) * factor) + + return h_bar, w_bar + + +@auto_docstring +class GlmImageImageProcessorFast(BaseImageProcessorFast): + do_resize = True + resample = PILImageResampling.BICUBIC + size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280} + do_rescale = True + do_normalize = True + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + do_convert_rgb = True + patch_size = 14 + temporal_patch_size = 2 + merge_size = 2 + min_pixels = None + max_pixels = None + valid_kwargs = GlmImageImageProcessorKwargs + model_input_names = ["pixel_values", "image_grid_thw"] + + def __init__(self, **kwargs: Unpack[GlmImageImageProcessorKwargs]): + size = kwargs.pop("size", None) + min_pixels = kwargs.pop("min_pixels", None) + max_pixels = kwargs.pop("max_pixels", None) + # backward compatibility: override size with min_pixels and max_pixels if they are provided + size = self.size if size is None else size + if min_pixels is not None: + size["shortest_edge"] = min_pixels + size.pop("min_pixels", None) + if max_pixels is not None: + size["longest_edge"] = max_pixels + size.pop("max_pixels", None) + if "shortest_edge" not in size or "longest_edge" not in size: + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + + super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs) + + def _further_process_kwargs( + self, + size: SizeDict | None = None, + min_pixels: int | None = None, + max_pixels: int | None = None, + **kwargs, + ) -> dict: + """ + Update kwargs that need further processing before being validated + Can be overridden by subclasses to customize the processing of kwargs. + """ + if min_pixels is not None and max_pixels is not None: + size = {"shortest_edge": min_pixels, "longest_edge": max_pixels} + elif size is not None: + if "shortest_edge" not in size or "longest_edge" not in size: + raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") + min_pixels = size["shortest_edge"] + max_pixels = size["longest_edge"] + else: + size = {**self.size} + + return super()._further_process_kwargs(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs) + + @auto_docstring + def preprocess( + self, + images: ImageInput, + **kwargs: Unpack[GlmImageImageProcessorKwargs], + ) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Union[str, "torch.device"] | None = None, + **kwargs: Unpack[GlmImageImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + To be overridden by subclasses when image-like inputs other than images should be processed. + It can be used for segmentation maps, depth maps, etc. + """ + # Prepare input images + batch_feature = BatchFeature() + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + batch_feature = self._preprocess(images, **kwargs) + return batch_feature + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: float | list[float] | None, + image_std: float | list[float] | None, + patch_size: int, + temporal_patch_size: int, + merge_size: int, + disable_grouping: bool | None, + return_tensors: str | TensorType | None, + **kwargs, + ): + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + height, width = stacked_images.shape[-2:] + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=patch_size * merge_size, + min_pixels=size["shortest_edge"], + max_pixels=size["longest_edge"], + ) + stacked_images = self.resize( + image=stacked_images, + size=SizeDict(height=resized_height, width=resized_width), + interpolation=interpolation, + ) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + processed_grids = {} + for shape, stacked_images in grouped_images.items(): + resized_height, resized_width = stacked_images.shape[-2:] + # Fused rescale and normalize + patches = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + if patches.ndim == 4: + # add a temporal dimension if we have images + patches = patches.unsqueeze(1) + if patches.shape[1] % temporal_patch_size != 0: + repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) + patches = torch.cat([patches, repeats], dim=1) + batch_size, grid_t, channel = patches.shape[:3] + grid_t = grid_t // temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + + patches = patches.view( + batch_size, + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + # Reorder dimensions to group grid and patch information for subsequent flattening. + # (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w) + patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + flatten_patches = patches.reshape( + batch_size, + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + + processed_images_grouped[shape] = flatten_patches + processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_grids = reorder_images(processed_grids, grouped_images_index) + pixel_values = torch.cat(processed_images, dim=0) + image_grid_thw = torch.tensor(processed_grids) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors + ) + + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number of image patches for a given image size. + + Note: Do not remove this method! It is used by vLLM to infer the number of patches and placeholders + without an image input. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of image patches per image. + """ + min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"] + max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"] + patch_size = images_kwargs.get("patch_size", self.patch_size) + merge_size = images_kwargs.get("merge_size", self.merge_size) + + factor = patch_size * merge_size + resized_height, resized_width = smart_resize( + height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels + ) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + return grid_h * grid_w + + +__all__ = ["GlmImageImageProcessorFast"] diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py new file mode 100644 index 000000000000..d3227c9f6b84 --- /dev/null +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -0,0 +1,1592 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm_image.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Optional + +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub, use_kernelized_func +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available +from ...utils.generic import check_model_inputs, maybe_autocast +from .configuration_glm_image import GlmImageConfig, GlmImageTextConfig, GlmImageVisionConfig, GlmImageVQVAEConfig + + +if is_torch_available(): + import torch + + +class GlmImageVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class GlmImageVisionAttention(nn.Module): + def __init__(self, config: GlmImageVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = config.attention_dropout + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if "flash" in self.config._attn_implementation: + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class GlmImageVisionPatchEmbed(nn.Module): + def __init__(self, config: GlmImageVisionConfig) -> None: + super().__init__() + self.patch_size = config.patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + kernel_size = [self.patch_size, self.patch_size] + self.proj = nn.Conv2d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size) + + def forward(self, hidden_states) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view(-1, self.in_channels, self.patch_size, self.patch_size) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class GlmImageVisionEmbeddings(nn.Module): + def __init__(self, config: GlmImageVisionConfig) -> None: + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.interpolated_method = "bilinear" + + def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor: + """ + Forward pass with integrated position encoding adaptation using 2D interpolation. + + Args: + embeddings: Input embeddings tensor + lengths (torch.Tensor): Sequence lengths for each image in the batch. + image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w). + h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch. + w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch. + + Returns: + torch.Tensor: Embeddings with adapted position encoding added. + """ + # Get position embedding parameters + pos_embed_weight = self.position_embedding.weight + hidden_size = pos_embed_weight.shape[1] + device = pos_embed_weight.device + + # Convert inputs to tensors if needed + if isinstance(lengths, list): + lengths = torch.tensor(lengths, device=device, dtype=torch.long) + + # Prepare 2D position embedding + orig_size_sq = pos_embed_weight.shape[0] + orig_size = int(orig_size_sq**0.5) + pos_embed_2d = ( + pos_embed_weight.view(orig_size, orig_size, hidden_size) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=device, dtype=torch.float32) + ) + + # Calculate target dimensions for each patch + target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( + device=device, dtype=torch.float32 + ) + + # Normalize coordinates to [-1, 1] range for grid_sample + norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 + norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 + + # Create sampling grid + grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) + + # Perform bicubic interpolation + interpolated_embed_fp32 = F.grid_sample( + pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border" + ) + + # Reshape and convert back to original dtype + adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) + adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device) + + # Add adapted position encoding to embeddings + embeddings = embeddings + adapted_pos_embed + return embeddings + + +class GlmImageVisionBlock(GradientCheckpointingLayer): + def __init__(self, config: GlmImageVisionConfig) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = GlmImageVisionAttention(config) + self.mlp = GlmImageVisionMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + r""" + cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`): + The cumulative sequence lengths of each image or video feature. + position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`): + The cosine and sine position embeddings for vision attention. + """ + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.attn( + hidden_states, + cu_seqlens=cu_seqlens, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +@use_kernelized_func(apply_rotary_pos_emb) +class GlmImageTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GlmImageTextConfig, layer_idx: int | None = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + self.rope_parameters = config.rope_parameters + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +@use_kernel_forward_from_hub("RMSNorm") +class GlmImageRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GlmImageRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GlmImageTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +class GlmImageTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GlmImageTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GlmImageTextAttention(config, layer_idx) + self.mlp = GlmImageTextMLP(config) + self.input_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mlp_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @auto_docstring + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring +class GlmImagePreTrainedModel(PreTrainedModel): + config: GlmImageConfig + base_model_prefix = "model" + input_modalities = ("image", "text") + supports_gradient_checkpointing = True + _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": GlmImageTextDecoderLayer, + "attentions": GlmImageTextAttention, + } + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class GlmImageModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + + +class GlmImageVQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config: GlmImageVQVAEConfig): + super().__init__() + self.num_embeddings = config.num_embeddings + self.embedding_dim = config.embed_dim + self.beta = getattr(config, "beta", 0.25) + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + + def forward(self, hidden_state: torch.Tensor): + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) + + # L2 normalize + hidden_state = F.normalize(hidden_state, p=2, dim=-1) + hidden_state_flattened = F.normalize(hidden_state_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + distances = ( + torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, embedding.transpose(0, 1)) + ) + + min_encoding_indices = torch.argmin(distances, dim=1) + hidden_state_quant = embedding[min_encoding_indices].view(hidden_state.shape) + + # compute loss for embedding + loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean( + (hidden_state_quant - hidden_state.detach()) ** 2 + ) + + # preserve gradients + hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach() + + # reshape back to match original input shape + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() + + return hidden_state_quant, loss, min_encoding_indices + + +@auto_docstring( + custom_intro=""" + The VQ-VAE model used in GlmImage for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv + Taigman](https://huggingface.co/papers/2203.13131). + """ +) +class GlmImageVQVAE(GlmImagePreTrainedModel): + config: GlmImageVQVAEConfig + _no_split_modules = [ + "GlmImageVQVAEVectorQuantizer", + ] + + def __init__(self, config: GlmImageVQVAEConfig): + super().__init__(config) + self.quantize = GlmImageVQVAEVectorQuantizer(config) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) + self.eval() # GlmImage's VQ model is frozen + self.post_init() + + def encode(self, hidden_states): + hidden_states = self.quant_conv(hidden_states) + quant, emb_loss, indices = self.quantize(hidden_states) + return quant, emb_loss, indices + + +class GlmImageVisionModel(GlmImagePreTrainedModel): + config: GlmImageVisionConfig + input_modalities = ("image",) + _no_split_modules = ["GlmImageVisionBlock"] + main_input_name = "pixel_values" + + def __init__(self, config: GlmImageVisionConfig) -> None: + super().__init__(config) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + + self.embeddings = GlmImageVisionEmbeddings(config) + self.patch_embed = GlmImageVisionPatchEmbed(config) + + head_dim = config.hidden_size // config.num_heads + + self.blocks = nn.ModuleList([GlmImageVisionBlock(config) for _ in range(config.depth)]) + + self.gradient_checkpointing = False + self.head_dim = head_dim + self.post_init() + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + return pos_ids + + def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`): + Packed pixel values. + grid_thw (`torch.Tensor` of shape `(num_images, 3)`): + The temporal, height and width of feature shape of each image. + + Returns: + `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. + """ + + hidden_states = self.patch_embed(pixel_values) + image_type_ids = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + hidden_states = self.embeddings( + hidden_states, + seqlens, + grid_thw, + image_type_ids[:, 0].to(hidden_states.device), + image_type_ids[:, 1].to(hidden_states.device), + ) + + # Transformer blocks (no position_embeddings needed, already added above) + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + ) + return hidden_states + + +class GlmImageTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: GlmImageTextConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12]) + + @staticmethod + def compute_default_rope_parameters( + config: GlmImageTextConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, GLM-V has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def apply_mrope(self, freqs, mrope_section): + section = mrope_section + chunks = freqs.split(section, dim=-1) + result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + return result + + +@auto_docstring +class GlmImageTextModel(GlmImagePreTrainedModel): + config: GlmImageTextConfig + input_modalities = ("text",) + + def __init__(self, config: GlmImageTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GlmImageTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GlmImageTextRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + @check_model_inputs + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple | BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions + # where each dim indicates visual spatial positions for temporal/height/width grids. + # There are two scenarios when FA2-like packed masking might be activated. + # 1. User specifically passed packed `position_ids` and no attention mask. + # In this case we expect the useer to create correct position ids for all 3 grids + # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len] + # 2. User runs forward with no attention mask and no position ids. In this case, position ids + # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are + # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass + # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation` + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids + text_position_ids = None + + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": text_position_ids, + } + # Create the masks + causal_mask = create_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class GlmImageModel(GlmImagePreTrainedModel): + base_model_prefix = "model" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: GlmImageConfig + _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = GlmImageVisionModel._from_config(config.vision_config) + self.language_model = GlmImageTextModel._from_config(config.text_config) + + self.rope_deltas = None # cache rope_deltas here + self.vqmodel = GlmImageVQVAE._from_config(config.vq_config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_rope_index( + self, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + attention_mask: torch.LongTensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index for image generation task. + + Explanation: + Each embedding sequence may contain image tokens (for generation) and text tokens, + or just text tokens. + + Input format: + - Text-to-Image: [text tokens] + <|dit_token_16384|> + - Image-to-Image: <|dit_token_16384|> [image tokens] <|dit_token_16385|> + [text tokens] + <|dit_token_16384|> + + For pure text embedding sequence, the rotary position embedding is the same across all 3 dimensions. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For sequences with image tokens, we use special markers to denote image regions: + - <|dit_token_16384|>: image start marker + - <|dit_token_16385|>: image end marker + - Image tokens between these markers use 2D spatial position encoding. + + For image tokens: + - temporal: stays constant at (image_start_pos + 1) + - height: increments every w tokens, representing row position + - width: cycles from 0 to w-1, representing column position + + After each image region, the next position jumps to: image_start_pos + 1 + max(h, w) + This ensures sufficient positional separation between images and subsequent tokens. + + Examples: + === Case 1: Image-to-Image Generation === + + Source image with grid [1, 3, 2], followed by text, then generation. + input_ids: [<|dit_token_16384|> V V V V V V <|dit_token_16385|> T T T T <|dit_token_16384|>] + image_grid_thw: [[1, 3, 2], [1, 4, 4]] # first is source, second is target + + For source image (h=3, w=2, 6 tokens): + Start marker at position 0 + Image tokens at temporal=1, height=[1,1,2,2,3,3], width=[1,2,1,2,1,2] + End marker at position 4 (= 0 + 1 + max(3,2)) + + Text tokens and trailing start marker continue from position 5. + + Full prefill position_ids: + temporal: [0, 1,1,1,1,1,1, 4, 5,6,7,8, 9] + height: [0, 1,1,2,2,3,3, 4, 5,6,7,8, 9] + width: [0, 1,2,1,2,1,2, 4, 5,6,7,8, 9] + + Decode stage: use image_grid_thw[-1] = [1, 4, 4] to build cached position_ids, + starting from gen_st_idx = 10. + + === Case 2: Text-to-Image Generation (multi-resolution) === + + Pure text input with two image_grids for progressive generation. + input_ids: [hello3 33 2<|dit_token_16384|>] + Assume "hello3 33 2" = 4 tokens (positions 0-3) + <|dit_token_16384|> at position 4 + image_grid_thw: [[1, 3, 3], [1, 3, 2]] + - image_grid_thw[-1] = [1, 3, 2]: first generated image (smaller/draft) + - image_grid_thw[-2] = [1, 3, 3]: second generated image (larger/final) + + Prefill position_ids (5 tokens: 4 text + 1 start marker): + temporal: [0, 1, 2, 3, 4] + height: [0, 1, 2, 3, 4] + width: [0, 1, 2, 3, 4] + + Decode stage builds position_ids in reverse order of image_grid_thw: + + First: image_grid_thw[-1] = [1, 3, 2] (6 tokens), starting at position 5: + temporal: [5, 5, 5, 5, 5, 5] + height: [5, 5, 6, 6, 7, 7] + width: [5, 6, 5, 6, 5, 6] + next_pos = 5 + max(3, 2) = 8 + + Then: image_grid_thw[-2] = [1, 3, 3] (9 tokens), starting at position 8: + temporal: [8, 8, 8, 8, 8, 8, 8, 8, 8] + height: [8, 8, 8, 9, 9, 9, 10, 10, 10] + width: [8, 9, 10, 8, 9, 10, 8, 9, 10] + next_pos = 8 + max(3, 3) = 11 + + Finally: <|dit_token_16385|> end marker at position 11 + + Full sequence position_ids (prefill + decode): + temporal: [0,1,2,3, 4, 5,5,5,5,5,5, 8,8,8,8,8,8,8,8,8, 11] + height: [0,1,2,3, 4, 5,5,6,6,7,7, 8,8,8,9,9,9,10,10,10, 11] + width: [0,1,2,3, 4, 5,6,5,6,5,6, 8,9,10,8,9,10,8,9,10, 11] + + _cached_decode_position_ids shape: [3, 6 + 9 + 1] = [3, 16] + (includes all generated image tokens + end marker) + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default + should you provide it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image. For image generation, + temporal is typically 1. + - For image-to-image: includes source image grids + target image grid(s) + - For text-to-image with multi-resolution: includes multiple target grids, + processed in reverse order (last grid first, second-to-last grid second, etc.) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`): + Position IDs for temporal, height, and width dimensions. + mrope_position_deltas (`torch.Tensor` of shape `(batch_size, 1)`): + Position deltas for multi-modal rotary position embedding (zeros for this task). + """ + + batch_size, seq_len = input_ids.shape + device = input_ids.device + dtype = input_ids.dtype + + image_start_token_id = self.config.image_start_token_id + image_end_token_id = self.config.image_end_token_id + num_complete_images = (input_ids == image_end_token_id).sum().item() + + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device + ) + text_positions = torch.arange(seq_len)[None, :].repeat(3, 1) + for batch_idx in range(batch_size): + curr_input_ids = input_ids[batch_idx] + if attention_mask is not None: + curr_input_ids = curr_input_ids[attention_mask[batch_idx] == 1] + + image_end = torch.where(curr_input_ids == image_end_token_id)[0] + image_start = torch.where(curr_input_ids == image_start_token_id)[0] + 1 + current_pos = 0 # track the current position value + prev_image_end = 0 + curr_position_ids = [] + for start, end, grid in zip(image_start, image_end, image_grid_thw): + _, num_width_grid, num_height_grid = grid + + # Create text position ids first if there are text tokens before image + llm_pos_length = start - prev_image_end + llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to( + device=input_ids.device + ) + current_pos += llm_position_ids.shape[-1] + + # Now create image position ids for each grid + image_seq_length = num_height_grid * num_width_grid + h_grids = image_seq_length // num_height_grid + current_pos + w_grids = image_seq_length // num_width_grid + current_pos + position_width = torch.arange(current_pos, w_grids, device=input_ids.device).repeat(num_width_grid) + position_height = torch.arange(current_pos, h_grids, device=input_ids.device).repeat_interleave( + num_height_grid + ) + position_temporal = torch.full( + (image_seq_length,), current_pos, device=input_ids.device, dtype=torch.long + ) + vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0) + current_pos += max(num_height_grid, num_width_grid) + + prev_image_end = end + curr_position_ids.append(torch.cat([llm_position_ids, vision_position_ids], dim=-1)) + + # Add position ids for the last text tokens if any + end_position = len(curr_input_ids) - prev_image_end + llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=input_ids.device) + current_pos += llm_position_ids.shape[-1] + curr_position_ids.append(llm_position_ids) + curr_position_ids = torch.cat(curr_position_ids, dim=-1) + if attention_mask is not None: + position_ids[:, batch_idx, attention_mask[batch_idx] == 1] = curr_position_ids.to(position_ids.device) + else: + position_ids[:, batch_idx, :] = curr_position_ids.to(position_ids.device) + + # Build and store position ids for tokens that will be generated. Later we will just + # slice these instead of computing each decoding step + self._prefill_len = seq_len + if image_grid_thw is not None and len(image_grid_thw) > 0: + num_decode_grids = len(image_grid_thw) - num_complete_images + num_decode_grids = max(num_decode_grids, 0) + decode_pos = current_pos + + decode_temporal_list = [] + decode_height_list = [] + decode_width_list = [] + + for i in range(1, num_decode_grids + 1): + grid_idx = -i + h = image_grid_thw[grid_idx, 1].item() + w = image_grid_thw[grid_idx, 2].item() + total_tokens = h * w + + h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten() + w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten() + + decode_temporal_list.append(torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long)) + decode_height_list.append(decode_pos + h_indices) + decode_width_list.append(decode_pos + w_indices) + decode_pos = decode_pos + max(h, w) + + decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long)) + decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long)) + decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long)) + + self._cached_decode_position_ids = torch.stack( + [ + torch.cat(decode_temporal_list, dim=0), + torch.cat(decode_height_list, dim=0), + torch.cat(decode_width_list, dim=0), + ], + dim=0, + ) + else: + self._cached_decode_position_ids = None + + mrope_position_deltas = torch.zeros([batch_size, 1], dtype=dtype, device=device) + + return position_ids, mrope_position_deltas + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + image_ids: torch.LongTensor, + ): + """ + Replace image placeholder tokens in input_ids with actual image token ids from VQVAE. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Input token ids with image placeholders. + image_ids (`torch.LongTensor` of shape `(num_images, num_tokens_per_image)` or flattened): + Discrete token indices from the VQVAE codebook. + + Returns: + special_image_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Mask indicating positions in input ids that will be replaced by actual image tokens. + """ + + special_image_mask = input_ids == self.config.image_token_id + n_placeholder_tokens = special_image_mask.sum().item() + n_image_tokens = image_ids.shape[0] + + if n_placeholder_tokens != n_image_tokens: + raise ValueError( + f"Number of image placeholder tokens ({n_placeholder_tokens}) does not match " + f"number of image tokens from VQVAE ({n_image_tokens})" + ) + + return special_image_mask + + @auto_docstring + @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GlmImageModelOutputWithPast: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw[:-1]) + image_embeds = torch.cat(image_embeds, dim=0) + image_ids = self.get_image_tokens(image_embeds, image_grid_thw[:-1]) + image_ids = image_ids.view(-1).to(input_ids.device) + special_image_mask = self.get_placeholder_mask(input_ids, image_ids) + input_ids = input_ids.masked_scatter(special_image_mask, image_ids) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if position_ids is None: + attention_mask_2d = attention_mask + if attention_mask is not None and attention_mask.ndim == 4: + attention_mask_2d = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_2d.dtype.is_floating_point: + attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min + attention_mask_2d = (1.0 - attention_mask_2d).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # It is safe to assume that `length!=1` means we're in pre-fill because the + # model is used only by DiT pipeline without assisted decoding, etc. techniques + is_prefill_stage = (input_ids is not None and input_ids.shape[1] != 1) or ( + inputs_embeds is not None and inputs_embeds.shape[1] != 1 + ) + if is_prefill_stage or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + attention_mask=attention_mask_2d, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + # Use prefill token length, not position value + step = cache_position[0].item() - self._prefill_len + # Direct lookup - no tensor creation overhead + position_ids = self._cached_decode_position_ids[:, step : step + seq_length] + position_ids = position_ids.unsqueeze(1).expand(-1, batch_size, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + return GlmImageModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def get_image_tokens( + self, + hidden_states: torch.FloatTensor, + image_grid_thw: torch.LongTensor, + ) -> torch.LongTensor: + """ + Tokenizes image features into discrete tokens with VQVAE module. + + Args: + hidden_states (`torch.FloatTensor` of shape `(total_patches, hidden_size)`): + The packed image features from vision encoder. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): + The temporal, height and width of feature shape of each image. + + Returns: + image_tokens (`torch.LongTensor` of shape `(total_patches,)`): + Discrete token indices from the VQVAE codebook. + """ + hidden_size = hidden_states.shape[-1] + split_sizes = (image_grid_thw.prod(dim=-1)).tolist() + hidden_states_list = torch.split(hidden_states, split_sizes, dim=0) + + all_image_toks = [] + for i, hs in enumerate(hidden_states_list): + grid_t, grid_h, grid_w = image_grid_thw[i].tolist() + hs = hs.view(grid_t, grid_h, grid_w, hidden_size) + hs = hs.permute(0, 3, 1, 2).contiguous() + _, _, image_toks = self.vqmodel.encode(hs) + all_image_toks.append(image_toks) + return torch.cat(all_image_toks, dim=0) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for GlmImage causal language model (or autoregressive) outputs. + """ +) +class GlmImageCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + rope_deltas: torch.LongTensor | None = None + + +class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + base_model_prefix = "model" + config: GlmImageConfig + + def __init__(self, config): + super().__init__(config) + self.model = GlmImageModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vision_vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + def get_image_tokens(self, hidden_states: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None): + return self.model.get_image_tokens(hidden_states, image_grid_thw) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GlmImageCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, GlmImageForConditionalGeneration + + >>> model = GlmImageForConditionalGeneration.from_pretrained("zai-org/GLM-Image") + >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-Image") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Add a truck of this photo.28 40"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return GlmImageCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + image_grid_thw=None, + is_first_iteration=False, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + is_first_iteration=is_first_iteration, + use_cache=use_cache, + **kwargs, + ) + + model_inputs["position_ids"] = None + + if not is_first_iteration and use_cache: + model_inputs["pixel_values"] = None + + return model_inputs + + def _get_image_nums( + self, + input_ids: torch.LongTensor | None, + ) -> torch.Tensor: + """ + Get the number of images for each sample. + For GLM-Image, only input_ids allow us to get the number of images. + + Returns: + image_counts (`torch.LongTensor` of shape `(batch_size,)`) + """ + is_image = input_ids == self.config.image_start_token_id + + return is_image.sum(dim=1) + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + image_nums = self._get_image_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample and +1 for the image being generated + lengths = list(image_nums) + last_image = dict_to_expand[key][:-1] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key][: sum(image_nums)], lengths=lengths, repeat_times=expand_size + ) + dict_to_expand[key] = torch.cat([dict_to_expand[key], last_image], dim=0) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = [ + "GlmImagePreTrainedModel", + "GlmImageVQVAE", + "GlmImageVisionModel", + "GlmImageTextModel", + "GlmImageModel", + "GlmImageForConditionalGeneration", +] diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py new file mode 100644 index 000000000000..ac7d1c33b92d --- /dev/null +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -0,0 +1,1480 @@ +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from ...cache_utils import Cache +from ...configuration_utils import PreTrainedConfig +from ...feature_extraction_utils import BatchFeature +from ...generation import GenerationMixin +from ...image_utils import ImageInput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import ImagesKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import TransformersKwargs, is_torch_available, logging +from ..chameleon.modeling_chameleon import ChameleonVQVAE, ChameleonVQVAEVectorQuantizer +from ..glm4v.configuration_glm4v import Glm4vTextConfig, Glm4vVisionConfig +from ..glm4v.modeling_glm4v import ( + Glm4vCausalLMOutputWithPast, + Glm4vModel, + Glm4vModelOutputWithPast, + Glm4vPreTrainedModel, + Glm4vTextModel, + Glm4vVisionAttention, + Glm4vVisionBlock, + Glm4vVisionEmbeddings, + Glm4vVisionModel, + Glm4vVisionPatchEmbed, +) +from ..glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextAttention, eager_attention_forward +from ..qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor +from ..qwen2_vl.image_processing_qwen2_vl_fast import Qwen2VLImageProcessorFast +from ..qwen2_vl.processing_qwen2_vl import Qwen2VLProcessorKwargs +from ..siglip.modeling_siglip import SiglipMLP + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class GlmImageVQVAEConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GlmImageVQModel`]. It is used to instantiate a + `GlmImageVQModel` according to the specified arguments, defining the model architecture. + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. Instantiating a + configuration with the defaults will yield a similar configuration to the VQModel of the + [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image) architecture. + + Args: + embed_dim (`int`, *optional*, defaults to 2048): + Dimensionality of each embedding vector. + num_embeddings (`int`, *optional*, defaults to 16384): + Number of codebook embeddings. + latent_channels (`int`, *optional*, defaults to 1536): + Number of channels for the latent space. + in_channels (`int`, *optional*, defaults to 3): + Number of input channels. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + """ + + model_type = "glm_image_vqmodel" + base_config_key = "vq_config" + + def __init__( + self, + embed_dim: int = 2048, + num_embeddings: int = 16384, + latent_channels: int = 1536, + in_channels: int = 3, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.num_embeddings = num_embeddings + self.latent_channels = latent_channels + self.in_channels = in_channels + self.initializer_range = initializer_range + + +class GlmImageVisionConfig(Glm4vVisionConfig): + r""" + This is the configuration class to store the configuration of a [`GlmImageVisionModel`]. It is used to instantiate an GlmImageVisionModel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield + a similar configuration to that of + GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image). + + Args: + depth (`int`, *optional*, defaults to 40): + Number of layers (depth) in the model. + hidden_size (`int`, *optional*, defaults to 1536): + Dimensionality of the encoder layers and the pooler layer. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for attention weights. + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer architecture. + in_channels (`int`, *optional*, defaults to 3): + Number of input channels. + image_size (`int` or `list[int]`, *optional*, defaults to 2048): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + spatial_merge_size (`int`, *optional*, defaults to 1): + The size used for merging spatial dimensions. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + """ + + model_type = "glm_image_vision" + base_config_key = "vision_config" + + def __init__( + self, + depth=40, + hidden_size=1536, + hidden_act="gelu", + attention_bias=True, + attention_dropout=0.0, + num_heads=16, + in_channels=3, + image_size=2048, + patch_size=16, + layer_norm_eps=1e-06, + spatial_merge_size=1, + intermediate_size=6144, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + del self.out_hidden_size + del self.rms_norm_eps + del self.temporal_patch_size + self.layer_norm_eps = layer_norm_eps + + +class GlmImageTextConfig(Glm4vTextConfig): + r""" + This is the configuration class to store the configuration of a [`GlmImageTextModel`]. It is used to instantiate a + GLM-Image model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of + GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 168064): + Vocabulary size of the GlmImage model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GlmImageModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 13696): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + vision_vocab_size (`int`, *optional*, defaults to 16512): + Vision vocabulary size of the GlmImage model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`GlmImageVisionModel`] + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + + ```python + >>> from transformers import GlmImageTextModel, GlmImageConfig + + >>> # Initializing a GlmImageConfig style configuration + >>> configuration = GlmImageConfig() + + >>> # Initializing a model from the GlmImageConfig style configuration + >>> model = GlmImageTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + def __init__( + self, + vocab_size: int | None = 168064, + vision_vocab_size: int | None = 16512, + attention_bias: bool | None = True, + tie_word_embeddings: bool | None = False, + **super_kwargs, + ): + self.vocab_size = vocab_size + self.vision_vocab_size = vision_vocab_size + self.attention_bias = attention_bias + super().__init__( + tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **super_kwargs + ) + + +class GlmImageConfig(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GlmImageModel`]. It is used to instantiate a + GLM-Image model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of + GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image) architecture. + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `GlmImageTextConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `GlmImageVisionConfig`): + The config object or dictionary of the vision backbone. + vq_config (`Union[Dict, GlmImageVQVAEConfig]`, *optional*): + GlmImageVQVAEConfig instance containing the configuration for the VQ-VAE model. + image_token_id (`int`, *optional*, defaults to 167855): + The image token index to encode the image prompt. + image_start_token_id (`int`, *optional*, defaults to 16384): + The image start token index to encode the start of image. + image_end_token_id (`int`, *optional*, defaults to 16385): + The image end token index to encode the end of image. + + ```python + >>> from transformers import Glm4vForConditionalGeneration, Glm4vConfig + + >>> # Initializing a GLM-Image style configuration + >>> configuration = Glm4vConfig() + + >>> # Initializing a model from the GLM-Image style configuration + >>> model = Glm4vForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "glm_image" + sub_configs = { + "vision_config": GlmImageVisionConfig, + "text_config": GlmImageTextConfig, + "vq_config": GlmImageVQVAEConfig, + } + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + text_config=None, + vision_config=None, + vq_config=None, + image_token_id=167855, + image_start_token_id=16384, + image_end_token_id=16385, + **kwargs, + ): + if isinstance(vision_config, dict): + vision_config = self.sub_configs["vision_config"](**vision_config) + elif vision_config is None: + vision_config = self.sub_configs["vision_config"](**kwargs) + + if isinstance(vq_config, dict): + vq_config = self.sub_configs["vq_config"](**vq_config) + elif vq_config is None: + vq_config = self.sub_configs["vq_config"](**kwargs) + + if isinstance(text_config, dict): + text_config = self.sub_configs["text_config"](**text_config) + elif text_config is None: + text_config = self.sub_configs["text_config"](**kwargs) + + self.image_token_id = image_token_id + self.image_start_token_id = image_start_token_id + self.image_end_token_id = image_end_token_id + self.text_config = text_config + self.vision_config = vision_config + self.vq_config = vq_config + super().__init__(**kwargs) + + +class GlmImageVisionMLP(SiglipMLP): + pass + + +class GlmImageVisionAttention(Glm4vVisionAttention): + def __init__(self, config: GlmImageVisionConfig) -> None: + super().__init__(config) + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if "flash" in self.config._attn_implementation: + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class GlmImageVisionPatchEmbed(Glm4vVisionPatchEmbed): + def __init__(self, config: GlmImageVisionConfig) -> None: + super().__init__(config) + + del self.temporal_patch_size + kernel_size = [self.patch_size, self.patch_size] + self.proj = nn.Conv2d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size) + + def forward(self, hidden_states): + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view(-1, self.in_channels, self.patch_size, self.patch_size) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class GlmImageVisionEmbeddings(Glm4vVisionEmbeddings): + def __init__(self, config: GlmImageVisionConfig) -> None: + super().__init__(config) + self.interpolated_method = "bilinear" + + +class GlmImageVisionBlock(Glm4vVisionBlock): + def __init__(self, config: GlmImageVisionConfig): + super().__init__(config) + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = GlmImageVisionAttention(config) + self.mlp = GlmImageVisionMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + r""" + cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`): + The cumulative sequence lengths of each image or video feature. + position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`): + The cosine and sine position embeddings for vision attention. + """ + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.attn( + hidden_states, + cu_seqlens=cu_seqlens, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class GlmImageTextAttention(Glm4vMoeTextAttention): + pass + + +class GlmImagePreTrainedModel(Glm4vPreTrainedModel): + config: GlmImageConfig + input_modalities = ("image", "text") + + @torch.no_grad() + def _init_weights(self, module): + PreTrainedModel._init_weights(module) + + +class GlmImageModelOutputWithPast(Glm4vModelOutputWithPast): + pass + + +class GlmImageVQVAEVectorQuantizer(ChameleonVQVAEVectorQuantizer): + def __init__(self, config: GlmImageVQVAEConfig): + super().__init__(config) + self.num_embeddings = config.num_embeddings + self.embedding_dim = config.embed_dim + self.beta = getattr(config, "beta", 0.25) + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + + def forward(self, hidden_state: torch.Tensor): + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) + + # L2 normalize + hidden_state = F.normalize(hidden_state, p=2, dim=-1) + hidden_state_flattened = F.normalize(hidden_state_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + distances = ( + torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + + torch.sum(embedding**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, embedding.transpose(0, 1)) + ) + + min_encoding_indices = torch.argmin(distances, dim=1) + hidden_state_quant = embedding[min_encoding_indices].view(hidden_state.shape) + + # compute loss for embedding + loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean( + (hidden_state_quant - hidden_state.detach()) ** 2 + ) + + # preserve gradients + hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach() + + # reshape back to match original input shape + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() + + return hidden_state_quant, loss, min_encoding_indices + + +class GlmImageVQVAE(ChameleonVQVAE): + _no_split_modules = [ + "GlmImageVQVAEVectorQuantizer", + ] + + def __init__(self, config: GlmImageVQVAEConfig): + super().__init__(config) + del self.encoder + + def encode(self, hidden_states): + hidden_states = self.quant_conv(hidden_states) + quant, emb_loss, indices = self.quantize(hidden_states) + return quant, emb_loss, indices + + +class GlmImageVisionModel(Glm4vVisionModel): + config: GlmImageVisionConfig + main_input_name = "pixel_values" + input_modalities = ("image",) + + def __init__(self, config: GlmImageVisionConfig): + super().__init__(config) + + head_dim = config.hidden_size // config.num_heads + self.head_dim = head_dim + + del self.merger + del self.rotary_pos_emb + del self.post_conv_layernorm + del self.downsample + del self.post_layernorm + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + return pos_ids + + def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`): + Packed pixel values. + grid_thw (`torch.Tensor` of shape `(num_images, 3)`): + The temporal, height and width of feature shape of each image. + + Returns: + `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. + """ + + hidden_states = self.patch_embed(pixel_values) + image_type_ids = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + hidden_states = self.embeddings( + hidden_states, + seqlens, + grid_thw, + image_type_ids[:, 0].to(hidden_states.device), + image_type_ids[:, 1].to(hidden_states.device), + ) + + # Transformer blocks (no position_embeddings needed, already added above) + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + ) + return hidden_states + + +class GlmImageTextModel(Glm4vTextModel): + pass + + +class GlmImageModel(Glm4vModel): + def __init__(self, config): + super().__init__(config) + self.visual = GlmImageVisionModel._from_config(config.vision_config) + self.language_model = GlmImageTextModel._from_config(config.text_config) + self.vqmodel = GlmImageVQVAE._from_config(config.vq_config) + + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_rope_index( + self, + input_ids: torch.LongTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + attention_mask: torch.LongTensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index for image generation task. + + Explanation: + Each embedding sequence may contain image tokens (for generation) and text tokens, + or just text tokens. + + Input format: + - Text-to-Image: [text tokens] + <|dit_token_16384|> + - Image-to-Image: <|dit_token_16384|> [image tokens] <|dit_token_16385|> + [text tokens] + <|dit_token_16384|> + + For pure text embedding sequence, the rotary position embedding is the same across all 3 dimensions. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For sequences with image tokens, we use special markers to denote image regions: + - <|dit_token_16384|>: image start marker + - <|dit_token_16385|>: image end marker + - Image tokens between these markers use 2D spatial position encoding. + + For image tokens: + - temporal: stays constant at (image_start_pos + 1) + - height: increments every w tokens, representing row position + - width: cycles from 0 to w-1, representing column position + + After each image region, the next position jumps to: image_start_pos + 1 + max(h, w) + This ensures sufficient positional separation between images and subsequent tokens. + + Examples: + === Case 1: Image-to-Image Generation === + + Source image with grid [1, 3, 2], followed by text, then generation. + input_ids: [<|dit_token_16384|> V V V V V V <|dit_token_16385|> T T T T <|dit_token_16384|>] + image_grid_thw: [[1, 3, 2], [1, 4, 4]] # first is source, second is target + + For source image (h=3, w=2, 6 tokens): + Start marker at position 0 + Image tokens at temporal=1, height=[1,1,2,2,3,3], width=[1,2,1,2,1,2] + End marker at position 4 (= 0 + 1 + max(3,2)) + + Text tokens and trailing start marker continue from position 5. + + Full prefill position_ids: + temporal: [0, 1,1,1,1,1,1, 4, 5,6,7,8, 9] + height: [0, 1,1,2,2,3,3, 4, 5,6,7,8, 9] + width: [0, 1,2,1,2,1,2, 4, 5,6,7,8, 9] + + Decode stage: use image_grid_thw[-1] = [1, 4, 4] to build cached position_ids, + starting from gen_st_idx = 10. + + === Case 2: Text-to-Image Generation (multi-resolution) === + + Pure text input with two image_grids for progressive generation. + input_ids: [hello3 33 2<|dit_token_16384|>] + Assume "hello3 33 2" = 4 tokens (positions 0-3) + <|dit_token_16384|> at position 4 + image_grid_thw: [[1, 3, 3], [1, 3, 2]] + - image_grid_thw[-1] = [1, 3, 2]: first generated image (smaller/draft) + - image_grid_thw[-2] = [1, 3, 3]: second generated image (larger/final) + + Prefill position_ids (5 tokens: 4 text + 1 start marker): + temporal: [0, 1, 2, 3, 4] + height: [0, 1, 2, 3, 4] + width: [0, 1, 2, 3, 4] + + Decode stage builds position_ids in reverse order of image_grid_thw: + + First: image_grid_thw[-1] = [1, 3, 2] (6 tokens), starting at position 5: + temporal: [5, 5, 5, 5, 5, 5] + height: [5, 5, 6, 6, 7, 7] + width: [5, 6, 5, 6, 5, 6] + next_pos = 5 + max(3, 2) = 8 + + Then: image_grid_thw[-2] = [1, 3, 3] (9 tokens), starting at position 8: + temporal: [8, 8, 8, 8, 8, 8, 8, 8, 8] + height: [8, 8, 8, 9, 9, 9, 10, 10, 10] + width: [8, 9, 10, 8, 9, 10, 8, 9, 10] + next_pos = 8 + max(3, 3) = 11 + + Finally: <|dit_token_16385|> end marker at position 11 + + Full sequence position_ids (prefill + decode): + temporal: [0,1,2,3, 4, 5,5,5,5,5,5, 8,8,8,8,8,8,8,8,8, 11] + height: [0,1,2,3, 4, 5,5,6,6,7,7, 8,8,8,9,9,9,10,10,10, 11] + width: [0,1,2,3, 4, 5,6,5,6,5,6, 8,9,10,8,9,10,8,9,10, 11] + + _cached_decode_position_ids shape: [3, 6 + 9 + 1] = [3, 16] + (includes all generated image tokens + end marker) + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default + should you provide it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image. For image generation, + temporal is typically 1. + - For image-to-image: includes source image grids + target image grid(s) + - For text-to-image with multi-resolution: includes multiple target grids, + processed in reverse order (last grid first, second-to-last grid second, etc.) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`): + Position IDs for temporal, height, and width dimensions. + mrope_position_deltas (`torch.Tensor` of shape `(batch_size, 1)`): + Position deltas for multi-modal rotary position embedding (zeros for this task). + """ + + batch_size, seq_len = input_ids.shape + device = input_ids.device + dtype = input_ids.dtype + + image_start_token_id = self.config.image_start_token_id + image_end_token_id = self.config.image_end_token_id + num_complete_images = (input_ids == image_end_token_id).sum().item() + + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device + ) + text_positions = torch.arange(seq_len)[None, :].repeat(3, 1) + for batch_idx in range(batch_size): + curr_input_ids = input_ids[batch_idx] + if attention_mask is not None: + curr_input_ids = curr_input_ids[attention_mask[batch_idx] == 1] + + image_end = torch.where(curr_input_ids == image_end_token_id)[0] + image_start = torch.where(curr_input_ids == image_start_token_id)[0] + 1 + current_pos = 0 # track the current position value + prev_image_end = 0 + curr_position_ids = [] + for start, end, grid in zip(image_start, image_end, image_grid_thw): + _, num_width_grid, num_height_grid = grid + + # Create text position ids first if there are text tokens before image + llm_pos_length = start - prev_image_end + llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to( + device=input_ids.device + ) + current_pos += llm_position_ids.shape[-1] + + # Now create image position ids for each grid + image_seq_length = num_height_grid * num_width_grid + h_grids = image_seq_length // num_height_grid + current_pos + w_grids = image_seq_length // num_width_grid + current_pos + position_width = torch.arange(current_pos, w_grids, device=input_ids.device).repeat(num_width_grid) + position_height = torch.arange(current_pos, h_grids, device=input_ids.device).repeat_interleave( + num_height_grid + ) + position_temporal = torch.full( + (image_seq_length,), current_pos, device=input_ids.device, dtype=torch.long + ) + vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0) + current_pos += max(num_height_grid, num_width_grid) + + prev_image_end = end + curr_position_ids.append(torch.cat([llm_position_ids, vision_position_ids], dim=-1)) + + # Add position ids for the last text tokens if any + end_position = len(curr_input_ids) - prev_image_end + llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=input_ids.device) + current_pos += llm_position_ids.shape[-1] + curr_position_ids.append(llm_position_ids) + curr_position_ids = torch.cat(curr_position_ids, dim=-1) + if attention_mask is not None: + position_ids[:, batch_idx, attention_mask[batch_idx] == 1] = curr_position_ids.to(position_ids.device) + else: + position_ids[:, batch_idx, :] = curr_position_ids.to(position_ids.device) + + # Build and store position ids for tokens that will be generated. Later we will just + # slice these instead of computing each decoding step + self._prefill_len = seq_len + if image_grid_thw is not None and len(image_grid_thw) > 0: + num_decode_grids = len(image_grid_thw) - num_complete_images + num_decode_grids = max(num_decode_grids, 0) + decode_pos = current_pos + + decode_temporal_list = [] + decode_height_list = [] + decode_width_list = [] + + for i in range(1, num_decode_grids + 1): + grid_idx = -i + h = image_grid_thw[grid_idx, 1].item() + w = image_grid_thw[grid_idx, 2].item() + total_tokens = h * w + + h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten() + w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten() + + decode_temporal_list.append(torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long)) + decode_height_list.append(decode_pos + h_indices) + decode_width_list.append(decode_pos + w_indices) + decode_pos = decode_pos + max(h, w) + + decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long)) + decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long)) + decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long)) + + self._cached_decode_position_ids = torch.stack( + [ + torch.cat(decode_temporal_list, dim=0), + torch.cat(decode_height_list, dim=0), + torch.cat(decode_width_list, dim=0), + ], + dim=0, + ) + else: + self._cached_decode_position_ids = None + + mrope_position_deltas = torch.zeros([batch_size, 1], dtype=dtype, device=device) + + return position_ids, mrope_position_deltas + + def get_image_tokens( + self, + hidden_states: torch.FloatTensor, + image_grid_thw: torch.LongTensor, + ) -> torch.LongTensor: + """ + Tokenizes image features into discrete tokens with VQVAE module. + + Args: + hidden_states (`torch.FloatTensor` of shape `(total_patches, hidden_size)`): + The packed image features from vision encoder. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): + The temporal, height and width of feature shape of each image. + + Returns: + image_tokens (`torch.LongTensor` of shape `(total_patches,)`): + Discrete token indices from the VQVAE codebook. + """ + hidden_size = hidden_states.shape[-1] + split_sizes = (image_grid_thw.prod(dim=-1)).tolist() + hidden_states_list = torch.split(hidden_states, split_sizes, dim=0) + + all_image_toks = [] + for i, hs in enumerate(hidden_states_list): + grid_t, grid_h, grid_w = image_grid_thw[i].tolist() + hs = hs.view(grid_t, grid_h, grid_w, hidden_size) + hs = hs.permute(0, 3, 1, 2).contiguous() + _, _, image_toks = self.vqmodel.encode(hs) + all_image_toks.append(image_toks) + return torch.cat(all_image_toks, dim=0) + + def get_video_features(self): + raise AttributeError("Not needed for GlmImage") + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + image_ids: torch.LongTensor, + ): + """ + Replace image placeholder tokens in input_ids with actual image token ids from VQVAE. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Input token ids with image placeholders. + image_ids (`torch.LongTensor` of shape `(num_images, num_tokens_per_image)` or flattened): + Discrete token indices from the VQVAE codebook. + + Returns: + special_image_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`): + Mask indicating positions in input ids that will be replaced by actual image tokens. + """ + + special_image_mask = input_ids == self.config.image_token_id + n_placeholder_tokens = special_image_mask.sum().item() + n_image_tokens = image_ids.shape[0] + + if n_placeholder_tokens != n_image_tokens: + raise ValueError( + f"Number of image placeholder tokens ({n_placeholder_tokens}) does not match " + f"number of image tokens from VQVAE ({n_image_tokens})" + ) + + return special_image_mask + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + rope_deltas: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GlmImageModelOutputWithPast: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw[:-1]) + image_embeds = torch.cat(image_embeds, dim=0) + image_ids = self.get_image_tokens(image_embeds, image_grid_thw[:-1]) + image_ids = image_ids.view(-1).to(input_ids.device) + special_image_mask = self.get_placeholder_mask(input_ids, image_ids) + input_ids = input_ids.masked_scatter(special_image_mask, image_ids) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if position_ids is None: + attention_mask_2d = attention_mask + if attention_mask is not None and attention_mask.ndim == 4: + attention_mask_2d = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_2d.dtype.is_floating_point: + attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min + attention_mask_2d = (1.0 - attention_mask_2d).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # It is safe to assume that `length!=1` means we're in pre-fill because the + # model is used only by DiT pipeline without assisted decoding, etc. techniques + is_prefill_stage = (input_ids is not None and input_ids.shape[1] != 1) or ( + inputs_embeds is not None and inputs_embeds.shape[1] != 1 + ) + if is_prefill_stage or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + attention_mask=attention_mask_2d, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + # Use prefill token length, not position value + step = cache_position[0].item() - self._prefill_len + # Direct lookup - no tensor creation overhead + position_ids = self._cached_decode_position_ids[:, step : step + seq_length] + position_ids = position_ids.unsqueeze(1).expand(-1, batch_size, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + return GlmImageModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + +class GlmImageCausalLMOutputWithPast(Glm4vCausalLMOutputWithPast): + pass + + +class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + base_model_prefix = "model" + config: GlmImageConfig + + def __init__(self, config): + super().__init__(config) + self.model = GlmImageModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vision_vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + def get_image_tokens(self, hidden_states: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None): + return self.model.get_image_tokens(hidden_states, image_grid_thw) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GlmImageCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, GlmImageForConditionalGeneration + + >>> model = GlmImageForConditionalGeneration.from_pretrained("zai-org/GLM-Image") + >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-Image") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Add a truck of this photo.28 40"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return GlmImageCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + image_grid_thw=None, + is_first_iteration=False, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + is_first_iteration=is_first_iteration, + use_cache=use_cache, + **kwargs, + ) + + model_inputs["position_ids"] = None + + if not is_first_iteration and use_cache: + model_inputs["pixel_values"] = None + + return model_inputs + + def _get_image_nums( + self, + input_ids: torch.LongTensor | None, + ) -> torch.Tensor: + """ + Get the number of images for each sample. + For GLM-Image, only input_ids allow us to get the number of images. + + Returns: + image_counts (`torch.LongTensor` of shape `(batch_size,)`) + """ + is_image = input_ids == self.config.image_start_token_id + + return is_image.sum(dim=1) + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: torch.LongTensor | None = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + image_nums = self._get_image_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample and +1 for the image being generated + lengths = list(image_nums) + last_image = dict_to_expand[key][:-1] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key][: sum(image_nums)], lengths=lengths, repeat_times=expand_size + ) + dict_to_expand[key] = torch.cat([dict_to_expand[key], last_image], dim=0) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +def smart_resize( + height: int, + width: int, + factor: int = 16, + min_pixels: int = 512 * 512, + max_pixels: int = 2048 * 2048, +) -> tuple[int, int]: + if height < factor or width < factor: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 4: + raise ValueError( + f"absolute aspect ratio must be smaller than 4, got {max(height, width) / min(height, width)}" + ) + + shortest_edge = int(round(math.sqrt(min_pixels))) + longest_edge = int(round(math.sqrt(max_pixels))) + min_side = min(height, width) + max_side = max(height, width) + + scale = 1.0 + + if min_side < shortest_edge: + scale = shortest_edge / min_side + + if max_side * scale > longest_edge: + scale = longest_edge / max_side + + height = height // 2 + width = width // 2 + + h_bar = max(factor, int(round(height * scale / factor)) * factor) + w_bar = max(factor, int(round(width * scale / factor)) * factor) + + if max(h_bar, w_bar) > longest_edge: + beta = max(h_bar, w_bar) / longest_edge + h_bar = max(factor, int(math.floor((h_bar / beta) / factor)) * factor) + w_bar = max(factor, int(math.floor((w_bar / beta) / factor)) * factor) + + return h_bar, w_bar + + +class GlmImageImageProcessor(Qwen2VLImageProcessor): + pass + + +class GlmImageImageProcessorFast(Qwen2VLImageProcessorFast): + pass + + +class GlmImageImagesKwargs(ImagesKwargs, total=False): + """ + target_h (`int`): + Height of the target image to be generated. + target_w (`int`): + Width of the target image to be generated. + """ + + target_h: int + target_w: int + + +class GlmImageProcessorKwargs(Qwen2VLProcessorKwargs): + images_kwargs: GlmImageImagesKwargs + + _defaults = { + "text_kwargs": { + "padding": False, + "return_mm_token_type_ids": False, + }, + "images_kwargs": { + "target_h": 1152, + "target_w": 768, + }, + } + + +class GlmImageProcessor(ProcessorMixin): + r""" + Constructs a GLM-Image processor which wraps a GLM-Image image processor and a GLM-Image tokenizer into a single processor. + [`~GlmImageProcessor.__call__`] and [`~GlmImageProcessor.decode`] for more information. + Args: + image_processor ([`GlmImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`PreTrainedTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = tokenizer.image_token + self.grid_bos_token = tokenizer.grid_bos_token + self.grid_eos_token = tokenizer.grid_eos_token + self.bos_token = tokenizer.bos_token + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput | None = None, + text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, + **kwargs: Unpack[GlmImageProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode + the text. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + GlmImageProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + target_h = output_kwargs["images_kwargs"].pop("target_h", None) + target_w = output_kwargs["images_kwargs"].pop("target_w", None) + is_text_to_image = images is None + + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if len(text) > 1: + raise ValueError("The model does not support batch size > 1") + + text = text.copy() # below lines change text in-place + if not is_text_to_image: + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + grid = image_grid_thw[index] + num_image_tokens = int(grid[1] * grid[2]) + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + text[0], token_h, token_w, prev_h, prev_w = self._build_prompt_with_target_shape( + text[0], height=target_h, width=target_w, is_text_to_image=is_text_to_image + ) + image_inputs["image_grid_thw"] = self._build_target_image_grid_thw( + token_h=token_h, + token_w=token_w, + prev_token_h=prev_h, + prev_token_w=prev_w, + image_grid_thw=image_grid_thw if not is_text_to_image else None, + ) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + def _build_prompt_with_target_shape( + self, + prompt: str, + height: int, + width: int, + is_text_to_image: bool, + ) -> tuple[str, int, int, int, int]: + factor = 32 + height = (height // factor) * factor + width = (width // factor) * factor + token_h = height // factor + token_w = width // factor + ratio = token_h / token_w + prev_token_h = int(math.sqrt(ratio) * (factor // 2)) + prev_token_w = int(math.sqrt(1 / ratio) * (factor // 2)) + + if is_text_to_image: + expanded_prompt = f"{prompt}{self.grid_bos_token}{token_h} {token_w}{self.grid_eos_token}{self.grid_bos_token}{prev_token_h} {prev_token_w}{self.grid_eos_token}{self.bos_token}" + else: + expanded_prompt = f"{prompt}{self.grid_bos_token}{token_h} {token_w}{self.grid_eos_token}{self.bos_token}" + + return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w + + @staticmethod + def _build_target_image_grid_thw( + token_h: int, + token_w: int, + prev_token_h: int, + prev_token_w: int, + image_grid_thw: None, + ): + if image_grid_thw is None: + return torch.tensor( + [ + [1, token_h, token_w], + [1, prev_token_h, prev_token_w], + ], + ) + else: + return torch.cat( + [image_grid_thw, torch.tensor([[1, token_h, token_w]], device=image_grid_thw.device)], dim=0 + ) + + +__all__ = [ + "GlmImageVQVAEConfig", + "GlmImageVisionConfig", + "GlmImageTextConfig", + "GlmImageConfig", + "GlmImagePreTrainedModel", + "GlmImageVQVAE", + "GlmImageVisionModel", + "GlmImageTextModel", + "GlmImageModel", + "GlmImageForConditionalGeneration", + "GlmImageImageProcessor", + "GlmImageImageProcessorFast", + "GlmImageProcessor", +] diff --git a/src/transformers/models/glm_image/processing_glm_image.py b/src/transformers/models/glm_image/processing_glm_image.py new file mode 100644 index 000000000000..8cbfabf4ecc2 --- /dev/null +++ b/src/transformers/models/glm_image/processing_glm_image.py @@ -0,0 +1,217 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_glm_image.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_torch_available + + +if is_torch_available(): + import torch + + +class GlmImageImagesKwargs(ImagesKwargs, total=False): + """ + target_h (`int`): + Height of the target image to be generated. + target_w (`int`): + Width of the target image to be generated. + """ + + target_h: int + target_w: int + + +class GlmImageProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + "return_mm_token_type_ids": False, + }, + "images_kwargs": { + "target_h": 1152, + "target_w": 768, + }, + } + images_kwargs: GlmImageImagesKwargs + + +class GlmImageProcessor(ProcessorMixin): + r""" + Constructs a GLM-Image processor which wraps a GLM-Image image processor and a GLM-Image tokenizer into a single processor. + [`~GlmImageProcessor.__call__`] and [`~GlmImageProcessor.decode`] for more information. + Args: + image_processor ([`GlmImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`PreTrainedTokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = tokenizer.image_token + self.grid_bos_token = tokenizer.grid_bos_token + self.grid_eos_token = tokenizer.grid_eos_token + self.bos_token = tokenizer.bos_token + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput | None = None, + text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None, + **kwargs: Unpack[GlmImageProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode + the text. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + GlmImageProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + target_h = output_kwargs["images_kwargs"].pop("target_h", None) + target_w = output_kwargs["images_kwargs"].pop("target_w", None) + is_text_to_image = images is None + + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if len(text) > 1: + raise ValueError("The model does not support batch size > 1") + + text = text.copy() # below lines change text in-place + if not is_text_to_image: + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + grid = image_grid_thw[index] + num_image_tokens = int(grid[1] * grid[2]) + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + text[0], token_h, token_w, prev_h, prev_w = self._build_prompt_with_target_shape( + text[0], height=target_h, width=target_w, is_text_to_image=is_text_to_image + ) + image_inputs["image_grid_thw"] = self._build_target_image_grid_thw( + token_h=token_h, + token_w=token_w, + prev_token_h=prev_h, + prev_token_w=prev_w, + image_grid_thw=image_grid_thw if not is_text_to_image else None, + ) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + def _build_prompt_with_target_shape( + self, + prompt: str, + height: int, + width: int, + is_text_to_image: bool, + ) -> tuple[str, int, int, int, int]: + factor = 32 + height = (height // factor) * factor + width = (width // factor) * factor + token_h = height // factor + token_w = width // factor + ratio = token_h / token_w + prev_token_h = int(math.sqrt(ratio) * (factor // 2)) + prev_token_w = int(math.sqrt(1 / ratio) * (factor // 2)) + + if is_text_to_image: + expanded_prompt = f"{prompt}{self.grid_bos_token}{token_h} {token_w}{self.grid_eos_token}{self.grid_bos_token}{prev_token_h} {prev_token_w}{self.grid_eos_token}{self.bos_token}" + else: + expanded_prompt = f"{prompt}{self.grid_bos_token}{token_h} {token_w}{self.grid_eos_token}{self.bos_token}" + + return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w + + @staticmethod + def _build_target_image_grid_thw( + token_h: int, + token_w: int, + prev_token_h: int, + prev_token_w: int, + image_grid_thw: None, + ): + if image_grid_thw is None: + return torch.tensor( + [ + [1, token_h, token_w], + [1, prev_token_h, prev_token_w], + ], + ) + else: + return torch.cat( + [image_grid_thw, torch.tensor([[1, token_h, token_w]], device=image_grid_thw.device)], dim=0 + ) + + +__all__ = ["GlmImageProcessor"] diff --git a/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py b/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py index 56c01c6b2fb0..698ceebb22c0 100644 --- a/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py +++ b/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py @@ -22,7 +22,6 @@ from transformers import Cache, is_torch_available from transformers.testing_utils import ( cleanup, - require_read_token, require_torch, require_torch_accelerator, slow, @@ -82,7 +81,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l @require_torch_accelerator -@require_read_token @slow class Glm4MoeIntegrationTest(unittest.TestCase): def tearDown(self): @@ -92,7 +90,6 @@ def tearDown(self): @slow @require_torch_accelerator @pytest.mark.torch_compile_test - @require_read_token def test_compile_static_cache(self): # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 # work as intended. See https://github.com/pytorch/pytorch/issues/121943 diff --git a/tests/models/glm_image/__init__.py b/tests/models/glm_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/glm_image/test_modeling_glm_image.py b/tests/models/glm_image/test_modeling_glm_image.py new file mode 100644 index 000000000000..5aeb255d25fd --- /dev/null +++ b/tests/models/glm_image/test_modeling_glm_image.py @@ -0,0 +1,484 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch GLM-Image model.""" + +import unittest + +from parameterized import parameterized + +from transformers import ( + AutoProcessor, + GlmImageConfig, + GlmImageForConditionalGeneration, + GlmImageModel, + is_torch_available, +) +from transformers.models.auto import get_values +from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES +from transformers.testing_utils import ( + cleanup, + require_flash_attn, + require_torch, + require_torch_accelerator, + run_first, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, + ModelTesterMixin, + floats_tensor, + ids_tensor, +) + + +if is_torch_available(): + import torch + + +class GlmImageVisionText2TextModelTester: + def __init__( + self, + parent, + batch_size=1, + seq_length=7, + num_channels=3, + ignore_index=-100, + image_size=128, + image_start_token_id=85, + image_end_token_id=86, + image_token_id=7, + is_training=True, + text_config={ + "vocab_size": 99, + "vision_vocab_size": 99, + "hidden_size": 16, + "intermediate_size": 22, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "output_channels": 64, + "hidden_act": "silu", + "max_position_embeddings": 512, + "rope_parameters": {"type": "default", "mrope_section": [2, 1, 1]}, + "rope_theta": 10000, + "tie_word_embeddings": True, + "bos_token_id": 0, + "eos_token_id": 0, + "pad_token_id": 0, + "n_routed_experts": 8, + "n_shared_experts": 1, + "n_group": 1, + "topk_group": 1, + "num_experts_per_tok": 8, + }, + vision_config={ + "depth": 2, + "hidden_act": "gelu", + "hidden_size": 32, + "out_hidden_size": 16, + "intermediate_size": 22, + "patch_size": 16, + "spatial_merge_size": 1, + "temporal_patch_size": 1, + }, + vq_config={ + "embed_dim": 48, + "in_channels": 3, + "initializer_range": 0.02, + "latent_channels": 32, + "num_embeddings": 32, + }, + ): + self.parent = parent + self.ignore_index = ignore_index + self.bos_token_id = text_config["bos_token_id"] + self.eos_token_id = text_config["eos_token_id"] + self.pad_token_id = text_config["pad_token_id"] + self.image_start_token_id = image_start_token_id + self.image_end_token_id = image_end_token_id + self.image_token_id = image_token_id + self.text_config = text_config + self.vision_config = vision_config + self.vq_config = vq_config + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.is_training = is_training + self.hidden_size = text_config["hidden_size"] + self.num_hidden_layers = text_config["num_hidden_layers"] + self.num_attention_heads = text_config["num_attention_heads"] + self.vision_vocab_size = text_config["vision_vocab_size"] + self.vocab_size = text_config["vocab_size"] + self.num_image_tokens = 64 + self.seq_length = seq_length + self.num_image_tokens + self.n_routed_experts = text_config["n_routed_experts"] + self.n_shared_experts = text_config["n_shared_experts"] + self.num_experts_per_tok = text_config["num_experts_per_tok"] + self.n_group = text_config["n_group"] + self.topk_group = text_config["topk_group"] + + def get_config(self): + return GlmImageConfig( + text_config=self.text_config, + vision_config=self.vision_config, + vq_config=self.vq_config, + image_token_id=self.image_token_id, + image_start_token_id=self.image_start_token_id, + image_end_token_id=self.image_end_token_id, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + patch_size = config.vision_config.patch_size + temporal_patch_size = config.vision_config.temporal_patch_size + pixel_values = floats_tensor( + [ + self.batch_size * (self.image_size**2) // (patch_size**2), + self.num_channels * (patch_size**2) * temporal_patch_size, + ] + ) + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[input_ids == self.image_start_token_id] = self.pad_token_id + input_ids[input_ids == self.image_end_token_id] = self.pad_token_id + + input_ids[:, 0] = self.image_start_token_id + input_ids[:, 1 : 1 + self.num_image_tokens] = self.image_token_id + input_ids[:, 1 + self.num_image_tokens] = self.image_end_token_id + patch_size = config.vision_config.patch_size + patches_per_side = self.image_size // patch_size + + # Key fix: image_grid_thw should have batch_size rows for input images + # plus 1 extra row that will be skipped by model's [:-1] slicing + inputs_dict = { + "pixel_values": pixel_values, + "image_grid_thw": torch.tensor( + [[1, patches_per_side, patches_per_side]] * self.batch_size + + [[1, patches_per_side, patches_per_side]], # Extra row for model's [:-1] + device=torch_device, + ), + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class GlmImageModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (GlmImageModel, GlmImageForConditionalGeneration) if is_torch_available() else () + + model_split_percents = [0.7, 0.9] # model too big to split at 0.5 + _is_composite = True + + def setUp(self): + self.model_tester = GlmImageVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=GlmImageConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + # GlmImage has images shaped as (bs*patch_len, dim) so we can't slice to batches in generate + def prepare_config_and_inputs_for_generate(self, batch_size=2): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # We don't want a few model inputs in our model input dictionary for generation tests + input_keys_to_ignore = [ + # we don't want to mask attention heads + # we don't want encoder-decoder models to start from filled decoder ids + "decoder_input_ids", + "decoder_attention_mask", + # we'll set cache use in each test differently + "use_cache", + # Ignore labels if it is in the input dict + "labels", + # model-specific exceptions should overload/overwrite this function + ] + + # The diff from the general `prepare_config_and_inputs_for_generate` lies here + patch_size = config.vision_config.patch_size + filtered_image_length = batch_size * (self.model_tester.image_size**2) // (patch_size**2) + filtered_inputs_dict = { + k: v[:batch_size, ...] if isinstance(v, torch.Tensor) else v + for k, v in inputs_dict.items() + if k not in input_keys_to_ignore + } + filtered_inputs_dict["pixel_values"] = inputs_dict["pixel_values"][:filtered_image_length] + filtered_inputs_dict["image_grid_thw"] = inputs_dict["image_grid_thw"][: batch_size + 1] + + # It is important set `eos_token_id` to `None` to avoid early stopping (would break for length-based checks) + text_gen_config = config.get_text_config(decoder=True) + if text_gen_config.eos_token_id is not None and text_gen_config.pad_token_id is None: + text_gen_config.pad_token_id = ( + text_gen_config.eos_token_id + if isinstance(text_gen_config.eos_token_id, int) + else text_gen_config.eos_token_id[0] + ) + text_gen_config.eos_token_id = None + text_gen_config.forced_eos_token_id = None + + return config, filtered_inputs_dict + + def test_training(self): + # Model isn't in any auto-mapping so we need to build labels manually + if not self.model_tester.is_training: + self.skipTest(reason="ModelTester is not configured to run training tests") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + if model_class.__name__ in [ + *get_values(MODEL_MAPPING_NAMES), + ]: + continue + + model = model_class(config) + model.to(torch_device) + model.train() + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + loss = model(**inputs_dict).loss + loss.backward() + + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + @unittest.skip("Needs special input preparation. Not important test for model, skip for now") + def test_eager_matches_sdpa_inference( + self, name, dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + pass + + @unittest.skip(reason="No available kernels - not supported") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @unittest.skip(reason="Size mismatch") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip("Error with compilation") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @parameterized.expand([("greedy", 1), ("beam search", 2)]) + @unittest.skip(reason="GLM-Image does not use inputs_embeds") + def test_generate_from_inputs_embeds(self, _, num_beams): + pass + + @unittest.skip(reason="GLM-Image input embed is compare with inputs_ids and image_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="GLM-Image does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="GLM-Image can't do text-only inference") + def test_generate_from_random_inputs_embeds(self): + pass + + @unittest.skip(reason="GLM-Image can't do and does not need assisted generation. Not worth fixing!") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip(reason="GLM-Image can't do and does not need assisted generation. Not worth fixing!") + def test_prompt_lookup_decoding_matches_greedy_searc(self): + pass + + @parameterized.expand([("random",), ("same",)]) + @unittest.skip(reason="GLM-Image can't do and does not need assisted generation. Not worth fixing!") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip(reason="GlmImageVisionModel does not support training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="GlmImageVision does not support output_hidden_states test") + def test_model_outputs_equivalence(self): + pass + + @unittest.skip( + reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="GlmImageVisionModel does not support training") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="GlmImage needs special input preparation to pass this test") + def test_generate_compile_model_forward_fullgraph(self): + pass + + +@require_torch +@slow +class GlmImageIntegrationTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = None + + @classmethod + def get_model(cls): + if cls.model is None: + cls.model = GlmImageForConditionalGeneration.from_pretrained( + "zai-org/GLM-4.5V", dtype="auto", device_map="auto" + ) + return cls.model + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "model"): + del cls.model + cleanup(torch_device, gc_collect=True) + + def setUp(self): + cleanup(torch_device, gc_collect=True) + self.processor = AutoProcessor.from_pretrained( + "zai-org/GLM-4.5V", size={"shortest_edge": 10800, "longest_edge": 10800} + ) + self.message = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg", + }, + {"type": "text", "text": "What kind of dog is this?"}, + ], + } + ] + self.message2 = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png", + }, + {"type": "text", "text": "What kind of dog is this?"}, + ], + } + ] + self.message_wo_image = [ + {"role": "user", "content": [{"type": "text", "text": "Who are you?"}]}, + ] + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def test_small_model_integration_test(self): + inputs = self.processor.apply_chat_template( + self.message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" + ) + expected_input_ids = [151331, 151333, 151336, 198, 151339, 151363, 151363, 151363, 151363, 151363, 151363, + 151340, 3838, 3093, 315, 5562, 374] # fmt: skip + assert expected_input_ids == inputs.input_ids[0].tolist()[:17] + + expected_pixel_slice = torch.tensor( + [ + [-0.1134, -0.4492, -0.8580], + [-0.6244, -1.1645, -0.7120], + [-0.3324, -0.7996, -0.7120], + [0.2077, 0.2223, 0.4121], + [0.4413, 0.1931, 0.4559], + [0.5873, 0.3099, 0.4851], + ], + dtype=torch.float32, + device="cpu", + ) + torch.testing.assert_close(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=1e-4, rtol=1e-4) + + def test_small_model_integration_test_batch(self): + model = self.get_model() + batch_messages = [self.message, self.message2, self.message_wo_image] + inputs = self.processor.apply_chat_template( + batch_messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + padding=True, + ).to(torch_device) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=10) + + EXPECTED_DECODED_TEXT = [ + "\nWhat kind of dog is this?\nGot it, let's try to figure out", + "\nWhat kind of dog is this?\nGot it, let's see. The user", + '\nWho are you?\nThe user is asking "Who are you?"' + ] # fmt: skip + decoded = self.processor.batch_decode(output, skip_special_tokens=True) + decoded = [x.replace("<|image|>", "") for x in decoded] + self.assertEqual( + decoded, + EXPECTED_DECODED_TEXT, + ) + + @run_first + @require_flash_attn + @require_torch_accelerator + def test_small_model_integration_test_batch_flashatt2(self): + model = GlmImageForConditionalGeneration.from_pretrained( + "zai-org/GLM-4.5V", + dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map="auto", + ) + batch_messages = [self.message, self.message2, self.message_wo_image] + inputs = self.processor.apply_chat_template( + batch_messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + padding=True, + ).to(torch_device) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=3) + + EXPECTED_DECODED_TEXT = [ + "\nWhat kind of dog is this?\nGot it", + "\nWhat kind of dog is this?\nGot it", + "\nWho are you?\nThe user", + ] # fmt: skip + decoded = self.processor.batch_decode(output, skip_special_tokens=True) + decoded = [x.replace("<|image|>", "") for x in decoded] + self.assertEqual( + decoded, + EXPECTED_DECODED_TEXT, + ) diff --git a/tests/models/glm_image/test_processor_glm_image.py b/tests/models/glm_image/test_processor_glm_image.py new file mode 100644 index 000000000000..0475f92a8caa --- /dev/null +++ b/tests/models/glm_image/test_processor_glm_image.py @@ -0,0 +1,155 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers.testing_utils import require_av, require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import GlmImageProcessor + +if is_torch_available(): + import torch + + +@require_vision +@require_torch +@unittest.skip(reason="Model not released yet") +class GlmImageProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = GlmImageProcessor + model_id = "zai-org/GLM-Image" + + @classmethod + def _setup_test_attributes(cls, processor): + cls.image_token = processor.image_token + + @classmethod + def _setup_from_pretrained(cls, model_id, **kwargs): + return super()._setup_from_pretrained( + model_id, + do_sample_frames=False, + patch_size=4, + size={"shortest_edge": 12 * 12, "longest_edge": 18 * 18}, + **kwargs, + ) + + @require_torch + @require_av + def _test_apply_chat_template( + self, + modality: str, + batch_size: int, + return_tensors: str, + input_name: str, + processor_name: str, + input_data: list[str], + ): + processor = self.get_processor() + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + if processor_name not in self.processor_class.get_attributes(): + self.skipTest(f"{processor_name} attribute not present in {self.processor_class}") + + batch_messages = [ + [ + { + "role": "user", + "content": [{"type": "text", "text": "Describe this."}], + }, + ] + ] * batch_size + + # Test that jinja can be applied + formatted_prompt = processor.apply_chat_template(batch_messages, add_generation_prompt=True, tokenize=False) + self.assertEqual(len(formatted_prompt), batch_size) + + # Test that tokenizing with template and directly with `self.tokenizer` gives same output + formatted_prompt_tokenized = processor.apply_chat_template( + batch_messages, add_generation_prompt=True, tokenize=True, return_tensors=return_tensors + ) + add_special_tokens = True + if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token): + add_special_tokens = False + tok_output = processor.tokenizer( + formatted_prompt, return_tensors=return_tensors, add_special_tokens=add_special_tokens + ) + expected_output = tok_output.input_ids + self.assertListEqual(expected_output.tolist(), formatted_prompt_tokenized.tolist()) + + # Test that kwargs passed to processor's `__call__` are actually used + tokenized_prompt_100 = processor.apply_chat_template( + batch_messages, + add_generation_prompt=True, + tokenize=True, + padding="max_length", + truncation=True, + return_tensors=return_tensors, + max_length=100, + ) + self.assertEqual(len(tokenized_prompt_100[0]), 100) + + # Test that `return_dict=True` returns text related inputs in the dict + out_dict_text = processor.apply_chat_template( + batch_messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors=return_tensors, + ) + self.assertTrue(all(key in out_dict_text for key in ["input_ids", "attention_mask"])) + self.assertEqual(len(out_dict_text["input_ids"]), batch_size) + self.assertEqual(len(out_dict_text["attention_mask"]), batch_size) + + # Test that with modality URLs and `return_dict=True`, we get modality inputs in the dict + for idx, url in enumerate(input_data[:batch_size]): + batch_messages[idx][0]["content"] = [batch_messages[idx][0]["content"][0], {"type": modality, "url": url}] + + out_dict = processor.apply_chat_template( + batch_messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors=return_tensors, + fps=2 + if isinstance(input_data[0], str) + else None, # by default no more than 2 frames per second, otherwise too slow + ) + input_name = getattr(self, input_name) + self.assertTrue(input_name in out_dict) + self.assertEqual(len(out_dict["input_ids"]), batch_size) + self.assertEqual(len(out_dict["attention_mask"]), batch_size) + + mm_len = batch_size * 4 + self.assertEqual(len(out_dict[input_name]), mm_len) + + return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list} + for k in out_dict: + self.assertIsInstance(out_dict[k], return_tensor_to_type[return_tensors]) + + def test_model_input_names(self): + processor = self.get_processor() + + text = self.prepare_text_inputs(modalities=["image"]) + image_input = self.prepare_image_inputs() + inputs_dict = {"text": text, "images": image_input} + inputs = processor(**inputs_dict, return_tensors="pt", do_sample_frames=False) + + self.assertSetEqual(set(inputs.keys()), set(processor.model_input_names)) diff --git a/utils/check_repo.py b/utils/check_repo.py index 7e0819a245c4..699fbf24d0f7 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -100,6 +100,7 @@ "Phi4MultimodalVisionModel", "Glm4vVisionModel", "Glm4vMoeVisionModel", + "GlmImageVisionModel", "EvollaSaProtPreTrainedModel", "BltLocalEncoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. "BltLocalDecoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. @@ -125,6 +126,7 @@ "ErnieMForInformationExtraction", "FastSpeech2ConformerHifiGan", # Already tested by SpeechT5HifiGan (# Copied from) "FastSpeech2ConformerWithHifiGan", # Built with two smaller (tested) models. + "GlmImageVQVAE", # Building part of bigger (tested) model. "GraphormerDecoderHead", # Building part of bigger (tested) model. "JukeboxVQVAE", # Building part of bigger (tested) model. "JukeboxPrior", # Building part of bigger (tested) model. @@ -192,6 +194,7 @@ "Emu3TextModel", # Building part of bigger (tested) model "Glm4vTextModel", # Building part of bigger (tested) model "Glm4vMoeTextModel", # Building part of bigger (tested) model + "GlmImageTextModel", # Building part of bigger (tested) model "Qwen2VLTextModel", # Building part of bigger (tested) model "Qwen2_5_VLTextModel", # Building part of bigger (tested) model "InternVLVisionModel", # Building part of bigger (tested) model @@ -310,6 +313,7 @@ "FlavaTextModel", "FlavaImageModel", "FlavaMultimodalModel", + "GlmImageForConditionalGeneration", "GPT2DoubleHeadsModel", "GPTSw3DoubleHeadsModel", "InstructBlipVisionModel", diff --git a/utils/models_to_deprecate.py b/utils/models_to_deprecate.py index 24313f65419d..9d6c38cda388 100644 --- a/utils/models_to_deprecate.py +++ b/utils/models_to_deprecate.py @@ -93,6 +93,7 @@ "gemma3n": ["gemma3n_audio", "gemma3n_text", "gemma3n_vision"], "gpt2": ["cpm", "dialogpt", "gpt-sw3", "megatron_gpt2"], "glm4v_moe": ["glm4v_moe_text", "glm4v_moe_vision"], + "glm4_image": ["glm4_image_text", "glm4_image_vision"], "glm4v": ["glm4v_text", "glm4v_vision"], "idefics3": ["idefics3_vision"], "internvl": ["internvl_vision"],