diff --git a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt index a1f6bb38..b53d6c24 100644 --- a/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt +++ b/src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt @@ -76,7 +76,7 @@ class SchemaParser internal constructor( val inputObjects: MutableList = mutableListOf() inputObjectDefinitions.forEach { if (inputObjects.none { io -> io.name == it.name }) { - inputObjects.add(createInputObject(it, inputObjects)) + inputObjects.add(createInputObject(it, inputObjects, mutableSetOf())) } } val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) } @@ -155,7 +155,8 @@ class SchemaParser internal constructor( return schemaGeneratorDirectiveHelper.onObject(objectType, directiveHelperParameters) } - private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List): GraphQLInputObjectType { + private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List, + referencingInputObjects: MutableSet): GraphQLInputObjectType { val extensionDefinitions = inputExtensionDefinitions.filter { it.name == definition.name } val builder = GraphQLInputObjectType.newInputObject() @@ -166,6 +167,8 @@ class SchemaParser internal constructor( builder.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.INPUT_OBJECT)) + referencingInputObjects.add(definition.name) + (extensionDefinitions + definition).forEach { it.inputValueDefinitions.forEach { inputDefinition -> val fieldBuilder = GraphQLInputObjectField.newInputObjectField() @@ -173,7 +176,7 @@ class SchemaParser internal constructor( .definition(inputDefinition) .description(if (inputDefinition.description != null) inputDefinition.description.content else getDocumentation(inputDefinition)) .defaultValue(buildDefaultValue(inputDefinition.defaultValue)) - .type(determineInputType(inputDefinition.type, inputObjects)) + .type(determineInputType(inputDefinition.type, inputObjects, referencingInputObjects)) .withDirectives(*buildDirectives(inputDefinition.directives, Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION)) builder.field(fieldBuilder.build()) } @@ -280,7 +283,7 @@ class SchemaParser internal constructor( .name(argumentDefinition.name) .definition(argumentDefinition) .description(if (argumentDefinition.description != null) argumentDefinition.description.content else getDocumentation(argumentDefinition)) - .type(determineInputType(argumentDefinition.type, inputObjects)) + .type(determineInputType(argumentDefinition.type, inputObjects, setOf())) .apply { buildDefaultValue(argumentDefinition.defaultValue)?.let { defaultValue(it) } } .withDirectives(*buildDirectives(argumentDefinition.directives, Introspection.DirectiveLocation.ARGUMENT_DEFINITION)) @@ -380,7 +383,7 @@ class SchemaParser internal constructor( is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) is InputObjectTypeDefinition -> { log.info("Create input object") - createInputObject(typeDefinition, inputObjects) + createInputObject(typeDefinition, inputObjects, mutableSetOf()) } is TypeName -> { val scalarType = customScalars[typeDefinition.name] @@ -398,16 +401,19 @@ class SchemaParser internal constructor( else -> throw SchemaError("Unknown type: $typeDefinition") } - private fun determineInputType(typeDefinition: Type<*>, inputObjects: List) = - determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType + private fun determineInputType(typeDefinition: Type<*>, inputObjects: List, referencingInputObjects: Set) = + determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects) as GraphQLInputType - private fun determineInputType(expectedType: KClass, typeDefinition: Type<*>, allowedTypeReferences: Set, inputObjects: List): GraphQLType = + private fun determineInputType(expectedType: KClass, + typeDefinition: Type<*>, allowedTypeReferences: Set, + inputObjects: List, + referencingInputObjects: Set): GraphQLType = when (typeDefinition) { is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects)) is InputObjectTypeDefinition -> { log.info("Create input object") - createInputObject(typeDefinition, inputObjects) + createInputObject(typeDefinition, inputObjects, referencingInputObjects as MutableSet) } is TypeName -> { val scalarType = customScalars[typeDefinition.name] @@ -425,9 +431,14 @@ class SchemaParser internal constructor( } else { val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name } if (filteredDefinitions.isNotEmpty()) { - val inputObject = createInputObject(filteredDefinitions[0], inputObjects) - (inputObjects as MutableList).add(inputObject) - inputObject + val referencingInputObject = referencingInputObjects.find { it == typeDefinition.name } + if (referencingInputObject != null) { + GraphQLTypeReference(referencingInputObject) + } else { + val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects as MutableSet) + (inputObjects as MutableList).add(inputObject) + inputObject + } } else { // todo: handle enum type GraphQLTypeReference(typeDefinition.name) diff --git a/src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy b/src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy index d08df14c..a773e4b9 100644 --- a/src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy +++ b/src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy @@ -368,6 +368,50 @@ class SchemaParserSpec extends Specification { noExceptionThrown() } + def "allow circular relations in input objects"() { + when: + SchemaParser.newParser().schemaString('''\ + input A { + id: ID! + b: B + } + input B { + id: ID! + a: A + } + input C { + id: ID! + c: C + } + type Query {} + type Mutation { + test(input: A!): Boolean + testC(input: C!): Boolean + } + '''.stripIndent()) + .resolvers(new GraphQLMutationResolver() { + static class A { + String id; + B b; + } + static class B { + String id; + A a; + } + static class C { + String id; + C c; + } + boolean test(A a) { return true } + boolean testC(C c) { return true } + }, new GraphQLQueryResolver() {}) + .build() + .makeExecutableSchema() + + then: + noExceptionThrown() + } + enum EnumType { TEST }