diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index b85b710bf4a1..943496db6cc3 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -525,6 +525,8 @@
title: glm4
- local: model_doc/glm4_moe
title: glm4_moe
+ - local: model_doc/glm_image
+ title: GlmImage
- local: model_doc/openai-gpt
title: GPT
- local: model_doc/gpt_neo
diff --git a/docs/source/en/model_doc/glm46v.md b/docs/source/en/model_doc/glm46v.md
index ab62530a438a..bc5cbdc4ee43 100644
--- a/docs/source/en/model_doc/glm46v.md
+++ b/docs/source/en/model_doc/glm46v.md
@@ -1,22 +1,55 @@
-
-*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-15.*
+*This model was released on 2025-12-09 and added to Hugging Face Transformers on 2025-11-15.*
# GLM-4.6V
+## Overview
+
+The GLM-V model was proposed in [GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning](https://huggingface.co/papers/2507.01006v6).
+
+The abstract from the paper is the following:
+
+> *We present GLM-4.1V-Thinking, GLM-4.5V, and GLM-4.6V, a family of vision-language models (VLMs) designed to advance
+general-purpose multimodal understanding and reasoning. In this report, we share our key findings in the development of
+the reasoning-centric training framework. We first develop a capable vision foundation model with significant potential
+through large-scale pre-training, which arguably sets the upper bound for the final performance. We then propose
+Reinforcement Learning with Curriculum Sampling (RLCS) to unlock the full potential of the model, leading to
+comprehensive capability enhancement across a diverse range of tasks, including STEM problem solving, video
+understanding, content recognition, coding, grounding, GUI-based agents, and long document interpretation. In a
+comprehensive evaluation across 42 public benchmarks, GLM-4.5V achieves state-of-the-art performance on nearly all tasks
+among open-source models of similar size, and demonstrates competitive or even superior results compared to
+closed-source models such as Gemini-2.5-Flash on challenging tasks including Coding and GUI Agents. Meanwhile, the
+smaller GLM-4.1V-9B-Thinking remains highly competitive-achieving superior results to the much larger Qwen2.5-VL-72B on
+29 benchmarks. We open-source both GLM-4.1V-9B-Thinking and GLM-4.5V. We further introduce the GLM-4.6V series,
+open-source multimodal models with native tool use and a 128K context window. A brief overview is available at this
+https URL. Code, models and more information are released at https://github.com/zai-org/GLM-V*
+
+## Support Model
+
+This Model Processor support these model of zai-org:
+
++ [GLM-4.6V-Flash](https://huggingface.co/zai-org/GLM-4.6V-Flash)
++ [GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
+
+This model was contributed by [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) and [Yuxuan Zhang](https://huggingface.co/ZHANGYUXUAN-zR).
+
## Glm46VConfig
[[autodoc]] Glm46VConfig
diff --git a/docs/source/en/model_doc/glm4v.md b/docs/source/en/model_doc/glm4v.md
index 38e43ab5c5c8..206287f9d576 100644
--- a/docs/source/en/model_doc/glm4v.md
+++ b/docs/source/en/model_doc/glm4v.md
@@ -1,49 +1,61 @@
-
*This model was released on 2025-07-01 and added to Hugging Face Transformers on 2025-06-25.*
-
-
-# GLM-4.1V
+# GLM-V
## Overview
-**GLM-4.1V-9B-Thinking** is a bilingual vision-language model optimized for reasoning, built on GLM-4-9B. It introduces
-a "thinking paradigm" with reinforcement learning, achieving state-of-the-art results among 10B-class models and
-rivaling 72B-scale models. It supports 64k context, 4K resolution, and arbitrary aspect ratios, with an open-source base
-model for further research. You can check our paper [here](https://huggingface.co/papers/2507.01006). and below is a abstract.
-
-*We present GLM-4.1V-Thinking, a vision-language model (VLM) designed to advance general-purpose multimodal understanding
-and reasoning. In this report, we share our key findings in the development of the reasoning-centric training framework.
-We first develop a capable vision foundation model with significant potential through large-scale pre-training, which
-arguably sets the upper bound for the final performance. We then propose Reinforcement Learning with Curriculum
-Sampling (RLCS) to unlock the full potential of the model, leading to comprehensive capability enhancement across a
-diverse range of tasks, including STEM problem solving, video understanding, content recognition, coding, grounding,
-GUI-based agents, and long document understanding. We open-source GLM-4.1V-9B-Thinking, which achieves state-of-the-art
-performance among models of comparable size. In a comprehensive evaluation across 28 public benchmarks, our model
-outperforms Qwen2.5-VL-7B on nearly all tasks and achieves comparable or even superior performance on 18 benchmarks
-relative to the significantly larger Qwen2.5-VL-72B. Notably, GLM-4.1V-9B-Thinking also demonstrates competitive or
-superior performance compared to closed-source models such as GPT-4o on challenging tasks including long document
-understanding and STEM reasoning, further underscoring its strong capabilities. Code, models and more information
-are released at https://github.com/THUDM/GLM-4.1V-Thinking.*
+The GLM-V model was proposed in [GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning](https://huggingface.co/papers/2507.01006v6).
+
+The abstract from the paper is the following:
+
+> *We present GLM-4.1V-Thinking, GLM-4.5V, and GLM-4.6V, a family of vision-language models (VLMs) designed to advance
+general-purpose multimodal understanding and reasoning. In this report, we share our key findings in the development of
+the reasoning-centric training framework. We first develop a capable vision foundation model with significant potential
+through large-scale pre-training, which arguably sets the upper bound for the final performance. We then propose
+Reinforcement Learning with Curriculum Sampling (RLCS) to unlock the full potential of the model, leading to
+comprehensive capability enhancement across a diverse range of tasks, including STEM problem solving, video
+understanding, content recognition, coding, grounding, GUI-based agents, and long document interpretation. In a
+comprehensive evaluation across 42 public benchmarks, GLM-4.5V achieves state-of-the-art performance on nearly all tasks
+among open-source models of similar size, and demonstrates competitive or even superior results compared to
+closed-source models such as Gemini-2.5-Flash on challenging tasks including Coding and GUI Agents. Meanwhile, the
+smaller GLM-4.1V-9B-Thinking remains highly competitive-achieving superior results to the much larger Qwen2.5-VL-72B on
+29 benchmarks. We open-source both GLM-4.1V-9B-Thinking and GLM-4.5V. We further introduce the GLM-4.6V series,
+open-source multimodal models with native tool use and a 128K context window. A brief overview is available at this
+https URL. Code, models and more information are released at https://github.com/zai-org/GLM-V*
+
+## Support Model
+
+This Model type support these model of zai-org:
+
++ [GLM-4.1V-9B-Base](https://huggingface.co/zai-org/GLM-4.1V-9B-Base)
++ [GLM-4.1V-9B-Thinking](https://huggingface.co/zai-org/GLM-4.1V-9B-Thinking)
++ [GLM-4.6V-Flash](https://huggingface.co/zai-org/GLM-4.6V-Flash)
++ [AutoGLM-Phone-9B](https://huggingface.co/zai-org/AutoGLM-Phone-9B)
++ [AutoGLM-Phone-9B-Multilingual](https://huggingface.co/zai-org/AutoGLM-Phone-9B-Multilingual)
++ [Glyph](https://huggingface.co/zai-org/Glyph)
++ [WebVIA-Agent](https://huggingface.co/zai-org/WebVIA-Agent)
++ [UI2Code_N](https://huggingface.co/zai-org/UI2Code_N)
+
+This model was contributed by [Raushan Turganbay](https://huggingface.co/RaushanTurganbay)
+and [Yuxuan Zhang](https://huggingface.co/ZHANGYUXUAN-zR).
## Usage
@@ -55,6 +67,7 @@ The example below demonstrates how to generate text based on an image with [`Pip
```py
import torch
from transformers import pipeline
+
pipe = pipeline(
task="image-text-to-text",
model="THUDM/GLM-4.1V-9B-Thinking",
@@ -69,11 +82,11 @@ messages = [
"type": "image",
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
},
- { "type": "text", "text": "Describe this image."},
+ {"type": "text", "text": "Describe this image."},
]
}
]
-pipe(text=messages,max_new_tokens=20, return_full_text=False)
+pipe(text=messages, max_new_tokens=20, return_full_text=False)
```
@@ -92,15 +105,15 @@ model = Glm4vForConditionalGeneration.from_pretrained(
processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
messages = [
{
- "role":"user",
- "content":[
+ "role": "user",
+ "content": [
{
- "type":"image",
+ "type": "image",
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
},
{
- "type":"text",
- "text":"Describe this image."
+ "type": "text",
+ "text": "Describe this image."
}
]
}
@@ -117,10 +130,10 @@ inputs = processor.apply_chat_template(
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
```
@@ -160,9 +173,10 @@ messages = [
],
}
]
-inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", padding=True).to(model.device)
+inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True,
+ return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=1.0)
-output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True)
+output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(output_text)
```
@@ -181,17 +195,17 @@ print(output_text)
## Glm4vImageProcessor
[[autodoc]] Glm4vImageProcessor
- - preprocess
+- preprocess
## Glm4vVideoProcessor
[[autodoc]] Glm4vVideoProcessor
- - preprocess
+- preprocess
## Glm4vImageProcessorFast
[[autodoc]] Glm4vImageProcessorFast
- - preprocess
+- preprocess
## Glm4vProcessor
@@ -201,19 +215,19 @@ print(output_text)
## Glm4vVisionModel
[[autodoc]] Glm4vVisionModel
- - forward
+- forward
## Glm4vTextModel
[[autodoc]] Glm4vTextModel
- - forward
+- forward
## Glm4vModel
[[autodoc]] Glm4vModel
- - forward
+- forward
## Glm4vForConditionalGeneration
[[autodoc]] Glm4vForConditionalGeneration
- - forward
+- forward
diff --git a/docs/source/en/model_doc/glm4v_moe.md b/docs/source/en/model_doc/glm4v_moe.md
index ffb6a3d85cb2..a67906ea001e 100644
--- a/docs/source/en/model_doc/glm4v_moe.md
+++ b/docs/source/en/model_doc/glm4v_moe.md
@@ -1,48 +1,54 @@
-
-*This model was released on 2025-07-28 and added to Hugging Face Transformers on 2025-08-08.*
+⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.
-
+-->
+*This model was released on 2025-08-12 and added to Hugging Face Transformers on 2025-08-08.*
-# Glm4vMoeMoe
+# Glm4vMoe
## Overview
-Vision-language models (VLMs) have become a key cornerstone of intelligent systems. As real-world AI tasks grow increasingly complex, VLMs urgently need to enhance reasoning capabilities beyond basic multimodal perception — improving accuracy, comprehensiveness, and intelligence — to enable complex problem solving, long-context understanding, and multimodal agents.
+The GLM-V model was proposed in [GLM-4.5V and GLM-4.1V-Thinking: Towards Versatile Multimodal Reasoning with Scalable Reinforcement Learning](https://huggingface.co/papers/2507.01006v6).
-Through our open-source work, we aim to explore the technological frontier together with the community while empowering more developers to create exciting and innovative applications.
+The abstract from the paper is the following:
-[GLM-4.5V](https://huggingface.co/papers/2508.06471) ([Github repo](https://github.com/zai-org/GLM-V)) is based on ZhipuAI’s next-generation flagship text foundation model GLM-4.5-Air (106B parameters, 12B active). It continues the technical approach of [GLM-4.1V-Thinking](https://huggingface.co/papers/2507.01006), achieving SOTA performance among models of the same scale on 42 public vision-language benchmarks. It covers common tasks such as image, video, and document understanding, as well as GUI agent operations.
+> *We present GLM-4.1V-Thinking, GLM-4.5V, and GLM-4.6V, a family of vision-language models (VLMs) designed to advance
+general-purpose multimodal understanding and reasoning. In this report, we share our key findings in the development of
+the reasoning-centric training framework. We first develop a capable vision foundation model with significant potential
+through large-scale pre-training, which arguably sets the upper bound for the final performance. We then propose
+Reinforcement Learning with Curriculum Sampling (RLCS) to unlock the full potential of the model, leading to
+comprehensive capability enhancement across a diverse range of tasks, including STEM problem solving, video
+understanding, content recognition, coding, grounding, GUI-based agents, and long document interpretation. In a
+comprehensive evaluation across 42 public benchmarks, GLM-4.5V achieves state-of-the-art performance on nearly all tasks
+among open-source models of similar size, and demonstrates competitive or even superior results compared to
+closed-source models such as Gemini-2.5-Flash on challenging tasks including Coding and GUI Agents. Meanwhile, the
+smaller GLM-4.1V-9B-Thinking remains highly competitive-achieving superior results to the much larger Qwen2.5-VL-72B on
+29 benchmarks. We open-source both GLM-4.1V-9B-Thinking and GLM-4.5V. We further introduce the GLM-4.6V series,
+open-source multimodal models with native tool use and a 128K context window. A brief overview is available at this
+https URL. Code, models and more information are released at https://github.com/zai-org/GLM-V*
-
+## Support Model
-Beyond benchmark performance, GLM-4.5V focuses on real-world usability. Through efficient hybrid training, it can handle diverse types of visual content, enabling full-spectrum vision reasoning, including:
+This Model type support these model of zai-org:
-- **Image reasoning** (scene understanding, complex multi-image analysis, spatial recognition)
-- **Video understanding** (long video segmentation and event recognition)
-- **GUI tasks** (screen reading, icon recognition, desktop operation assistance)
-- **Complex chart & long document parsing** (research report analysis, information extraction)
-- **Grounding** (precise visual element localization)
++ [GLM-4.5V](https://huggingface.co/zai-org/GLM-4.5V)
++ [GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
-The model also introduces a **Thinking Mode** switch, allowing users to balance between quick responses and deep reasoning. This switch works the same as in the `GLM-4.5` language model.
+This model was contributed by [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) and [Yuxuan Zhang](https://huggingface.co/ZHANGYUXUAN-zR).
## Glm4vMoeConfig
diff --git a/docs/source/en/model_doc/glm_image.md b/docs/source/en/model_doc/glm_image.md
new file mode 100644
index 000000000000..4b4ff609d2dd
--- /dev/null
+++ b/docs/source/en/model_doc/glm_image.md
@@ -0,0 +1,206 @@
+
+*This model was released on 2026-01-10 and added to Hugging Face Transformers on 2026-01-10.*
+
+# GlmImage
+
+## Overview
+
+GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios.
+
+Model architecture: a hybrid autoregressive + diffusion decoder design、
+
++ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs.
++ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images.
+
+Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality.
+
++ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness.
++ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering.
+
+GLM-Image supports both text-to-image and image-to-image generation within a single model
+
++ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios.
++ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects.
+
++ `GlmImageForConditionalGeneration` is the AR part of GLM-Image model, and for full image generation pipeline, please refer to [here](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/glm_image).
+
+This model was contributed by [Raushan Turganbay](https://huggingface.co/RaushanTurganbay) and [Yuxuan Zhang](https://huggingface.co/ZHANGYUXUAN-zR).
+
+## Usage examples
+
+Using GLM-Image with image input to generate vision token for DIT using.
+
+```python
+from transformers import GlmImageForConditionalGeneration, AutoProcessor
+import torch
+
+model = GlmImageForConditionalGeneration.from_pretrained(
+ pretrained_model_name_or_path="zai-org/GLM-Image/vision_language_encoder",
+ dtype=torch.bfloat16,
+ device_map="cuda:0"
+)
+processor = AutoProcessor.from_pretrained(
+ pretrained_model_name_or_path="zai-org/GLM-Image/processor",
+ use_fast=True
+)
+
+# Case1 T2I
+prompt = "现代美食杂志风格的甜点制作教程图,主题为覆盆子慕斯蛋糕。整体布局干净明亮,分为四个主要区域:顶部左侧是黑色粗体标题“覆盆子慕斯蛋糕制作指南”,右侧搭配光线柔和的成品蛋糕特写照片,蛋糕呈淡粉色,表面点缀新鲜覆盆子与薄荷叶;左下方为配料清单区域,标题“配料”使用简洁字体,下方列有“面粉 150g”“鸡蛋 3个”“细砂糖 120g”“覆盆子果泥 200g”“明胶片 10g”“淡奶油 300ml”“新鲜覆盆子”等配料,每种配料旁配有简约线图标(如面粉袋、鸡蛋、糖罐等);右下方是四个等大的步骤方框,每个方框内含高清微距实拍图及对应操作说明,从上到下依次为:步骤1展示打蛋器打发白色泡沫(对应说明“打发蛋白至干性发泡”),步骤2展示红白相间的混合物被刮刀翻拌(对应说明“轻柔翻拌果泥与面糊”),步骤3展示粉色液体被倒入圆形模具(对应说明“倒入模具并冷藏4小时”),步骤4展示成品蛋糕表面装饰覆盆子与薄荷叶(对应说明“用覆盆子和薄荷装饰”);底部边缘设浅棕色信息条,左侧图标分别代表“准备时间:30分钟”“烹饪时间:20分钟”“份量:8人份”。整体色调以奶油白、淡粉色为主,背景带轻微纸质纹理,图文排版紧凑有序,信息层级分明。"
+target_h, target_w = 1152, 768
+use_reference_images = False
+reference_image_paths = None
+
+# ## Case2
+# prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator."
+# cond_0 = "cond.jpg"
+# target_h, target_w = 1152, 768
+# use_reference_images = True
+# reference_image_paths = [cond_0]
+
+## Case3
+# prompt = "Make the man in the first figure and the child from the second image bow at the same time in a respectful KTV."
+# cond_0 = "cond_0.jpg"
+# cond_1 = "cond_1.jpg"
+# target_h, target_w = 1152, 768
+# use_reference_images = True
+# reference_image_paths = [cond_0, cond_1]
+
+
+def build_messages(prompt, use_reference_images, reference_image_paths):
+ content = []
+ if use_reference_images:
+ for img_path in reference_image_paths:
+ content.append({"type": "image", "url": img_path})
+ content.append({"type": "text", "text": prompt})
+ return [{"role": "user", "content": content}]
+
+
+def compute_generation_params(image_grid_thw, use_reference_images):
+ grid_sizes = []
+ for i in range(image_grid_thw.shape[0]):
+ t, h, w = image_grid_thw[i].tolist()
+ grid_sizes.append(int(h * w))
+
+ target_output_length = grid_sizes[0]
+
+ if use_reference_images:
+ max_new_tokens = grid_sizes[-1] + 1
+ output_start_offset = 0
+ output_length = grid_sizes[-1]
+ else:
+ total_tokens = sum(grid_sizes)
+ max_new_tokens = total_tokens + 1
+ output_start_offset = sum(grid_sizes[1:])
+ output_length = target_output_length
+
+ return max_new_tokens, output_start_offset, output_length
+
+
+messages = build_messages(prompt, use_reference_images, reference_image_paths if use_reference_images else None)
+
+inputs = processor.apply_chat_template(
+ messages,
+ target_h=target_h,
+ target_w=target_w,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+).to(model.device)
+
+image_grid_thw = inputs.get('image_grid_thw')
+print(f"image_grid_thw: {image_grid_thw}")
+
+max_new_tokens, output_start_offset, output_length = compute_generation_params(
+ image_grid_thw, use_reference_images
+)
+
+print(f"use_reference_images: {use_reference_images}")
+print(f"max_new_tokens: {max_new_tokens}")
+print(f"output_start_offset: {output_start_offset}")
+print(f"output_length: {output_length}")
+
+outputs = model.generate(
+ **inputs,
+ max_new_tokens=max_new_tokens,
+ do_sample=True
+)
+
+input_length = inputs["input_ids"].shape[-1]
+output_tokens = outputs[0][input_length:][output_start_offset:output_start_offset + output_length]
+print(f"Input length: {input_length}")
+print(f"Total generated tokens: {outputs[0].shape[-1] - input_length}")
+print(f"Extracted output tokens shape: {output_tokens.shape}")
+print(f"Output tokens: {output_tokens}")
+```
+
+## GlmImageConfig
+
+[[autodoc]] GlmImageConfig
+
+## GlmImageVisionConfig
+
+[[autodoc]] GlmImageVisionConfig
+
+## GlmImageTextConfig
+
+[[autodoc]] GlmImageTextConfig
+
+## GlmImageVQVAEConfig
+
+[[autodoc]] GlmImageVQVAEConfig
+
+## GlmImageImageProcessor
+
+[[autodoc]] GlmImageImageProcessor
+ - preprocess
+
+## GlmImageImageProcessorFast
+
+[[autodoc]] GlmImageImageProcessorFast
+ - preprocess
+
+## GlmImageProcessor
+
+[[autodoc]] GlmImageProcessor
+
+## GlmImageVisionModel
+
+[[autodoc]] GlmImageVisionModel
+ - forward
+
+## GlmImageTextModel
+
+[[autodoc]] GlmImageTextModel
+ - forward
+
+## GlmImageVQVAE
+
+[[autodoc]] GlmImageVQVAE
+ - forward
+
+## GlmImageModel
+
+[[autodoc]] GlmImageModel
+ - forward
+
+## GlmImageForConditionalGeneration
+
+[[autodoc]] GlmImageForConditionalGeneration
+ - forward
diff --git a/docs/source/en/model_doc/glmasr.md b/docs/source/en/model_doc/glmasr.md
index a895b778bedb..d04be02a5f33 100644
--- a/docs/source/en/model_doc/glmasr.md
+++ b/docs/source/en/model_doc/glmasr.md
@@ -16,7 +16,7 @@ limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.
-->
-*This model was released on {release_date} and added to Hugging Face Transformers on 2025-12-24.*
+*This model was released on 2025-12-08 and added to Hugging Face Transformers on 2025-12-24.*
# GlmAsr
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index edc2cfb32f94..0cd0b26035ba 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -155,6 +155,7 @@
from .glm4v import *
from .glm4v_moe import *
from .glm46v import *
+ from .glm_image import *
from .glmasr import *
from .glpn import *
from .got_ocr2 import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 51ac332febbb..e85836dbe426 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -180,6 +180,10 @@
("glm4v_moe_vision", "Glm4vMoeVisionConfig"),
("glm4v_text", "Glm4vTextConfig"),
("glm4v_vision", "Glm4vVisionConfig"),
+ ("glm_image", "GlmImageConfig"),
+ ("glm_image_text", "GlmImageTextConfig"),
+ ("glm_image_vision", "GlmImageVisionConfig"),
+ ("glm_image_vqmodel", "GlmImageVQVAEConfig"),
("glmasr", "GlmAsrConfig"),
("glmasr_encoder", "GlmAsrEncoderConfig"),
("glpn", "GLPNConfig"),
@@ -638,6 +642,10 @@
("glm4v_moe_vision", "Glm4vMoeVisionModel"),
("glm4v_text", "GLM4V"),
("glm4v_vision", "Glm4vVisionModel"),
+ ("glm_image", "GlmImage"),
+ ("glm_image_text", "GlmImageText"),
+ ("glm_image_vision", "GlmImageVisionModel"),
+ ("glm_image_vqmodel", "GlmImageVQVAE"),
("glmasr", "GLM-ASR"),
("glmasr_encoder", "GLM-ASR Encoder"),
("glpn", "GLPN"),
@@ -984,6 +992,9 @@
("glm4v_moe_vision", "glm4v_moe"),
("glm4v_text", "glm4v"),
("glm4v_moe_text", "glm4v_moe"),
+ ("glm_image_vision", "glm_image"),
+ ("glm_image_vqmodel", "glm_image"),
+ ("glm_image_text", "glm_image"),
("glmasr_encoder", "glmasr"),
("grounding-dino", "grounding_dino"),
("mm-grounding-dino", "mm_grounding_dino"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 5379386cc1df..8c5663dff3fa 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -109,6 +109,7 @@
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("glm46v", ("Glm46VImageProcessor", "Glm46VImageProcessorFast")),
("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
+ ("glm_image", ("GlmImageImageProcessor", "GlmImageImageProcessorFast")),
("glpn", ("GLPNImageProcessor", "GLPNImageProcessorFast")),
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 583fbc8c1d3f..13eabcdc3405 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -183,6 +183,10 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("glm4v_moe_vision", "Glm4vMoeVisionModel"),
("glm4v_text", "Glm4vTextModel"),
("glm4v_vision", "Glm4vVisionModel"),
+ ("glm_image", "GlmImageModel"),
+ ("glm_image_text", "GlmImageTextModel"),
+ ("glm_image_vision", "GlmImageVisionModel"),
+ ("glm_image_vqmodel", "GlmImageVQVAE"),
("glmasr", "GlmAsrForConditionalGeneration"),
("glmasr_encoder", "GlmAsrEncoder"),
("glpn", "GLPNModel"),
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 8ec4f0254d24..58cef02556b3 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -78,6 +78,7 @@
("glm46v", "Glm46VProcessor"),
("glm4v", "Glm4vProcessor"),
("glm4v_moe", "Glm4vProcessor"),
+ ("glm_image", "Glm4vProcessor"),
("glmasr", "GlmAsrProcessor"),
("got_ocr2", "GotOcr2Processor"),
("granite_speech", "GraniteSpeechProcessor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index ac7503053ba5..1f2d68cab65b 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -136,6 +136,7 @@
("glm4_moe_lite", "TokenizersBackend" if is_tokenizers_available() else None),
("glm4v", "TokenizersBackend" if is_tokenizers_available() else None),
("glm4v_moe", "TokenizersBackend" if is_tokenizers_available() else None),
+ ("glm_image", "TokenizersBackend" if is_tokenizers_available() else None),
("glmasr", "TokenizersBackend" if is_tokenizers_available() else None),
("got_ocr2", "TokenizersBackend" if is_tokenizers_available() else None),
("gpt-sw3", "GPTSw3Tokenizer" if is_sentencepiece_available() else None),
diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py
index f3eb66c734d0..7388fdca11f0 100644
--- a/src/transformers/models/glm4v/modeling_glm4v.py
+++ b/src/transformers/models/glm4v/modeling_glm4v.py
@@ -143,6 +143,7 @@ def __init__(self, config: Glm4vVisionConfig):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.interpolated_method = "bicubic"
def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
"""
@@ -161,57 +162,45 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc
# Get position embedding parameters
pos_embed_weight = self.position_embedding.weight
hidden_size = pos_embed_weight.shape[1]
- total_seq = h_coords.shape[0]
device = pos_embed_weight.device
- # Move coordinates to correct device
- h_coords, w_coords = h_coords.to(device), w_coords.to(device)
-
- # Handle empty sequence case
- if total_seq == 0:
- adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype)
- else:
- # Convert inputs to tensors if needed
- if isinstance(lengths, list):
- lengths = torch.tensor(lengths, device=device, dtype=torch.long)
- if not isinstance(image_shapes, torch.Tensor):
- image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long)
-
- # Prepare 2D position embedding
- orig_size_sq = pos_embed_weight.shape[0]
- orig_size = int(orig_size_sq**0.5)
- pos_embed_2d = (
- pos_embed_weight.view(orig_size, orig_size, hidden_size)
- .permute(2, 0, 1)
- .unsqueeze(0)
- .to(device=device, dtype=torch.float32)
- )
+ # Convert inputs to tensors if needed
+ if isinstance(lengths, list):
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
+
+ # Prepare 2D position embedding
+ orig_size_sq = pos_embed_weight.shape[0]
+ orig_size = int(orig_size_sq**0.5)
+ pos_embed_2d = (
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .to(device=device, dtype=torch.float32)
+ )
- # Calculate target dimensions for each patch
- target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
- device=device, dtype=torch.float32
- )
- target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
- device=device, dtype=torch.float32
- )
+ # Calculate target dimensions for each patch
+ target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+ target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
- # Normalize coordinates to [-1, 1] range for grid_sample
- h_coords = h_coords.to(device=device, dtype=torch.float32)
- w_coords = w_coords.to(device=device, dtype=torch.float32)
- norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
- norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
+ # Normalize coordinates to [-1, 1] range for grid_sample
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
- # Create sampling grid
- grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
+ # Create sampling grid
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
- # Perform bicubic interpolation
- interpolated_embed_fp32 = F.grid_sample(
- pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border"
- )
+ # Perform bicubic interpolation
+ interpolated_embed_fp32 = F.grid_sample(
+ pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border"
+ )
- # Reshape and convert back to original dtype
- adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
- adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
+ # Reshape and convert back to original dtype
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
# Add adapted position encoding to embeddings
embeddings = embeddings + adapted_pos_embed
@@ -405,6 +394,7 @@ def __init__(self, config: Glm4vTextConfig, device=None):
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
+ self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12])
@staticmethod
def compute_default_rope_parameters(
@@ -441,7 +431,7 @@ def compute_default_rope_parameters(
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
- # In contrast to other models, GLM4V different position ids for the grids
+ # In contrast to other models, GLM-V has different position ids for the grids
# So we expand the inv_freq to shape (3, ...)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
@@ -449,12 +439,19 @@ def forward(self, x, position_ids):
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_mrope(freqs, self.mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+ def apply_mrope(self, freqs, mrope_section):
+ section = mrope_section
+ chunks = freqs.split(section, dim=-1)
+ result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
+ return result
+
def rotate_half_llm(x):
"""Rotates half the hidden dims of the input."""
@@ -463,25 +460,16 @@ def rotate_half_llm(x):
return torch.stack((-x2, x1), dim=-1).flatten(-2)
-def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
- """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
-
- Explanation:
- Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
- sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
- vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
- Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
- For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
- height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
- difference with modern LLMs.
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
- mrope_section(`List(int)`):
- Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -492,13 +480,8 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
- mrope_section = mrope_section * 2
- cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
- unsqueeze_dim
- )
- sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
- unsqueeze_dim
- )
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
# Interleave them instead of usual shape
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
@@ -516,7 +499,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
-
return q_embed, k_embed
@@ -566,9 +548,7 @@ def forward(
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
- query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama
- query_states, key_states, cos, sin, self.rope_parameters["mrope_section"]
- )
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -788,7 +768,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
"""
hidden_states = self.patch_embed(hidden_states)
hidden_states = self.post_conv_layernorm(hidden_states)
-
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
@@ -803,7 +782,13 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
+ hidden_states = self.embeddings(
+ hidden_states,
+ seqlens,
+ grid_thw,
+ image_type_ids[:, 0].to(hidden_states.device),
+ image_type_ids[:, 1].to(hidden_states.device),
+ )
for blk in self.blocks:
hidden_states = blk(
@@ -915,7 +900,7 @@ def forward(
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
- position_ids=position_ids,
+ position_ids=text_position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py
index 61edcdc7a390..77946b1310f9 100644
--- a/src/transformers/models/glm4v/modular_glm4v.py
+++ b/src/transformers/models/glm4v/modular_glm4v.py
@@ -408,6 +408,7 @@ def __init__(self, config: Glm4vVisionConfig):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.interpolated_method = "bicubic"
def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
"""
@@ -426,57 +427,45 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc
# Get position embedding parameters
pos_embed_weight = self.position_embedding.weight
hidden_size = pos_embed_weight.shape[1]
- total_seq = h_coords.shape[0]
device = pos_embed_weight.device
- # Move coordinates to correct device
- h_coords, w_coords = h_coords.to(device), w_coords.to(device)
-
- # Handle empty sequence case
- if total_seq == 0:
- adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype)
- else:
- # Convert inputs to tensors if needed
- if isinstance(lengths, list):
- lengths = torch.tensor(lengths, device=device, dtype=torch.long)
- if not isinstance(image_shapes, torch.Tensor):
- image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long)
-
- # Prepare 2D position embedding
- orig_size_sq = pos_embed_weight.shape[0]
- orig_size = int(orig_size_sq**0.5)
- pos_embed_2d = (
- pos_embed_weight.view(orig_size, orig_size, hidden_size)
- .permute(2, 0, 1)
- .unsqueeze(0)
- .to(device=device, dtype=torch.float32)
- )
+ # Convert inputs to tensors if needed
+ if isinstance(lengths, list):
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
+
+ # Prepare 2D position embedding
+ orig_size_sq = pos_embed_weight.shape[0]
+ orig_size = int(orig_size_sq**0.5)
+ pos_embed_2d = (
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .to(device=device, dtype=torch.float32)
+ )
- # Calculate target dimensions for each patch
- target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
- device=device, dtype=torch.float32
- )
- target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
- device=device, dtype=torch.float32
- )
+ # Calculate target dimensions for each patch
+ target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+ target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
- # Normalize coordinates to [-1, 1] range for grid_sample
- h_coords = h_coords.to(device=device, dtype=torch.float32)
- w_coords = w_coords.to(device=device, dtype=torch.float32)
- norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
- norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
+ # Normalize coordinates to [-1, 1] range for grid_sample
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
- # Create sampling grid
- grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
+ # Create sampling grid
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
- # Perform bicubic interpolation
- interpolated_embed_fp32 = F.grid_sample(
- pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border"
- )
+ # Perform bicubic interpolation
+ interpolated_embed_fp32 = F.grid_sample(
+ pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border"
+ )
- # Reshape and convert back to original dtype
- adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
- adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
+ # Reshape and convert back to original dtype
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
# Add adapted position encoding to embeddings
embeddings = embeddings + adapted_pos_embed
@@ -501,9 +490,12 @@ def __init__(self, config) -> None:
class Glm4vTextRotaryEmbedding(Glm4RotaryEmbedding):
- # Ignore copy
+ def __init__(self, config: Glm4vTextConfig, device=None):
+ super().__init__()
+ self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12])
+
def forward(self, x, position_ids):
- # In contrast to other models, GLM4V different position ids for the grids
+ # In contrast to other models, GLM-V has different position ids for the grids
# So we expand the inv_freq to shape (3, ...)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
@@ -511,12 +503,19 @@ def forward(self, x, position_ids):
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_mrope(freqs, self.mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+ def apply_mrope(self, freqs, mrope_section):
+ section = mrope_section
+ chunks = freqs.split(section, dim=-1)
+ result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
+ return result
+
def rotate_half_llm(x):
"""Rotates half the hidden dims of the input."""
@@ -525,25 +524,16 @@ def rotate_half_llm(x):
return torch.stack((-x2, x1), dim=-1).flatten(-2)
-def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
- """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
-
- Explanation:
- Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
- sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
- vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
- Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
- For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
- height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
- difference with modern LLMs.
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
- mrope_section(`List(int)`):
- Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -554,13 +544,8 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
- mrope_section = mrope_section * 2
- cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
- unsqueeze_dim
- )
- sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
- unsqueeze_dim
- )
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
# Interleave them instead of usual shape
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
@@ -578,7 +563,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
-
return q_embed, k_embed
@@ -628,9 +612,7 @@ def forward(
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
- query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama
- query_states, key_states, cos, sin, self.rope_parameters["mrope_section"]
- )
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -805,7 +787,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
"""
hidden_states = self.patch_embed(hidden_states)
hidden_states = self.post_conv_layernorm(hidden_states)
-
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
@@ -820,7 +801,13 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
+ hidden_states = self.embeddings(
+ hidden_states,
+ seqlens,
+ grid_thw,
+ image_type_ids[:, 0].to(hidden_states.device),
+ image_type_ids[:, 1].to(hidden_states.device),
+ )
for blk in self.blocks:
hidden_states = blk(
@@ -922,7 +909,7 @@ def forward(
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
- position_ids=position_ids,
+ position_ids=text_position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
diff --git a/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py b/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py
index 29f0ed747f2e..5772ed2957ee 100644
--- a/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py
+++ b/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py
@@ -21,99 +21,6 @@
from ...modeling_rope_utils import RopeParameters
-class Glm4vMoeVisionConfig(PreTrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`Glm4vMoeVisionModel`]. It is used to instantiate an Glm4vMoeVisionModel
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
- a similar configuration to that of
- GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking).
-
- Args:
- depth (`int`, *optional*, defaults to 24):
- Number of layers (depth) in the model.
- hidden_size (`int`, *optional*, defaults to 1536):
- Dimensionality of the encoder layers and the pooler layer.
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
- `"relu"`, `"selu"` and `"gelu_new"` are supported.
- attention_bias (`bool`, *optional*, defaults to `False`):
- Whether to add a bias to the queries, keys and values.
- attention_dropout (`float`, *optional*, defaults to 0.0):
- Dropout probability for attention weights.
- num_heads (``, *optional*, defaults to 12):
- in_channels (``, *optional*, defaults to 3):
- image_size (`int` or `list[int]`, *optional*, defaults to 336):
- The size (resolution) of each image.
- patch_size (`int`, *optional*, defaults to 14):
- The size (resolution) of each patch.
- rms_norm_eps (`float`, *optional*, defaults to 1e-05):
- The epsilon used by the rms normalization layers.
- spatial_merge_size (`int`, *optional*, defaults to 2):
- The size used for merging spatial dimensions.
- temporal_patch_size (`int`, *optional*, defaults to 2):
- The size used for patches along the temporal dimension.
- out_hidden_size (`int`, *optional*, defaults to 4096):
- The output hidden size of the vision model.
- intermediate_size (`int`, *optional*, defaults to 13696):
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- Example:
-
- ```python
- >>> from transformers import Glm4vMoeVisionConfig, Glm4vMoeVisionModel
-
- >>> # Initializing a Glm4vMoeVisionConfig GLM-4.1V-9B style configuration
- >>> configuration = Glm4vMoeVisionConfig()
-
- >>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration
- >>> model = Glm4vMoeVisionModel(configuration)
-
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
-
- model_type = "glm4v_moe_vision"
- base_config_key = "vision_config"
-
- def __init__(
- self,
- depth=24,
- hidden_size=1536,
- hidden_act="silu",
- attention_bias=False,
- attention_dropout=0.0,
- num_heads=12,
- in_channels=3,
- image_size=336,
- patch_size=14,
- rms_norm_eps=1e-05,
- spatial_merge_size=2,
- temporal_patch_size=2,
- out_hidden_size=4096,
- intermediate_size=13696,
- initializer_range=0.02,
- **kwargs,
- ):
- super().__init__(**kwargs)
-
- self.depth = depth
- self.hidden_size = hidden_size
- self.hidden_act = hidden_act
- self.num_heads = num_heads
- self.in_channels = in_channels
- self.image_size = image_size
- self.patch_size = patch_size
- self.spatial_merge_size = spatial_merge_size
- self.temporal_patch_size = temporal_patch_size
- self.out_hidden_size = out_hidden_size
- self.intermediate_size = intermediate_size
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.attention_bias = attention_bias
- self.attention_dropout = attention_dropout
-
-
class Glm4vMoeTextConfig(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
@@ -282,6 +189,99 @@ def __init__(
)
+class Glm4vMoeVisionConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Glm4vMoeVisionModel`]. It is used to instantiate an Glm4vMoeVisionModel
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
+ a similar configuration to that of
+ GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking).
+
+ Args:
+ depth (`int`, *optional*, defaults to 24):
+ Number of layers (depth) in the model.
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimensionality of the encoder layers and the pooler layer.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the queries, keys and values.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for attention weights.
+ num_heads (``, *optional*, defaults to 12):
+ in_channels (``, *optional*, defaults to 3):
+ image_size (`int` or `list[int]`, *optional*, defaults to 336):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ spatial_merge_size (`int`, *optional*, defaults to 2):
+ The size used for merging spatial dimensions.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The size used for patches along the temporal dimension.
+ out_hidden_size (`int`, *optional*, defaults to 4096):
+ The output hidden size of the vision model.
+ intermediate_size (`int`, *optional*, defaults to 13696):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ Example:
+
+ ```python
+ >>> from transformers import Glm4vMoeVisionConfig, Glm4vMoeVisionModel
+
+ >>> # Initializing a Glm4vMoeVisionConfig GLM-4.1V-9B style configuration
+ >>> configuration = Glm4vMoeVisionConfig()
+
+ >>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration
+ >>> model = Glm4vMoeVisionModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "glm4v_moe_vision"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=24,
+ hidden_size=1536,
+ hidden_act="silu",
+ attention_bias=False,
+ attention_dropout=0.0,
+ num_heads=12,
+ in_channels=3,
+ image_size=336,
+ patch_size=14,
+ rms_norm_eps=1e-05,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=4096,
+ intermediate_size=13696,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.intermediate_size = intermediate_size
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+
class Glm4vMoeConfig(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
@@ -360,4 +360,4 @@ def __init__(
super().__init__(**kwargs)
-__all__ = ["Glm4vMoeConfig", "Glm4vMoeTextConfig", "Glm4vMoeVisionConfig"]
+__all__ = ["Glm4vMoeConfig", "Glm4vMoeVisionConfig", "Glm4vMoeTextConfig"]
diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py
index 3ed571198353..ef95cf26e191 100644
--- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py
+++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py
@@ -50,120 +50,6 @@
from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig
-@use_kernel_forward_from_hub("RMSNorm")
-class Glm4vMoeRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- Glm4vMoeRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
-
-
-@dataclass
-@auto_docstring(
- custom_intro="""
- Base class for Llava outputs, with hidden states and attentions.
- """
-)
-class Glm4vMoeModelOutputWithPast(ModelOutput):
- r"""
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
- The rope index difference between sequence length and multimodal rope.
- """
-
- last_hidden_state: torch.FloatTensor | None = None
- past_key_values: Cache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- rope_deltas: torch.LongTensor | None = None
-
-
-class Glm4vMoeTextRotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
-
- def __init__(self, config: Glm4vMoeTextConfig, device=None, layer_type=None):
- super().__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
-
- self.config = config
-
- self.rope_type = self.config.rope_parameters["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
-
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
-
- @staticmethod
- def compute_default_rope_parameters(
- config: Glm4vMoeTextConfig | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- base = config.rope_parameters["rope_theta"]
- partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- dim = int(head_dim * partial_rotary_factor)
-
- attention_factor = 1.0 # Unused in this type of RoPE
-
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
-
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- # In contrast to other models, GLM4V_MOE different position ids for the grids
- # So we expand the inv_freq to shape (3, ...)
- inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
- position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
-
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -232,10 +118,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
- # Interleave them instead of usual shape
- cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
- sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
-
# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
@@ -251,59 +133,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed
-def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
- """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
-
- Explanation:
- Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
- sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
- vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
- Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
- For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
- height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
- difference with modern LLMs.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- mrope_section(`List(int)`):
- Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- mrope_section = mrope_section * 2
- cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
- unsqueeze_dim
- )
- sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
- unsqueeze_dim
- )
-
- # Keep half or full tensor for later concatenation
- rotary_dim = cos.shape[-1]
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
-
- # Apply rotary embeddings on the first half or full tensor
- q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
- k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
-
- # Concatenate back to full shape
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
-
- return q_embed, k_embed
-
-
@use_kernelized_func(apply_rotary_pos_emb)
class Glm4vMoeTextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -351,9 +180,7 @@ def forward(
value_states = value_states.transpose(1, 2)
cos, sin = position_embeddings
- query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama
- query_states, key_states, cos, sin, self.rope_parameters["mrope_section"]
- )
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
@@ -665,6 +492,27 @@ def forward(self, seqlen: int) -> torch.Tensor:
return freqs
+@use_kernel_forward_from_hub("RMSNorm")
+class Glm4vMoeRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Glm4vMoeRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
class Glm4vMoeisionMlp(nn.Module):
def __init__(self, config, bias: bool = False):
super().__init__()
@@ -727,6 +575,7 @@ def __init__(self, config: Glm4vMoeVisionConfig):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.interpolated_method = "bicubic"
def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
"""
@@ -745,57 +594,45 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc
# Get position embedding parameters
pos_embed_weight = self.position_embedding.weight
hidden_size = pos_embed_weight.shape[1]
- total_seq = h_coords.shape[0]
device = pos_embed_weight.device
- # Move coordinates to correct device
- h_coords, w_coords = h_coords.to(device), w_coords.to(device)
-
- # Handle empty sequence case
- if total_seq == 0:
- adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype)
- else:
- # Convert inputs to tensors if needed
- if isinstance(lengths, list):
- lengths = torch.tensor(lengths, device=device, dtype=torch.long)
- if not isinstance(image_shapes, torch.Tensor):
- image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long)
-
- # Prepare 2D position embedding
- orig_size_sq = pos_embed_weight.shape[0]
- orig_size = int(orig_size_sq**0.5)
- pos_embed_2d = (
- pos_embed_weight.view(orig_size, orig_size, hidden_size)
- .permute(2, 0, 1)
- .unsqueeze(0)
- .to(device=device, dtype=torch.float32)
- )
+ # Convert inputs to tensors if needed
+ if isinstance(lengths, list):
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
+
+ # Prepare 2D position embedding
+ orig_size_sq = pos_embed_weight.shape[0]
+ orig_size = int(orig_size_sq**0.5)
+ pos_embed_2d = (
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .to(device=device, dtype=torch.float32)
+ )
- # Calculate target dimensions for each patch
- target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
- device=device, dtype=torch.float32
- )
- target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
- device=device, dtype=torch.float32
- )
+ # Calculate target dimensions for each patch
+ target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+ target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
- # Normalize coordinates to [-1, 1] range for grid_sample
- h_coords = h_coords.to(device=device, dtype=torch.float32)
- w_coords = w_coords.to(device=device, dtype=torch.float32)
- norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
- norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
+ # Normalize coordinates to [-1, 1] range for grid_sample
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
- # Create sampling grid
- grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
+ # Create sampling grid
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
- # Perform bicubic interpolation
- interpolated_embed_fp32 = F.grid_sample(
- pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border"
- )
+ # Perform bicubic interpolation
+ interpolated_embed_fp32 = F.grid_sample(
+ pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border"
+ )
- # Reshape and convert back to original dtype
- adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
- adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
+ # Reshape and convert back to original dtype
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
# Add adapted position encoding to embeddings
embeddings = embeddings + adapted_pos_embed
@@ -1002,7 +839,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
"""
hidden_states = self.patch_embed(hidden_states)
hidden_states = self.post_conv_layernorm(hidden_states)
-
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
@@ -1017,7 +853,13 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
+ hidden_states = self.embeddings(
+ hidden_states,
+ seqlens,
+ grid_thw,
+ image_type_ids[:, 0].to(hidden_states.device),
+ image_type_ids[:, 1].to(hidden_states.device),
+ )
for blk in self.blocks:
hidden_states = blk(
@@ -1038,6 +880,83 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
return hidden_states
+class Glm4vMoeTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Glm4vMoeTextConfig, device=None):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+
+ self.rope_type = self.config.rope_parameters["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
+ self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12])
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: Glm4vMoeTextConfig | None = None,
+ device: Optional["torch.device"] = None,
+ seq_len: int | None = None,
+ ) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PreTrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_parameters["rope_theta"]
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, GLM-V has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_mrope(freqs, self.mrope_section)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+ def apply_mrope(self, freqs, mrope_section):
+ section = mrope_section
+ chunks = freqs.split(section, dim=-1)
+ result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
+ return result
+
+
@auto_docstring
class Glm4vMoeTextModel(Glm4vMoePreTrainedModel):
config: Glm4vMoeTextConfig
@@ -1147,6 +1066,30 @@ def forward(
)
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class Glm4vMoeModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ rope_deltas: torch.LongTensor | None = None
+
+
@auto_docstring
class Glm4vMoeModel(Glm4vMoePreTrainedModel):
base_model_prefix = "model"
diff --git a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py
index 15a15c0fda9a..5e81dc84c6c9 100644
--- a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py
+++ b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
-from typing import Optional
import torch
import torch.nn as nn
@@ -36,22 +35,19 @@
Glm4MoeMLP,
Glm4MoeMoE,
Glm4MoePreTrainedModel,
- Glm4MoeRMSNorm,
Glm4MoeTopkRouter,
eager_attention_forward,
)
-from ..glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
+from ..glm4v.configuration_glm4v import Glm4vConfig
from ..glm4v.modeling_glm4v import (
Glm4vForConditionalGeneration,
Glm4vTextModel,
- Glm4vTextRotaryEmbedding,
Glm4vVisionModel,
Glm4vVisionRotaryEmbedding,
- rotate_half,
)
+from ..gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb
from ..qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeCausalLMOutputWithPast,
- Qwen3VLMoeModelOutputWithPast,
load_balancing_loss_func,
)
@@ -59,14 +55,6 @@
logger = logging.get_logger(__name__)
-class Glm4vMoeVisionConfig(Glm4vVisionConfig):
- pass
-
-
-class Glm4vMoeRMSNorm(Glm4MoeRMSNorm):
- pass
-
-
class Glm4vMoeTextConfig(Glm4MoeConfig, RotaryEmbeddingConfigMixin):
r"""
This is the configuration class to store the configuration of a [`Glm4vMoeModel`]. It is used to instantiate a
@@ -289,100 +277,6 @@ def __init__(
super().__init__()
-class Glm4vMoeModelOutputWithPast(Qwen3VLMoeModelOutputWithPast):
- pass
-
-
-def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
- """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
-
- Explanation:
- Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
- sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
- vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
- Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
- For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
- height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
- difference with modern LLMs.
-
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- mrope_section(`List(int)`):
- Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- mrope_section = mrope_section * 2
- cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
- unsqueeze_dim
- )
- sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
- unsqueeze_dim
- )
-
- # Keep half or full tensor for later concatenation
- rotary_dim = cos.shape[-1]
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
-
- # Apply rotary embeddings on the first half or full tensor
- q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
- k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
-
- # Concatenate back to full shape
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
-
- return q_embed, k_embed
-
-
-class Glm4vMoeTextRotaryEmbedding(Glm4vTextRotaryEmbedding):
- def __init__(self, config: Glm4vMoeTextConfig, device=None, layer_type=None):
- super().__init__(config, device=device, layer_type=layer_type)
-
- @staticmethod
- def compute_default_rope_parameters(
- config: Glm4vMoeTextConfig | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- base = config.rope_parameters["rope_theta"]
- partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- dim = int(head_dim * partial_rotary_factor)
-
- attention_factor = 1.0 # Unused in this type of RoPE
-
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
-
-
class Glm4vMoeTextAttention(Glm4Attention):
def __init__(self, config: Glm4vMoeTextConfig, layer_idx: int | None = None):
super().__init__(config, layer_idx)
@@ -409,9 +303,7 @@ def forward(
value_states = value_states.transpose(1, 2)
cos, sin = position_embeddings
- query_states, key_states = apply_multimodal_rotary_pos_emb( # diff with Llama
- query_states, key_states, cos, sin, self.rope_parameters["mrope_section"]
- )
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
@@ -656,8 +548,8 @@ def forward(
__all__ = [
"Glm4vMoeConfig",
+ "Glm4vMoeVisionConfig", # noqa: F822
"Glm4vMoeTextConfig",
- "Glm4vMoeVisionConfig",
"Glm4vMoeForConditionalGeneration",
"Glm4vMoeModel", # noqa: F822
"Glm4vMoePreTrainedModel",
diff --git a/src/transformers/models/glm_image/__init__.py b/src/transformers/models/glm_image/__init__.py
new file mode 100644
index 000000000000..7c5108a29576
--- /dev/null
+++ b/src/transformers/models/glm_image/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_glm_image import *
+ from .image_processing_glm_image import *
+ from .image_processing_glm_image_fast import *
+ from .modeling_glm_image import *
+ from .processing_glm_image import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/glm_image/configuration_glm_image.py b/src/transformers/models/glm_image/configuration_glm_image.py
new file mode 100644
index 000000000000..403b5676dace
--- /dev/null
+++ b/src/transformers/models/glm_image/configuration_glm_image.py
@@ -0,0 +1,352 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm_image.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PreTrainedConfig
+from ...modeling_rope_utils import RopeParameters
+
+
+class GlmImageVQVAEConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GlmImageVQModel`]. It is used to instantiate a
+ `GlmImageVQModel` according to the specified arguments, defining the model architecture.
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information. Instantiating a
+ configuration with the defaults will yield a similar configuration to the VQModel of the
+ [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image) architecture.
+
+ Args:
+ embed_dim (`int`, *optional*, defaults to 2048):
+ Dimensionality of each embedding vector.
+ num_embeddings (`int`, *optional*, defaults to 16384):
+ Number of codebook embeddings.
+ latent_channels (`int`, *optional*, defaults to 1536):
+ Number of channels for the latent space.
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of input channels.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ """
+
+ model_type = "glm_image_vqmodel"
+ base_config_key = "vq_config"
+
+ def __init__(
+ self,
+ embed_dim: int = 2048,
+ num_embeddings: int = 16384,
+ latent_channels: int = 1536,
+ in_channels: int = 3,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+ self.num_embeddings = num_embeddings
+ self.latent_channels = latent_channels
+ self.in_channels = in_channels
+ self.initializer_range = initializer_range
+
+
+class GlmImageVisionConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GlmImageVisionModel`]. It is used to instantiate an GlmImageVisionModel
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
+ a similar configuration to that of
+ GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image).
+
+ Args:
+ depth (`int`, *optional*, defaults to 40):
+ Number of layers (depth) in the model.
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimensionality of the encoder layers and the pooler layer.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for attention weights.
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer architecture.
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of input channels.
+ image_size (`int` or `list[int]`, *optional*, defaults to 2048):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ spatial_merge_size (`int`, *optional*, defaults to 1):
+ The size used for merging spatial dimensions.
+ intermediate_size (`int`, *optional*, defaults to 6144):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ """
+
+ model_type = "glm_image_vision"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=40,
+ hidden_size=1536,
+ hidden_act="gelu",
+ attention_bias=True,
+ attention_dropout=0.0,
+ num_heads=16,
+ in_channels=3,
+ image_size=2048,
+ patch_size=16,
+ layer_norm_eps=1e-06,
+ spatial_merge_size=1,
+ intermediate_size=6144,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.intermediate_size = intermediate_size
+ self.initializer_range = initializer_range
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+
+
+class GlmImageTextConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GlmImageTextModel`]. It is used to instantiate a
+ GLM-Image model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 168064):
+ Vocabulary size of the GlmImage model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GlmImageModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 13696):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 40):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ rope_parameters (`RopeParameters`, *optional*):
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
+ with longer `max_position_embeddings`.
+ vision_vocab_size (`int`, *optional*, defaults to 16512):
+ Vision vocabulary size of the GlmImage model. Defines the number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`GlmImageVisionModel`]
+ attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+
+ ```python
+ >>> from transformers import GlmImageTextModel, GlmImageConfig
+
+ >>> # Initializing a GlmImageConfig style configuration
+ >>> configuration = GlmImageConfig()
+
+ >>> # Initializing a model from the GlmImageConfig style configuration
+ >>> model = GlmImageTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "glm_image_text"
+ base_config_key = "text_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `GlmImage`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
+ "layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size: int | None = 168064,
+ hidden_size: int | None = 4096,
+ intermediate_size: int | None = 13696,
+ num_hidden_layers: int | None = 40,
+ num_attention_heads: int | None = 32,
+ num_key_value_heads: int | None = 2,
+ hidden_act: str | None = "silu",
+ max_position_embeddings: int | None = 32768,
+ initializer_range: float | None = 0.02,
+ rms_norm_eps: int | None = 1e-05,
+ use_cache: bool | None = True,
+ tie_word_embeddings: bool | None = False,
+ attention_dropout: float | None = 0.0,
+ rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
+ vision_vocab_size: int | None = 16512,
+ attention_bias: bool | None = True,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.vision_vocab_size = vision_vocab_size
+ self.attention_bias = attention_bias
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.attention_dropout = attention_dropout
+ self.rope_parameters = rope_parameters
+
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **kwargs
+ )
+
+
+class GlmImageConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GlmImageModel`]. It is used to instantiate a
+ GLM-Image model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image) architecture.
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `GlmImageTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `GlmImageVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ vq_config (`Union[Dict, GlmImageVQVAEConfig]`, *optional*):
+ GlmImageVQVAEConfig instance containing the configuration for the VQ-VAE model.
+ image_token_id (`int`, *optional*, defaults to 167855):
+ The image token index to encode the image prompt.
+ image_start_token_id (`int`, *optional*, defaults to 16384):
+ The image start token index to encode the start of image.
+ image_end_token_id (`int`, *optional*, defaults to 16385):
+ The image end token index to encode the end of image.
+
+ ```python
+ >>> from transformers import Glm4vForConditionalGeneration, Glm4vConfig
+
+ >>> # Initializing a GLM-Image style configuration
+ >>> configuration = Glm4vConfig()
+
+ >>> # Initializing a model from the GLM-Image style configuration
+ >>> model = Glm4vForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "glm_image"
+ sub_configs = {
+ "vision_config": GlmImageVisionConfig,
+ "text_config": GlmImageTextConfig,
+ "vq_config": GlmImageVQVAEConfig,
+ }
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ vq_config=None,
+ image_token_id=167855,
+ image_start_token_id=16384,
+ image_end_token_id=16385,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ vision_config = self.sub_configs["vision_config"](**kwargs)
+
+ if isinstance(vq_config, dict):
+ vq_config = self.sub_configs["vq_config"](**vq_config)
+ elif vq_config is None:
+ vq_config = self.sub_configs["vq_config"](**kwargs)
+
+ if isinstance(text_config, dict):
+ text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ text_config = self.sub_configs["text_config"](**kwargs)
+
+ self.image_token_id = image_token_id
+ self.image_start_token_id = image_start_token_id
+ self.image_end_token_id = image_end_token_id
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.vq_config = vq_config
+ super().__init__(**kwargs)
+
+
+__all__ = ["GlmImageVQVAEConfig", "GlmImageVisionConfig", "GlmImageTextConfig", "GlmImageConfig"]
diff --git a/src/transformers/models/glm_image/image_processing_glm_image.py b/src/transformers/models/glm_image/image_processing_glm_image.py
new file mode 100644
index 000000000000..acf934b180a8
--- /dev/null
+++ b/src/transformers/models/glm_image/image_processing_glm_image.py
@@ -0,0 +1,503 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm_image.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_processing_utils import BaseImageProcessor
+from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_flat_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...processing_utils import ImagesKwargs
+from ...utils import TensorType, logging
+from ...video_utils import VideoInput
+
+
+logger = logging.get_logger(__name__)
+
+
+class GlmImageImageProcessorKwargs(ImagesKwargs, total=False):
+ r"""
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to 2):
+ The merge size of the vision encoder to llm encoder.
+ """
+
+ min_pixels: int
+ max_pixels: int
+ patch_size: int
+ temporal_patch_size: int
+ merge_size: int
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 16,
+ min_pixels: int = 512 * 512,
+ max_pixels: int = 2048 * 2048,
+) -> tuple[int, int]:
+ if height < factor or width < factor:
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
+ elif max(height, width) / min(height, width) > 4:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 4, got {max(height, width) / min(height, width)}"
+ )
+
+ shortest_edge = int(round(math.sqrt(min_pixels)))
+ longest_edge = int(round(math.sqrt(max_pixels)))
+ min_side = min(height, width)
+ max_side = max(height, width)
+
+ scale = 1.0
+
+ if min_side < shortest_edge:
+ scale = shortest_edge / min_side
+
+ if max_side * scale > longest_edge:
+ scale = longest_edge / max_side
+
+ height = height // 2
+ width = width // 2
+
+ h_bar = max(factor, int(round(height * scale / factor)) * factor)
+ w_bar = max(factor, int(round(width * scale / factor)) * factor)
+
+ if max(h_bar, w_bar) > longest_edge:
+ beta = max(h_bar, w_bar) / longest_edge
+ h_bar = max(factor, int(math.floor((h_bar / beta) / factor)) * factor)
+ w_bar = max(factor, int(math.floor((w_bar / beta) / factor)) * factor)
+
+ return h_bar, w_bar
+
+
+class GlmImageImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions.
+ size (`dict[str, int]`, *optional*, defaults to `{"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}`):
+ Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
+ image_std (`float` or `list[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to 2):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to 2):
+ The merge size of the vision encoder to llm encoder.
+ """
+
+ model_input_names = ["pixel_values", "image_grid_thw"]
+ valid_kwargs = GlmImageImageProcessorKwargs
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: dict[str, int] | None = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: int | float = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ do_convert_rgb: bool = True,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ merge_size: int = 2,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ if size is not None:
+ if "shortest_edge" not in size or "longest_edge" not in size:
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
+ else:
+ size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
+ # backward compatibility: override size with min_pixels and max_pixels if they are provided
+ if min_pixels is not None:
+ size["shortest_edge"] = min_pixels
+ if max_pixels is not None:
+ size["longest_edge"] = max_pixels
+ self.min_pixels = size["shortest_edge"]
+ self.max_pixels = size["longest_edge"]
+ self.size = size
+
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.merge_size = merge_size
+ self.do_convert_rgb = do_convert_rgb
+
+ def _preprocess(
+ self,
+ images: ImageInput | VideoInput,
+ do_resize: bool | None = None,
+ size: dict[str, int] | None = None,
+ resample: PILImageResampling | None = None,
+ do_rescale: bool | None = None,
+ rescale_factor: float | None = None,
+ do_normalize: bool | None = None,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ patch_size: int | None = None,
+ temporal_patch_size: int | None = None,
+ merge_size: int | None = None,
+ do_convert_rgb: bool | None = None,
+ data_format: ChannelDimension | None = ChannelDimension.FIRST,
+ input_data_format: str | ChannelDimension | None = None,
+ ):
+ """
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
+ vision_info (`list[Dict]`, *optional*):
+ Optional list of dictionaries containing additional information about vision inputs.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ images = make_flat_list_of_images(images)
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
+ resized_height, resized_width = height, width
+ processed_images = []
+ for image in images:
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=patch_size * merge_size,
+ min_pixels=size["shortest_edge"],
+ max_pixels=size["longest_edge"],
+ )
+ image = resize(
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
+ )
+
+ if do_rescale:
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ processed_images.append(image)
+
+ patches = np.array(processed_images)
+ if data_format == ChannelDimension.LAST:
+ patches = patches.transpose(0, 3, 1, 2)
+ if patches.shape[0] % temporal_patch_size != 0:
+ repeats = np.repeat(
+ patches[-1][np.newaxis], temporal_patch_size - (patches.shape[0] % temporal_patch_size), axis=0
+ )
+ patches = np.concatenate([patches, repeats], axis=0)
+ channel = patches.shape[1]
+ grid_t = patches.shape[0] // temporal_patch_size
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ patches = patches.reshape(
+ grid_t,
+ temporal_patch_size,
+ channel,
+ grid_h // merge_size,
+ merge_size,
+ patch_size,
+ grid_w // merge_size,
+ merge_size,
+ patch_size,
+ )
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
+ flatten_patches = patches.reshape(
+ grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
+ )
+
+ return flatten_patches, (grid_t, grid_h, grid_w)
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: bool | None = None,
+ size: dict[str, int] | None = None,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ resample: PILImageResampling | None = None,
+ do_rescale: bool | None = None,
+ rescale_factor: float | None = None,
+ do_normalize: bool | None = None,
+ image_mean: float | list[float] | None = None,
+ image_std: float | list[float] | None = None,
+ patch_size: int | None = None,
+ temporal_patch_size: int | None = None,
+ merge_size: int | None = None,
+ do_convert_rgb: bool | None = None,
+ return_tensors: str | TensorType | None = None,
+ data_format: ChannelDimension | None = ChannelDimension.FIRST,
+ input_data_format: str | ChannelDimension | None = None,
+ ):
+ """
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ min_pixels (`int`, *optional*, defaults to `self.min_pixels`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `self.max_pixels`):
+ The max pixels of the image to resize the image.
+ patch_size (`int`, *optional*, defaults to `self.patch_size`):
+ The spatial patch size of the vision encoder.
+ temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
+ The temporal patch size of the vision encoder.
+ merge_size (`int`, *optional*, defaults to `self.merge_size`):
+ The merge size of the vision encoder to llm encoder.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ """
+ min_pixels = min_pixels if min_pixels is not None else self.min_pixels
+ max_pixels = max_pixels if max_pixels is not None else self.max_pixels
+
+ if size is not None:
+ if "shortest_edge" not in size or "longest_edge" not in size:
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
+ min_pixels = size["shortest_edge"]
+ elif min_pixels is not None and max_pixels is not None:
+ # backward compatibility: override size with min_pixels and max_pixels if they are provided
+ size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
+ else:
+ size = {**self.size}
+
+ do_resize = do_resize if do_resize is not None else self.do_resize
+
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size
+ merge_size = merge_size if merge_size is not None else self.merge_size
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ if images is not None:
+ images = self.fetch_images(images)
+ images = make_flat_list_of_images(images)
+
+ if images is not None and not valid_images(images):
+ raise ValueError("Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, or torch.Tensor")
+
+ validate_preprocess_arguments(
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ data = {}
+ pixel_values, vision_grid_thws = [], []
+ for image in images:
+ patches, image_grid_thw = self._preprocess(
+ image,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ merge_size=merge_size,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values.extend(patches)
+ vision_grid_thws.append(image_grid_thw)
+ pixel_values = np.array(pixel_values)
+ vision_grid_thws = np.array(vision_grid_thws)
+ data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws})
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of image patches per image.
+ """
+ min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"]
+ max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"]
+ patch_size = images_kwargs.get("patch_size", self.patch_size)
+ merge_size = images_kwargs.get("merge_size", self.merge_size)
+
+ factor = patch_size * merge_size
+ resized_height, resized_width = smart_resize(
+ height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ return grid_h * grid_w
+
+
+__all__ = ["GlmImageImageProcessor"]
diff --git a/src/transformers/models/glm_image/image_processing_glm_image_fast.py b/src/transformers/models/glm_image/image_processing_glm_image_fast.py
new file mode 100644
index 000000000000..7ca16ce810b7
--- /dev/null
+++ b/src/transformers/models/glm_image/image_processing_glm_image_fast.py
@@ -0,0 +1,296 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm_image.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Optional, Union
+
+import torch.nn.functional as F
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+)
+from ...processing_utils import Unpack
+from ...utils import TensorType, auto_docstring, is_torch_available
+from .image_processing_glm_image import GlmImageImageProcessorKwargs
+
+
+if is_torch_available():
+ import torch
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 16,
+ min_pixels: int = 512 * 512,
+ max_pixels: int = 2048 * 2048,
+) -> tuple[int, int]:
+ if height < factor or width < factor:
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
+ elif max(height, width) / min(height, width) > 4:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 4, got {max(height, width) / min(height, width)}"
+ )
+
+ shortest_edge = int(round(math.sqrt(min_pixels)))
+ longest_edge = int(round(math.sqrt(max_pixels)))
+ min_side = min(height, width)
+ max_side = max(height, width)
+
+ scale = 1.0
+
+ if min_side < shortest_edge:
+ scale = shortest_edge / min_side
+
+ if max_side * scale > longest_edge:
+ scale = longest_edge / max_side
+
+ height = height // 2
+ width = width // 2
+
+ h_bar = max(factor, int(round(height * scale / factor)) * factor)
+ w_bar = max(factor, int(round(width * scale / factor)) * factor)
+
+ if max(h_bar, w_bar) > longest_edge:
+ beta = max(h_bar, w_bar) / longest_edge
+ h_bar = max(factor, int(math.floor((h_bar / beta) / factor)) * factor)
+ w_bar = max(factor, int(math.floor((w_bar / beta) / factor)) * factor)
+
+ return h_bar, w_bar
+
+
+@auto_docstring
+class GlmImageImageProcessorFast(BaseImageProcessorFast):
+ do_resize = True
+ resample = PILImageResampling.BICUBIC
+ size = {"shortest_edge": 56 * 56, "longest_edge": 28 * 28 * 1280}
+ do_rescale = True
+ do_normalize = True
+ image_mean = OPENAI_CLIP_MEAN
+ image_std = OPENAI_CLIP_STD
+ do_convert_rgb = True
+ patch_size = 14
+ temporal_patch_size = 2
+ merge_size = 2
+ min_pixels = None
+ max_pixels = None
+ valid_kwargs = GlmImageImageProcessorKwargs
+ model_input_names = ["pixel_values", "image_grid_thw"]
+
+ def __init__(self, **kwargs: Unpack[GlmImageImageProcessorKwargs]):
+ size = kwargs.pop("size", None)
+ min_pixels = kwargs.pop("min_pixels", None)
+ max_pixels = kwargs.pop("max_pixels", None)
+ # backward compatibility: override size with min_pixels and max_pixels if they are provided
+ size = self.size if size is None else size
+ if min_pixels is not None:
+ size["shortest_edge"] = min_pixels
+ size.pop("min_pixels", None)
+ if max_pixels is not None:
+ size["longest_edge"] = max_pixels
+ size.pop("max_pixels", None)
+ if "shortest_edge" not in size or "longest_edge" not in size:
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
+
+ super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
+
+ def _further_process_kwargs(
+ self,
+ size: SizeDict | None = None,
+ min_pixels: int | None = None,
+ max_pixels: int | None = None,
+ **kwargs,
+ ) -> dict:
+ """
+ Update kwargs that need further processing before being validated
+ Can be overridden by subclasses to customize the processing of kwargs.
+ """
+ if min_pixels is not None and max_pixels is not None:
+ size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
+ elif size is not None:
+ if "shortest_edge" not in size or "longest_edge" not in size:
+ raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
+ min_pixels = size["shortest_edge"]
+ max_pixels = size["longest_edge"]
+ else:
+ size = {**self.size}
+
+ return super()._further_process_kwargs(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs)
+
+ @auto_docstring
+ def preprocess(
+ self,
+ images: ImageInput,
+ **kwargs: Unpack[GlmImageImageProcessorKwargs],
+ ) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def _preprocess_image_like_inputs(
+ self,
+ images: ImageInput,
+ do_convert_rgb: bool,
+ input_data_format: ChannelDimension,
+ device: Union[str, "torch.device"] | None = None,
+ **kwargs: Unpack[GlmImageImageProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Preprocess image-like inputs.
+ To be overridden by subclasses when image-like inputs other than images should be processed.
+ It can be used for segmentation maps, depth maps, etc.
+ """
+ # Prepare input images
+ batch_feature = BatchFeature()
+ images = self._prepare_image_like_inputs(
+ images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
+ )
+ batch_feature = self._preprocess(images, **kwargs)
+ return batch_feature
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: float | list[float] | None,
+ image_std: float | list[float] | None,
+ patch_size: int,
+ temporal_patch_size: int,
+ merge_size: int,
+ disable_grouping: bool | None,
+ return_tensors: str | TensorType | None,
+ **kwargs,
+ ):
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ height, width = stacked_images.shape[-2:]
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=patch_size * merge_size,
+ min_pixels=size["shortest_edge"],
+ max_pixels=size["longest_edge"],
+ )
+ stacked_images = self.resize(
+ image=stacked_images,
+ size=SizeDict(height=resized_height, width=resized_width),
+ interpolation=interpolation,
+ )
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
+ processed_images_grouped = {}
+ processed_grids = {}
+ for shape, stacked_images in grouped_images.items():
+ resized_height, resized_width = stacked_images.shape[-2:]
+ # Fused rescale and normalize
+ patches = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ if patches.ndim == 4:
+ # add a temporal dimension if we have images
+ patches = patches.unsqueeze(1)
+ if patches.shape[1] % temporal_patch_size != 0:
+ repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
+ patches = torch.cat([patches, repeats], dim=1)
+ batch_size, grid_t, channel = patches.shape[:3]
+ grid_t = grid_t // temporal_patch_size
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+
+ patches = patches.view(
+ batch_size,
+ grid_t,
+ temporal_patch_size,
+ channel,
+ grid_h // merge_size,
+ merge_size,
+ patch_size,
+ grid_w // merge_size,
+ merge_size,
+ patch_size,
+ )
+ # Reorder dimensions to group grid and patch information for subsequent flattening.
+ # (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w)
+ patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
+ flatten_patches = patches.reshape(
+ batch_size,
+ grid_t * grid_h * grid_w,
+ channel * temporal_patch_size * patch_size * patch_size,
+ )
+
+ processed_images_grouped[shape] = flatten_patches
+ processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_grids = reorder_images(processed_grids, grouped_images_index)
+ pixel_values = torch.cat(processed_images, dim=0)
+ image_grid_thw = torch.tensor(processed_grids)
+
+ return BatchFeature(
+ data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
+ )
+
+ def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
+ """
+ A utility that returns number of image patches for a given image size.
+
+ Note: Do not remove this method! It is used by vLLM to infer the number of patches and placeholders
+ without an image input.
+
+ Args:
+ height (`int`):
+ Height of the input image.
+ width (`int`):
+ Width of the input image.
+ images_kwargs (`dict`, *optional*)
+ Any kwargs to override defaults of the image processor.
+ Returns:
+ `int`: Number of image patches per image.
+ """
+ min_pixels = images_kwargs["min_pixels"] if "min_pixels" in images_kwargs else self.size["shortest_edge"]
+ max_pixels = images_kwargs["max_pixels"] if "max_pixels" in images_kwargs else self.size["longest_edge"]
+ patch_size = images_kwargs.get("patch_size", self.patch_size)
+ merge_size = images_kwargs.get("merge_size", self.merge_size)
+
+ factor = patch_size * merge_size
+ resized_height, resized_width = smart_resize(
+ height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
+ return grid_h * grid_w
+
+
+__all__ = ["GlmImageImageProcessorFast"]
diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py
new file mode 100644
index 000000000000..d3227c9f6b84
--- /dev/null
+++ b/src/transformers/models/glm_image/modeling_glm_image.py
@@ -0,0 +1,1592 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm_image.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import Any, Optional
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available
+from ...utils.generic import check_model_inputs, maybe_autocast
+from .configuration_glm_image import GlmImageConfig, GlmImageTextConfig, GlmImageVisionConfig, GlmImageVQVAEConfig
+
+
+if is_torch_available():
+ import torch
+
+
+class GlmImageVisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: torch.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class GlmImageVisionAttention(nn.Module):
+ def __init__(self, config: GlmImageVisionConfig) -> None:
+ super().__init__()
+ self.dim = config.hidden_size
+ self.num_heads = config.num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
+ self.scaling = self.head_dim**-0.5
+ self.config = config
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states, key_states, value_states = (
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ )
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if "flash" in self.config._attn_implementation:
+ # Flash Attention: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class GlmImageVisionPatchEmbed(nn.Module):
+ def __init__(self, config: GlmImageVisionConfig) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+ kernel_size = [self.patch_size, self.patch_size]
+ self.proj = nn.Conv2d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
+
+ def forward(self, hidden_states) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(-1, self.in_channels, self.patch_size, self.patch_size)
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class GlmImageVisionEmbeddings(nn.Module):
+ def __init__(self, config: GlmImageVisionConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.interpolated_method = "bilinear"
+
+ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
+ """
+ Forward pass with integrated position encoding adaptation using 2D interpolation.
+
+ Args:
+ embeddings: Input embeddings tensor
+ lengths (torch.Tensor): Sequence lengths for each image in the batch.
+ image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
+ h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
+ w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.
+
+ Returns:
+ torch.Tensor: Embeddings with adapted position encoding added.
+ """
+ # Get position embedding parameters
+ pos_embed_weight = self.position_embedding.weight
+ hidden_size = pos_embed_weight.shape[1]
+ device = pos_embed_weight.device
+
+ # Convert inputs to tensors if needed
+ if isinstance(lengths, list):
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
+
+ # Prepare 2D position embedding
+ orig_size_sq = pos_embed_weight.shape[0]
+ orig_size = int(orig_size_sq**0.5)
+ pos_embed_2d = (
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .to(device=device, dtype=torch.float32)
+ )
+
+ # Calculate target dimensions for each patch
+ target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+ target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+
+ # Normalize coordinates to [-1, 1] range for grid_sample
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
+
+ # Create sampling grid
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
+
+ # Perform bicubic interpolation
+ interpolated_embed_fp32 = F.grid_sample(
+ pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border"
+ )
+
+ # Reshape and convert back to original dtype
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
+
+ # Add adapted position encoding to embeddings
+ embeddings = embeddings + adapted_pos_embed
+ return embeddings
+
+
+class GlmImageVisionBlock(GradientCheckpointingLayer):
+ def __init__(self, config: GlmImageVisionConfig) -> None:
+ super().__init__()
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attn = GlmImageVisionAttention(config)
+ self.mlp = GlmImageVisionMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ r"""
+ cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
+ The cumulative sequence lengths of each image or video feature.
+ position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
+ The cosine and sine position embeddings for vision attention.
+ """
+ residual = hidden_states
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.attn(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ # Keep half or full tensor for later concatenation
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ # Apply rotary embeddings on the first half or full tensor
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ # Concatenate back to full shape
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
+ return q_embed, k_embed
+
+
+@use_kernelized_func(apply_rotary_pos_emb)
+class GlmImageTextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GlmImageTextConfig, layer_idx: int | None = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+ self.rope_parameters = config.rope_parameters
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: torch.Tensor | None,
+ past_key_values: Cache | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class GlmImageRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ GlmImageRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class GlmImageTextMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ up_states = self.gate_up_proj(hidden_states)
+
+ gate, up_states = up_states.chunk(2, dim=-1)
+ up_states = up_states * self.activation_fn(gate)
+
+ return self.down_proj(up_states)
+
+
+class GlmImageTextDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: GlmImageTextConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = GlmImageTextAttention(config, layer_idx)
+ self.mlp = GlmImageTextMLP(config)
+ self.input_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_self_attn_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_mlp_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ @auto_docstring
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ use_cache: bool | None = False,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs,
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = self.post_self_attn_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = self.post_mlp_layernorm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+@auto_docstring
+class GlmImagePreTrainedModel(PreTrainedModel):
+ config: GlmImageConfig
+ base_model_prefix = "model"
+ input_modalities = ("image", "text")
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": GlmImageTextDecoderLayer,
+ "attentions": GlmImageTextAttention,
+ }
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ super()._init_weights(module)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Llava outputs, with hidden states and attentions.
+ """
+)
+class GlmImageModelOutputWithPast(ModelOutput):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ last_hidden_state: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ rope_deltas: torch.LongTensor | None = None
+
+
+class GlmImageVQVAEVectorQuantizer(nn.Module):
+ """
+ A module for vector quantization using learned embedding vectors.
+
+ This module implements the quantization process similar to te one described in
+ the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
+ input vectors into discrete codebook vectors, which are learned during training.
+ Current implementation improves over previous ones by avoiding costly matrix multiplications
+ and allowing for post-hoc remapping of indices.
+ """
+
+ def __init__(self, config: GlmImageVQVAEConfig):
+ super().__init__()
+ self.num_embeddings = config.num_embeddings
+ self.embedding_dim = config.embed_dim
+ self.beta = getattr(config, "beta", 0.25)
+
+ self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
+
+ def forward(self, hidden_state: torch.Tensor):
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
+ hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
+
+ # L2 normalize
+ hidden_state = F.normalize(hidden_state, p=2, dim=-1)
+ hidden_state_flattened = F.normalize(hidden_state_flattened, p=2, dim=-1)
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ distances = (
+ torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+ + torch.sum(embedding**2, dim=1)
+ - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, embedding.transpose(0, 1))
+ )
+
+ min_encoding_indices = torch.argmin(distances, dim=1)
+ hidden_state_quant = embedding[min_encoding_indices].view(hidden_state.shape)
+
+ # compute loss for embedding
+ loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
+ (hidden_state_quant - hidden_state.detach()) ** 2
+ )
+
+ # preserve gradients
+ hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
+
+ # reshape back to match original input shape
+ hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
+
+ return hidden_state_quant, loss, min_encoding_indices
+
+
+@auto_docstring(
+ custom_intro="""
+ The VQ-VAE model used in GlmImage for encoding/decoding images into discrete tokens.
+ This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
+ [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
+ Taigman](https://huggingface.co/papers/2203.13131).
+ """
+)
+class GlmImageVQVAE(GlmImagePreTrainedModel):
+ config: GlmImageVQVAEConfig
+ _no_split_modules = [
+ "GlmImageVQVAEVectorQuantizer",
+ ]
+
+ def __init__(self, config: GlmImageVQVAEConfig):
+ super().__init__(config)
+ self.quantize = GlmImageVQVAEVectorQuantizer(config)
+ self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
+ self.eval() # GlmImage's VQ model is frozen
+ self.post_init()
+
+ def encode(self, hidden_states):
+ hidden_states = self.quant_conv(hidden_states)
+ quant, emb_loss, indices = self.quantize(hidden_states)
+ return quant, emb_loss, indices
+
+
+class GlmImageVisionModel(GlmImagePreTrainedModel):
+ config: GlmImageVisionConfig
+ input_modalities = ("image",)
+ _no_split_modules = ["GlmImageVisionBlock"]
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: GlmImageVisionConfig) -> None:
+ super().__init__(config)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+
+ self.embeddings = GlmImageVisionEmbeddings(config)
+ self.patch_embed = GlmImageVisionPatchEmbed(config)
+
+ head_dim = config.hidden_size // config.num_heads
+
+ self.blocks = nn.ModuleList([GlmImageVisionBlock(config) for _ in range(config.depth)])
+
+ self.gradient_checkpointing = False
+ self.head_dim = head_dim
+ self.post_init()
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ return pos_ids
+
+ def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`):
+ Packed pixel values.
+ grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
+ The temporal, height and width of feature shape of each image.
+
+ Returns:
+ `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states.
+ """
+
+ hidden_states = self.patch_embed(pixel_values)
+ image_type_ids = self.rot_pos_emb(grid_thw)
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ hidden_states = self.embeddings(
+ hidden_states,
+ seqlens,
+ grid_thw,
+ image_type_ids[:, 0].to(hidden_states.device),
+ image_type_ids[:, 1].to(hidden_states.device),
+ )
+
+ # Transformer blocks (no position_embeddings needed, already added above)
+ for blk in self.blocks:
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ )
+ return hidden_states
+
+
+class GlmImageTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: GlmImageTextConfig, device=None):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+
+ self.rope_type = self.config.rope_parameters["rope_type"]
+ rope_init_fn: Callable = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
+
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
+ self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12])
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: GlmImageTextConfig | None = None,
+ device: Optional["torch.device"] = None,
+ seq_len: int | None = None,
+ ) -> tuple["torch.Tensor", float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PreTrainedConfig`]):
+ The model configuration.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_parameters["rope_theta"]
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
+ )
+ return inv_freq, attention_factor
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, GLM-V has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_mrope(freqs, self.mrope_section)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+ def apply_mrope(self, freqs, mrope_section):
+ section = mrope_section
+ chunks = freqs.split(section, dim=-1)
+ result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
+ return result
+
+
+@auto_docstring
+class GlmImageTextModel(GlmImagePreTrainedModel):
+ config: GlmImageTextConfig
+ input_modalities = ("text",)
+
+ def __init__(self, config: GlmImageTextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [GlmImageTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = GlmImageTextRotaryEmbedding(config=config)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ @check_model_inputs
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ use_cache: bool | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple | BaseModelOutputWithPast:
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache(config=self.config)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
+ # where each dim indicates visual spatial positions for temporal/height/width grids.
+ # There are two scenarios when FA2-like packed masking might be activated.
+ # 1. User specifically passed packed `position_ids` and no attention mask.
+ # In this case we expect the useer to create correct position ids for all 3 grids
+ # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
+ # 2. User runs forward with no attention mask and no position ids. In this case, position ids
+ # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
+ # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
+ # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
+ text_position_ids = position_ids[0]
+ position_ids = position_ids[1:]
+ else:
+ # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
+ text_position_ids = None
+
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": text_position_ids,
+ }
+ # Create the masks
+ causal_mask = create_causal_mask(**mask_kwargs)
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
+
+ for decoder_layer in self.layers:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=text_position_ids,
+ past_key_values=past_key_values,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = layer_outputs
+
+ hidden_states = self.norm(hidden_states)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ )
+
+
+@auto_docstring
+class GlmImageModel(GlmImagePreTrainedModel):
+ base_model_prefix = "model"
+ _checkpoint_conversion_mapping = {}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ config: GlmImageConfig
+ _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = GlmImageVisionModel._from_config(config.vision_config)
+ self.language_model = GlmImageTextModel._from_config(config.text_config)
+
+ self.rope_deltas = None # cache rope_deltas here
+ self.vqmodel = GlmImageVQVAE._from_config(config.vq_config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_rope_index(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ attention_mask: torch.LongTensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index for image generation task.
+
+ Explanation:
+ Each embedding sequence may contain image tokens (for generation) and text tokens,
+ or just text tokens.
+
+ Input format:
+ - Text-to-Image: [text tokens] + <|dit_token_16384|>
+ - Image-to-Image: <|dit_token_16384|> [image tokens] <|dit_token_16385|> + [text tokens] + <|dit_token_16384|>
+
+ For pure text embedding sequence, the rotary position embedding is the same across all 3 dimensions.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For sequences with image tokens, we use special markers to denote image regions:
+ - <|dit_token_16384|>: image start marker
+ - <|dit_token_16385|>: image end marker
+ - Image tokens between these markers use 2D spatial position encoding.
+
+ For image tokens:
+ - temporal: stays constant at (image_start_pos + 1)
+ - height: increments every w tokens, representing row position
+ - width: cycles from 0 to w-1, representing column position
+
+ After each image region, the next position jumps to: image_start_pos + 1 + max(h, w)
+ This ensures sufficient positional separation between images and subsequent tokens.
+
+ Examples:
+ === Case 1: Image-to-Image Generation ===
+
+ Source image with grid [1, 3, 2], followed by text, then generation.
+ input_ids: [<|dit_token_16384|> V V V V V V <|dit_token_16385|> T T T T <|dit_token_16384|>]
+ image_grid_thw: [[1, 3, 2], [1, 4, 4]] # first is source, second is target
+
+ For source image (h=3, w=2, 6 tokens):
+ Start marker at position 0
+ Image tokens at temporal=1, height=[1,1,2,2,3,3], width=[1,2,1,2,1,2]
+ End marker at position 4 (= 0 + 1 + max(3,2))
+
+ Text tokens and trailing start marker continue from position 5.
+
+ Full prefill position_ids:
+ temporal: [0, 1,1,1,1,1,1, 4, 5,6,7,8, 9]
+ height: [0, 1,1,2,2,3,3, 4, 5,6,7,8, 9]
+ width: [0, 1,2,1,2,1,2, 4, 5,6,7,8, 9]
+
+ Decode stage: use image_grid_thw[-1] = [1, 4, 4] to build cached position_ids,
+ starting from gen_st_idx = 10.
+
+ === Case 2: Text-to-Image Generation (multi-resolution) ===
+
+ Pure text input with two image_grids for progressive generation.
+ input_ids: [hello3 33 2<|dit_token_16384|>]
+ Assume "hello3 33 2" = 4 tokens (positions 0-3)
+ <|dit_token_16384|> at position 4
+ image_grid_thw: [[1, 3, 3], [1, 3, 2]]
+ - image_grid_thw[-1] = [1, 3, 2]: first generated image (smaller/draft)
+ - image_grid_thw[-2] = [1, 3, 3]: second generated image (larger/final)
+
+ Prefill position_ids (5 tokens: 4 text + 1 start marker):
+ temporal: [0, 1, 2, 3, 4]
+ height: [0, 1, 2, 3, 4]
+ width: [0, 1, 2, 3, 4]
+
+ Decode stage builds position_ids in reverse order of image_grid_thw:
+
+ First: image_grid_thw[-1] = [1, 3, 2] (6 tokens), starting at position 5:
+ temporal: [5, 5, 5, 5, 5, 5]
+ height: [5, 5, 6, 6, 7, 7]
+ width: [5, 6, 5, 6, 5, 6]
+ next_pos = 5 + max(3, 2) = 8
+
+ Then: image_grid_thw[-2] = [1, 3, 3] (9 tokens), starting at position 8:
+ temporal: [8, 8, 8, 8, 8, 8, 8, 8, 8]
+ height: [8, 8, 8, 9, 9, 9, 10, 10, 10]
+ width: [8, 9, 10, 8, 9, 10, 8, 9, 10]
+ next_pos = 8 + max(3, 3) = 11
+
+ Finally: <|dit_token_16385|> end marker at position 11
+
+ Full sequence position_ids (prefill + decode):
+ temporal: [0,1,2,3, 4, 5,5,5,5,5,5, 8,8,8,8,8,8,8,8,8, 11]
+ height: [0,1,2,3, 4, 5,5,6,6,7,7, 8,8,8,9,9,9,10,10,10, 11]
+ width: [0,1,2,3, 4, 5,6,5,6,5,6, 8,9,10,8,9,10,8,9,10, 11]
+
+ _cached_decode_position_ids shape: [3, 6 + 9 + 1] = [3, 16]
+ (includes all generated image tokens + end marker)
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default
+ should you provide it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image. For image generation,
+ temporal is typically 1.
+ - For image-to-image: includes source image grids + target image grid(s)
+ - For text-to-image with multi-resolution: includes multiple target grids,
+ processed in reverse order (last grid first, second-to-last grid second, etc.)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`):
+ Position IDs for temporal, height, and width dimensions.
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size, 1)`):
+ Position deltas for multi-modal rotary position embedding (zeros for this task).
+ """
+
+ batch_size, seq_len = input_ids.shape
+ device = input_ids.device
+ dtype = input_ids.dtype
+
+ image_start_token_id = self.config.image_start_token_id
+ image_end_token_id = self.config.image_end_token_id
+ num_complete_images = (input_ids == image_end_token_id).sum().item()
+
+ position_ids = torch.ones(
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
+ )
+ text_positions = torch.arange(seq_len)[None, :].repeat(3, 1)
+ for batch_idx in range(batch_size):
+ curr_input_ids = input_ids[batch_idx]
+ if attention_mask is not None:
+ curr_input_ids = curr_input_ids[attention_mask[batch_idx] == 1]
+
+ image_end = torch.where(curr_input_ids == image_end_token_id)[0]
+ image_start = torch.where(curr_input_ids == image_start_token_id)[0] + 1
+ current_pos = 0 # track the current position value
+ prev_image_end = 0
+ curr_position_ids = []
+ for start, end, grid in zip(image_start, image_end, image_grid_thw):
+ _, num_width_grid, num_height_grid = grid
+
+ # Create text position ids first if there are text tokens before image
+ llm_pos_length = start - prev_image_end
+ llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to(
+ device=input_ids.device
+ )
+ current_pos += llm_position_ids.shape[-1]
+
+ # Now create image position ids for each grid
+ image_seq_length = num_height_grid * num_width_grid
+ h_grids = image_seq_length // num_height_grid + current_pos
+ w_grids = image_seq_length // num_width_grid + current_pos
+ position_width = torch.arange(current_pos, w_grids, device=input_ids.device).repeat(num_width_grid)
+ position_height = torch.arange(current_pos, h_grids, device=input_ids.device).repeat_interleave(
+ num_height_grid
+ )
+ position_temporal = torch.full(
+ (image_seq_length,), current_pos, device=input_ids.device, dtype=torch.long
+ )
+ vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0)
+ current_pos += max(num_height_grid, num_width_grid)
+
+ prev_image_end = end
+ curr_position_ids.append(torch.cat([llm_position_ids, vision_position_ids], dim=-1))
+
+ # Add position ids for the last text tokens if any
+ end_position = len(curr_input_ids) - prev_image_end
+ llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=input_ids.device)
+ current_pos += llm_position_ids.shape[-1]
+ curr_position_ids.append(llm_position_ids)
+ curr_position_ids = torch.cat(curr_position_ids, dim=-1)
+ if attention_mask is not None:
+ position_ids[:, batch_idx, attention_mask[batch_idx] == 1] = curr_position_ids.to(position_ids.device)
+ else:
+ position_ids[:, batch_idx, :] = curr_position_ids.to(position_ids.device)
+
+ # Build and store position ids for tokens that will be generated. Later we will just
+ # slice these instead of computing each decoding step
+ self._prefill_len = seq_len
+ if image_grid_thw is not None and len(image_grid_thw) > 0:
+ num_decode_grids = len(image_grid_thw) - num_complete_images
+ num_decode_grids = max(num_decode_grids, 0)
+ decode_pos = current_pos
+
+ decode_temporal_list = []
+ decode_height_list = []
+ decode_width_list = []
+
+ for i in range(1, num_decode_grids + 1):
+ grid_idx = -i
+ h = image_grid_thw[grid_idx, 1].item()
+ w = image_grid_thw[grid_idx, 2].item()
+ total_tokens = h * w
+
+ h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten()
+ w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten()
+
+ decode_temporal_list.append(torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long))
+ decode_height_list.append(decode_pos + h_indices)
+ decode_width_list.append(decode_pos + w_indices)
+ decode_pos = decode_pos + max(h, w)
+
+ decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
+ decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
+ decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
+
+ self._cached_decode_position_ids = torch.stack(
+ [
+ torch.cat(decode_temporal_list, dim=0),
+ torch.cat(decode_height_list, dim=0),
+ torch.cat(decode_width_list, dim=0),
+ ],
+ dim=0,
+ )
+ else:
+ self._cached_decode_position_ids = None
+
+ mrope_position_deltas = torch.zeros([batch_size, 1], dtype=dtype, device=device)
+
+ return position_ids, mrope_position_deltas
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ return image_embeds
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ image_ids: torch.LongTensor,
+ ):
+ """
+ Replace image placeholder tokens in input_ids with actual image token ids from VQVAE.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`):
+ Input token ids with image placeholders.
+ image_ids (`torch.LongTensor` of shape `(num_images, num_tokens_per_image)` or flattened):
+ Discrete token indices from the VQVAE codebook.
+
+ Returns:
+ special_image_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
+ Mask indicating positions in input ids that will be replaced by actual image tokens.
+ """
+
+ special_image_mask = input_ids == self.config.image_token_id
+ n_placeholder_tokens = special_image_mask.sum().item()
+ n_image_tokens = image_ids.shape[0]
+
+ if n_placeholder_tokens != n_image_tokens:
+ raise ValueError(
+ f"Number of image placeholder tokens ({n_placeholder_tokens}) does not match "
+ f"number of image tokens from VQVAE ({n_image_tokens})"
+ )
+
+ return special_image_mask
+
+ @auto_docstring
+ @can_return_tuple
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ rope_deltas: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | GlmImageModelOutputWithPast:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw[:-1])
+ image_embeds = torch.cat(image_embeds, dim=0)
+ image_ids = self.get_image_tokens(image_embeds, image_grid_thw[:-1])
+ image_ids = image_ids.view(-1).to(input_ids.device)
+ special_image_mask = self.get_placeholder_mask(input_ids, image_ids)
+ input_ids = input_ids.masked_scatter(special_image_mask, image_ids)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if position_ids is None:
+ attention_mask_2d = attention_mask
+ if attention_mask is not None and attention_mask.ndim == 4:
+ attention_mask_2d = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
+ # Only apply conversion for floating point tensors (inverted masks)
+ if attention_mask_2d.dtype.is_floating_point:
+ attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
+ attention_mask_2d = (1.0 - attention_mask_2d).int()
+
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # It is safe to assume that `length!=1` means we're in pre-fill because the
+ # model is used only by DiT pipeline without assisted decoding, etc. techniques
+ is_prefill_stage = (input_ids is not None and input_ids.shape[1] != 1) or (
+ inputs_embeds is not None and inputs_embeds.shape[1] != 1
+ )
+ if is_prefill_stage or self.rope_deltas is None:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ attention_mask=attention_mask_2d,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ # Use prefill token length, not position value
+ step = cache_position[0].item() - self._prefill_len
+ # Direct lookup - no tensor creation overhead
+ position_ids = self._cached_decode_position_ids[:, step : step + seq_length]
+ position_ids = position_ids.unsqueeze(1).expand(-1, batch_size, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return GlmImageModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+ def get_image_tokens(
+ self,
+ hidden_states: torch.FloatTensor,
+ image_grid_thw: torch.LongTensor,
+ ) -> torch.LongTensor:
+ """
+ Tokenizes image features into discrete tokens with VQVAE module.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(total_patches, hidden_size)`):
+ The packed image features from vision encoder.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
+ The temporal, height and width of feature shape of each image.
+
+ Returns:
+ image_tokens (`torch.LongTensor` of shape `(total_patches,)`):
+ Discrete token indices from the VQVAE codebook.
+ """
+ hidden_size = hidden_states.shape[-1]
+ split_sizes = (image_grid_thw.prod(dim=-1)).tolist()
+ hidden_states_list = torch.split(hidden_states, split_sizes, dim=0)
+
+ all_image_toks = []
+ for i, hs in enumerate(hidden_states_list):
+ grid_t, grid_h, grid_w = image_grid_thw[i].tolist()
+ hs = hs.view(grid_t, grid_h, grid_w, hidden_size)
+ hs = hs.permute(0, 3, 1, 2).contiguous()
+ _, _, image_toks = self.vqmodel.encode(hs)
+ all_image_toks.append(image_toks)
+ return torch.cat(all_image_toks, dim=0)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for GlmImage causal language model (or autoregressive) outputs.
+ """
+)
+class GlmImageCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: torch.FloatTensor | None = None
+ logits: torch.FloatTensor | None = None
+ past_key_values: Cache | None = None
+ hidden_states: tuple[torch.FloatTensor] | None = None
+ attentions: tuple[torch.FloatTensor] | None = None
+ rope_deltas: torch.LongTensor | None = None
+
+
+class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = {}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ base_model_prefix = "model"
+ config: GlmImageConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GlmImageModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vision_vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
+ return self.model.get_image_features(pixel_values, image_grid_thw)
+
+ def get_image_tokens(self, hidden_states: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
+ return self.model.get_image_tokens(hidden_states, image_grid_thw)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | GlmImageCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, GlmImageForConditionalGeneration
+
+ >>> model = GlmImageForConditionalGeneration.from_pretrained("zai-org/GLM-Image")
+ >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-Image")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "Add a truck of this photo.28 40"},
+ ],
+ },
+ ]
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return GlmImageCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ image_grid_thw=None,
+ is_first_iteration=False,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ is_first_iteration=is_first_iteration,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ model_inputs["position_ids"] = None
+
+ if not is_first_iteration and use_cache:
+ model_inputs["pixel_values"] = None
+
+ return model_inputs
+
+ def _get_image_nums(
+ self,
+ input_ids: torch.LongTensor | None,
+ ) -> torch.Tensor:
+ """
+ Get the number of images for each sample.
+ For GLM-Image, only input_ids allow us to get the number of images.
+
+ Returns:
+ image_counts (`torch.LongTensor` of shape `(batch_size,)`)
+ """
+ is_image = input_ids == self.config.image_start_token_id
+
+ return is_image.sum(dim=1)
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: torch.LongTensor | None = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ image_nums = self._get_image_nums(input_ids)
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample and +1 for the image being generated
+ lengths = list(image_nums)
+ last_image = dict_to_expand[key][:-1]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key][: sum(image_nums)], lengths=lengths, repeat_times=expand_size
+ )
+ dict_to_expand[key] = torch.cat([dict_to_expand[key], last_image], dim=0)
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+__all__ = [
+ "GlmImagePreTrainedModel",
+ "GlmImageVQVAE",
+ "GlmImageVisionModel",
+ "GlmImageTextModel",
+ "GlmImageModel",
+ "GlmImageForConditionalGeneration",
+]
diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py
new file mode 100644
index 000000000000..ac7d1c33b92d
--- /dev/null
+++ b/src/transformers/models/glm_image/modular_glm_image.py
@@ -0,0 +1,1480 @@
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections.abc import Callable
+from typing import Any
+
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...cache_utils import Cache
+from ...configuration_utils import PreTrainedConfig
+from ...feature_extraction_utils import BatchFeature
+from ...generation import GenerationMixin
+from ...image_utils import ImageInput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import ImagesKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import TransformersKwargs, is_torch_available, logging
+from ..chameleon.modeling_chameleon import ChameleonVQVAE, ChameleonVQVAEVectorQuantizer
+from ..glm4v.configuration_glm4v import Glm4vTextConfig, Glm4vVisionConfig
+from ..glm4v.modeling_glm4v import (
+ Glm4vCausalLMOutputWithPast,
+ Glm4vModel,
+ Glm4vModelOutputWithPast,
+ Glm4vPreTrainedModel,
+ Glm4vTextModel,
+ Glm4vVisionAttention,
+ Glm4vVisionBlock,
+ Glm4vVisionEmbeddings,
+ Glm4vVisionModel,
+ Glm4vVisionPatchEmbed,
+)
+from ..glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextAttention, eager_attention_forward
+from ..qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor
+from ..qwen2_vl.image_processing_qwen2_vl_fast import Qwen2VLImageProcessorFast
+from ..qwen2_vl.processing_qwen2_vl import Qwen2VLProcessorKwargs
+from ..siglip.modeling_siglip import SiglipMLP
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class GlmImageVQVAEConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GlmImageVQModel`]. It is used to instantiate a
+ `GlmImageVQModel` according to the specified arguments, defining the model architecture.
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information. Instantiating a
+ configuration with the defaults will yield a similar configuration to the VQModel of the
+ [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image) architecture.
+
+ Args:
+ embed_dim (`int`, *optional*, defaults to 2048):
+ Dimensionality of each embedding vector.
+ num_embeddings (`int`, *optional*, defaults to 16384):
+ Number of codebook embeddings.
+ latent_channels (`int`, *optional*, defaults to 1536):
+ Number of channels for the latent space.
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of input channels.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ """
+
+ model_type = "glm_image_vqmodel"
+ base_config_key = "vq_config"
+
+ def __init__(
+ self,
+ embed_dim: int = 2048,
+ num_embeddings: int = 16384,
+ latent_channels: int = 1536,
+ in_channels: int = 3,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+ self.num_embeddings = num_embeddings
+ self.latent_channels = latent_channels
+ self.in_channels = in_channels
+ self.initializer_range = initializer_range
+
+
+class GlmImageVisionConfig(Glm4vVisionConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GlmImageVisionModel`]. It is used to instantiate an GlmImageVisionModel
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
+ a similar configuration to that of
+ GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image).
+
+ Args:
+ depth (`int`, *optional*, defaults to 40):
+ Number of layers (depth) in the model.
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimensionality of the encoder layers and the pooler layer.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ Dropout probability for attention weights.
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer architecture.
+ in_channels (`int`, *optional*, defaults to 3):
+ Number of input channels.
+ image_size (`int` or `list[int]`, *optional*, defaults to 2048):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ spatial_merge_size (`int`, *optional*, defaults to 1):
+ The size used for merging spatial dimensions.
+ intermediate_size (`int`, *optional*, defaults to 6144):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ """
+
+ model_type = "glm_image_vision"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ depth=40,
+ hidden_size=1536,
+ hidden_act="gelu",
+ attention_bias=True,
+ attention_dropout=0.0,
+ num_heads=16,
+ in_channels=3,
+ image_size=2048,
+ patch_size=16,
+ layer_norm_eps=1e-06,
+ spatial_merge_size=1,
+ intermediate_size=6144,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ del self.out_hidden_size
+ del self.rms_norm_eps
+ del self.temporal_patch_size
+ self.layer_norm_eps = layer_norm_eps
+
+
+class GlmImageTextConfig(Glm4vTextConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GlmImageTextModel`]. It is used to instantiate a
+ GLM-Image model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image).
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 168064):
+ Vocabulary size of the GlmImage model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GlmImageModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 13696):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 40):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 2):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ rope_parameters (`RopeParameters`, *optional*):
+ Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
+ a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
+ with longer `max_position_embeddings`.
+ vision_vocab_size (`int`, *optional*, defaults to 16512):
+ Vision vocabulary size of the GlmImage model. Defines the number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`GlmImageVisionModel`]
+ attention_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+
+ ```python
+ >>> from transformers import GlmImageTextModel, GlmImageConfig
+
+ >>> # Initializing a GlmImageConfig style configuration
+ >>> configuration = GlmImageConfig()
+
+ >>> # Initializing a model from the GlmImageConfig style configuration
+ >>> model = GlmImageTextModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ def __init__(
+ self,
+ vocab_size: int | None = 168064,
+ vision_vocab_size: int | None = 16512,
+ attention_bias: bool | None = True,
+ tie_word_embeddings: bool | None = False,
+ **super_kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.vision_vocab_size = vision_vocab_size
+ self.attention_bias = attention_bias
+ super().__init__(
+ tie_word_embeddings=tie_word_embeddings, ignore_keys_at_rope_validation={"mrope_section"}, **super_kwargs
+ )
+
+
+class GlmImageConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GlmImageModel`]. It is used to instantiate a
+ GLM-Image model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of
+ GLM-Image [zai-org/GLM-Image](https://huggingface.co/zai-org/GLM-Image) architecture.
+
+ Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PreTrainedConfig`] for more information.
+
+ Args:
+ text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `GlmImageTextConfig`):
+ The config object or dictionary of the text backbone.
+ vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `GlmImageVisionConfig`):
+ The config object or dictionary of the vision backbone.
+ vq_config (`Union[Dict, GlmImageVQVAEConfig]`, *optional*):
+ GlmImageVQVAEConfig instance containing the configuration for the VQ-VAE model.
+ image_token_id (`int`, *optional*, defaults to 167855):
+ The image token index to encode the image prompt.
+ image_start_token_id (`int`, *optional*, defaults to 16384):
+ The image start token index to encode the start of image.
+ image_end_token_id (`int`, *optional*, defaults to 16385):
+ The image end token index to encode the end of image.
+
+ ```python
+ >>> from transformers import Glm4vForConditionalGeneration, Glm4vConfig
+
+ >>> # Initializing a GLM-Image style configuration
+ >>> configuration = Glm4vConfig()
+
+ >>> # Initializing a model from the GLM-Image style configuration
+ >>> model = Glm4vForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "glm_image"
+ sub_configs = {
+ "vision_config": GlmImageVisionConfig,
+ "text_config": GlmImageTextConfig,
+ "vq_config": GlmImageVQVAEConfig,
+ }
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ text_config=None,
+ vision_config=None,
+ vq_config=None,
+ image_token_id=167855,
+ image_start_token_id=16384,
+ image_end_token_id=16385,
+ **kwargs,
+ ):
+ if isinstance(vision_config, dict):
+ vision_config = self.sub_configs["vision_config"](**vision_config)
+ elif vision_config is None:
+ vision_config = self.sub_configs["vision_config"](**kwargs)
+
+ if isinstance(vq_config, dict):
+ vq_config = self.sub_configs["vq_config"](**vq_config)
+ elif vq_config is None:
+ vq_config = self.sub_configs["vq_config"](**kwargs)
+
+ if isinstance(text_config, dict):
+ text_config = self.sub_configs["text_config"](**text_config)
+ elif text_config is None:
+ text_config = self.sub_configs["text_config"](**kwargs)
+
+ self.image_token_id = image_token_id
+ self.image_start_token_id = image_start_token_id
+ self.image_end_token_id = image_end_token_id
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.vq_config = vq_config
+ super().__init__(**kwargs)
+
+
+class GlmImageVisionMLP(SiglipMLP):
+ pass
+
+
+class GlmImageVisionAttention(Glm4vVisionAttention):
+ def __init__(self, config: GlmImageVisionConfig) -> None:
+ super().__init__(config)
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ **kwargs,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ query_states, key_states, value_states = (
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ )
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ if "flash" in self.config._attn_implementation:
+ # Flash Attention: Use cu_seqlens for variable length attention
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+ attn_output, _ = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ cu_seq_lens_q=cu_seqlens,
+ cu_seq_lens_k=cu_seqlens,
+ max_length_q=max_seqlen,
+ max_length_k=max_seqlen,
+ is_causal=False,
+ **kwargs,
+ )
+ else:
+ # Other implementations: Process each chunk separately
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ splits = [
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
+ ]
+
+ attn_outputs = [
+ attention_interface(
+ self,
+ q,
+ k,
+ v,
+ attention_mask=None,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ is_causal=False,
+ **kwargs,
+ )[0]
+ for q, k, v in zip(*splits)
+ ]
+ attn_output = torch.cat(attn_outputs, dim=1)
+
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class GlmImageVisionPatchEmbed(Glm4vVisionPatchEmbed):
+ def __init__(self, config: GlmImageVisionConfig) -> None:
+ super().__init__(config)
+
+ del self.temporal_patch_size
+ kernel_size = [self.patch_size, self.patch_size]
+ self.proj = nn.Conv2d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
+
+ def forward(self, hidden_states):
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(-1, self.in_channels, self.patch_size, self.patch_size)
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class GlmImageVisionEmbeddings(Glm4vVisionEmbeddings):
+ def __init__(self, config: GlmImageVisionConfig) -> None:
+ super().__init__(config)
+ self.interpolated_method = "bilinear"
+
+
+class GlmImageVisionBlock(Glm4vVisionBlock):
+ def __init__(self, config: GlmImageVisionConfig):
+ super().__init__(config)
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attn = GlmImageVisionAttention(config)
+ self.mlp = GlmImageVisionMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ r"""
+ cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
+ The cumulative sequence lengths of each image or video feature.
+ position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
+ The cosine and sine position embeddings for vision attention.
+ """
+ residual = hidden_states
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.attn(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class GlmImageTextAttention(Glm4vMoeTextAttention):
+ pass
+
+
+class GlmImagePreTrainedModel(Glm4vPreTrainedModel):
+ config: GlmImageConfig
+ input_modalities = ("image", "text")
+
+ @torch.no_grad()
+ def _init_weights(self, module):
+ PreTrainedModel._init_weights(module)
+
+
+class GlmImageModelOutputWithPast(Glm4vModelOutputWithPast):
+ pass
+
+
+class GlmImageVQVAEVectorQuantizer(ChameleonVQVAEVectorQuantizer):
+ def __init__(self, config: GlmImageVQVAEConfig):
+ super().__init__(config)
+ self.num_embeddings = config.num_embeddings
+ self.embedding_dim = config.embed_dim
+ self.beta = getattr(config, "beta", 0.25)
+
+ self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
+
+ def forward(self, hidden_state: torch.Tensor):
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
+ hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
+
+ # L2 normalize
+ hidden_state = F.normalize(hidden_state, p=2, dim=-1)
+ hidden_state_flattened = F.normalize(hidden_state_flattened, p=2, dim=-1)
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ distances = (
+ torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+ + torch.sum(embedding**2, dim=1)
+ - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, embedding.transpose(0, 1))
+ )
+
+ min_encoding_indices = torch.argmin(distances, dim=1)
+ hidden_state_quant = embedding[min_encoding_indices].view(hidden_state.shape)
+
+ # compute loss for embedding
+ loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
+ (hidden_state_quant - hidden_state.detach()) ** 2
+ )
+
+ # preserve gradients
+ hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
+
+ # reshape back to match original input shape
+ hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
+
+ return hidden_state_quant, loss, min_encoding_indices
+
+
+class GlmImageVQVAE(ChameleonVQVAE):
+ _no_split_modules = [
+ "GlmImageVQVAEVectorQuantizer",
+ ]
+
+ def __init__(self, config: GlmImageVQVAEConfig):
+ super().__init__(config)
+ del self.encoder
+
+ def encode(self, hidden_states):
+ hidden_states = self.quant_conv(hidden_states)
+ quant, emb_loss, indices = self.quantize(hidden_states)
+ return quant, emb_loss, indices
+
+
+class GlmImageVisionModel(Glm4vVisionModel):
+ config: GlmImageVisionConfig
+ main_input_name = "pixel_values"
+ input_modalities = ("image",)
+
+ def __init__(self, config: GlmImageVisionConfig):
+ super().__init__(config)
+
+ head_dim = config.hidden_size // config.num_heads
+ self.head_dim = head_dim
+
+ del self.merger
+ del self.rotary_pos_emb
+ del self.post_conv_layernorm
+ del self.downsample
+ del self.post_layernorm
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ return pos_ids
+
+ def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`):
+ Packed pixel values.
+ grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
+ The temporal, height and width of feature shape of each image.
+
+ Returns:
+ `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states.
+ """
+
+ hidden_states = self.patch_embed(pixel_values)
+ image_type_ids = self.rot_pos_emb(grid_thw)
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ hidden_states = self.embeddings(
+ hidden_states,
+ seqlens,
+ grid_thw,
+ image_type_ids[:, 0].to(hidden_states.device),
+ image_type_ids[:, 1].to(hidden_states.device),
+ )
+
+ # Transformer blocks (no position_embeddings needed, already added above)
+ for blk in self.blocks:
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ )
+ return hidden_states
+
+
+class GlmImageTextModel(Glm4vTextModel):
+ pass
+
+
+class GlmImageModel(Glm4vModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = GlmImageVisionModel._from_config(config.vision_config)
+ self.language_model = GlmImageTextModel._from_config(config.text_config)
+ self.vqmodel = GlmImageVQVAE._from_config(config.vq_config)
+
+ self.rope_deltas = None # cache rope_deltas here
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_rope_index(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ attention_mask: torch.LongTensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index for image generation task.
+
+ Explanation:
+ Each embedding sequence may contain image tokens (for generation) and text tokens,
+ or just text tokens.
+
+ Input format:
+ - Text-to-Image: [text tokens] + <|dit_token_16384|>
+ - Image-to-Image: <|dit_token_16384|> [image tokens] <|dit_token_16385|> + [text tokens] + <|dit_token_16384|>
+
+ For pure text embedding sequence, the rotary position embedding is the same across all 3 dimensions.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For sequences with image tokens, we use special markers to denote image regions:
+ - <|dit_token_16384|>: image start marker
+ - <|dit_token_16385|>: image end marker
+ - Image tokens between these markers use 2D spatial position encoding.
+
+ For image tokens:
+ - temporal: stays constant at (image_start_pos + 1)
+ - height: increments every w tokens, representing row position
+ - width: cycles from 0 to w-1, representing column position
+
+ After each image region, the next position jumps to: image_start_pos + 1 + max(h, w)
+ This ensures sufficient positional separation between images and subsequent tokens.
+
+ Examples:
+ === Case 1: Image-to-Image Generation ===
+
+ Source image with grid [1, 3, 2], followed by text, then generation.
+ input_ids: [<|dit_token_16384|> V V V V V V <|dit_token_16385|> T T T T <|dit_token_16384|>]
+ image_grid_thw: [[1, 3, 2], [1, 4, 4]] # first is source, second is target
+
+ For source image (h=3, w=2, 6 tokens):
+ Start marker at position 0
+ Image tokens at temporal=1, height=[1,1,2,2,3,3], width=[1,2,1,2,1,2]
+ End marker at position 4 (= 0 + 1 + max(3,2))
+
+ Text tokens and trailing start marker continue from position 5.
+
+ Full prefill position_ids:
+ temporal: [0, 1,1,1,1,1,1, 4, 5,6,7,8, 9]
+ height: [0, 1,1,2,2,3,3, 4, 5,6,7,8, 9]
+ width: [0, 1,2,1,2,1,2, 4, 5,6,7,8, 9]
+
+ Decode stage: use image_grid_thw[-1] = [1, 4, 4] to build cached position_ids,
+ starting from gen_st_idx = 10.
+
+ === Case 2: Text-to-Image Generation (multi-resolution) ===
+
+ Pure text input with two image_grids for progressive generation.
+ input_ids: [hello3 33 2<|dit_token_16384|>]
+ Assume "hello3 33 2" = 4 tokens (positions 0-3)
+ <|dit_token_16384|> at position 4
+ image_grid_thw: [[1, 3, 3], [1, 3, 2]]
+ - image_grid_thw[-1] = [1, 3, 2]: first generated image (smaller/draft)
+ - image_grid_thw[-2] = [1, 3, 3]: second generated image (larger/final)
+
+ Prefill position_ids (5 tokens: 4 text + 1 start marker):
+ temporal: [0, 1, 2, 3, 4]
+ height: [0, 1, 2, 3, 4]
+ width: [0, 1, 2, 3, 4]
+
+ Decode stage builds position_ids in reverse order of image_grid_thw:
+
+ First: image_grid_thw[-1] = [1, 3, 2] (6 tokens), starting at position 5:
+ temporal: [5, 5, 5, 5, 5, 5]
+ height: [5, 5, 6, 6, 7, 7]
+ width: [5, 6, 5, 6, 5, 6]
+ next_pos = 5 + max(3, 2) = 8
+
+ Then: image_grid_thw[-2] = [1, 3, 3] (9 tokens), starting at position 8:
+ temporal: [8, 8, 8, 8, 8, 8, 8, 8, 8]
+ height: [8, 8, 8, 9, 9, 9, 10, 10, 10]
+ width: [8, 9, 10, 8, 9, 10, 8, 9, 10]
+ next_pos = 8 + max(3, 3) = 11
+
+ Finally: <|dit_token_16385|> end marker at position 11
+
+ Full sequence position_ids (prefill + decode):
+ temporal: [0,1,2,3, 4, 5,5,5,5,5,5, 8,8,8,8,8,8,8,8,8, 11]
+ height: [0,1,2,3, 4, 5,5,6,6,7,7, 8,8,8,9,9,9,10,10,10, 11]
+ width: [0,1,2,3, 4, 5,6,5,6,5,6, 8,9,10,8,9,10,8,9,10, 11]
+
+ _cached_decode_position_ids shape: [3, 6 + 9 + 1] = [3, 16]
+ (includes all generated image tokens + end marker)
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default
+ should you provide it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image. For image generation,
+ temporal is typically 1.
+ - For image-to-image: includes source image grids + target image grid(s)
+ - For text-to-image with multi-resolution: includes multiple target grids,
+ processed in reverse order (last grid first, second-to-last grid second, etc.)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`):
+ Position IDs for temporal, height, and width dimensions.
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size, 1)`):
+ Position deltas for multi-modal rotary position embedding (zeros for this task).
+ """
+
+ batch_size, seq_len = input_ids.shape
+ device = input_ids.device
+ dtype = input_ids.dtype
+
+ image_start_token_id = self.config.image_start_token_id
+ image_end_token_id = self.config.image_end_token_id
+ num_complete_images = (input_ids == image_end_token_id).sum().item()
+
+ position_ids = torch.ones(
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
+ )
+ text_positions = torch.arange(seq_len)[None, :].repeat(3, 1)
+ for batch_idx in range(batch_size):
+ curr_input_ids = input_ids[batch_idx]
+ if attention_mask is not None:
+ curr_input_ids = curr_input_ids[attention_mask[batch_idx] == 1]
+
+ image_end = torch.where(curr_input_ids == image_end_token_id)[0]
+ image_start = torch.where(curr_input_ids == image_start_token_id)[0] + 1
+ current_pos = 0 # track the current position value
+ prev_image_end = 0
+ curr_position_ids = []
+ for start, end, grid in zip(image_start, image_end, image_grid_thw):
+ _, num_width_grid, num_height_grid = grid
+
+ # Create text position ids first if there are text tokens before image
+ llm_pos_length = start - prev_image_end
+ llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to(
+ device=input_ids.device
+ )
+ current_pos += llm_position_ids.shape[-1]
+
+ # Now create image position ids for each grid
+ image_seq_length = num_height_grid * num_width_grid
+ h_grids = image_seq_length // num_height_grid + current_pos
+ w_grids = image_seq_length // num_width_grid + current_pos
+ position_width = torch.arange(current_pos, w_grids, device=input_ids.device).repeat(num_width_grid)
+ position_height = torch.arange(current_pos, h_grids, device=input_ids.device).repeat_interleave(
+ num_height_grid
+ )
+ position_temporal = torch.full(
+ (image_seq_length,), current_pos, device=input_ids.device, dtype=torch.long
+ )
+ vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0)
+ current_pos += max(num_height_grid, num_width_grid)
+
+ prev_image_end = end
+ curr_position_ids.append(torch.cat([llm_position_ids, vision_position_ids], dim=-1))
+
+ # Add position ids for the last text tokens if any
+ end_position = len(curr_input_ids) - prev_image_end
+ llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=input_ids.device)
+ current_pos += llm_position_ids.shape[-1]
+ curr_position_ids.append(llm_position_ids)
+ curr_position_ids = torch.cat(curr_position_ids, dim=-1)
+ if attention_mask is not None:
+ position_ids[:, batch_idx, attention_mask[batch_idx] == 1] = curr_position_ids.to(position_ids.device)
+ else:
+ position_ids[:, batch_idx, :] = curr_position_ids.to(position_ids.device)
+
+ # Build and store position ids for tokens that will be generated. Later we will just
+ # slice these instead of computing each decoding step
+ self._prefill_len = seq_len
+ if image_grid_thw is not None and len(image_grid_thw) > 0:
+ num_decode_grids = len(image_grid_thw) - num_complete_images
+ num_decode_grids = max(num_decode_grids, 0)
+ decode_pos = current_pos
+
+ decode_temporal_list = []
+ decode_height_list = []
+ decode_width_list = []
+
+ for i in range(1, num_decode_grids + 1):
+ grid_idx = -i
+ h = image_grid_thw[grid_idx, 1].item()
+ w = image_grid_thw[grid_idx, 2].item()
+ total_tokens = h * w
+
+ h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten()
+ w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten()
+
+ decode_temporal_list.append(torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long))
+ decode_height_list.append(decode_pos + h_indices)
+ decode_width_list.append(decode_pos + w_indices)
+ decode_pos = decode_pos + max(h, w)
+
+ decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
+ decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
+ decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
+
+ self._cached_decode_position_ids = torch.stack(
+ [
+ torch.cat(decode_temporal_list, dim=0),
+ torch.cat(decode_height_list, dim=0),
+ torch.cat(decode_width_list, dim=0),
+ ],
+ dim=0,
+ )
+ else:
+ self._cached_decode_position_ids = None
+
+ mrope_position_deltas = torch.zeros([batch_size, 1], dtype=dtype, device=device)
+
+ return position_ids, mrope_position_deltas
+
+ def get_image_tokens(
+ self,
+ hidden_states: torch.FloatTensor,
+ image_grid_thw: torch.LongTensor,
+ ) -> torch.LongTensor:
+ """
+ Tokenizes image features into discrete tokens with VQVAE module.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(total_patches, hidden_size)`):
+ The packed image features from vision encoder.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
+ The temporal, height and width of feature shape of each image.
+
+ Returns:
+ image_tokens (`torch.LongTensor` of shape `(total_patches,)`):
+ Discrete token indices from the VQVAE codebook.
+ """
+ hidden_size = hidden_states.shape[-1]
+ split_sizes = (image_grid_thw.prod(dim=-1)).tolist()
+ hidden_states_list = torch.split(hidden_states, split_sizes, dim=0)
+
+ all_image_toks = []
+ for i, hs in enumerate(hidden_states_list):
+ grid_t, grid_h, grid_w = image_grid_thw[i].tolist()
+ hs = hs.view(grid_t, grid_h, grid_w, hidden_size)
+ hs = hs.permute(0, 3, 1, 2).contiguous()
+ _, _, image_toks = self.vqmodel.encode(hs)
+ all_image_toks.append(image_toks)
+ return torch.cat(all_image_toks, dim=0)
+
+ def get_video_features(self):
+ raise AttributeError("Not needed for GlmImage")
+
+ def get_placeholder_mask(
+ self,
+ input_ids: torch.LongTensor,
+ image_ids: torch.LongTensor,
+ ):
+ """
+ Replace image placeholder tokens in input_ids with actual image token ids from VQVAE.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`):
+ Input token ids with image placeholders.
+ image_ids (`torch.LongTensor` of shape `(num_images, num_tokens_per_image)` or flattened):
+ Discrete token indices from the VQVAE codebook.
+
+ Returns:
+ special_image_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
+ Mask indicating positions in input ids that will be replaced by actual image tokens.
+ """
+
+ special_image_mask = input_ids == self.config.image_token_id
+ n_placeholder_tokens = special_image_mask.sum().item()
+ n_image_tokens = image_ids.shape[0]
+
+ if n_placeholder_tokens != n_image_tokens:
+ raise ValueError(
+ f"Number of image placeholder tokens ({n_placeholder_tokens}) does not match "
+ f"number of image tokens from VQVAE ({n_image_tokens})"
+ )
+
+ return special_image_mask
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ rope_deltas: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | GlmImageModelOutputWithPast:
+ r"""
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None:
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw[:-1])
+ image_embeds = torch.cat(image_embeds, dim=0)
+ image_ids = self.get_image_tokens(image_embeds, image_grid_thw[:-1])
+ image_ids = image_ids.view(-1).to(input_ids.device)
+ special_image_mask = self.get_placeholder_mask(input_ids, image_ids)
+ input_ids = input_ids.masked_scatter(special_image_mask, image_ids)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if position_ids is None:
+ attention_mask_2d = attention_mask
+ if attention_mask is not None and attention_mask.ndim == 4:
+ attention_mask_2d = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
+ # Only apply conversion for floating point tensors (inverted masks)
+ if attention_mask_2d.dtype.is_floating_point:
+ attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
+ attention_mask_2d = (1.0 - attention_mask_2d).int()
+
+ # Calculate RoPE index once per generation in the pre-fill stage only.
+ # It is safe to assume that `length!=1` means we're in pre-fill because the
+ # model is used only by DiT pipeline without assisted decoding, etc. techniques
+ is_prefill_stage = (input_ids is not None and input_ids.shape[1] != 1) or (
+ inputs_embeds is not None and inputs_embeds.shape[1] != 1
+ )
+ if is_prefill_stage or self.rope_deltas is None:
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ attention_mask=attention_mask_2d,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ # Use prefill token length, not position value
+ step = cache_position[0].item() - self._prefill_len
+ # Direct lookup - no tensor creation overhead
+ position_ids = self._cached_decode_position_ids[:, step : step + seq_length]
+ position_ids = position_ids.unsqueeze(1).expand(-1, batch_size, -1)
+
+ outputs = self.language_model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return GlmImageModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+
+class GlmImageCausalLMOutputWithPast(Glm4vCausalLMOutputWithPast):
+ pass
+
+
+class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = {}
+ # Reference: fix gemma3 grad acc #37208
+ accepts_loss_kwargs = False
+ base_model_prefix = "model"
+ config: GlmImageConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GlmImageModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vision_vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
+ return self.model.get_image_features(pixel_values, image_grid_thw)
+
+ def get_image_tokens(self, hidden_states: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
+ return self.model.get_image_tokens(hidden_states, image_grid_thw)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: Cache | None = None,
+ inputs_embeds: torch.FloatTensor | None = None,
+ labels: torch.LongTensor | None = None,
+ pixel_values: torch.Tensor | None = None,
+ image_grid_thw: torch.LongTensor | None = None,
+ cache_position: torch.LongTensor | None = None,
+ logits_to_keep: int | torch.Tensor = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> tuple | GlmImageCausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, GlmImageForConditionalGeneration
+
+ >>> model = GlmImageForConditionalGeneration.from_pretrained("zai-org/GLM-Image")
+ >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-Image")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "Add a truck of this photo.28 40"},
+ ],
+ },
+ ]
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+ ```"""
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+ return GlmImageCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=outputs.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ image_grid_thw=None,
+ is_first_iteration=False,
+ **kwargs,
+ ):
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw,
+ is_first_iteration=is_first_iteration,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ model_inputs["position_ids"] = None
+
+ if not is_first_iteration and use_cache:
+ model_inputs["pixel_values"] = None
+
+ return model_inputs
+
+ def _get_image_nums(
+ self,
+ input_ids: torch.LongTensor | None,
+ ) -> torch.Tensor:
+ """
+ Get the number of images for each sample.
+ For GLM-Image, only input_ids allow us to get the number of images.
+
+ Returns:
+ image_counts (`torch.LongTensor` of shape `(batch_size,)`)
+ """
+ is_image = input_ids == self.config.image_start_token_id
+
+ return is_image.sum(dim=1)
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: torch.LongTensor | None = None,
+ **model_kwargs,
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
+ # Overwritten -- Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ image_nums = self._get_image_nums(input_ids)
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample and +1 for the image being generated
+ lengths = list(image_nums)
+ last_image = dict_to_expand[key][:-1]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key][: sum(image_nums)], lengths=lengths, repeat_times=expand_size
+ )
+ dict_to_expand[key] = torch.cat([dict_to_expand[key], last_image], dim=0)
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 16,
+ min_pixels: int = 512 * 512,
+ max_pixels: int = 2048 * 2048,
+) -> tuple[int, int]:
+ if height < factor or width < factor:
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
+ elif max(height, width) / min(height, width) > 4:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 4, got {max(height, width) / min(height, width)}"
+ )
+
+ shortest_edge = int(round(math.sqrt(min_pixels)))
+ longest_edge = int(round(math.sqrt(max_pixels)))
+ min_side = min(height, width)
+ max_side = max(height, width)
+
+ scale = 1.0
+
+ if min_side < shortest_edge:
+ scale = shortest_edge / min_side
+
+ if max_side * scale > longest_edge:
+ scale = longest_edge / max_side
+
+ height = height // 2
+ width = width // 2
+
+ h_bar = max(factor, int(round(height * scale / factor)) * factor)
+ w_bar = max(factor, int(round(width * scale / factor)) * factor)
+
+ if max(h_bar, w_bar) > longest_edge:
+ beta = max(h_bar, w_bar) / longest_edge
+ h_bar = max(factor, int(math.floor((h_bar / beta) / factor)) * factor)
+ w_bar = max(factor, int(math.floor((w_bar / beta) / factor)) * factor)
+
+ return h_bar, w_bar
+
+
+class GlmImageImageProcessor(Qwen2VLImageProcessor):
+ pass
+
+
+class GlmImageImageProcessorFast(Qwen2VLImageProcessorFast):
+ pass
+
+
+class GlmImageImagesKwargs(ImagesKwargs, total=False):
+ """
+ target_h (`int`):
+ Height of the target image to be generated.
+ target_w (`int`):
+ Width of the target image to be generated.
+ """
+
+ target_h: int
+ target_w: int
+
+
+class GlmImageProcessorKwargs(Qwen2VLProcessorKwargs):
+ images_kwargs: GlmImageImagesKwargs
+
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ "images_kwargs": {
+ "target_h": 1152,
+ "target_w": 768,
+ },
+ }
+
+
+class GlmImageProcessor(ProcessorMixin):
+ r"""
+ Constructs a GLM-Image processor which wraps a GLM-Image image processor and a GLM-Image tokenizer into a single processor.
+ [`~GlmImageProcessor.__call__`] and [`~GlmImageProcessor.decode`] for more information.
+ Args:
+ image_processor ([`GlmImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`PreTrainedTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
+ self.image_token = tokenizer.image_token
+ self.grid_bos_token = tokenizer.grid_bos_token
+ self.grid_eos_token = tokenizer.grid_eos_token
+ self.bos_token = tokenizer.bos_token
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: ImageInput | None = None,
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
+ **kwargs: Unpack[GlmImageProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ GlmImageProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ target_h = output_kwargs["images_kwargs"].pop("target_h", None)
+ target_w = output_kwargs["images_kwargs"].pop("target_w", None)
+ is_text_to_image = images is None
+
+ if images is not None:
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ image_grid_thw = image_inputs["image_grid_thw"]
+ else:
+ image_inputs = {}
+ image_grid_thw = None
+
+ if not isinstance(text, list):
+ text = [text]
+
+ if len(text) > 1:
+ raise ValueError("The model does not support batch size > 1")
+
+ text = text.copy() # below lines change text in-place
+ if not is_text_to_image:
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ grid = image_grid_thw[index]
+ num_image_tokens = int(grid[1] * grid[2])
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
+
+ text[0], token_h, token_w, prev_h, prev_w = self._build_prompt_with_target_shape(
+ text[0], height=target_h, width=target_w, is_text_to_image=is_text_to_image
+ )
+ image_inputs["image_grid_thw"] = self._build_target_image_grid_thw(
+ token_h=token_h,
+ token_w=token_w,
+ prev_token_h=prev_h,
+ prev_token_w=prev_w,
+ image_grid_thw=image_grid_thw if not is_text_to_image else None,
+ )
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+ def _build_prompt_with_target_shape(
+ self,
+ prompt: str,
+ height: int,
+ width: int,
+ is_text_to_image: bool,
+ ) -> tuple[str, int, int, int, int]:
+ factor = 32
+ height = (height // factor) * factor
+ width = (width // factor) * factor
+ token_h = height // factor
+ token_w = width // factor
+ ratio = token_h / token_w
+ prev_token_h = int(math.sqrt(ratio) * (factor // 2))
+ prev_token_w = int(math.sqrt(1 / ratio) * (factor // 2))
+
+ if is_text_to_image:
+ expanded_prompt = f"{prompt}{self.grid_bos_token}{token_h} {token_w}{self.grid_eos_token}{self.grid_bos_token}{prev_token_h} {prev_token_w}{self.grid_eos_token}{self.bos_token}"
+ else:
+ expanded_prompt = f"{prompt}{self.grid_bos_token}{token_h} {token_w}{self.grid_eos_token}{self.bos_token}"
+
+ return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w
+
+ @staticmethod
+ def _build_target_image_grid_thw(
+ token_h: int,
+ token_w: int,
+ prev_token_h: int,
+ prev_token_w: int,
+ image_grid_thw: None,
+ ):
+ if image_grid_thw is None:
+ return torch.tensor(
+ [
+ [1, token_h, token_w],
+ [1, prev_token_h, prev_token_w],
+ ],
+ )
+ else:
+ return torch.cat(
+ [image_grid_thw, torch.tensor([[1, token_h, token_w]], device=image_grid_thw.device)], dim=0
+ )
+
+
+__all__ = [
+ "GlmImageVQVAEConfig",
+ "GlmImageVisionConfig",
+ "GlmImageTextConfig",
+ "GlmImageConfig",
+ "GlmImagePreTrainedModel",
+ "GlmImageVQVAE",
+ "GlmImageVisionModel",
+ "GlmImageTextModel",
+ "GlmImageModel",
+ "GlmImageForConditionalGeneration",
+ "GlmImageImageProcessor",
+ "GlmImageImageProcessorFast",
+ "GlmImageProcessor",
+]
diff --git a/src/transformers/models/glm_image/processing_glm_image.py b/src/transformers/models/glm_image/processing_glm_image.py
new file mode 100644
index 000000000000..8cbfabf4ecc2
--- /dev/null
+++ b/src/transformers/models/glm_image/processing_glm_image.py
@@ -0,0 +1,217 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_glm_image.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import numpy as np
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+
+class GlmImageImagesKwargs(ImagesKwargs, total=False):
+ """
+ target_h (`int`):
+ Height of the target image to be generated.
+ target_w (`int`):
+ Width of the target image to be generated.
+ """
+
+ target_h: int
+ target_w: int
+
+
+class GlmImageProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ "return_mm_token_type_ids": False,
+ },
+ "images_kwargs": {
+ "target_h": 1152,
+ "target_w": 768,
+ },
+ }
+ images_kwargs: GlmImageImagesKwargs
+
+
+class GlmImageProcessor(ProcessorMixin):
+ r"""
+ Constructs a GLM-Image processor which wraps a GLM-Image image processor and a GLM-Image tokenizer into a single processor.
+ [`~GlmImageProcessor.__call__`] and [`~GlmImageProcessor.decode`] for more information.
+ Args:
+ image_processor ([`GlmImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`PreTrainedTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
+ self.image_token = tokenizer.image_token
+ self.grid_bos_token = tokenizer.grid_bos_token
+ self.grid_eos_token = tokenizer.grid_eos_token
+ self.bos_token = tokenizer.bos_token
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: ImageInput | None = None,
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
+ **kwargs: Unpack[GlmImageProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
+ the text.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ GlmImageProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ target_h = output_kwargs["images_kwargs"].pop("target_h", None)
+ target_w = output_kwargs["images_kwargs"].pop("target_w", None)
+ is_text_to_image = images is None
+
+ if images is not None:
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ image_grid_thw = image_inputs["image_grid_thw"]
+ else:
+ image_inputs = {}
+ image_grid_thw = None
+
+ if not isinstance(text, list):
+ text = [text]
+
+ if len(text) > 1:
+ raise ValueError("The model does not support batch size > 1")
+
+ text = text.copy() # below lines change text in-place
+ if not is_text_to_image:
+ index = 0
+ for i in range(len(text)):
+ while self.image_token in text[i]:
+ grid = image_grid_thw[index]
+ num_image_tokens = int(grid[1] * grid[2])
+ text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
+
+ text[0], token_h, token_w, prev_h, prev_w = self._build_prompt_with_target_shape(
+ text[0], height=target_h, width=target_w, is_text_to_image=is_text_to_image
+ )
+ image_inputs["image_grid_thw"] = self._build_target_image_grid_thw(
+ token_h=token_h,
+ token_w=token_w,
+ prev_token_h=prev_h,
+ prev_token_w=prev_w,
+ image_grid_thw=image_grid_thw if not is_text_to_image else None,
+ )
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
+
+ if return_mm_token_type_ids:
+ array_ids = np.array(text_inputs["input_ids"])
+ mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
+ text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
+ return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
+
+ def _build_prompt_with_target_shape(
+ self,
+ prompt: str,
+ height: int,
+ width: int,
+ is_text_to_image: bool,
+ ) -> tuple[str, int, int, int, int]:
+ factor = 32
+ height = (height // factor) * factor
+ width = (width // factor) * factor
+ token_h = height // factor
+ token_w = width // factor
+ ratio = token_h / token_w
+ prev_token_h = int(math.sqrt(ratio) * (factor // 2))
+ prev_token_w = int(math.sqrt(1 / ratio) * (factor // 2))
+
+ if is_text_to_image:
+ expanded_prompt = f"{prompt}{self.grid_bos_token}{token_h} {token_w}{self.grid_eos_token}{self.grid_bos_token}{prev_token_h} {prev_token_w}{self.grid_eos_token}{self.bos_token}"
+ else:
+ expanded_prompt = f"{prompt}{self.grid_bos_token}{token_h} {token_w}{self.grid_eos_token}{self.bos_token}"
+
+ return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w
+
+ @staticmethod
+ def _build_target_image_grid_thw(
+ token_h: int,
+ token_w: int,
+ prev_token_h: int,
+ prev_token_w: int,
+ image_grid_thw: None,
+ ):
+ if image_grid_thw is None:
+ return torch.tensor(
+ [
+ [1, token_h, token_w],
+ [1, prev_token_h, prev_token_w],
+ ],
+ )
+ else:
+ return torch.cat(
+ [image_grid_thw, torch.tensor([[1, token_h, token_w]], device=image_grid_thw.device)], dim=0
+ )
+
+
+__all__ = ["GlmImageProcessor"]
diff --git a/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py b/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py
index 56c01c6b2fb0..698ceebb22c0 100644
--- a/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py
+++ b/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py
@@ -22,7 +22,6 @@
from transformers import Cache, is_torch_available
from transformers.testing_utils import (
cleanup,
- require_read_token,
require_torch,
require_torch_accelerator,
slow,
@@ -82,7 +81,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l
@require_torch_accelerator
-@require_read_token
@slow
class Glm4MoeIntegrationTest(unittest.TestCase):
def tearDown(self):
@@ -92,7 +90,6 @@ def tearDown(self):
@slow
@require_torch_accelerator
@pytest.mark.torch_compile_test
- @require_read_token
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
diff --git a/tests/models/glm_image/__init__.py b/tests/models/glm_image/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/glm_image/test_modeling_glm_image.py b/tests/models/glm_image/test_modeling_glm_image.py
new file mode 100644
index 000000000000..5aeb255d25fd
--- /dev/null
+++ b/tests/models/glm_image/test_modeling_glm_image.py
@@ -0,0 +1,484 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch GLM-Image model."""
+
+import unittest
+
+from parameterized import parameterized
+
+from transformers import (
+ AutoProcessor,
+ GlmImageConfig,
+ GlmImageForConditionalGeneration,
+ GlmImageModel,
+ is_torch_available,
+)
+from transformers.models.auto import get_values
+from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
+from transformers.testing_utils import (
+ cleanup,
+ require_flash_attn,
+ require_torch,
+ require_torch_accelerator,
+ run_first,
+ slow,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import (
+ TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
+ ModelTesterMixin,
+ floats_tensor,
+ ids_tensor,
+)
+
+
+if is_torch_available():
+ import torch
+
+
+class GlmImageVisionText2TextModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=1,
+ seq_length=7,
+ num_channels=3,
+ ignore_index=-100,
+ image_size=128,
+ image_start_token_id=85,
+ image_end_token_id=86,
+ image_token_id=7,
+ is_training=True,
+ text_config={
+ "vocab_size": 99,
+ "vision_vocab_size": 99,
+ "hidden_size": 16,
+ "intermediate_size": 22,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 1,
+ "output_channels": 64,
+ "hidden_act": "silu",
+ "max_position_embeddings": 512,
+ "rope_parameters": {"type": "default", "mrope_section": [2, 1, 1]},
+ "rope_theta": 10000,
+ "tie_word_embeddings": True,
+ "bos_token_id": 0,
+ "eos_token_id": 0,
+ "pad_token_id": 0,
+ "n_routed_experts": 8,
+ "n_shared_experts": 1,
+ "n_group": 1,
+ "topk_group": 1,
+ "num_experts_per_tok": 8,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_act": "gelu",
+ "hidden_size": 32,
+ "out_hidden_size": 16,
+ "intermediate_size": 22,
+ "patch_size": 16,
+ "spatial_merge_size": 1,
+ "temporal_patch_size": 1,
+ },
+ vq_config={
+ "embed_dim": 48,
+ "in_channels": 3,
+ "initializer_range": 0.02,
+ "latent_channels": 32,
+ "num_embeddings": 32,
+ },
+ ):
+ self.parent = parent
+ self.ignore_index = ignore_index
+ self.bos_token_id = text_config["bos_token_id"]
+ self.eos_token_id = text_config["eos_token_id"]
+ self.pad_token_id = text_config["pad_token_id"]
+ self.image_start_token_id = image_start_token_id
+ self.image_end_token_id = image_end_token_id
+ self.image_token_id = image_token_id
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.vq_config = vq_config
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.is_training = is_training
+ self.hidden_size = text_config["hidden_size"]
+ self.num_hidden_layers = text_config["num_hidden_layers"]
+ self.num_attention_heads = text_config["num_attention_heads"]
+ self.vision_vocab_size = text_config["vision_vocab_size"]
+ self.vocab_size = text_config["vocab_size"]
+ self.num_image_tokens = 64
+ self.seq_length = seq_length + self.num_image_tokens
+ self.n_routed_experts = text_config["n_routed_experts"]
+ self.n_shared_experts = text_config["n_shared_experts"]
+ self.num_experts_per_tok = text_config["num_experts_per_tok"]
+ self.n_group = text_config["n_group"]
+ self.topk_group = text_config["topk_group"]
+
+ def get_config(self):
+ return GlmImageConfig(
+ text_config=self.text_config,
+ vision_config=self.vision_config,
+ vq_config=self.vq_config,
+ image_token_id=self.image_token_id,
+ image_start_token_id=self.image_start_token_id,
+ image_end_token_id=self.image_end_token_id,
+ )
+
+ def prepare_config_and_inputs(self):
+ config = self.get_config()
+ patch_size = config.vision_config.patch_size
+ temporal_patch_size = config.vision_config.temporal_patch_size
+ pixel_values = floats_tensor(
+ [
+ self.batch_size * (self.image_size**2) // (patch_size**2),
+ self.num_channels * (patch_size**2) * temporal_patch_size,
+ ]
+ )
+
+ return config, pixel_values
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
+
+ input_ids[input_ids == self.image_token_id] = self.pad_token_id
+ input_ids[input_ids == self.image_start_token_id] = self.pad_token_id
+ input_ids[input_ids == self.image_end_token_id] = self.pad_token_id
+
+ input_ids[:, 0] = self.image_start_token_id
+ input_ids[:, 1 : 1 + self.num_image_tokens] = self.image_token_id
+ input_ids[:, 1 + self.num_image_tokens] = self.image_end_token_id
+ patch_size = config.vision_config.patch_size
+ patches_per_side = self.image_size // patch_size
+
+ # Key fix: image_grid_thw should have batch_size rows for input images
+ # plus 1 extra row that will be skipped by model's [:-1] slicing
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "image_grid_thw": torch.tensor(
+ [[1, patches_per_side, patches_per_side]] * self.batch_size
+ + [[1, patches_per_side, patches_per_side]], # Extra row for model's [:-1]
+ device=torch_device,
+ ),
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class GlmImageModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ all_model_classes = (GlmImageModel, GlmImageForConditionalGeneration) if is_torch_available() else ()
+
+ model_split_percents = [0.7, 0.9] # model too big to split at 0.5
+ _is_composite = True
+
+ def setUp(self):
+ self.model_tester = GlmImageVisionText2TextModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=GlmImageConfig, has_text_modality=False)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ # GlmImage has images shaped as (bs*patch_len, dim) so we can't slice to batches in generate
+ def prepare_config_and_inputs_for_generate(self, batch_size=2):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ # We don't want a few model inputs in our model input dictionary for generation tests
+ input_keys_to_ignore = [
+ # we don't want to mask attention heads
+ # we don't want encoder-decoder models to start from filled decoder ids
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ # we'll set cache use in each test differently
+ "use_cache",
+ # Ignore labels if it is in the input dict
+ "labels",
+ # model-specific exceptions should overload/overwrite this function
+ ]
+
+ # The diff from the general `prepare_config_and_inputs_for_generate` lies here
+ patch_size = config.vision_config.patch_size
+ filtered_image_length = batch_size * (self.model_tester.image_size**2) // (patch_size**2)
+ filtered_inputs_dict = {
+ k: v[:batch_size, ...] if isinstance(v, torch.Tensor) else v
+ for k, v in inputs_dict.items()
+ if k not in input_keys_to_ignore
+ }
+ filtered_inputs_dict["pixel_values"] = inputs_dict["pixel_values"][:filtered_image_length]
+ filtered_inputs_dict["image_grid_thw"] = inputs_dict["image_grid_thw"][: batch_size + 1]
+
+ # It is important set `eos_token_id` to `None` to avoid early stopping (would break for length-based checks)
+ text_gen_config = config.get_text_config(decoder=True)
+ if text_gen_config.eos_token_id is not None and text_gen_config.pad_token_id is None:
+ text_gen_config.pad_token_id = (
+ text_gen_config.eos_token_id
+ if isinstance(text_gen_config.eos_token_id, int)
+ else text_gen_config.eos_token_id[0]
+ )
+ text_gen_config.eos_token_id = None
+ text_gen_config.forced_eos_token_id = None
+
+ return config, filtered_inputs_dict
+
+ def test_training(self):
+ # Model isn't in any auto-mapping so we need to build labels manually
+ if not self.model_tester.is_training:
+ self.skipTest(reason="ModelTester is not configured to run training tests")
+
+ for model_class in self.all_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ if model_class.__name__ in [
+ *get_values(MODEL_MAPPING_NAMES),
+ ]:
+ continue
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.train()
+ inputs_dict["labels"] = torch.zeros(
+ (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
+ )
+ loss = model(**inputs_dict).loss
+ loss.backward()
+
+ @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
+ @unittest.skip("Needs special input preparation. Not important test for model, skip for now")
+ def test_eager_matches_sdpa_inference(
+ self, name, dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
+ ):
+ pass
+
+ @unittest.skip(reason="No available kernels - not supported")
+ def test_sdpa_can_dispatch_on_flash(self):
+ pass
+
+ @unittest.skip(reason="Size mismatch")
+ def test_multi_gpu_data_parallel_forward(self):
+ pass
+
+ @unittest.skip("Error with compilation")
+ def test_generate_from_inputs_embeds_with_static_cache(self):
+ pass
+
+ @parameterized.expand([("greedy", 1), ("beam search", 2)])
+ @unittest.skip(reason="GLM-Image does not use inputs_embeds")
+ def test_generate_from_inputs_embeds(self, _, num_beams):
+ pass
+
+ @unittest.skip(reason="GLM-Image input embed is compare with inputs_ids and image_ids")
+ def test_inputs_embeds_matches_input_ids(self):
+ pass
+
+ @unittest.skip(reason="GLM-Image does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="GLM-Image can't do text-only inference")
+ def test_generate_from_random_inputs_embeds(self):
+ pass
+
+ @unittest.skip(reason="GLM-Image can't do and does not need assisted generation. Not worth fixing!")
+ def test_assisted_decoding_sample(self):
+ pass
+
+ @unittest.skip(reason="GLM-Image can't do and does not need assisted generation. Not worth fixing!")
+ def test_prompt_lookup_decoding_matches_greedy_searc(self):
+ pass
+
+ @parameterized.expand([("random",), ("same",)])
+ @unittest.skip(reason="GLM-Image can't do and does not need assisted generation. Not worth fixing!")
+ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
+ pass
+
+ @unittest.skip(reason="GlmImageVisionModel does not support training")
+ def test_training_gradient_checkpointing(self):
+ pass
+
+ @unittest.skip(reason="GlmImageVision does not support output_hidden_states test")
+ def test_model_outputs_equivalence(self):
+ pass
+
+ @unittest.skip(
+ reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing_use_reentrant(self):
+ pass
+
+ @unittest.skip(
+ reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
+ )
+ def test_training_gradient_checkpointing_use_reentrant_false(self):
+ pass
+
+ @unittest.skip(reason="GlmImageVisionModel does not support training")
+ def test_retain_grad_hidden_states_attentions(self):
+ pass
+
+ @unittest.skip(reason="GlmImage needs special input preparation to pass this test")
+ def test_generate_compile_model_forward_fullgraph(self):
+ pass
+
+
+@require_torch
+@slow
+class GlmImageIntegrationTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.model = None
+
+ @classmethod
+ def get_model(cls):
+ if cls.model is None:
+ cls.model = GlmImageForConditionalGeneration.from_pretrained(
+ "zai-org/GLM-4.5V", dtype="auto", device_map="auto"
+ )
+ return cls.model
+
+ @classmethod
+ def tearDownClass(cls):
+ if hasattr(cls, "model"):
+ del cls.model
+ cleanup(torch_device, gc_collect=True)
+
+ def setUp(self):
+ cleanup(torch_device, gc_collect=True)
+ self.processor = AutoProcessor.from_pretrained(
+ "zai-org/GLM-4.5V", size={"shortest_edge": 10800, "longest_edge": 10800}
+ )
+ self.message = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
+ },
+ {"type": "text", "text": "What kind of dog is this?"},
+ ],
+ }
+ ]
+ self.message2 = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png",
+ },
+ {"type": "text", "text": "What kind of dog is this?"},
+ ],
+ }
+ ]
+ self.message_wo_image = [
+ {"role": "user", "content": [{"type": "text", "text": "Who are you?"}]},
+ ]
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ def test_small_model_integration_test(self):
+ inputs = self.processor.apply_chat_template(
+ self.message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
+ )
+ expected_input_ids = [151331, 151333, 151336, 198, 151339, 151363, 151363, 151363, 151363, 151363, 151363,
+ 151340, 3838, 3093, 315, 5562, 374] # fmt: skip
+ assert expected_input_ids == inputs.input_ids[0].tolist()[:17]
+
+ expected_pixel_slice = torch.tensor(
+ [
+ [-0.1134, -0.4492, -0.8580],
+ [-0.6244, -1.1645, -0.7120],
+ [-0.3324, -0.7996, -0.7120],
+ [0.2077, 0.2223, 0.4121],
+ [0.4413, 0.1931, 0.4559],
+ [0.5873, 0.3099, 0.4851],
+ ],
+ dtype=torch.float32,
+ device="cpu",
+ )
+ torch.testing.assert_close(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=1e-4, rtol=1e-4)
+
+ def test_small_model_integration_test_batch(self):
+ model = self.get_model()
+ batch_messages = [self.message, self.message2, self.message_wo_image]
+ inputs = self.processor.apply_chat_template(
+ batch_messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ ).to(torch_device)
+
+ # it should not matter whether two images are the same size or not
+ output = model.generate(**inputs, max_new_tokens=10)
+
+ EXPECTED_DECODED_TEXT = [
+ "\nWhat kind of dog is this?\nGot it, let's try to figure out",
+ "\nWhat kind of dog is this?\nGot it, let's see. The user",
+ '\nWho are you?\nThe user is asking "Who are you?"'
+ ] # fmt: skip
+ decoded = self.processor.batch_decode(output, skip_special_tokens=True)
+ decoded = [x.replace("<|image|>", "") for x in decoded]
+ self.assertEqual(
+ decoded,
+ EXPECTED_DECODED_TEXT,
+ )
+
+ @run_first
+ @require_flash_attn
+ @require_torch_accelerator
+ def test_small_model_integration_test_batch_flashatt2(self):
+ model = GlmImageForConditionalGeneration.from_pretrained(
+ "zai-org/GLM-4.5V",
+ dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+ device_map="auto",
+ )
+ batch_messages = [self.message, self.message2, self.message_wo_image]
+ inputs = self.processor.apply_chat_template(
+ batch_messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ ).to(torch_device)
+
+ # it should not matter whether two images are the same size or not
+ output = model.generate(**inputs, max_new_tokens=3)
+
+ EXPECTED_DECODED_TEXT = [
+ "\nWhat kind of dog is this?\nGot it",
+ "\nWhat kind of dog is this?\nGot it",
+ "\nWho are you?\nThe user",
+ ] # fmt: skip
+ decoded = self.processor.batch_decode(output, skip_special_tokens=True)
+ decoded = [x.replace("<|image|>", "") for x in decoded]
+ self.assertEqual(
+ decoded,
+ EXPECTED_DECODED_TEXT,
+ )
diff --git a/tests/models/glm_image/test_processor_glm_image.py b/tests/models/glm_image/test_processor_glm_image.py
new file mode 100644
index 000000000000..0475f92a8caa
--- /dev/null
+++ b/tests/models/glm_image/test_processor_glm_image.py
@@ -0,0 +1,155 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_av, require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_processing_common import ProcessorTesterMixin
+
+
+if is_vision_available():
+ from transformers import GlmImageProcessor
+
+if is_torch_available():
+ import torch
+
+
+@require_vision
+@require_torch
+@unittest.skip(reason="Model not released yet")
+class GlmImageProcessorTest(ProcessorTesterMixin, unittest.TestCase):
+ processor_class = GlmImageProcessor
+ model_id = "zai-org/GLM-Image"
+
+ @classmethod
+ def _setup_test_attributes(cls, processor):
+ cls.image_token = processor.image_token
+
+ @classmethod
+ def _setup_from_pretrained(cls, model_id, **kwargs):
+ return super()._setup_from_pretrained(
+ model_id,
+ do_sample_frames=False,
+ patch_size=4,
+ size={"shortest_edge": 12 * 12, "longest_edge": 18 * 18},
+ **kwargs,
+ )
+
+ @require_torch
+ @require_av
+ def _test_apply_chat_template(
+ self,
+ modality: str,
+ batch_size: int,
+ return_tensors: str,
+ input_name: str,
+ processor_name: str,
+ input_data: list[str],
+ ):
+ processor = self.get_processor()
+ if processor.chat_template is None:
+ self.skipTest("Processor has no chat template")
+
+ if processor_name not in self.processor_class.get_attributes():
+ self.skipTest(f"{processor_name} attribute not present in {self.processor_class}")
+
+ batch_messages = [
+ [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": "Describe this."}],
+ },
+ ]
+ ] * batch_size
+
+ # Test that jinja can be applied
+ formatted_prompt = processor.apply_chat_template(batch_messages, add_generation_prompt=True, tokenize=False)
+ self.assertEqual(len(formatted_prompt), batch_size)
+
+ # Test that tokenizing with template and directly with `self.tokenizer` gives same output
+ formatted_prompt_tokenized = processor.apply_chat_template(
+ batch_messages, add_generation_prompt=True, tokenize=True, return_tensors=return_tensors
+ )
+ add_special_tokens = True
+ if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
+ add_special_tokens = False
+ tok_output = processor.tokenizer(
+ formatted_prompt, return_tensors=return_tensors, add_special_tokens=add_special_tokens
+ )
+ expected_output = tok_output.input_ids
+ self.assertListEqual(expected_output.tolist(), formatted_prompt_tokenized.tolist())
+
+ # Test that kwargs passed to processor's `__call__` are actually used
+ tokenized_prompt_100 = processor.apply_chat_template(
+ batch_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ padding="max_length",
+ truncation=True,
+ return_tensors=return_tensors,
+ max_length=100,
+ )
+ self.assertEqual(len(tokenized_prompt_100[0]), 100)
+
+ # Test that `return_dict=True` returns text related inputs in the dict
+ out_dict_text = processor.apply_chat_template(
+ batch_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors=return_tensors,
+ )
+ self.assertTrue(all(key in out_dict_text for key in ["input_ids", "attention_mask"]))
+ self.assertEqual(len(out_dict_text["input_ids"]), batch_size)
+ self.assertEqual(len(out_dict_text["attention_mask"]), batch_size)
+
+ # Test that with modality URLs and `return_dict=True`, we get modality inputs in the dict
+ for idx, url in enumerate(input_data[:batch_size]):
+ batch_messages[idx][0]["content"] = [batch_messages[idx][0]["content"][0], {"type": modality, "url": url}]
+
+ out_dict = processor.apply_chat_template(
+ batch_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors=return_tensors,
+ fps=2
+ if isinstance(input_data[0], str)
+ else None, # by default no more than 2 frames per second, otherwise too slow
+ )
+ input_name = getattr(self, input_name)
+ self.assertTrue(input_name in out_dict)
+ self.assertEqual(len(out_dict["input_ids"]), batch_size)
+ self.assertEqual(len(out_dict["attention_mask"]), batch_size)
+
+ mm_len = batch_size * 4
+ self.assertEqual(len(out_dict[input_name]), mm_len)
+
+ return_tensor_to_type = {"pt": torch.Tensor, "np": np.ndarray, None: list}
+ for k in out_dict:
+ self.assertIsInstance(out_dict[k], return_tensor_to_type[return_tensors])
+
+ def test_model_input_names(self):
+ processor = self.get_processor()
+
+ text = self.prepare_text_inputs(modalities=["image"])
+ image_input = self.prepare_image_inputs()
+ inputs_dict = {"text": text, "images": image_input}
+ inputs = processor(**inputs_dict, return_tensors="pt", do_sample_frames=False)
+
+ self.assertSetEqual(set(inputs.keys()), set(processor.model_input_names))
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 7e0819a245c4..699fbf24d0f7 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -100,6 +100,7 @@
"Phi4MultimodalVisionModel",
"Glm4vVisionModel",
"Glm4vMoeVisionModel",
+ "GlmImageVisionModel",
"EvollaSaProtPreTrainedModel",
"BltLocalEncoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM.
"BltLocalDecoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM.
@@ -125,6 +126,7 @@
"ErnieMForInformationExtraction",
"FastSpeech2ConformerHifiGan", # Already tested by SpeechT5HifiGan (# Copied from)
"FastSpeech2ConformerWithHifiGan", # Built with two smaller (tested) models.
+ "GlmImageVQVAE", # Building part of bigger (tested) model.
"GraphormerDecoderHead", # Building part of bigger (tested) model.
"JukeboxVQVAE", # Building part of bigger (tested) model.
"JukeboxPrior", # Building part of bigger (tested) model.
@@ -192,6 +194,7 @@
"Emu3TextModel", # Building part of bigger (tested) model
"Glm4vTextModel", # Building part of bigger (tested) model
"Glm4vMoeTextModel", # Building part of bigger (tested) model
+ "GlmImageTextModel", # Building part of bigger (tested) model
"Qwen2VLTextModel", # Building part of bigger (tested) model
"Qwen2_5_VLTextModel", # Building part of bigger (tested) model
"InternVLVisionModel", # Building part of bigger (tested) model
@@ -310,6 +313,7 @@
"FlavaTextModel",
"FlavaImageModel",
"FlavaMultimodalModel",
+ "GlmImageForConditionalGeneration",
"GPT2DoubleHeadsModel",
"GPTSw3DoubleHeadsModel",
"InstructBlipVisionModel",
diff --git a/utils/models_to_deprecate.py b/utils/models_to_deprecate.py
index 24313f65419d..9d6c38cda388 100644
--- a/utils/models_to_deprecate.py
+++ b/utils/models_to_deprecate.py
@@ -93,6 +93,7 @@
"gemma3n": ["gemma3n_audio", "gemma3n_text", "gemma3n_vision"],
"gpt2": ["cpm", "dialogpt", "gpt-sw3", "megatron_gpt2"],
"glm4v_moe": ["glm4v_moe_text", "glm4v_moe_vision"],
+ "glm4_image": ["glm4_image_text", "glm4_image_vision"],
"glm4v": ["glm4v_text", "glm4v_vision"],
"idefics3": ["idefics3_vision"],
"internvl": ["internvl_vision"],