Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
json parser update, handle floating point precision in double values
Browse files Browse the repository at this point in the history
zanderjiang committed Jan 23, 2025
1 parent 220cb36 commit 6b770e8
Showing 3 changed files with 132 additions and 45 deletions.
102 changes: 60 additions & 42 deletions cpp/function_calling.cc
Original file line number Diff line number Diff line change
@@ -5,6 +5,10 @@

#include <xgrammar/function_calling.h>

#include <iomanip>
#include <limits>
#include <sstream>
#include <string>
#include <string_view>
#include <unordered_map>

@@ -13,7 +17,47 @@

namespace xgrammar {

/******************* Parser *******************/
std::string json_to_str(const picojson::value& v) {
// convert picjson value
if (v.is<std::string>()) {
return v.get<std::string>();
}
if (v.is<double>()) {
std::ostringstream oss;
oss << v.get<double>();
return oss.str();
}
if (v.is<bool>()) {
return v.evaluate_as_boolean() ? "true" : "false";
}
if (v.is<picojson::null>()) {
return "null";
}
if (v.is<picojson::array>()) {
const auto& arr = v.get<picojson::array>();
std::string result = "[";
for (size_t i = 0; i < arr.size(); ++i) {
if (i > 0) result += ",";
result += json_to_str(arr[i]);
}
result += "]";
return result;
}
if (v.is<picojson::object>()) {
const auto& obj = v.get<picojson::object>();
std::string result = "{";
bool first = true;
for (const auto& pair : obj) {
if (!first) result += ",";
result += "\"" + pair.first + "\":" + json_to_str(pair.second);
first = false;
}
result += "}";
return result;
}
XGRAMMAR_LOG(FATAL) << "Unsupported JSON value type";
return "";
}

std::vector<std::pair<std::string, std::unordered_map<std::string, std::string>>> parse_message(
const std::string& input, bool ignore_error
@@ -56,7 +100,7 @@ std::vector<std::pair<std::string, std::unordered_map<std::string, std::string>>
std::string function_name = input.substr(name_start, name_end - name_start);
std::string params_str = input.substr(params_start, params_end - params_start);

// Replace single quotes with double quotes
// replace single quotes with double quotes
size_t quote_pos = 0;
while ((quote_pos = params_str.find('\'', quote_pos)) != std::string::npos) {
params_str.replace(quote_pos, 1, "\"");
@@ -83,31 +127,25 @@ std::vector<std::pair<std::string, std::unordered_map<std::string, std::string>>
}

// main processing, can handle string, double, bool, null types and convert to string
if (!v.is<picojson::object>()) {
if (!ignore_error) {
XGRAMMAR_LOG(FATAL) << "Parameters must be a JSON object";
}
input_pos++;
continue;
}
const picojson::object& obj = v.get<picojson::object>();
std::unordered_map<std::string, std::string> params;
for (const auto& pair : obj) {
if (pair.second.is<std::string>()) {
params[pair.first] = pair.second.get<std::string>();
} else if (pair.second.is<double>()) {
params[pair.first] = std::to_string(pair.second.get<double>());
} else if (pair.second.is<bool>()) {
params[pair.first] = pair.second.get<bool>() ? "true" : "false";
} else if (pair.second.is<picojson::null>()) {
params[pair.first] = "null";
} else {
if (!ignore_error) {
XGRAMMAR_LOG(FATAL) << "Invalid parameter type for field: " << pair.first;
}
continue;
}
params[pair.first] = json_to_str(pair.second);
}

tool_calls.emplace_back(function_name, std::move(params));
input_pos = params_end + 11;
continue;
}

// JSON format case,
// JSON format
// e.g. {"name": "get_current_conditions", "parameters": {"location": "San Francisco, CA",
// "unit": "Fahrenheit"}}
int bracket_count = 1;
@@ -153,41 +191,21 @@ std::vector<std::pair<std::string, std::unordered_map<std::string, std::string>>
auto name_it = obj.find("name");
auto params_it = obj.find("parameters");

if (name_it == obj.end() || !name_it->second.is<std::string>()) {
if (name_it == obj.end() || !name_it->second.is<std::string>() || params_it == obj.end() ||
!params_it->second.is<picojson::object>()) {
if (!ignore_error) {
XGRAMMAR_LOG(FATAL) << "Invalid JSON format: missing or invalid name field";
XGRAMMAR_LOG(FATAL) << "Invalid JSON format: missing or invalid name/parameters fields";
}
input_pos = json_end_pos;
continue;
}

if (params_it == obj.end() || !params_it->second.is<picojson::object>()) {
if (!ignore_error) {
XGRAMMAR_LOG(FATAL) << "Invalid JSON format: missing or invalid parameters field";
}
input_pos = json_end_pos;
continue;
}

std::string name = name_it->second.get<std::string>();
const std::string& name = name_it->second.get<std::string>();
const picojson::object& params_obj = params_it->second.get<picojson::object>();
std::unordered_map<std::string, std::string> params;

for (const auto& pair : params_obj) {
if (pair.second.is<std::string>()) {
params[pair.first] = pair.second.get<std::string>();
} else if (pair.second.is<double>()) {
params[pair.first] = std::to_string(pair.second.get<double>());
} else if (pair.second.is<bool>()) {
params[pair.first] = pair.second.get<bool>() ? "true" : "false";
} else if (pair.second.is<picojson::null>()) {
params[pair.first] = "null";
} else {
if (!ignore_error) {
XGRAMMAR_LOG(FATAL) << "Invalid parameter type for field: " << pair.first;
}
continue;
}
params[pair.first] = json_to_str(pair.second);
}

tool_calls.emplace_back(name, std::move(params));
2 changes: 1 addition & 1 deletion cpp/pybind/pybind.cc
Original file line number Diff line number Diff line change
@@ -87,7 +87,7 @@ PYBIND11_MODULE(xgrammar_bindings, m) {
)
.def("_regex_to_ebnf", &RegexToEBNF)
.def("_get_masked_tokens_from_bitmask", &Matcher_DebugGetMaskedTokensFromBitmask)
.def("_parse_message", &parse_message, py::arg("input"), py::arg("ignore_error") = false);
.def("_parse_message", &parse_message);

auto pyKernelsModule = m.def_submodule("kernels");
pyKernelsModule.def("apply_token_bitmask_inplace_cpu", &Kernels_ApplyTokenBitmaskInplaceCPU);
73 changes: 71 additions & 2 deletions tests/python/test_function_calling.py
Original file line number Diff line number Diff line change
@@ -98,8 +98,8 @@ def test_special_values():
"null_value": "null",
"bool_true": "true",
"bool_false": "false",
"number": "42.000000",
"float": "3.140000",
"number": "42",
"float": "3.14",
},
)

