Skip to content

Commit ed2b48b

Browse files
authored
Merge pull request #764 from graphql-java-kickstart/664-scan-directive-enum-input-arguments
Scan directives arguments while parsing schema
2 parents 7fb45d6 + a433966 commit ed2b48b

File tree

4 files changed

+221
-77
lines changed

4 files changed

+221
-77
lines changed

src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt

+9-8
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,11 @@ internal class SchemaClassScanner(
149149
?: error("No ${TypeDefinition::class.java.simpleName} for type name $inputTypeName")
150150
when (typeDefinition) {
151151
is ScalarTypeDefinition -> handleFoundScalarType(typeDefinition)
152-
is InputObjectTypeDefinition -> {
153-
for (input in typeDefinition.inputValueDefinitions) {
154-
handleDirectiveInput(input.type)
155-
}
152+
is EnumTypeDefinition -> handleDictionaryTypes(listOf(typeDefinition)) {
153+
"Enum type '${it.name}' is used in a directive, but no class could be found for that type name. Please pass a class for type '${it.name}' in the parser's dictionary."
154+
}
155+
is InputObjectTypeDefinition -> handleDictionaryTypes(listOf(typeDefinition)) {
156+
"Input object type '${it.name}' is used in a directive, but no class could be found for that type name. Please pass a class for type '${it.name}' in the parser's dictionary."
156157
}
157158
}
158159
}
@@ -209,9 +210,9 @@ internal class SchemaClassScanner(
209210
log.warn("Schema type was defined but can never be accessed, and can be safely deleted: ${definition.name}")
210211
}
211212

212-
val fieldResolvers = fieldResolversByType.flatMap { it.value.map { it.value } }
213-
val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance<NormalResolverInfo>()
214-
val observedMultiResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance<MultiResolverInfo>().flatMap { it.resolverInfoList }
213+
val fieldResolvers = fieldResolversByType.flatMap { entry -> entry.value.map { it.value } }
214+
val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.filterIsInstance<NormalResolverInfo>().toSet()
215+
val observedMultiResolverInfos = fieldResolvers.map { it.resolverInfo }.filterIsInstance<MultiResolverInfo>().flatMap { it.resolverInfoList }.toSet()
215216

216217
(resolverInfos - observedNormalResolverInfos - observedMultiResolverInfos).forEach { resolverInfo ->
217218
log.warn("Resolver was provided but no methods on it were used in data fetchers, and can be safely deleted: ${resolverInfo.resolver}")
@@ -255,7 +256,7 @@ internal class SchemaClassScanner(
255256
}.flatten().distinct()
256257
}
257258

258-
private fun handleDictionaryTypes(types: List<ObjectTypeDefinition>, failureMessage: (ObjectTypeDefinition) -> String) {
259+
private fun handleDictionaryTypes(types: List<TypeDefinition<*>>, failureMessage: (TypeDefinition<*>) -> String) {
259260
types.forEach { type ->
260261
val dictionaryContainsType = dictionary.filter { it.key.name == type.name }.isNotEmpty()
261262
if (!unvalidatedTypes.contains(type) && !dictionaryContainsType) {

src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt

+69-69
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package graphql.kickstart.tools
22

3-
import graphql.Scalars
43
import graphql.introspection.Introspection
54
import graphql.introspection.Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION
65
import graphql.kickstart.tools.directive.DirectiveWiringHelper
@@ -9,6 +8,7 @@ import graphql.kickstart.tools.util.getExtendedFieldDefinitions
98
import graphql.kickstart.tools.util.unwrap
109
import graphql.language.*
1110
import graphql.schema.*
11+
import graphql.schema.idl.DirectiveInfo
1212
import graphql.schema.idl.RuntimeWiring
1313
import graphql.schema.idl.ScalarInfo
1414
import graphql.schema.visibility.NoIntrospectionGraphqlFieldVisibility
@@ -60,6 +60,8 @@ class SchemaParser internal constructor(
6060
private val codeRegistryBuilder = GraphQLCodeRegistry.newCodeRegistry()
6161
private val directiveWiringHelper = DirectiveWiringHelper(options, runtimeWiring, codeRegistryBuilder, directiveDefinitions)
6262

63+
private lateinit var schemaDirectives : Set<GraphQLDirective>
64+
6365
/**
6466
* Parses the given schema with respect to the given dictionary and returns GraphQL objects.
6567
*/
@@ -72,6 +74,7 @@ class SchemaParser internal constructor(
7274

7375
// Create GraphQL objects
7476
val inputObjects: MutableList<GraphQLInputObjectType> = mutableListOf()
77+
createDirectives(inputObjects)
7578
inputObjectDefinitions.forEach {
7679
if (inputObjects.none { io -> io.name == it.name }) {
7780
inputObjects.add(createInputObject(it, inputObjects, mutableSetOf()))
@@ -82,8 +85,6 @@ class SchemaParser internal constructor(
8285
val unions = unionDefinitions.map { createUnionObject(it, objects) }
8386
val enums = enumDefinitions.map { createEnumObject(it) }
8487

85-
val directives = directiveDefinitions.map { createDirective(it, inputObjects) }.toSet()
86-
8788
// Assign type resolver to interfaces now that we know all of the object types
8889
interfaces.forEach { codeRegistryBuilder.typeResolver(it, InterfaceTypeResolver(dictionary.inverse(), it)) }
8990
unions.forEach { codeRegistryBuilder.typeResolver(it, UnionTypeResolver(dictionary.inverse(), it)) }
@@ -103,7 +104,7 @@ class SchemaParser internal constructor(
103104
val additionalObjects = objects.filter { o -> o != query && o != subscription && o != mutation }
104105

105106
val types = (additionalObjects.toSet() as Set<GraphQLType>) + inputObjects + enums + interfaces + unions
106-
return SchemaObjects(query, mutation, subscription, types, directives, codeRegistryBuilder, rootInfo.getDescription())
107+
return SchemaObjects(query, mutation, subscription, types, schemaDirectives, codeRegistryBuilder, rootInfo.getDescription())
107108
}
108109

109110
/**
@@ -300,44 +301,75 @@ class SchemaParser internal constructor(
300301
.name(definition.name)
301302
.definition(definition)
302303
.description(getDocumentation(definition, options))
303-
.type(determineInputType(definition.type, inputObjects, setOf()))
304+
.type(determineInputType(definition.type, inputObjects, mutableSetOf()))
304305
.apply { getDeprecated(definition.directives)?.let { deprecate(it) } }
305306
.apply { definition.defaultValue?.let { defaultValueLiteral(it) } }
306307
.withAppliedDirectives(*buildAppliedDirectives(definition.directives))
307308
.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.ARGUMENT_DEFINITION))
308309
.build()
309310
}
310311

311-
private fun createDirective(definition: DirectiveDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLDirective {
312-
val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray()
312+
private fun createDirectives(inputObjects: MutableList<GraphQLInputObjectType>) {
313+
schemaDirectives = directiveDefinitions.map { definition ->
314+
val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray()
315+
316+
GraphQLDirective.newDirective()
317+
.name(definition.name)
318+
.description(getDocumentation(definition, options))
319+
.definition(definition)
320+
.comparatorRegistry(runtimeWiring.comparatorRegistry)
321+
.validLocations(*locations)
322+
.repeatable(definition.isRepeatable)
323+
.apply {
324+
definition.inputValueDefinitions.forEach { argumentDefinition ->
325+
argument(createDirectiveArgument(argumentDefinition, inputObjects))
326+
}
327+
}
328+
.build()
329+
}.toSet()
330+
// because the arguments can have directives too, we attach them only after the directives themselves are created
331+
schemaDirectives = schemaDirectives.map { d ->
332+
val arguments = d.arguments.map { a -> a.transform {
333+
it.withAppliedDirectives(*buildAppliedDirectives(a.definition!!.directives))
334+
.withDirectives(*buildDirectives(a.definition!!.directives, Introspection.DirectiveLocation.OBJECT))
335+
} }
336+
d.transform { it.replaceArguments(arguments) }
337+
}.toSet()
338+
}
313339

314-
return GraphQLDirective.newDirective()
340+
private fun createDirectiveArgument(definition: InputValueDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLArgument {
341+
return GraphQLArgument.newArgument()
315342
.name(definition.name)
316-
.description(getDocumentation(definition, options))
317343
.definition(definition)
318-
.comparatorRegistry(runtimeWiring.comparatorRegistry)
319-
.validLocations(*locations)
320-
.repeatable(definition.isRepeatable)
321-
.apply {
322-
definition.inputValueDefinitions.forEach { argumentDefinition ->
323-
argument(createArgument(argumentDefinition, inputObjects))
324-
}
325-
}
344+
.description(getDocumentation(definition, options))
345+
.type(determineInputType(definition.type, inputObjects, mutableSetOf()))
346+
.apply { getDeprecated(definition.directives)?.let { deprecate(it) } }
347+
.apply { definition.defaultValue?.let { defaultValueLiteral(it) } }
326348
.build()
327349
}
328350

329351
private fun buildAppliedDirectives(directives: List<Directive>): Array<GraphQLAppliedDirective> {
330-
return directives.map {
352+
return directives.map { directive ->
353+
val graphQLDirective = schemaDirectives.find { d -> d.name == directive.name }
354+
?: DirectiveInfo.GRAPHQL_SPECIFICATION_DIRECTIVE_MAP[directive.name]
355+
?: throw SchemaError("Found applied directive ${directive.name} without corresponding directive definition.")
356+
val graphQLArguments = graphQLDirective.arguments.associateBy { it.name }
357+
331358
GraphQLAppliedDirective.newDirective()
332-
.name(it.name)
333-
.description(getDocumentation(it, options))
359+
.name(directive.name)
360+
.description(getDocumentation(directive, options))
361+
.definition(directive)
334362
.comparatorRegistry(runtimeWiring.comparatorRegistry)
335363
.apply {
336-
it.arguments.forEach { arg ->
364+
directive.arguments.forEach { arg ->
365+
val graphQLArgument = graphQLArguments[arg.name]
366+
?: throw SchemaError("Found an unexpected directive argument ${directive.name}#${arg.name} .")
337367
argument(GraphQLAppliedDirectiveArgument.newArgument()
338368
.name(arg.name)
339-
.type(buildDirectiveInputType(arg.value))
369+
// TODO instead of guessing the type from its value, lookup the directive definition
370+
.type(graphQLArgument.type)
340371
.valueLiteral(arg.value)
372+
.description(graphQLArgument.description)
341373
.build()
342374
)
343375
}
@@ -358,6 +390,10 @@ class SchemaParser internal constructor(
358390
val repeatable = directiveDefinitions.find { it.name.equals(directive.name) }?.isRepeatable ?: false
359391
if (repeatable || !names.contains(directive.name)) {
360392
names.add(directive.name)
393+
val graphQLDirective = this.schemaDirectives.find { d -> d.name == directive.name }
394+
?: DirectiveInfo.GRAPHQL_SPECIFICATION_DIRECTIVE_MAP[directive.name]
395+
?: throw SchemaError("Found applied directive ${directive.name} without corresponding directive definition.")
396+
val graphQLArguments = graphQLDirective.arguments.associateBy { it.name }
361397
output.add(
362398
GraphQLDirective.newDirective()
363399
.name(directive.name)
@@ -367,9 +403,11 @@ class SchemaParser internal constructor(
367403
.repeatable(repeatable)
368404
.apply {
369405
directive.arguments.forEach { arg ->
406+
val graphQLArgument = graphQLArguments[arg.name]
407+
?: throw SchemaError("Found an unexpected directive argument ${directive.name}#${arg.name}.")
370408
argument(GraphQLArgument.newArgument()
371409
.name(arg.name)
372-
.type(buildDirectiveInputType(arg.value))
410+
.type(graphQLArgument.type)
373411
// TODO remove this once directives are fully replaced with applied directives
374412
.valueLiteral(arg.value)
375413
.build())
@@ -383,46 +421,6 @@ class SchemaParser internal constructor(
383421
return output.toTypedArray()
384422
}
385423

386-
private fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? {
387-
return when (value) {
388-
is NullValue -> Scalars.GraphQLString
389-
is FloatValue -> Scalars.GraphQLFloat
390-
is StringValue -> Scalars.GraphQLString
391-
is IntValue -> Scalars.GraphQLInt
392-
is BooleanValue -> Scalars.GraphQLBoolean
393-
is ArrayValue -> GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value)))
394-
// TODO to implement this we'll need to "observe" directive's input types + match them here based on their fields(?)
395-
else -> throw SchemaError("Directive values of type '${value::class.simpleName}' are not supported yet.")
396-
}
397-
}
398-
399-
private fun getArrayValueWrappedType(value: ArrayValue): Value<*> {
400-
// empty array [] is equivalent to [null]
401-
if (value.values.isEmpty()) {
402-
return NullValue.newNullValue().build()
403-
}
404-
405-
// get rid of null values
406-
val nonNullValueList = value.values.filter { v -> v !is NullValue }
407-
408-
// [null, null, ...] unwrapped is null
409-
if (nonNullValueList.isEmpty()) {
410-
return NullValue.newNullValue().build()
411-
}
412-
413-
// make sure the array isn't polymorphic
414-
val distinctTypes = nonNullValueList
415-
.map { it::class.java }
416-
.distinct()
417-
418-
if (distinctTypes.size > 1) {
419-
throw SchemaError("Arrays containing multiple types of values are not supported yet.")
420-
}
421-
422-
// peek at first value, value exists and is assured to be non-null
423-
return nonNullValueList[0]
424-
}
425-
426424
private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
427425
determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject, inputObjects) as GraphQLOutputType
428426

@@ -455,13 +453,15 @@ class SchemaParser internal constructor(
455453
else -> throw SchemaError("Unknown type: $typeDefinition")
456454
}
457455

458-
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: Set<String>) =
456+
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: MutableSet<String>) =
459457
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects)
460458

461-
private fun <T : Any> determineInputType(expectedType: KClass<T>,
462-
typeDefinition: Type<*>, allowedTypeReferences: Set<String>,
463-
inputObjects: List<GraphQLInputObjectType>,
464-
referencingInputObjects: Set<String>): GraphQLInputType =
459+
private fun <T : Any> determineInputType(
460+
expectedType: KClass<T>,
461+
typeDefinition: Type<*>,
462+
allowedTypeReferences: Set<String>,
463+
inputObjects: List<GraphQLInputObjectType>,
464+
referencingInputObjects: MutableSet<String>): GraphQLInputType =
465465
when (typeDefinition) {
466466
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
467467
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
@@ -489,7 +489,7 @@ class SchemaParser internal constructor(
489489
if (referencingInputObject != null) {
490490
GraphQLTypeReference(referencingInputObject)
491491
} else {
492-
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects as MutableSet<String>)
492+
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects)
493493
(inputObjects as MutableList).add(inputObject)
494494
inputObject
495495
}

0 commit comments

Comments
 (0)