Skip to content

Adapt Sentence Transformer model to BertForSeqClassification #4462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions examples/pytorch/quickstart.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
import torch
from sentence_transformers import SentenceTransformer


def main():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=32)
sbert_model_path = "./my_custom_model"
sbert = SentenceTransformer(sbert_model_path)
sbert_embeddings = sbert.encode(prompts)
print(sbert_embeddings)
print(f"shape: {sbert_embeddings.shape}")

llm = LLM(model='TinyLlama/TinyLlama-1.1B-Chat-v1.0')
print("=======")

sampling_params = SamplingParams(max_tokens=32, return_context_logits=True)
llm = LLM(model='./converted-classification-model/')
outputs = llm.generate(prompts, sampling_params)

tllm_logits = []
for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
tllm_logit = output.context_logits.cpu()[0, :]
print(f"Prompt: {prompt!r}, Context logits: {tllm_logit}")
tllm_logits += [tllm_logit]
# Stack the output
tllm_logits = torch.stack(tllm_logits)
print(tllm_logits)
print(f"shape: {tllm_logits.shape}")


if __name__ == '__main__':
Expand Down
56 changes: 56 additions & 0 deletions examples/pytorch/sbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from sentence_transformers import SentenceTransformer
from transformers import BertForSequenceClassification, BertConfig
import torch


my_model_path = './model/my-custom-sbert'

def load_sentence_transformer():
# Load the model
model = SentenceTransformer(my_model_path)
print(model)
# SentenceTransformer(
# (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel
# (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True,
# 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False,
# 'pooling_mode_lasttoken': False, 'include_prompt': True})
# (2): Dense({'in_features': 768, 'out_features': 192, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
return model


# 1. Load SentenceTransformer model
sentence_model = load_sentence_transformer()

# 2-1. Get the underlying BERT model
bert_model = sentence_model._first_module().auto_model
# 2-2. Extract the Dense layer weights
dense_layer = sentence_model._last_module() # Gets the final Dense layer
dense_weight = dense_layer.linear.weight # Shape [192, 768]

# 3. Create config for classification model
config = BertConfig.from_pretrained(
my_model_path,
num_labels=192, # Set number of classes
hidden_dropout_prob=bert_model.config.hidden_dropout_prob,
attention_probs_dropout_prob=bert_model.config.attention_probs_dropout_prob
)

# 4. Create new classification model
classification_model = BertForSequenceClassification(config)

# 5. Copy weights
classification_model.bert.load_state_dict(bert_model.state_dict(), strict=False)
classification_model.classifier = torch.nn.Linear(
in_features=768,
out_features=192,
bias=True
)
classification_model.classifier.weight.data = dense_weight.clone()
# sentence-transformer's Dense layer has no bias,
# but trtllm weight loading method enforces a bias,
# so setting all to zero here to be computationally equivalent, though less efficient
classification_model.classifier.bias.data.zero_()
print(classification_model)

# 7. Save if needed
classification_model.save_pretrained('./converted-classification-model')
32 changes: 17 additions & 15 deletions tensorrt_llm/_torch/models/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,23 +162,25 @@ def forward(
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:

# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
if attn_metadata is not None:
#NOTE: select the first tokens
offset = attn_metadata.seq_lens_cuda
selected_tokens = torch.cumsum(
attn_metadata.seq_lens_cuda,
dim=0,
dtype=torch.long,
) - offset
hidden_states = hidden_states[selected_tokens]
# Average pooling with attention metadata (for variable-length sequences)
pooled_output = []
start_idx = 0
for seq_len in attn_metadata.seq_lens_cuda:
# Extract the sequence and average pool its tokens
seq = hidden_states[start_idx : start_idx + seq_len]
pooled_seq = seq.mean(dim=0) # Average over tokens
pooled_output.append(pooled_seq)
start_idx += seq_len
pooled_output = torch.stack(pooled_output, dim=0)
else:
# hidden_states: [B, N, H]
hidden_states = hidden_states[:, 0]
pooled_output = self.dense(hidden_states)
pooled_output = self.activation(pooled_output)
# Standard average pooling for fixed-length sequences
# hidden_states shape: [batch_size, seq_len, hidden_size]
pooled_output = hidden_states.mean(dim=1) # Average over sequence length

# Apply dense layer and activation
#pooled_output = self.dense(pooled_output)
#pooled_output = self.activation(pooled_output)
return pooled_output


Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def update_requests(self, state: SampleState) -> None:
logits = state.logits[idx]
if logits.ndim == 1:
# For BERT: Add vocab_size axis to be compatible with LogitsStorage.
logits = logits.unsqueeze(-1)
logits = logits.unsqueeze(0)
request.py_result.append_context_logits(logits)


Expand Down
7 changes: 7 additions & 0 deletions triton_backend/all_models/llmapi/tensorrt_llm/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def initialize(self, args):
)
return

# This seems to be boilerplate code, can we extact them out?
def _init_engine(self):
"""
Initialize the LLM engine in a separate thread running the AsyncIO event loop.
Expand Down Expand Up @@ -313,6 +314,9 @@ async def _execute_single_request(self, request):
raise pb_utils.TritonModelException(
"Streaming is only supported in decoupled mode.")
# Generate the response.
# TODO(Caron): add max_tokens=32, return_context_logits=True to sampling_params
sampling_params["max_tokens"] = 32
sampling_params["return_context_logits"] = True
response_iterator = self._llm_engine.generate_async(
prompt, SamplingParams(**sampling_params), streaming)

Expand Down Expand Up @@ -419,6 +423,9 @@ def _create_response(self, request_output, output_config):
text_output = [
output.text.encode("utf-8") for output in request_output.outputs
]
# TODO(Caron): print the logits of the context instead of output.text for BERT
# however CompletionOutput doesn't seem to have context_logits field
print(request_output.outputs[0].context_logits)

response.append(
pb_utils.Tensor("text_output",
Expand Down
2 changes: 1 addition & 1 deletion triton_backend/all_models/llmapi/tensorrt_llm/1/model.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ======= TRTLLM Model Configurations =======
# The below configurations are used to initialize the LLM model
# Check tensorrt_llm.llmapi.LLM for more details
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model: textattack/bert-base-uncased-yelp-polarity
backend: "pytorch"

# Parallel Configurations
Expand Down