@@ -150,6 +150,75 @@ def test_whitespace_handling():
assert result[0] == ("test", {"param": "value"})


def test_unicode_parameters():
"""Test handling of Unicode characters in parameters."""
input_str = '<function=translate>{"text": "こんにちは世界", "target": "español"}</function>'
result = _parse_message(input_str)
assert len(result) == 1
assert result[0] == ("translate", {"text": "こんにちは世界", "target": "español"})


def test_escaped_characters():
"""Test handling of escaped characters in JSON."""
input_str = '<function=process>{"path": "C:\\\\Program Files\\\\App", "query": ""quoted string""}</function>'
result = _parse_message(input_str)
assert len(result) == 1
assert result[0] == ("process", {"path": "C:\\Program Files\\App", "query": '"quoted string"'})


def test_empty_array_parameters():
"""Test handling of empty arrays in parameters."""
input_str = '<function=update_list>{"ids": [], "tags": [""], "flags": null}</function>'
result = _parse_message(input_str)
assert len(result) == 1
assert result[0] == ("update_list", {"ids": [], "tags": [""], "flags": "null"})


def test_large_number_of_functions():
"""Test handling of a large number of sequential function calls."""
functions = [f'<function=func_{i}>{{"id": "{i}"}}</function>' for i in range(100)]
input_str = "\n".join(functions)
result = _parse_message(input_str)
assert len(result) == 100
for i, (name, params) in enumerate(result):
assert name == f"func_{i}"
assert params == {"id": str(i)}


def test_mixed_line_endings():
"""Test handling of different line endings (CRLF, LF)."""
input_str = '<function=test1>{"a": "1"}</function>\r\n<function=test2>{"b": "2"}</function>\n<function=test3>{"c": "3"}</function>'
result = _parse_message(input_str)
assert len(result) == 3
assert result[0] == ("test1", {"a": "1"})
assert result[1] == ("test2", {"b": "2"})
assert result[2] == ("test3", {"c": "3"})


def test_json_with_comments():
"""Test handling of JSON-like content with comments (should fail)."""
input_str = """
<function=test>{
// This is a comment
"param": "value" /* inline comment */
}</function>
"""
with pytest.raises(Exception):
_parse_message(input_str)


def test_case_sensitivity():
"""Test case sensitivity in function names and parameters."""
input_str = """
<function=TestFunction>{"ParamOne": "Value", "paramTwo": "value"}</function>
{"name": "TestFunction2", "parameters": {"PARAM": "VALUE"}}
"""
result = _parse_message(input_str)
assert len(result) == 2
assert result[0] == ("TestFunction", {"ParamOne": "Value", "paramTwo": "value"})
assert result[1] == ("TestFunction2", {"PARAM": "VALUE"})


@pytest.mark.parametrize(
"input_str,expected",
[

0 comments on commit 6b770e8

Please sign in to comment.