Skip to content

Retrofit sentence bert into BertForSeqClassification #3924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

caronzh03
Copy link

@caronzh03 caronzh03 commented Apr 28, 2025

Primary goal of this PR is to validate my idea of retrofitting a BERT-based SentenceTransformer model into BertForSequenceClassification model, and use trtllm's LLM API to do inference.

The SentenceTransformer model has 3 modules:

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 768, 'out_features': 192, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
)

To map the weights above to a BertForSequenceClassification model, I'm planning to:

  1. Copy the Transformer layers' weights.
  2. Change the Pooler layer implementation in modeling_bert.py to computing the average of token embeddings. (This PR has the changes needed)
  3. Copy the Dense layer (same as torch.nn.Linear)'s weights to classifier layer in BertForSequenceClassification.

I have not been able to verify the above method due to some Docker setup issues.
@qixiang-99 , if you could help comment on whether the above approach is reasonable, that'll greatly help.

Note - I wasn't planning on merging this PR, because this is a very specific experiment for my own use case. However, if people see values in generalizing this approach, we can certainly polish this and create a new PR against main branch.

@caronzh03 caronzh03 requested a review from a team as a code owner April 28, 2025 22:43
@symphonylyh symphonylyh requested a review from qixiang-99 April 28, 2025 23:38
@symphonylyh
Copy link
Collaborator

requesting @qixiang-99 's review

btw, @caronzh03 can you re-target the PR to main instead of 0.19 release branch? the rule doesn't allow merges into a release branch

@qixiang-99
Copy link
Collaborator

Hi @caronzh03 , your plan seems reasonable. One thing to consider is the naming convention. You might find the weights name conversion function, which details the naming mapping individually, helpful. Let me know if I can assist further.

@caronzh03
Copy link
Author

caronzh03 commented May 17, 2025

hi @qixiang-99 , I finally got the Docker setup working and was able to test out my changes.
I have aBertForSequenceClassification model with 192 classes:

config = BertConfig.from_pretrained(
    my_model_path,
    num_labels=192, 
    ...
)
my_model = BertForSequenceClassification(config)

And when I tried to load this model using trtllm's PyTorch API and run inference on it using the modified quickstart.py, I got this error:

tensorrt_llm.executor.utils.RequestError: LogitsStorage overflow. This storage can only hold 16 logits (0 already filled) but trying to append 192 more logits
Traceback (most recent call last):
  File "/mnt/task_runtime/myenv/lib/python3.10/site-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 1732, in _update_requests
    self.decoder.update_requests(decoder_state)
  File "/mnt/task_runtime/myenv/lib/python3.10/site-packages/tensorrt_llm/_torch/pyexecutor/decoder.py", line 78, in update_requests
    request.py_result.append_context_logits(logits)
  File "/mnt/task_runtime/myenv/lib/python3.10/site-packages/tensorrt_llm/_torch/pyexecutor/llm_request.py", line 138, in append_context_logits
    self._context_logits.append(context_logits)
  File "/mnt/task_runtime/myenv/lib/python3.10/site-packages/tensorrt_llm/_torch/pyexecutor/llm_request.py", line 79, in append
    raise ValueError(
ValueError: LogitsStorage overflow. This storage can only hold 33 logits (0 already filled) but trying to append 192 more logits
[05/16/2025-17:26:23] [TRT-LLM] [E] Encountered an error in decode: LogitsStorage overflow. This storage can only hold 33 logits (0 already filled) but trying to append 192 more logits

Wondering why we have a limit on logit size, and how do I fix this issue?

@MartinMarciniszyn
Copy link
Collaborator

We do not accept any changes in the release branch. Please target main.

@caronzh03
Copy link
Author

closing in favor of another PR against main: #4462

@caronzh03 caronzh03 closed this May 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants