Skip to content

Commit 38df1e9

Browse files
Allow parse_response to accept token IDs (#41849)
* Allow tokenizer.parse_response() to accept IDs/arrays directly * Allow tokenizer.parse_response() to accept IDs/arrays directly
1 parent 5462376 commit 38df1e9

File tree

2 files changed

+59
-6
lines changed

2 files changed

+59
-6
lines changed

src/transformers/tokenization_utils_base.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
of output with special method for the Fast tokenizers)
1919
"""
2020

21+
from __future__ import annotations
22+
2123
import copy
2224
import json
2325
import os
@@ -783,7 +785,7 @@ def as_tensor(value, dtype=None):
783785

784786
return self
785787

786-
def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding":
788+
def to(self, device: Union[str, torch.device], *, non_blocking: bool = False) -> BatchEncoding:
787789
"""
788790
Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only).
789791
@@ -1858,7 +1860,11 @@ def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional
18581860

18591861
return chat_template
18601862

1861-
def parse_response(self, response: str, schema: Optional[Union[list, dict]] = None):
1863+
def parse_response(
1864+
self,
1865+
response: str | list[str | int | list[int]] | np.ndarray | torch.Tensor,
1866+
schema: list | dict | None = None,
1867+
):
18621868
"""
18631869
Converts an output string created by generating text from a model into a parsed message dictionary.
18641870
This method is intended for use with chat models, and will read the tokenizer's `response_schema` attribute to
@@ -1869,16 +1875,29 @@ def parse_response(self, response: str, schema: Optional[Union[list, dict]] = No
18691875
18701876
Args:
18711877
response (`str`):
1872-
The output string generated by the model. This should be the decoded string, not raw tokens.
1878+
The output string generated by the model. This can be either a decoded string or list of strings,
1879+
or token IDs as a list/array.
18731880
schema (`Union[list, dict]`, *optional*):
18741881
A response schema that indicates the expected output format and how parsing should be performed.
18751882
If not provided, the tokenizer's `response_schema` attribute will be used.
18761883
"""
1884+
batched = (
1885+
(isinstance(response, list) and not isinstance(response[0], int))
1886+
or getattr(response, "ndim", 0) > 1 # For torch/numpy tensors
1887+
)
1888+
18771889
if schema is None:
18781890
if getattr(self, "response_schema", None) is None:
18791891
raise AttributeError("This tokenizer does not have a `response_schema` for parsing chat responses!")
18801892
schema = self.response_schema
1881-
return recursive_parse(response, schema)
1893+
if batched:
1894+
if not (isinstance(response, list) and isinstance(response[0], str)):
1895+
response = self.batch_decode(response)
1896+
return [recursive_parse(single_response, schema) for single_response in response]
1897+
else:
1898+
if not isinstance(response, str):
1899+
response = self.decode(response)
1900+
return recursive_parse(response, schema)
18821901

18831902
@classmethod
18841903
def from_pretrained(
@@ -3863,7 +3882,7 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str:
38633882

38643883
def batch_decode(
38653884
self,
3866-
sequences: Union[list[int], list[list[int]], np.ndarray, "torch.Tensor"],
3885+
sequences: Union[list[int], list[list[int]], np.ndarray, torch.Tensor],
38673886
skip_special_tokens: bool = False,
38683887
clean_up_tokenization_spaces: Optional[bool] = None,
38693888
**kwargs,
@@ -3897,7 +3916,7 @@ def batch_decode(
38973916

38983917
def decode(
38993918
self,
3900-
token_ids: Union[int, list[int], np.ndarray, "torch.Tensor"],
3919+
token_ids: Union[int, list[int], np.ndarray, torch.Tensor],
39013920
skip_special_tokens: bool = False,
39023921
clean_up_tokenization_spaces: Optional[bool] = None,
39033922
**kwargs,

tests/utils/test_chat_parsing_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,40 @@ def test_tokenizer_method(self):
200200
tokenizer_parsed_chat = tokenizer.parse_response(model_out)
201201
self.assertEqual(tokenizer_parsed_chat, parsed_chat)
202202

203+
def test_batched_inputs(self):
204+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
205+
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
206+
tokenizer.response_schema = cohere_schema
207+
parsed_chat = tokenizer.parse_response(model_out)
208+
self.assertEqual(tokenizer.parse_response([model_out]), [parsed_chat])
209+
self.assertEqual(tokenizer.parse_response([model_out] * 2), [parsed_chat] * 2)
210+
211+
def test_token_id_inputs(self):
212+
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") # Need an actual tokenizer to encode
213+
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
214+
tokenizer.response_schema = cohere_schema
215+
parsed_chat = tokenizer.parse_response(model_out)
216+
tokenized_out = tokenizer(model_out).input_ids
217+
self.assertEqual(tokenizer.parse_response(tokenized_out), parsed_chat)
218+
self.assertEqual(tokenizer.parse_response([tokenized_out]), [parsed_chat])
219+
self.assertEqual(tokenizer.parse_response([tokenized_out] * 2), [parsed_chat] * 2)
220+
221+
def test_numpy_inputs(self):
222+
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") # Need an actual tokenizer to encode
223+
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
224+
tokenizer.response_schema = cohere_schema
225+
parsed_chat = tokenizer.parse_response(model_out)
226+
tokenized_out = tokenizer(model_out, return_tensors="np").input_ids
227+
self.assertEqual(tokenizer.parse_response(tokenized_out), [parsed_chat])
228+
229+
def test_tensor_inputs(self):
230+
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") # Need an actual tokenizer to encode
231+
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
232+
tokenizer.response_schema = cohere_schema
233+
parsed_chat = tokenizer.parse_response(model_out)
234+
tokenized_out = tokenizer(model_out, return_tensors="pt").input_ids
235+
self.assertEqual(tokenizer.parse_response(tokenized_out), [parsed_chat])
236+
203237
def test_cohere_template(self):
204238
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
205239
parsed_chat = recursive_parse(model_out, cohere_schema)

0 commit comments

Comments
 (0)