Skip to content

Commit

Permalink
refactor: return reasoning traces for non-JSON models
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanKaddour committed Jan 27, 2025
1 parent 942e1ce commit 0506d5c
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions backend/app/nodes/llm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,30 +638,19 @@ async def generate_text(
# For models that don't support JSON output, wrap the response in a JSON structure
if not supports_json:
sanitized_response = response.replace('"', '\\"').replace("\n", "\\n")
# Check for provider-specific fields
if hasattr(raw_response, 'choices') and len(raw_response.choices) > 0:
if hasattr(raw_response.choices[0].message, 'provider_specific_fields'):
provider_fields = raw_response.choices[0].message.provider_specific_fields
return json.dumps({
"output": sanitized_response,
"provider_specific_fields": provider_fields
})
return f'{{"output": "{sanitized_response}"}}'

# Ensure response is valid JSON for models that support it
if supports_json:
try:
# Check if the response has provider-specific fields
if hasattr(raw_response, 'choices') and len(raw_response.choices) > 0:
if hasattr(raw_response.choices[0].message, 'provider_specific_fields'):
provider_fields = raw_response.choices[0].message.provider_specific_fields
try:
# Parse the existing response
response_json = json.loads(response)
# Add provider-specific fields
response_json['provider_specific_fields'] = provider_fields
return json.dumps(response_json)
except json.JSONDecodeError:
# If response is not valid JSON, create a new JSON object
sanitized_response = response.replace('"', '\\"').replace("\n", "\\n")
return json.dumps({
"output": sanitized_response,
"provider_specific_fields": provider_fields
})

# If no provider-specific fields, proceed with normal JSON validation
json.loads(response)
return response
except json.JSONDecodeError:
Expand All @@ -680,6 +669,14 @@ async def generate_text(

# If all attempts to parse JSON fail, wrap the response in a JSON structure
sanitized_response = response.replace('"', '\\"').replace("\n", "\\n")
# Check for provider-specific fields
if hasattr(raw_response, 'choices') and len(raw_response.choices) > 0:
if hasattr(raw_response.choices[0].message, 'provider_specific_fields'):
provider_fields = raw_response.choices[0].message.provider_specific_fields
return json.dumps({
"output": sanitized_response,
"provider_specific_fields": provider_fields
})
return f'{{"output": "{sanitized_response}"}}'

return response
Expand Down

0 comments on commit 0506d5c

Please sign in to comment.