diff --git a/internal/tool/examples_nested_test.go b/internal/tool/examples_nested_test.go new file mode 100644 index 000000000..bfbd5be46 --- /dev/null +++ b/internal/tool/examples_nested_test.go @@ -0,0 +1,94 @@ +package tool_test + +import ( + "encoding/json" + "reflect" + "testing" + + itool "trpc.group/trpc-go/trpc-agent-go/internal/tool" +) + +// Pet defines the user's fury friend. +type Pet struct { + // Name of the animal. + Name string `json:"name" jsonschema:"title=Name"` + // Animal type, e.g., dog, cat, hamster. + AnimalType AnimalType `json:"animal_type" jsonschema:"title=Animal Type"` +} + +type AnimalType string + +// Pets is a collection of Pet objects. +type Pets []*Pet + +// NamedPets is a map of animal names to pets. +type NamedPets map[string]*Pet + +type ( + // Plant represents the plants the user might have and serves as a test + // of structs inside a `type` set. + Plant struct { + Variant string `json:"variant" jsonschema:"title=Variant"` // This comment will be used + // Multicellular is true if the plant is multicellular + Multicellular bool `json:"multicellular,omitempty" jsonschema:"title=Multicellular"` // This comment will be ignored + } +) +type User struct { + // Unique sequential identifier. + ID int `json:"id" jsonschema:"required"` + // This comment will be ignored + Name string `json:"name" jsonschema:"required,minLength=1,maxLength=20,pattern=.*,description=this is a property,title=the name,example=joe,example=lucy,default=alex"` + Friends []int `json:"friends,omitempty" jsonschema_description:"list of IDs, omitted when empty"` + Tags map[string]any `json:"tags,omitempty"` + + // An array of pets the user cares for. + Pets Pets `json:"pets"` + + // Set of animal names to pets + NamedPets NamedPets `json:"named_pets"` + + // Set of plants that the user likes + Plants []*Plant `json:"plants" jsonschema:"title=Plants"` +} + +func Test_GenerateJSONSchema_User(t *testing.T) { + s := itool.GenerateJSONSchema(reflect.TypeOf(&User{})) + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + panic(err.Error()) + } + t.Log(string(data)) +} + +func Test_GenerateJSONSchema_TreeNode(t *testing.T) { + s := itool.GenerateJSONSchema(reflect.TypeOf(&TreeNode{})) + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + panic(err.Error()) + } + t.Log(string(data)) +} + +func Test_GenerateJSONSchema_LinkedListNode(t *testing.T) { + s := itool.GenerateJSONSchema(reflect.TypeOf(&LinkedListNode{})) + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + panic(err.Error()) + } + t.Log(string(data)) +} + +type TreeLinkListNode struct { + Name string `json:"name"` + TreeNode *TreeNode `json:"tree_node,omitempty"` + LinkListNode *LinkedListNode `json:"link_list_node,omitempty"` +} + +func Test_GenerateJSONSchema_TreeLinkListNode(t *testing.T) { + s := itool.GenerateJSONSchema(reflect.TypeOf(&TreeLinkListNode{})) + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + panic(err.Error()) + } + t.Log(string(data)) +} diff --git a/internal/tool/recursive_test.go b/internal/tool/recursive_test.go new file mode 100644 index 000000000..8b7da8e6d --- /dev/null +++ b/internal/tool/recursive_test.go @@ -0,0 +1,152 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +package tool_test + +import ( + "encoding/json" + "reflect" + "testing" + + itool "trpc.group/trpc-go/trpc-agent-go/internal/tool" +) + +// TreeNode represents a recursive tree structure +type TreeNode struct { + Name string `json:"name"` + Children []*TreeNode `json:"children,omitempty"` +} + +// LinkedListNode represents a recursive linked list structure +type LinkedListNode struct { + Value int `json:"value"` + Next *LinkedListNode `json:"next,omitempty"` +} + +// MutuallyRecursiveA and MutuallyRecursiveB represent mutually recursive structures +type MutuallyRecursiveA struct { + Name string `json:"name"` + B *MutuallyRecursiveB `json:"b,omitempty"` +} + +type MutuallyRecursiveB struct { + Value int `json:"value"` + A *MutuallyRecursiveA `json:"a,omitempty"` +} + +func TestGenerateJSONSchema_RecursiveStructure(t *testing.T) { + t.Run("tree node recursive structure", func(t *testing.T) { + // This should not panic and should generate a schema with $ref and $defs + result := itool.GenerateJSONSchema(reflect.TypeOf(TreeNode{})) + resultJson, _ := json.MarshalIndent(result, "", " ") + t.Logf("%s", resultJson) + + if result.Type != "object" { + t.Errorf("expected object type, got %s", result.Type) + } + + // Check that we have properties + if result.Properties == nil { + t.Fatal("expected properties to be set") + } + + // Check name property + if result.Properties["name"] == nil || result.Properties["name"].Type != "string" { + t.Errorf("expected name property of type string") + } + + // Check children property + if result.Properties["children"] == nil || result.Properties["children"].Type != "array" { + t.Errorf("expected children property of type array") + } + + // The children items should use $ref to avoid infinite recursion + if result.Properties["children"].Items == nil { + t.Fatal("expected children items to be defined") + } + + // Check if we have $defs defined for recursive types + if result.Defs == nil { + t.Errorf("expected $defs to be defined for recursive structure") + } + + // Check that children items use $ref + if result.Properties["children"].Items.Ref != "#/$defs/treenode" { + t.Errorf("expected children items to reference #/$defs/treenode, got %s", result.Properties["children"].Items.Ref) + } + + // Check the definition in $defs + treeDef := result.Defs["treenode"] + if treeDef == nil { + t.Fatal("expected treenode definition in $defs") + } + + if treeDef.Type != "object" { + t.Errorf("expected treenode definition type to be object, got %s", treeDef.Type) + } + + // Check that the definition also uses $ref for children + if treeDef.Properties["children"].Items == nil || treeDef.Properties["children"].Items.Ref != "#/$defs/treenode" { + t.Errorf("expected treenode definition children items to reference #/$defs/treenode") + } + }) + + t.Run("linked list recursive structure", func(t *testing.T) { + // This should not panic and should generate a schema with $ref and $defs + result := itool.GenerateJSONSchema(reflect.TypeOf(LinkedListNode{})) + resultJson, _ := json.MarshalIndent(result, "", " ") + t.Logf("%s", resultJson) + + if result.Type != "object" { + t.Errorf("expected object type, got %s", result.Type) + } + + // Check that we have properties + if result.Properties == nil { + t.Fatal("expected properties to be set") + } + + // Check value property + if result.Properties["value"] == nil || result.Properties["value"].Type != "integer" { + t.Errorf("expected value property of type integer") + } + + // Check next property - should use $ref to avoid infinite recursion + if result.Properties["next"] == nil { + t.Fatal("expected next property to be defined") + } + + // Check if we have $defs defined for recursive types + if result.Defs == nil { + t.Errorf("expected $defs to be defined for recursive structure") + } + }) + + t.Run("mutually recursive structures", func(t *testing.T) { + // This should not panic and should generate a schema with $ref and $defs + result := itool.GenerateJSONSchema(reflect.TypeOf(MutuallyRecursiveA{})) + resultJson, _ := json.MarshalIndent(result, "", " ") + t.Logf("%s", resultJson) + + if result.Type != "object" { + t.Errorf("expected object type, got %s", result.Type) + } + + // Check that we have $defs for both types + if result.Defs == nil { + t.Fatal("expected $defs to be defined for mutually recursive structures") + } + + // Should have definitions for both types + expectedDefs := 2 // MutuallyRecursiveA and MutuallyRecursiveB + if len(result.Defs) < expectedDefs { + t.Errorf("expected at least %d definitions in $defs, got %d", expectedDefs, len(result.Defs)) + } + }) +} diff --git a/internal/tool/tool.go b/internal/tool/tool.go index ccbb1b1b6..49ea07eae 100644 --- a/internal/tool/tool.go +++ b/internal/tool/tool.go @@ -23,72 +23,212 @@ import ( // GenerateJSONSchema generates a basic JSON schema from a reflect.Type. func GenerateJSONSchema(t reflect.Type) *tool.Schema { - schema := &tool.Schema{Type: "object"} + // Use a context to track visited types and handle recursion + ctx := &schemaContext{ + visited: make(map[reflect.Type]string), + defs: make(map[string]*tool.Schema), + } + + schema := generateJSONSchema(t, ctx, true) + + // Add $defs to the root schema if we have any definitions + if len(ctx.defs) > 0 { + schema.Defs = ctx.defs + } + + return schema +} + +// schemaContext tracks the state during schema generation to handle recursion +type schemaContext struct { + visited map[reflect.Type]string // Maps types to their definition names + defs map[string]*tool.Schema // Stores reusable schema definitions +} +// generateJSONSchema generates a JSON schema with recursion handling +func generateJSONSchema(t reflect.Type, ctx *schemaContext, isRoot bool) *tool.Schema { // Handle different kinds of types. switch t.Kind() { case reflect.Struct: - properties := map[string]*tool.Schema{} - required := make([]string, 0) + return handleGenerateJSONSchemaStruct(t, ctx, isRoot) - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - if !field.IsExported() { - continue - } + case reflect.Ptr: + // For function tool parameters, we typically use value types + // So we can just return the element type schema. + return generateFieldSchema(t.Elem(), ctx, isRoot) - // Get JSON tag or use field name. - jsonTag := field.Tag.Get("json") - if jsonTag == "-" { - continue // Skip fields marked with json:"-" - } + default: + return generateFieldSchema(t, ctx, isRoot) + } +} - fieldName := field.Name - isOmitEmpty := false +// hasRecursiveFields checks if a struct type has fields that reference itself +func hasRecursiveFields(t reflect.Type) bool { + return checkRecursion(t, t, make(map[reflect.Type]bool)) +} - if jsonTag != "" { - // Parse json tag (handle omitempty, etc.) - if commaIdx := strings.Index(jsonTag, ","); commaIdx != -1 { - fieldName = jsonTag[:commaIdx] - isOmitEmpty = strings.Contains(jsonTag[commaIdx:], "omitempty") - } else { - fieldName = jsonTag - } +// handleGenerateJSONSchemaStruct contains the previously large struct-handling +// logic extracted from generateJSONSchema to reduce cyclomatic complexity. +func handleGenerateJSONSchemaStruct(t reflect.Type, ctx *schemaContext, isRoot bool) *tool.Schema { + // Check if we've already seen this struct type + if defName, exists := ctx.visited[t]; exists { + // Return a reference to the existing definition + return &tool.Schema{Ref: "#/$defs/" + defName} + } + + // Generate a unique name for this type and mark it visited + defName := generateDefName(t) + ctx.visited[t] = defName + + // Create the schema for this struct + schema := &tool.Schema{Type: "object"} + properties := map[string]*tool.Schema{} + required := make([]string, 0) + + // Check if this struct has recursive fields + hasRecursion := hasRecursiveFields(t) + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + + // Get JSON tag or use field name. + jsonTag := field.Tag.Get("json") + if jsonTag == "-" { + continue // Skip fields marked with json:"-" + } + + fieldName := field.Name + isOmitEmpty := false + + if jsonTag != "" { + // Parse json tag (handle omitempty, etc.) + if commaIdx := strings.Index(jsonTag, ","); commaIdx != -1 { + fieldName = jsonTag[:commaIdx] + isOmitEmpty = strings.Contains(jsonTag[commaIdx:], "omitempty") + } else { + fieldName = jsonTag } + } - // Generate schema for field type. - fieldSchema := GenerateFieldSchema(field.Type) + // Generate schema for field type. + fieldSchema := generateFieldSchema(field.Type, ctx, false) + properties[fieldName] = fieldSchema - // Parse jsonschema tag to customize the schema + // Parse jsonschema tag to customize the schema + // Only apply jsonschema tags if the field schema is not a reference + if fieldSchema.Ref == "" { isRequiredByTag, err := parseJSONSchemaTag(field.Type, field.Tag, fieldSchema) if err != nil { log.Errorf("parseJSONSchemaTag error for field %s: %v", fieldName, err) // Continue execution with the field schema as is } - properties[fieldName] = fieldSchema - // Check if field is required (not a pointer and no omitempty, or explicitly marked as required by jsonschema tag). if (field.Type.Kind() != reflect.Ptr && !isOmitEmpty) || isRequiredByTag { required = append(required, fieldName) } + } else { + // For reference fields, check if they should be required based on type and omitempty + if field.Type.Kind() != reflect.Ptr && !isOmitEmpty { + required = append(required, fieldName) + } } + } - schema.Properties = properties - if len(required) > 0 { - schema.Required = required + schema.Properties = properties + if len(required) > 0 { + schema.Required = required + } + + // Store the definition if we have recursion or if it's not the root + if hasRecursion || !isRoot { + // Create a copy of the schema for the definition to avoid circular references + defSchema := &tool.Schema{ + Type: schema.Type, + Properties: make(map[string]*tool.Schema), + Required: schema.Required, } - case reflect.Ptr: - // For function tool parameters, we typically use value types - // So we can just return the element type schema. - return GenerateFieldSchema(t.Elem()) + // Copy properties but ensure we use references for recursive types + for propName, propSchema := range schema.Properties { + defSchema.Properties[propName] = propSchema + } - default: - return GenerateFieldSchema(t) + ctx.defs[defName] = defSchema } - return schema + // For the root type with recursion, return the actual schema + // For nested recursive types, return a reference + if isRoot { + return schema + } + + return &tool.Schema{Ref: "#/$defs/" + defName} +} + +// checkRecursion recursively checks if targetType appears in the fields of currentType +func checkRecursion(targetType, currentType reflect.Type, visited map[reflect.Type]bool) bool { + if visited[currentType] { + return false + } + visited[currentType] = true + + switch currentType.Kind() { + case reflect.Struct: + for i := 0; i < currentType.NumField(); i++ { + field := currentType.Field(i) + if !field.IsExported() { + continue + } + + fieldType := field.Type + // Check through pointers, slices, and arrays + for fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Array { + fieldType = fieldType.Elem() + } + + if fieldType == targetType { + return true + } + + if fieldType.Kind() == reflect.Struct && checkRecursion(targetType, fieldType, visited) { + return true + } + } + case reflect.Slice, reflect.Array: + elemType := currentType.Elem() + for elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + if elemType == targetType { + return true + } + if elemType.Kind() == reflect.Struct && checkRecursion(targetType, elemType, visited) { + return true + } + case reflect.Ptr: + elemType := currentType.Elem() + if elemType == targetType { + return true + } + if elemType.Kind() == reflect.Struct && checkRecursion(targetType, elemType, visited) { + return true + } + } + + return false +} + +// generateDefName creates a unique definition name for a type +func generateDefName(t reflect.Type) string { + // Use the type name if available, otherwise use a generic name + if t.Name() != "" { + return strings.ToLower(t.Name()) + } + return "anonymousStruct" } // parseJSONSchemaTag parses jsonschema struct tag and applies the settings to the schema. @@ -154,34 +294,89 @@ func parseJSONSchemaTag(fieldType reflect.Type, tag reflect.StructTag, schema *t return isRequiredByTag, nil } -// GenerateFieldSchema generates schema for a specific field type. -func GenerateFieldSchema(t reflect.Type) *tool.Schema { +// generateFieldSchema generates schema for a specific field type with recursion handling. +func generateFieldSchema(t reflect.Type, ctx *schemaContext, isRoot bool) *tool.Schema { + // Delegate to smaller focused helpers to reduce cyclomatic complexity. switch t.Kind() { - case reflect.String: - return &tool.Schema{Type: "string"} - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return &tool.Schema{Type: "integer"} - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return &tool.Schema{Type: "integer"} - case reflect.Float32, reflect.Float64: - return &tool.Schema{Type: "number"} - case reflect.Bool: - return &tool.Schema{Type: "boolean"} + case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.Bool: + return handlePrimitiveType(t) case reflect.Slice, reflect.Array: - return &tool.Schema{ - Type: "array", - Items: GenerateFieldSchema(t.Elem()), - } + return handleArrayOrSlice(t, ctx) case reflect.Map: - return &tool.Schema{ - Type: "object", - AdditionalProperties: GenerateFieldSchema(t.Elem()), - } + return handleMapType(t, ctx, isRoot) case reflect.Ptr: - // For function tool parameters, we typically use value types - // So we can just return the element type schema - return GenerateFieldSchema(t.Elem()) + return handlePointerType(t, ctx, isRoot) case reflect.Struct: + return handleStructType(t, ctx, isRoot) + default: + return &tool.Schema{Type: "object"} + } +} + +// handlePrimitiveType returns a simple schema for primitive kinds. +func handlePrimitiveType(t reflect.Type) *tool.Schema { + switch t.Kind() { + case reflect.String: + return &tool.Schema{Type: "string"} + case reflect.Bool: + return &tool.Schema{Type: "boolean"} + case reflect.Float32, reflect.Float64: + return &tool.Schema{Type: "number"} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return &tool.Schema{Type: "integer"} + default: + return &tool.Schema{Type: "object"} + } +} + +// handleArrayOrSlice builds schema for arrays and slices. +func handleArrayOrSlice(t reflect.Type, ctx *schemaContext) *tool.Schema { + // For struct element types we might prefer references; generateFieldSchema will + // handle nested struct recursion correctly. + return &tool.Schema{ + Type: "array", + Items: generateFieldSchema(t.Elem(), ctx, false), + } +} + +// handleMapType builds schema for map types using additionalProperties. +func handleMapType(t reflect.Type, ctx *schemaContext, isRoot bool) *tool.Schema { + valueSchema := generateFieldSchema(t.Elem(), ctx, false) + if valueSchema == nil { + valueSchema = &tool.Schema{Type: "object"} + } + + schema := &tool.Schema{ + Type: "object", + AdditionalProperties: valueSchema, + } + + if isRoot && len(ctx.defs) > 0 { + schema.Defs = ctx.defs + } + + return schema +} + +// handlePointerType returns the element type schema for pointer types. +func handlePointerType(t reflect.Type, ctx *schemaContext, isRoot bool) *tool.Schema { + return generateFieldSchema(t.Elem(), ctx, isRoot) +} + +// handleStructType handles inline and named struct schemas with recursion tracking. +func handleStructType(t reflect.Type, ctx *schemaContext, isRoot bool) *tool.Schema { + // If we've already created a definition for this type, return a reference. + if defName, exists := ctx.visited[t]; exists { + return &tool.Schema{Ref: "#/$defs/" + defName} + } + + hasRecursion := hasRecursiveFields(t) + + // Inline schema when there is no recursion (backwards compat). + if !hasRecursion { nestedSchema := &tool.Schema{ Type: "object", Properties: make(map[string]*tool.Schema), @@ -192,12 +387,10 @@ func GenerateFieldSchema(t reflect.Type) *tool.Schema { if !field.IsExported() { continue } - jsonTag := field.Tag.Get("json") if jsonTag == "-" { continue } - fieldName := field.Name if jsonTag != "" { if commaIdx := strings.Index(jsonTag, ","); commaIdx != -1 { @@ -206,13 +399,42 @@ func GenerateFieldSchema(t reflect.Type) *tool.Schema { fieldName = jsonTag } } - - nestedSchema.Properties[fieldName] = GenerateFieldSchema(field.Type) + nestedSchema.Properties[fieldName] = generateFieldSchema(field.Type, ctx, false) } - return nestedSchema - default: - // Default to any type - return &tool.Schema{Type: "object"} } + + // Named struct with recursion: create definition and return a reference. + defName := generateDefName(t) + ctx.visited[t] = defName + + nestedSchema := &tool.Schema{ + Type: "object", + Properties: make(map[string]*tool.Schema), + } + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + jsonTag := field.Tag.Get("json") + if jsonTag == "-" { + continue + } + fieldName := field.Name + if jsonTag != "" { + if commaIdx := strings.Index(jsonTag, ","); commaIdx != -1 { + fieldName = jsonTag[:commaIdx] + } else { + fieldName = jsonTag + } + } + nestedSchema.Properties[fieldName] = generateFieldSchema(field.Type, ctx, false) + } + + // Store the definition + ctx.defs[defName] = nestedSchema + + return &tool.Schema{Ref: "#/$defs/" + defName} } diff --git a/tool/tool.go b/tool/tool.go index e9218889a..ecd785b83 100644 --- a/tool/tool.go +++ b/tool/tool.go @@ -63,7 +63,7 @@ type Declaration struct { // and to validate that incoming data conforms to the expected structure. type Schema struct { // Type Specifies the data type (e.g., "object", "array", "string", "number") - Type string `json:"type"` + Type string `json:"type,omitempty"` Description string `json:"description,omitempty"` Required []string `json:"required,omitempty"` // Properties of the arguments, each with its own schema @@ -76,4 +76,8 @@ type Schema struct { Default any `json:"default,omitempty"` // Enum contains the list of allowed values for the parameter Enum []any `json:"enum,omitempty"` + // Ref is used for JSON Schema references to avoid infinite recursion + Ref string `json:"$ref,omitempty"` + // Defs contains reusable schema definitions + Defs map[string]*Schema `json:"$defs,omitempty"` }