Skip to content
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 203 additions & 24 deletions src/modelAccessors/openai_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,201 @@ def supports_tools(self, model: str) -> bool:

def _prepare_schema_for_openai(self, schema: dict) -> dict:
"""
Prepare schema for OpenAI's structured output requirements.
Prepare schema for OpenAI's structured output requirements with minimal changes.

OpenAI requires the root schema to have 'type': 'object' and does not support
oneOf/anyOf anywhere in the schema. Pydantic's discriminated unions generate
schemas with oneOf at the root level, and nullable fields use anyOf, so we
need to flatten and clean them.
Instead of aggressively flattening all unions, this applies targeted fixes:
1. Only flatten problematic union structures
2. Keep oneOf/anyOf that can work with proper branch structure
3. Clean nullable fields (anyOf with null)
4. Ensure all objects have additionalProperties: false
5. Clean $ref objects with extra keywords
6. Handle empty objects explicitly

This preserves the original schema design while ensuring OpenAI compliance.
"""
# Check if the schema already has a root type of "object" and no oneOf/anyOf
if schema.get("type") == "object" and not self._contains_oneof_anyof(schema):
return schema
import copy

# If it's a oneOf/anyOf schema (discriminated union), flatten it
if "oneOf" in schema or "anyOf" in schema:
flattened = self._flatten_discriminated_union(schema)
# Always start with a deep copy to avoid modifying the original
result_schema = copy.deepcopy(schema)

# Clean only problematic anyOf/oneOf structures
# Keep discriminated unions if they're properly structured
if self._has_problematic_unions(result_schema):
result_schema = self._clean_problematic_unions(result_schema)
else:
flattened = schema
# Even if we're not flattening, we need to ensure each branch is OpenAI compliant
self._make_union_branches_compliant(result_schema)

# Clean nullable fields (anyOf with null) - these are always problematic for OpenAI
result_schema = self._clean_oneof_anyof_recursive(result_schema)

# Apply targeted OpenAI compliance fixes without breaking conditional logic
self._fix_required_fields_recursive(result_schema)

return result_schema

def _has_problematic_unions(self, schema: dict) -> bool:
"""
Check if the schema has union structures that need flattening.

Try to preserve oneOf/anyOf if possible. Only flatten if the branches
can't be made OpenAI-compliant as-is.
"""
if "oneOf" in schema:
# Check if all branches can be made OpenAI compliant
one_of = schema["oneOf"]
for branch in one_of:
if "$ref" in branch:
# $ref branches should be fine if the referenced schema is compliant
continue
elif isinstance(branch, dict):
# Check if this branch can be made compliant
if branch.get("type") == "object" or "properties" in branch:
# This should be fixable - don't flatten
continue
else:
# Complex branch that might need flattening
return True

# Clean any remaining oneOf/anyOf structures (like nullable fields)
return self._clean_oneof_anyof_recursive(flattened)
# All branches look fine, don't flatten
return False

# Root-level anyOf might need flattening if it's not just nullable
if "anyOf" in schema:
any_of = schema["anyOf"]
# Check if it's just a nullable pattern
if len(any_of) == 2:
types = []
for item in any_of:
if isinstance(item, dict) and "type" in item:
types.append(item["type"])
if "null" in types:
# This is just a nullable field, don't flatten at root
return False
# Other anyOf patterns might need flattening
return True

return False

def _make_union_branches_compliant(self, schema: dict):
"""
Make each branch of a oneOf/anyOf union OpenAI compliant without flattening.

This processes each branch to ensure it has additionalProperties: false
and proper structure, while preserving the union semantics.
"""
if "oneOf" in schema:
for branch in schema["oneOf"]:
if isinstance(branch, dict) and "$ref" not in branch:
# Make this branch OpenAI compliant
if branch.get("type") == "object" or "properties" in branch:
branch["additionalProperties"] = False

if "anyOf" in schema:
for branch in schema["anyOf"]:
if isinstance(branch, dict) and "$ref" not in branch:
# Make this branch OpenAI compliant
if branch.get("type") == "object" or "properties" in branch:
branch["additionalProperties"] = False

def _clean_problematic_unions(self, schema: dict) -> dict:
"""
Clean only the problematic union structures, preserving good ones.
"""
if "oneOf" in schema or "anyOf" in schema:
return self._flatten_discriminated_union(schema)
return schema

