Retrofit sentence bert into BertForSeqClassification #3924
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Primary goal of this PR is to validate my idea of retrofitting a BERT-based
SentenceTransformer
model intoBertForSequenceClassification
model, and use trtllm'sLLM
API to do inference.The SentenceTransformer model has 3 modules:
To map the weights above to a
BertForSequenceClassification
model, I'm planning to:modeling_bert.py
to computing the average of token embeddings. (This PR has the changes needed)torch.nn.Linear
)'s weights toclassifier
layer inBertForSequenceClassification
.I have not been able to verify the above method due to some Docker setup issues.
@qixiang-99 , if you could help comment on whether the above approach is reasonable, that'll greatly help.
Note - I wasn't planning on merging this PR, because this is a very specific experiment for my own use case. However, if people see values in generalizing this approach, we can certainly polish this and create a new PR against main branch.