Skip to content

Conversation

@zRzRzRzRzRzRzR
Copy link
Contributor

@zRzRzRzRzRzRzR zRzRzRzRzRzRzR commented Jan 4, 2026

This PR is to adapt the implementation of the AR model for GLM-Image. For Full Pipeline, check diffusers repos

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for the great model and your hard work!
Appart from making sure we don't always compute the loss, LGTM!

Comment on lines +372 to +392
# Other implementations: Process each chunk separately
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]

attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep @zucchini-nlp started an internal thread on how to properly do this. I think the best for tthis model as it is rushed is to keep it as is but let's work on sometthing better for the nextt models !

Comment on lines +1184 to +1222
def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None)
image_nums = self._get_image_nums(input_ids)

def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
return result

for key in dict_to_expand:
if key == "pixel_values":
# split images into samples
samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums))
# compute the sequence length of images for each sample
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "image_grid_thw":
# get the num of images for each sample and +1 for the image being generated
lengths = list(image_nums)
last_image = dict_to_expand[key][:-1]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key][: sum(image_nums)], lengths=lengths, repeat_times=expand_size
)
dict_to_expand[key] = torch.cat([dict_to_expand[key], last_image], dim=0)
return dict_to_expand

def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if (
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and key not in visual_keys
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put nesting outside as much as possible please!

Comment on lines +1195 to +1203
if key == "pixel_values":
# split images into samples
samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums))
# compute the sequence length of images for each sample
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "image_grid_thw":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since there are only 2 keys being handled let's no iterate (for explicitness)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reference code here? I referenced the Qwen2 implementation

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, let's merge 🚀

@zucchini-nlp zucchini-nlp enabled auto-merge (squash) January 12, 2026 15:53
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, glm4v, glm4v_moe, glm_image

@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43100&sha=ef3af1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants