Skip to content

Conversation

@zRzRzRzRzRzRzR
Copy link
Contributor

@yiyixuxu @sayakpaul For check with model

@zRzRzRzRzRzRzR zRzRzRzRzRzRzR marked this pull request as draft January 7, 2026 11:01
@zRzRzRzRzRzRzR zRzRzRzRzRzRzR marked this pull request as ready for review January 8, 2026 07:56
@zRzRzRzRzRzRzR zRzRzRzRzRzRzR changed the title GLM-Imge for test [GLM-Imge] New Models Support Jan 8, 2026
@zRzRzRzRzRzRzR zRzRzRzRzRzRzR changed the title [GLM-Imge] New Models Support [GLM-Image] New Models Support Jan 8, 2026
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking quite good!

I think all the precomputations are in place and the use of caching also reads quite simple.

@sayakpaul sayakpaul requested a review from yiyixuxu January 8, 2026 10:38
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looking great and super excited about this model
I left some comments, mostly, I'm a bit confused on the correct logic to set height/width

Comment on lines 604 to 609
prior_token_id, prior_token_image_ids, ar_height, ar_width = self.generate_prior_tokens(
prompt=prompt[0] if isinstance(prompt, list) else prompt,
image=image,
height=height,
width=width,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few things here:

  1. generate_prior_tokens will error out if height = None and width = None
  2. ar_height/ar_width is pretty straightfoward to calculate, let's calculate them seperately for clarity
  3. we can update generate_prior_tokens to only return two tokens, this way it is easier for user to skip this stage reusing pre-computed tokens

here is just a suggestion, I'm not completely ure the logic to assign defaut height/width are correct

Suggested change
prior_token_id, prior_token_image_ids, ar_height, ar_width = self.generate_prior_tokens(
prompt=prompt[0] if isinstance(prompt, list) else prompt,
image=image,
height=height,
width=width,
)
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
height = (height // 32) * 32
width = (width //32) * 32
prior_token_id, prior_token_image_ids = self.generate_prior_tokens(
prompt=prompt[0] if isinstance(prompt, list) else prompt,
image=image,
height=height,
width=width,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a check to ensure that height and width cannot be None. This is a strict requirement, as these two parameters must be present for the AR model to correctly output tokens

f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)

if prompt is not None and prompt_embeds is not None:
Copy link
Collaborator

@yiyixuxu yiyixuxu Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, the current code structure won't work withprompt=None, when the prompt_embeds is passed, - we still need prompt generate tokens using the AR model

I think we'd need to accept both prior_token_id and prompt_embeds as inputs if prompt is None. so something like

if prompt is None:
    if prior_token_id is None or prompt_embeds is None:
        raise ValueError(
            "When `prompt` is not provided, both `prior_token_id` and `prompt_embeds` must be passed."
        )

you also need to add the prior_token_id to pipeline input

Copy link
Contributor Author

@zRzRzRzRzRzRzR zRzRzRzRzRzRzR Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prior_token_id implementation must be generated by AR so prompt must not be none

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me change logic of this

self.k_cache = k
self.v_cache = v
else:
self.k_cache = torch.cat([self.k_cache, k], dim=2)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Referring to L253, should dim be equal to 1 here?

self.k_cache = torch.cat([self.k_cache, k], dim=2)
self.v_cache = torch.cat([self.v_cache, v], dim=2)

def get(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if it should be 1 or 2 but they should be same

probably better move the logic together like this so mistakes like https://github.com/huggingface/diffusers/pull/12921/files#r2678634789 is less likely to happen

Suggested change
def get(self):
def get(self, k: torch.Tensor, v: torch.Tensor):
k_cache = torch.cat([self.k_cache, key], dim=2)
v_cache = torch.cat([self.v_cache, key], dim=2)

Comment on lines +252 to +254
k_cache, v_cache = kv_cache.get()
key = torch.cat([k_cache, key], dim=1) if k_cache is not None else key
value = torch.cat([v_cache, value], dim=1) if v_cache is not None else value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
k_cache, v_cache = kv_cache.get()
key = torch.cat([k_cache, key], dim=1) if k_cache is not None else key
value = torch.cat([v_cache, value], dim=1) if v_cache is not None else value
key, value = kv_cache.get(key, value) if kv_cache is not None

num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prior_token_ids: Optional[torch.Tensor] = None,
prior_image_token_ids: Optional[torch.Tensor] = None

we should allow them to pre-compute the tokens since it is the most compute expensive part
we should allow them to pass pre-compute negative_prompt_embeds too because it is fixed


device = self._execution_device

prior_token_id, prior_token_image_ids = self.generate_prior_tokens(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prior_token_id, prior_token_image_ids = self.generate_prior_tokens(
if prior_token_ids is None:
prior_token_id, prior_token_image_ids = self.generate_prior_tokens( ...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants