diff --git a/src/main/kotlin/graphql/kickstart/tools/resolver/FieldResolverScanner.kt b/src/main/kotlin/graphql/kickstart/tools/resolver/FieldResolverScanner.kt index 555b96c7..08ce1605 100644 --- a/src/main/kotlin/graphql/kickstart/tools/resolver/FieldResolverScanner.kt +++ b/src/main/kotlin/graphql/kickstart/tools/resolver/FieldResolverScanner.kt @@ -15,10 +15,8 @@ import org.apache.commons.lang3.ClassUtils import org.apache.commons.lang3.reflect.FieldUtils import org.reactivestreams.Publisher import org.slf4j.LoggerFactory -import java.lang.reflect.AccessibleObject -import java.lang.reflect.Method -import java.lang.reflect.Modifier -import java.lang.reflect.Type +import java.lang.reflect.* +import java.util.concurrent.CompletableFuture import kotlin.reflect.full.valueParameters import kotlin.reflect.jvm.javaType import kotlin.reflect.jvm.kotlinFunction @@ -131,7 +129,17 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) { } private fun resolverMethodReturnsPublisher(method: Method) = - method.returnType.isAssignableFrom(Publisher::class.java) || receiveChannelToPublisherWrapper(method) + method.returnType.isAssignableFrom(Publisher::class.java) + || resolverMethodReturnsPublisherFuture(method) + || receiveChannelToPublisherWrapper(method) + + private fun resolverMethodReturnsPublisherFuture(method: Method) = + method.returnType.isAssignableFrom(CompletableFuture::class.java) + && method.genericReturnType is ParameterizedType + && (method.genericReturnType as ParameterizedType).actualTypeArguments + .any { + it is ParameterizedType && it.unwrap().isAssignableFrom(Publisher::class.java) + } private fun receiveChannelToPublisherWrapper(method: Method) = method.returnType.isAssignableFrom(ReceiveChannel::class.java) diff --git a/src/test/kotlin/graphql/kickstart/tools/EndToEndSpecHelper.kt b/src/test/kotlin/graphql/kickstart/tools/EndToEndSpecHelper.kt index b728f7a0..4b9317d5 100644 --- a/src/test/kotlin/graphql/kickstart/tools/EndToEndSpecHelper.kt +++ b/src/test/kotlin/graphql/kickstart/tools/EndToEndSpecHelper.kt @@ -123,6 +123,7 @@ type Mutation { type Subscription { onItemCreated: Item! + onItemCreatedFuture: Item! onItemCreatedCoroutineChannel: Item! onItemCreatedCoroutineChannelAndSuspendFunction: Item! } @@ -373,7 +374,6 @@ class Subscription : GraphQLSubscriptionResolver { fun onItemCreated(env: DataFetchingEnvironment) = Publisher { subscriber -> subscriber.onNext(env.graphQlContext["newItem"]) -// subscriber.onComplete() } fun onItemCreatedCoroutineChannel(env: DataFetchingEnvironment): ReceiveChannel { @@ -382,6 +382,14 @@ class Subscription : GraphQLSubscriptionResolver { return channel } + fun onItemCreatedFuture(env: DataFetchingEnvironment): CompletableFuture> { + return CompletableFuture.supplyAsync { + Publisher { subscriber -> + subscriber.onNext(env.graphQlContext["newItem"]) + } + } + } + suspend fun onItemCreatedCoroutineChannelAndSuspendFunction(env: DataFetchingEnvironment): ReceiveChannel { return coroutineScope { val channel = Channel(1) diff --git a/src/test/kotlin/graphql/kickstart/tools/EndToEndTest.kt b/src/test/kotlin/graphql/kickstart/tools/EndToEndTest.kt index 34cb2a51..01c56965 100644 --- a/src/test/kotlin/graphql/kickstart/tools/EndToEndTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/EndToEndTest.kt @@ -1,6 +1,5 @@ package graphql.kickstart.tools -import com.fasterxml.jackson.module.kotlin.jacksonMapperBuilder import graphql.* import graphql.execution.AsyncExecutionStrategy import graphql.schema.* @@ -97,6 +96,44 @@ class EndToEndTest { assert(result.errors.isEmpty()) assertEquals(returnedItem?.get("onItemCreated"), mapOf("id" to 1)) } + + @Test + fun `generated schema should execute the subscription query future`() { + val newItem = Item(1, "item", Type.TYPE_1, UUID.randomUUID(), listOf()) + var returnedItem: Map>? = null + + val closure = { + """ + subscription { + onItemCreatedFuture { + id + } + } + """ + } + + val result = gql.execute(ExecutionInput.newExecutionInput() + .query(closure.invoke()) + .graphQLContext(mapOf("newItem" to newItem)) + .variables(mapOf())) + + val data = result.getData() as Publisher + val latch = CountDownLatch(1) + data.subscribe(object : Subscriber { + override fun onNext(item: ExecutionResult?) { + returnedItem = item?.getData() + latch.countDown() + } + + override fun onError(throwable: Throwable?) {} + override fun onComplete() {} + override fun onSubscribe(p0: Subscription?) {} + }) + latch.await(3, TimeUnit.SECONDS) + + assert(result.errors.isEmpty()) + assertEquals(returnedItem?.get("onItemCreatedFuture"), mapOf("id" to 1)) + } @Test fun `generated schema should handle interface types`() { diff --git a/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt b/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt index 87fb9daf..474c8bca 100644 --- a/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt +++ b/src/test/kotlin/graphql/kickstart/tools/SchemaParserTest.kt @@ -10,6 +10,7 @@ import org.junit.Before import org.junit.Test import org.springframework.aop.framework.ProxyFactory import java.io.FileNotFoundException +import java.util.concurrent.CompletableFuture.completedFuture import java.util.concurrent.Future @OptIn(ExperimentalCoroutinesApi::class) @@ -665,6 +666,10 @@ class SchemaParserTest { @Test fun `parser should verify subscription resolver return type`() { + class Subscription : GraphQLSubscriptionResolver { + fun onItemCreated(env: DataFetchingEnvironment) = env.hashCode() + } + val error = assertThrows(FieldResolverError::class.java) { SchemaParser.newParser() .schemaString( @@ -689,9 +694,9 @@ class SchemaParserTest { val expected = """ No method or field found as defined in schema :3 with any of the following signatures (with or without one of [interface graphql.schema.DataFetchingEnvironment, class graphql.GraphQLContext] as the last argument), in priority order: - graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.onItemCreated() - graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.getOnItemCreated() - graphql.kickstart.tools.SchemaParserTest${"$"}Subscription.onItemCreated + graphql.kickstart.tools.SchemaParserTest${"$"}parser should verify subscription resolver return type${"$"}Subscription.onItemCreated() + graphql.kickstart.tools.SchemaParserTest${"$"}parser should verify subscription resolver return type${"$"}Subscription.getOnItemCreated() + graphql.kickstart.tools.SchemaParserTest${"$"}parser should verify subscription resolver return type${"$"}Subscription.onItemCreated Note that a Subscription data fetcher must return a Publisher of events """.trimIndent() @@ -699,7 +704,43 @@ class SchemaParserTest { assertEquals(error.message, expected) } - class Subscription : GraphQLSubscriptionResolver { - fun onItemCreated(env: DataFetchingEnvironment) = env.hashCode() + @Test + fun `parser should verify subscription resolver generic future return type`() { + class Subscription : GraphQLSubscriptionResolver { + fun onItemCreated(env: DataFetchingEnvironment) = completedFuture(env.hashCode()) + } + + val error = assertThrows(FieldResolverError::class.java) { + SchemaParser.newParser() + .schemaString( + """ + type Subscription { + onItemCreated: Int! + } + + type Query { + test: String + } + """ + ) + .resolvers( + Subscription(), + object : GraphQLQueryResolver { fun test() = "test" } + ) + .build() + .makeExecutableSchema() + } + + val expected = """ + No method or field found as defined in schema :3 with any of the following signatures (with or without one of [interface graphql.schema.DataFetchingEnvironment, class graphql.GraphQLContext] as the last argument), in priority order: + + graphql.kickstart.tools.SchemaParserTest${"$"}parser should verify subscription resolver generic future return type${"$"}Subscription.onItemCreated() + graphql.kickstart.tools.SchemaParserTest${"$"}parser should verify subscription resolver generic future return type${"$"}Subscription.getOnItemCreated() + graphql.kickstart.tools.SchemaParserTest${"$"}parser should verify subscription resolver generic future return type${"$"}Subscription.onItemCreated + + Note that a Subscription data fetcher must return a Publisher of events + """.trimIndent() + + assertEquals(error.message, expected) } }