-
Notifications
You must be signed in to change notification settings - Fork 31.7k
[GLM-Image] AR Model Support for GLM-Image #43100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot for the great model and your hard work!
Appart from making sure we don't always compute the loss, LGTM!
| # Other implementations: Process each chunk separately | ||
| lengths = cu_seqlens[1:] - cu_seqlens[:-1] | ||
| splits = [ | ||
| torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) | ||
| ] | ||
|
|
||
| attn_outputs = [ | ||
| attention_interface( | ||
| self, | ||
| q, | ||
| k, | ||
| v, | ||
| attention_mask=None, | ||
| scaling=self.scaling, | ||
| dropout=0.0 if not self.training else self.attention_dropout, | ||
| is_causal=False, | ||
| **kwargs, | ||
| )[0] | ||
| for q, k, v in zip(*splits) | ||
| ] | ||
| attn_output = torch.cat(attn_outputs, dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep @zucchini-nlp started an internal thread on how to properly do this. I think the best for tthis model as it is rushed is to keep it as is but let's work on sometthing better for the nextt models !
| def _expand_dict_for_generation_visual(dict_to_expand): | ||
| image_grid_thw = model_kwargs.get("image_grid_thw", None) | ||
| image_nums = self._get_image_nums(input_ids) | ||
|
|
||
| def _repeat_interleave_samples(x, lengths, repeat_times): | ||
| samples = torch.split(x, lengths) | ||
| repeat_args = [repeat_times] + [1] * (x.dim() - 1) | ||
| result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) | ||
| return result | ||
|
|
||
| for key in dict_to_expand: | ||
| if key == "pixel_values": | ||
| # split images into samples | ||
| samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums)) | ||
| # compute the sequence length of images for each sample | ||
| lengths = [torch.prod(sample, dim=1).sum() for sample in samples] | ||
| dict_to_expand[key] = _repeat_interleave_samples( | ||
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | ||
| ) | ||
| elif key == "image_grid_thw": | ||
| # get the num of images for each sample and +1 for the image being generated | ||
| lengths = list(image_nums) | ||
| last_image = dict_to_expand[key][:-1] | ||
| dict_to_expand[key] = _repeat_interleave_samples( | ||
| dict_to_expand[key][: sum(image_nums)], lengths=lengths, repeat_times=expand_size | ||
| ) | ||
| dict_to_expand[key] = torch.cat([dict_to_expand[key], last_image], dim=0) | ||
| return dict_to_expand | ||
|
|
||
| def _expand_dict_for_generation(dict_to_expand): | ||
| for key in dict_to_expand: | ||
| if ( | ||
| key != "cache_position" | ||
| and dict_to_expand[key] is not None | ||
| and isinstance(dict_to_expand[key], torch.Tensor) | ||
| and key not in visual_keys | ||
| ): | ||
| dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) | ||
| return dict_to_expand |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's put nesting outside as much as possible please!
| if key == "pixel_values": | ||
| # split images into samples | ||
| samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums)) | ||
| # compute the sequence length of images for each sample | ||
| lengths = [torch.prod(sample, dim=1).sum() for sample in samples] | ||
| dict_to_expand[key] = _repeat_interleave_samples( | ||
| dict_to_expand[key], lengths=lengths, repeat_times=expand_size | ||
| ) | ||
| elif key == "image_grid_thw": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since there are only 2 keys being handled let's no iterate (for explicitness)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reference code here? I referenced the Qwen2 implementation
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, let's merge 🚀
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, glm4v, glm4v_moe, glm_image |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43100&sha=ef3af1 |
This PR is to adapt the implementation of the AR model for GLM-Image. For Full Pipeline, check diffusers repos