def _fix_required_fields_recursive(self, schema):
"""
Recursively ensure all objects in the schema are OpenAI-compliant.

Instead of forcing all properties to be required, this applies targeted fixes:
1. Add additionalProperties: false to all objects
2. Extend (don't overwrite) existing required arrays only when needed
3. Handle empty objects explicitly
4. Clean problematic $ref objects

This preserves the original schema structure while ensuring OpenAI compliance.
"""
if not isinstance(schema, dict):
return

# If this is an object type, ensure additionalProperties is false
if schema.get("type") == "object":
schema["additionalProperties"] = False

# Handle objects with properties
if "properties" in schema and isinstance(schema["properties"], dict):
properties = schema["properties"]

# Ensure additionalProperties is false for objects with properties
schema["additionalProperties"] = False

# Handle empty objects explicitly (OpenAI requirement)
if not properties:
schema.setdefault("required", [])
else:
# Ensure required key exists (OpenAI requirement)
schema.setdefault("required", [])

# Only extend required array if it's missing properties that should be required
# Don't force ALL properties to be required - preserve conditional logic
existing_required = set(schema.get("required", []))

# For discriminated unions, only the discriminator should be universally required
# Other fields remain conditional based on the original schema design
pass # Let the original schema determine what should be required

# Handle empty objects without properties (must be explicit)
elif schema.get("type") == "object" and "properties" not in schema:
schema["properties"] = {}
schema["required"] = []
schema["additionalProperties"] = False

# Recursively fix objects in $defs and definitions
for defs_key in ["$defs", "definitions"]:
if defs_key in schema and isinstance(schema[defs_key], dict):
for def_schema in schema[defs_key].values():
self._fix_required_fields_recursive(def_schema)

# Recursively fix nested objects in properties
if "properties" in schema and isinstance(schema["properties"], dict):
for prop_schema in schema["properties"].values():
self._fix_required_fields_recursive(prop_schema)

# Fix $ref objects that have additional keywords (OpenAI doesn't allow this)
self._clean_ref_objects_recursive(schema)

# Recursively fix array item schemas
if "items" in schema and isinstance(schema["items"], dict):
self._fix_required_fields_recursive(schema["items"])

def _clean_ref_objects_recursive(self, schema):
"""
Recursively clean $ref objects that have additional keywords.

OpenAI's strict validation doesn't allow $ref to be combined with other keywords
like 'default', 'title', etc. This removes such keywords from $ref objects.
"""
if not isinstance(schema, dict):
return

# If this object has $ref, remove all other keywords except $ref
if "$ref" in schema:
ref_value = schema["$ref"]
schema.clear()
schema["$ref"] = ref_value
return # Don't recurse into a pure $ref object

# Recursively clean nested structures
for key, value in list(schema.items()):
if isinstance(value, dict):
self._clean_ref_objects_recursive(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
self._clean_ref_objects_recursive(item)

def _contains_oneof_anyof(self, obj) -> bool:
"""Recursively check if an object contains oneOf or anyOf."""
Expand Down Expand Up @@ -140,18 +316,21 @@ def _clean_oneof_anyof_recursive(self, obj):
# Create a new dict without anyOf
cleaned = {k: v for k, v in obj.items() if k != "anyOf"}
cleaned.update(other_schema)
# Make it nullable by not including it in required fields

# If the field had a null default but is now non-nullable, remove the default
# This prevents invalid schemas where required string fields have null defaults
if cleaned.get("default") is None and cleaned.get("type") != "null":
# Remove null default for non-nullable required fields
cleaned.pop("default", None)

return self._clean_oneof_anyof_recursive(cleaned)

# Handle oneOf (shouldn't happen after flattening, but just in case)
# Handle oneOf - only remove if it's problematic
if "oneOf" in obj:
# This is a complex case - for now, take the first option
# In practice, this shouldn't happen after proper flattening
one_of = obj["oneOf"]
if one_of:
cleaned = {k: v for k, v in obj.items() if k != "oneOf"}
cleaned.update(one_of[0])
return self._clean_oneof_anyof_recursive(cleaned)
# Don't automatically remove oneOf at root level
# Only clean nested oneOf that might be problematic
# Skip root-level oneOf that we want to preserve
pass

# Recursively clean all nested objects
return {k: self._clean_oneof_anyof_recursive(v) for k, v in obj.items()}
Expand Down Expand Up @@ -222,7 +401,7 @@ def _flatten_discriminated_union(self, schema: dict) -> dict:
flattened = {
"type": "object",
"properties": all_properties,
"required": list(all_properties.keys()), # OpenAI requires all properties to be in required
"required": list(required_fields), # Only require discriminator and truly required fields
"additionalProperties": False
}

Expand Down
Loading
Loading