1818of output with special method for the Fast tokenizers)
1919"""
2020
21+ from __future__ import annotations
22+
2123import copy
2224import json
2325import os
@@ -783,7 +785,7 @@ def as_tensor(value, dtype=None):
783785
784786 return self
785787
786- def to (self , device : Union [str , " torch.device" ], * , non_blocking : bool = False ) -> " BatchEncoding" :
788+ def to (self , device : Union [str , torch .device ], * , non_blocking : bool = False ) -> BatchEncoding :
787789 """
788790 Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only).
789791
@@ -1858,7 +1860,11 @@ def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional
18581860
18591861 return chat_template
18601862
1861- def parse_response (self , response : str , schema : Optional [Union [list , dict ]] = None ):
1863+ def parse_response (
1864+ self ,
1865+ response : str | list [str | int | list [int ]] | np .ndarray | torch .Tensor ,
1866+ schema : list | dict | None = None ,
1867+ ):
18621868 """
18631869 Converts an output string created by generating text from a model into a parsed message dictionary.
18641870 This method is intended for use with chat models, and will read the tokenizer's `response_schema` attribute to
@@ -1869,16 +1875,29 @@ def parse_response(self, response: str, schema: Optional[Union[list, dict]] = No
18691875
18701876 Args:
18711877 response (`str`):
1872- The output string generated by the model. This should be the decoded string, not raw tokens.
1878+ The output string generated by the model. This can be either a decoded string or list of strings,
1879+ or token IDs as a list/array.
18731880 schema (`Union[list, dict]`, *optional*):
18741881 A response schema that indicates the expected output format and how parsing should be performed.
18751882 If not provided, the tokenizer's `response_schema` attribute will be used.
18761883 """
1884+ batched = (
1885+ (isinstance (response , list ) and not isinstance (response [0 ], int ))
1886+ or getattr (response , "ndim" , 0 ) > 1 # For torch/numpy tensors
1887+ )
1888+
18771889 if schema is None :
18781890 if getattr (self , "response_schema" , None ) is None :
18791891 raise AttributeError ("This tokenizer does not have a `response_schema` for parsing chat responses!" )
18801892 schema = self .response_schema
1881- return recursive_parse (response , schema )
1893+ if batched :
1894+ if not (isinstance (response , list ) and isinstance (response [0 ], str )):
1895+ response = self .batch_decode (response )
1896+ return [recursive_parse (single_response , schema ) for single_response in response ]
1897+ else :
1898+ if not isinstance (response , str ):
1899+ response = self .decode (response )
1900+ return recursive_parse (response , schema )
18821901
18831902 @classmethod
18841903 def from_pretrained (
@@ -3863,7 +3882,7 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str:
38633882
38643883 def batch_decode (
38653884 self ,
3866- sequences : Union [list [int ], list [list [int ]], np .ndarray , " torch.Tensor" ],
3885+ sequences : Union [list [int ], list [list [int ]], np .ndarray , torch .Tensor ],
38673886 skip_special_tokens : bool = False ,
38683887 clean_up_tokenization_spaces : Optional [bool ] = None ,
38693888 ** kwargs ,
@@ -3897,7 +3916,7 @@ def batch_decode(
38973916
38983917 def decode (
38993918 self ,
3900- token_ids : Union [int , list [int ], np .ndarray , " torch.Tensor" ],
3919+ token_ids : Union [int , list [int ], np .ndarray , torch .Tensor ],
39013920 skip_special_tokens : bool = False ,
39023921 clean_up_tokenization_spaces : Optional [bool ] = None ,
39033922 ** kwargs ,
0 commit comments