diff --git a/examples/pytorch/quickstart.py b/examples/pytorch/quickstart.py index 9c81a965e5..d53dee70ae 100644 --- a/examples/pytorch/quickstart.py +++ b/examples/pytorch/quickstart.py @@ -1,23 +1,35 @@ from tensorrt_llm import SamplingParams from tensorrt_llm._torch import LLM +import torch +from sentence_transformers import SentenceTransformer def main(): prompts = [ - "Hello, my name is", - "The president of the United States is", "The capital of France is", - "The future of AI is", ] - sampling_params = SamplingParams(max_tokens=32) + sbert_model_path = "./my_custom_model" + sbert = SentenceTransformer(sbert_model_path) + sbert_embeddings = sbert.encode(prompts) + print(sbert_embeddings) + print(f"shape: {sbert_embeddings.shape}") - llm = LLM(model='TinyLlama/TinyLlama-1.1B-Chat-v1.0') + print("=======") + + sampling_params = SamplingParams(max_tokens=32, return_context_logits=True) + llm = LLM(model='./converted-classification-model/') outputs = llm.generate(prompts, sampling_params) + tllm_logits = [] for i, output in enumerate(outputs): prompt = output.prompt - generated_text = output.outputs[0].text - print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}") + tllm_logit = output.context_logits.cpu()[0, :] + print(f"Prompt: {prompt!r}, Context logits: {tllm_logit}") + tllm_logits += [tllm_logit] + # Stack the output + tllm_logits = torch.stack(tllm_logits) + print(tllm_logits) + print(f"shape: {tllm_logits.shape}") if __name__ == '__main__': diff --git a/examples/pytorch/sbert.py b/examples/pytorch/sbert.py new file mode 100644 index 0000000000..7ce5c404dc --- /dev/null +++ b/examples/pytorch/sbert.py @@ -0,0 +1,56 @@ +from sentence_transformers import SentenceTransformer +from transformers import BertForSequenceClassification, BertConfig +import torch + + +my_model_path = './model/my-custom-sbert' + +def load_sentence_transformer(): + # Load the model + model = SentenceTransformer(my_model_path) + print(model) + # SentenceTransformer( + # (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel + # (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, + # 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, + # 'pooling_mode_lasttoken': False, 'include_prompt': True}) + # (2): Dense({'in_features': 768, 'out_features': 192, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'}) + return model + + +# 1. Load SentenceTransformer model +sentence_model = load_sentence_transformer() + +# 2-1. Get the underlying BERT model +bert_model = sentence_model._first_module().auto_model +# 2-2. Extract the Dense layer weights +dense_layer = sentence_model._last_module() # Gets the final Dense layer +dense_weight = dense_layer.linear.weight # Shape [192, 768] + +# 3. Create config for classification model +config = BertConfig.from_pretrained( + my_model_path, + num_labels=192, # Set number of classes + hidden_dropout_prob=bert_model.config.hidden_dropout_prob, + attention_probs_dropout_prob=bert_model.config.attention_probs_dropout_prob +) + +# 4. Create new classification model +classification_model = BertForSequenceClassification(config) + +# 5. Copy weights +classification_model.bert.load_state_dict(bert_model.state_dict(), strict=False) +classification_model.classifier = torch.nn.Linear( + in_features=768, + out_features=192, + bias=True +) +classification_model.classifier.weight.data = dense_weight.clone() +# sentence-transformer's Dense layer has no bias, +# but trtllm weight loading method enforces a bias, +# so setting all to zero here to be computationally equivalent, though less efficient +classification_model.classifier.bias.data.zero_() +print(classification_model) + +# 7. Save if needed +classification_model.save_pretrained('./converted-classification-model') \ No newline at end of file diff --git a/tensorrt_llm/_torch/models/modeling_bert.py b/tensorrt_llm/_torch/models/modeling_bert.py index e8218a53f7..4e8bceef20 100644 --- a/tensorrt_llm/_torch/models/modeling_bert.py +++ b/tensorrt_llm/_torch/models/modeling_bert.py @@ -162,23 +162,25 @@ def forward( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. if attn_metadata is not None: - #NOTE: select the first tokens - offset = attn_metadata.seq_lens_cuda - selected_tokens = torch.cumsum( - attn_metadata.seq_lens_cuda, - dim=0, - dtype=torch.long, - ) - offset - hidden_states = hidden_states[selected_tokens] + # Average pooling with attention metadata (for variable-length sequences) + pooled_output = [] + start_idx = 0 + for seq_len in attn_metadata.seq_lens_cuda: + # Extract the sequence and average pool its tokens + seq = hidden_states[start_idx : start_idx + seq_len] + pooled_seq = seq.mean(dim=0) # Average over tokens + pooled_output.append(pooled_seq) + start_idx += seq_len + pooled_output = torch.stack(pooled_output, dim=0) else: - # hidden_states: [B, N, H] - hidden_states = hidden_states[:, 0] - pooled_output = self.dense(hidden_states) - pooled_output = self.activation(pooled_output) + # Standard average pooling for fixed-length sequences + # hidden_states shape: [batch_size, seq_len, hidden_size] + pooled_output = hidden_states.mean(dim=1) # Average over sequence length + + # Apply dense layer and activation + #pooled_output = self.dense(pooled_output) + #pooled_output = self.activation(pooled_output) return pooled_output diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 08b32c15d1..229ed50630 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -83,7 +83,7 @@ def update_requests(self, state: SampleState) -> None: logits = state.logits[idx] if logits.ndim == 1: # For BERT: Add vocab_size axis to be compatible with LogitsStorage. - logits = logits.unsqueeze(-1) + logits = logits.unsqueeze(0) request.py_result.append_context_logits(logits) diff --git a/triton_backend/all_models/llmapi/tensorrt_llm/1/model.py b/triton_backend/all_models/llmapi/tensorrt_llm/1/model.py index e2a311fc27..307e5f405b 100755 --- a/triton_backend/all_models/llmapi/tensorrt_llm/1/model.py +++ b/triton_backend/all_models/llmapi/tensorrt_llm/1/model.py @@ -157,6 +157,7 @@ def initialize(self, args): ) return + # This seems to be boilerplate code, can we extact them out? def _init_engine(self): """ Initialize the LLM engine in a separate thread running the AsyncIO event loop. @@ -313,6 +314,9 @@ async def _execute_single_request(self, request): raise pb_utils.TritonModelException( "Streaming is only supported in decoupled mode.") # Generate the response. + # TODO(Caron): add max_tokens=32, return_context_logits=True to sampling_params + sampling_params["max_tokens"] = 32 + sampling_params["return_context_logits"] = True response_iterator = self._llm_engine.generate_async( prompt, SamplingParams(**sampling_params), streaming) @@ -419,6 +423,9 @@ def _create_response(self, request_output, output_config): text_output = [ output.text.encode("utf-8") for output in request_output.outputs ] + # TODO(Caron): print the logits of the context instead of output.text for BERT + # however CompletionOutput doesn't seem to have context_logits field + print(request_output.outputs[0].context_logits) response.append( pb_utils.Tensor("text_output", diff --git a/triton_backend/all_models/llmapi/tensorrt_llm/1/model.yaml b/triton_backend/all_models/llmapi/tensorrt_llm/1/model.yaml index 639e84a614..ef44f7aab2 100644 --- a/triton_backend/all_models/llmapi/tensorrt_llm/1/model.yaml +++ b/triton_backend/all_models/llmapi/tensorrt_llm/1/model.yaml @@ -1,7 +1,7 @@ # ======= TRTLLM Model Configurations ======= # The below configurations are used to initialize the LLM model # Check tensorrt_llm.llmapi.LLM for more details -model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +model: textattack/bert-base-uncased-yelp-polarity backend: "pytorch" # Parallel Configurations