diff --git a/examples/pytorch/quickstart.py b/examples/pytorch/quickstart.py index 9c81a965e5..c351c658c6 100644 --- a/examples/pytorch/quickstart.py +++ b/examples/pytorch/quickstart.py @@ -1,6 +1,8 @@ from tensorrt_llm import SamplingParams from tensorrt_llm._torch import LLM +import torch + def main(): prompts = [ @@ -9,15 +11,23 @@ def main(): "The capital of France is", "The future of AI is", ] - sampling_params = SamplingParams(max_tokens=32) + sampling_params = SamplingParams(max_tokens=32, return_context_logits=True) - llm = LLM(model='TinyLlama/TinyLlama-1.1B-Chat-v1.0') + model_path = "/code/tensorrt_llm/custom_bert_classifier" + llm = LLM(model=model_path) outputs = llm.generate(prompts, sampling_params) + tllm_logits = [] for i, output in enumerate(outputs): prompt = output.prompt - generated_text = output.outputs[0].text - print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}") + logits = output.context_logits.cpu() + print(f"[{i}] Prompt: {prompt!r}, logits: {logits}") + tllm_logits += [logits] + + # stack logits + tllm_logits = torch.stack(tllm_logits) + print(f"tllm_logits: {tllm_logits}") + if __name__ == '__main__': diff --git a/tensorrt_llm/_torch/models/modeling_bert.py b/tensorrt_llm/_torch/models/modeling_bert.py index 19d61f6f0e..f1d1b81373 100644 --- a/tensorrt_llm/_torch/models/modeling_bert.py +++ b/tensorrt_llm/_torch/models/modeling_bert.py @@ -156,8 +156,6 @@ class BertPooler(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() def forward( self, @@ -165,22 +163,16 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. + # Apply average pooling over all token embeddings if attn_metadata is not None: - #NOTE: select the first tokens - offset = attn_metadata.seq_lens_cuda - selected_tokens = torch.cumsum( - attn_metadata.seq_lens_cuda, - dim=0, - dtype=torch.long, - ) - offset - hidden_states = hidden_states[selected_tokens] + # Use sequence lengths from attn_metadata to mask padding tokens + seq_lens = attn_metadata.seq_lens_cuda + mask = torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) < seq_lens.unsqueeze(1) + hidden_states = hidden_states * mask.unsqueeze(-1) # Mask padding tokens + pooled_output = hidden_states.sum(dim=1) / seq_lens.unsqueeze(1) else: - # hidden_states: [B, N, H] - hidden_states = hidden_states[:, 0] - pooled_output = self.dense(hidden_states) - pooled_output = self.activation(pooled_output) + # Average pooling over all tokens (assumes no padding) + pooled_output = hidden_states.mean(dim=1) return pooled_output