Skip to content

Commit

Permalink
Merge pull request #764 from graphql-java-kickstart/664-scan-directiv…
Browse files Browse the repository at this point in the history
…e-enum-input-arguments

Scan directives arguments while parsing schema
  • Loading branch information
oryan-block authored Sep 20, 2023
2 parents 7fb45d6 + a433966 commit ed2b48b
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 77 deletions.
17 changes: 9 additions & 8 deletions src/main/kotlin/graphql/kickstart/tools/SchemaClassScanner.kt
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ internal class SchemaClassScanner(
?: error("No ${TypeDefinition::class.java.simpleName} for type name $inputTypeName")
when (typeDefinition) {
is ScalarTypeDefinition -> handleFoundScalarType(typeDefinition)
is InputObjectTypeDefinition -> {
for (input in typeDefinition.inputValueDefinitions) {
handleDirectiveInput(input.type)
}
is EnumTypeDefinition -> handleDictionaryTypes(listOf(typeDefinition)) {
"Enum type '${it.name}' is used in a directive, but no class could be found for that type name. Please pass a class for type '${it.name}' in the parser's dictionary."
}
is InputObjectTypeDefinition -> handleDictionaryTypes(listOf(typeDefinition)) {
"Input object type '${it.name}' is used in a directive, but no class could be found for that type name. Please pass a class for type '${it.name}' in the parser's dictionary."
}
}
}
Expand Down Expand Up @@ -209,9 +210,9 @@ internal class SchemaClassScanner(
log.warn("Schema type was defined but can never be accessed, and can be safely deleted: ${definition.name}")
}

val fieldResolvers = fieldResolversByType.flatMap { it.value.map { it.value } }
val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance<NormalResolverInfo>()
val observedMultiResolverInfos = fieldResolvers.map { it.resolverInfo }.distinct().filterIsInstance<MultiResolverInfo>().flatMap { it.resolverInfoList }
val fieldResolvers = fieldResolversByType.flatMap { entry -> entry.value.map { it.value } }
val observedNormalResolverInfos = fieldResolvers.map { it.resolverInfo }.filterIsInstance<NormalResolverInfo>().toSet()
val observedMultiResolverInfos = fieldResolvers.map { it.resolverInfo }.filterIsInstance<MultiResolverInfo>().flatMap { it.resolverInfoList }.toSet()

(resolverInfos - observedNormalResolverInfos - observedMultiResolverInfos).forEach { resolverInfo ->
log.warn("Resolver was provided but no methods on it were used in data fetchers, and can be safely deleted: ${resolverInfo.resolver}")
Expand Down Expand Up @@ -255,7 +256,7 @@ internal class SchemaClassScanner(
}.flatten().distinct()
}

private fun handleDictionaryTypes(types: List<ObjectTypeDefinition>, failureMessage: (ObjectTypeDefinition) -> String) {
private fun handleDictionaryTypes(types: List<TypeDefinition<*>>, failureMessage: (TypeDefinition<*>) -> String) {
types.forEach { type ->
val dictionaryContainsType = dictionary.filter { it.key.name == type.name }.isNotEmpty()
if (!unvalidatedTypes.contains(type) && !dictionaryContainsType) {
Expand Down
138 changes: 69 additions & 69 deletions src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package graphql.kickstart.tools

import graphql.Scalars
import graphql.introspection.Introspection
import graphql.introspection.Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION
import graphql.kickstart.tools.directive.DirectiveWiringHelper
Expand All @@ -9,6 +8,7 @@ import graphql.kickstart.tools.util.getExtendedFieldDefinitions
import graphql.kickstart.tools.util.unwrap
import graphql.language.*
import graphql.schema.*
import graphql.schema.idl.DirectiveInfo
import graphql.schema.idl.RuntimeWiring
import graphql.schema.idl.ScalarInfo
import graphql.schema.visibility.NoIntrospectionGraphqlFieldVisibility
Expand Down Expand Up @@ -60,6 +60,8 @@ class SchemaParser internal constructor(
private val codeRegistryBuilder = GraphQLCodeRegistry.newCodeRegistry()
private val directiveWiringHelper = DirectiveWiringHelper(options, runtimeWiring, codeRegistryBuilder, directiveDefinitions)

private lateinit var schemaDirectives : Set<GraphQLDirective>

/**
* Parses the given schema with respect to the given dictionary and returns GraphQL objects.
*/
Expand All @@ -72,6 +74,7 @@ class SchemaParser internal constructor(

// Create GraphQL objects
val inputObjects: MutableList<GraphQLInputObjectType> = mutableListOf()
createDirectives(inputObjects)
inputObjectDefinitions.forEach {
if (inputObjects.none { io -> io.name == it.name }) {
inputObjects.add(createInputObject(it, inputObjects, mutableSetOf()))
Expand All @@ -82,8 +85,6 @@ class SchemaParser internal constructor(
val unions = unionDefinitions.map { createUnionObject(it, objects) }
val enums = enumDefinitions.map { createEnumObject(it) }

val directives = directiveDefinitions.map { createDirective(it, inputObjects) }.toSet()

// Assign type resolver to interfaces now that we know all of the object types
interfaces.forEach { codeRegistryBuilder.typeResolver(it, InterfaceTypeResolver(dictionary.inverse(), it)) }
unions.forEach { codeRegistryBuilder.typeResolver(it, UnionTypeResolver(dictionary.inverse(), it)) }
Expand All @@ -103,7 +104,7 @@ class SchemaParser internal constructor(
val additionalObjects = objects.filter { o -> o != query && o != subscription && o != mutation }

val types = (additionalObjects.toSet() as Set<GraphQLType>) + inputObjects + enums + interfaces + unions
return SchemaObjects(query, mutation, subscription, types, directives, codeRegistryBuilder, rootInfo.getDescription())
return SchemaObjects(query, mutation, subscription, types, schemaDirectives, codeRegistryBuilder, rootInfo.getDescription())
}

/**
Expand Down Expand Up @@ -300,44 +301,75 @@ class SchemaParser internal constructor(
.name(definition.name)
.definition(definition)
.description(getDocumentation(definition, options))
.type(determineInputType(definition.type, inputObjects, setOf()))
.type(determineInputType(definition.type, inputObjects, mutableSetOf()))
.apply { getDeprecated(definition.directives)?.let { deprecate(it) } }
.apply { definition.defaultValue?.let { defaultValueLiteral(it) } }
.withAppliedDirectives(*buildAppliedDirectives(definition.directives))
.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.ARGUMENT_DEFINITION))
.build()
}

private fun createDirective(definition: DirectiveDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLDirective {
val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray()
private fun createDirectives(inputObjects: MutableList<GraphQLInputObjectType>) {
schemaDirectives = directiveDefinitions.map { definition ->
val locations = definition.directiveLocations.map { Introspection.DirectiveLocation.valueOf(it.name) }.toTypedArray()

GraphQLDirective.newDirective()
.name(definition.name)
.description(getDocumentation(definition, options))
.definition(definition)
.comparatorRegistry(runtimeWiring.comparatorRegistry)
.validLocations(*locations)
.repeatable(definition.isRepeatable)
.apply {
definition.inputValueDefinitions.forEach { argumentDefinition ->
argument(createDirectiveArgument(argumentDefinition, inputObjects))
}
}
.build()
}.toSet()
// because the arguments can have directives too, we attach them only after the directives themselves are created
schemaDirectives = schemaDirectives.map { d ->
val arguments = d.arguments.map { a -> a.transform {
it.withAppliedDirectives(*buildAppliedDirectives(a.definition!!.directives))
.withDirectives(*buildDirectives(a.definition!!.directives, Introspection.DirectiveLocation.OBJECT))
} }
d.transform { it.replaceArguments(arguments) }
}.toSet()
}

return GraphQLDirective.newDirective()
private fun createDirectiveArgument(definition: InputValueDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLArgument {
return GraphQLArgument.newArgument()
.name(definition.name)
.description(getDocumentation(definition, options))
.definition(definition)
.comparatorRegistry(runtimeWiring.comparatorRegistry)
.validLocations(*locations)
.repeatable(definition.isRepeatable)
.apply {
definition.inputValueDefinitions.forEach { argumentDefinition ->
argument(createArgument(argumentDefinition, inputObjects))
}
}
.description(getDocumentation(definition, options))
.type(determineInputType(definition.type, inputObjects, mutableSetOf()))
.apply { getDeprecated(definition.directives)?.let { deprecate(it) } }
.apply { definition.defaultValue?.let { defaultValueLiteral(it) } }
.build()
}

private fun buildAppliedDirectives(directives: List<Directive>): Array<GraphQLAppliedDirective> {
return directives.map {
return directives.map { directive ->
val graphQLDirective = schemaDirectives.find { d -> d.name == directive.name }
?: DirectiveInfo.GRAPHQL_SPECIFICATION_DIRECTIVE_MAP[directive.name]
?: throw SchemaError("Found applied directive ${directive.name} without corresponding directive definition.")
val graphQLArguments = graphQLDirective.arguments.associateBy { it.name }

GraphQLAppliedDirective.newDirective()
.name(it.name)
.description(getDocumentation(it, options))
.name(directive.name)
.description(getDocumentation(directive, options))
.definition(directive)
.comparatorRegistry(runtimeWiring.comparatorRegistry)
.apply {
it.arguments.forEach { arg ->
directive.arguments.forEach { arg ->
val graphQLArgument = graphQLArguments[arg.name]
?: throw SchemaError("Found an unexpected directive argument ${directive.name}#${arg.name} .")
argument(GraphQLAppliedDirectiveArgument.newArgument()
.name(arg.name)
.type(buildDirectiveInputType(arg.value))
// TODO instead of guessing the type from its value, lookup the directive definition
.type(graphQLArgument.type)
.valueLiteral(arg.value)
.description(graphQLArgument.description)
.build()
)
}
Expand All @@ -358,6 +390,10 @@ class SchemaParser internal constructor(
val repeatable = directiveDefinitions.find { it.name.equals(directive.name) }?.isRepeatable ?: false
if (repeatable || !names.contains(directive.name)) {
names.add(directive.name)
val graphQLDirective = this.schemaDirectives.find { d -> d.name == directive.name }
?: DirectiveInfo.GRAPHQL_SPECIFICATION_DIRECTIVE_MAP[directive.name]
?: throw SchemaError("Found applied directive ${directive.name} without corresponding directive definition.")
val graphQLArguments = graphQLDirective.arguments.associateBy { it.name }
output.add(
GraphQLDirective.newDirective()
.name(directive.name)
Expand All @@ -367,9 +403,11 @@ class SchemaParser internal constructor(
.repeatable(repeatable)
.apply {
directive.arguments.forEach { arg ->
val graphQLArgument = graphQLArguments[arg.name]
?: throw SchemaError("Found an unexpected directive argument ${directive.name}#${arg.name}.")
argument(GraphQLArgument.newArgument()
.name(arg.name)
.type(buildDirectiveInputType(arg.value))
.type(graphQLArgument.type)
// TODO remove this once directives are fully replaced with applied directives
.valueLiteral(arg.value)
.build())
Expand All @@ -383,46 +421,6 @@ class SchemaParser internal constructor(
return output.toTypedArray()
}

private fun buildDirectiveInputType(value: Value<*>): GraphQLInputType? {
return when (value) {
is NullValue -> Scalars.GraphQLString
is FloatValue -> Scalars.GraphQLFloat
is StringValue -> Scalars.GraphQLString
is IntValue -> Scalars.GraphQLInt
is BooleanValue -> Scalars.GraphQLBoolean
is ArrayValue -> GraphQLList.list(buildDirectiveInputType(getArrayValueWrappedType(value)))
// TODO to implement this we'll need to "observe" directive's input types + match them here based on their fields(?)
else -> throw SchemaError("Directive values of type '${value::class.simpleName}' are not supported yet.")
}
}

private fun getArrayValueWrappedType(value: ArrayValue): Value<*> {
// empty array [] is equivalent to [null]
if (value.values.isEmpty()) {
return NullValue.newNullValue().build()
}

// get rid of null values
val nonNullValueList = value.values.filter { v -> v !is NullValue }

// [null, null, ...] unwrapped is null
if (nonNullValueList.isEmpty()) {
return NullValue.newNullValue().build()
}

// make sure the array isn't polymorphic
val distinctTypes = nonNullValueList
.map { it::class.java }
.distinct()

if (distinctTypes.size > 1) {
throw SchemaError("Arrays containing multiple types of values are not supported yet.")
}

// peek at first value, value exists and is assured to be non-null
return nonNullValueList[0]
}

private fun determineOutputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
determineType(GraphQLOutputType::class, typeDefinition, permittedTypesForObject, inputObjects) as GraphQLOutputType

Expand Down Expand Up @@ -455,13 +453,15 @@ class SchemaParser internal constructor(
else -> throw SchemaError("Unknown type: $typeDefinition")
}

private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: Set<String>) =
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: MutableSet<String>) =
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects)

private fun <T : Any> determineInputType(expectedType: KClass<T>,
typeDefinition: Type<*>, allowedTypeReferences: Set<String>,
inputObjects: List<GraphQLInputObjectType>,
referencingInputObjects: Set<String>): GraphQLInputType =
private fun <T : Any> determineInputType(
expectedType: KClass<T>,
typeDefinition: Type<*>,
allowedTypeReferences: Set<String>,
inputObjects: List<GraphQLInputObjectType>,
referencingInputObjects: MutableSet<String>): GraphQLInputType =
when (typeDefinition) {
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
Expand Down Expand Up @@ -489,7 +489,7 @@ class SchemaParser internal constructor(
if (referencingInputObject != null) {
GraphQLTypeReference(referencingInputObject)
} else {
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects as MutableSet<String>)
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects)
(inputObjects as MutableList).add(inputObject)
inputObject
}
Expand Down
Loading

0 comments on commit ed2b48b

Please sign in to comment.