diff --git a/graphql.go b/graphql.go index eff0b9a3..18c9f913 100644 --- a/graphql.go +++ b/graphql.go @@ -72,13 +72,13 @@ type Schema struct { useStringDescriptions bool disableIntrospection bool - extendRes interface{} + extendRes map[string]interface{} } // SchemaOpt is an option to pass to ParseSchema or MustParseSchema. type SchemaOpt func(*Schema) -func UseExtendResolver(ext interface{}) SchemaOpt { +func UseExtendResolver(ext map[string]interface{}) SchemaOpt { return func(s *Schema) { s.extendRes = ext } diff --git a/graphql_test.go b/graphql_test.go index cb0cab4a..37a249fe 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -163,6 +163,48 @@ func (r *discussPlanResolver) DismissVader(ctx context.Context) (string, error) return "", errors.New("I find your lack of faith disturbing") } +type ByeWorldResolver1 struct { + helloWorldResolver1 +} + +func (b *ByeWorldResolver1) Bye() string { + return "Bye world!" +} + +func TestExtendResolver(t *testing.T) { + t.Parallel() + + extResolvers := make(map[string]interface{}) + extResolvers["Query"] = &ByeWorldResolver1{} + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + + Schema: graphql.MustParseSchema(` + schema { + query: Query + } + + type Query { + hello: String! + bye: String! + }`, &helloWorldResolver1{}, graphql.UseExtendResolver(extResolvers)), + Query: ` + { + hello + bye + } + `, + ExpectedResult: ` + { + "hello": "Hello world!", + "bye": "Bye world!" + } + `, + }, + }) +} + func TestHelloWorld(t *testing.T) { t.Parallel() @@ -2742,7 +2784,9 @@ func (r *inputResolver) Nullable(args struct{ Value *int32 }) *int32 { return args.Value } -func (r *inputResolver) List(args struct{ Value []*struct{ V int32 } }) []int32 { +func (r *inputResolver) List(args struct { + Value []*struct{ V int32 } +}) []int32 { l := make([]int32, len(args.Value)) for i, entry := range args.Value { l[i] = entry.V @@ -2750,7 +2794,9 @@ func (r *inputResolver) List(args struct{ Value []*struct{ V int32 } }) []int32 return l } -func (r *inputResolver) NullableList(args struct{ Value *[]*struct{ V int32 } }) *[]*int32 { +func (r *inputResolver) NullableList(args struct { + Value *[]*struct{ V int32 } +}) *[]*int32 { if args.Value == nil { return nil } @@ -2926,7 +2972,7 @@ func TestInput(t *testing.T) { }) } -type inputArgumentsHello struct {} +type inputArgumentsHello struct{} type inputArgumentsScalarMismatch1 struct{} @@ -2946,7 +2992,7 @@ type helloInputMismatch struct { World string } -func (r *inputArgumentsHello) Hello(args struct { Input *helloInput }) string { +func (r *inputArgumentsHello) Hello(args struct{ Input *helloInput }) string { return "Hello " + args.Input.Name + "!" } @@ -2954,7 +3000,7 @@ func (r *inputArgumentsScalarMismatch1) Hello(name string) string { return "Hello " + name + "!" } -func (r *inputArgumentsScalarMismatch2) Hello(args struct { World string }) string { +func (r *inputArgumentsScalarMismatch2) Hello(args struct{ World string }) string { return "Hello " + args.World + "!" } @@ -2962,11 +3008,11 @@ func (r *inputArgumentsObjectMismatch1) Hello(in helloInput) string { return "Hello " + in.Name + "!" } -func (r *inputArgumentsObjectMismatch2) Hello(args struct { Input *helloInputMismatch }) string { +func (r *inputArgumentsObjectMismatch2) Hello(args struct{ Input *helloInputMismatch }) string { return "Hello " + args.Input.World + "!" } -func (r *inputArgumentsObjectMismatch3) Hello(args struct { Input *struct { Thing string } }) string { +func (r *inputArgumentsObjectMismatch3) Hello(args struct{ Input *struct{ Thing string } }) string { return "Hello " + args.Input.Thing + "!" } diff --git a/internal/exec/exec.go b/internal/exec/exec.go index 1504b5ce..bbf1ed49 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -204,8 +204,8 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f var callOut []reflect.Value - if f.field.TypeName == "statTeam" { - res = reflect.NewAt(s.ExtResolver.Elem().Type(), unsafe.Pointer(res.Elem().UnsafeAddr())) + if v, ok := s.ExtResolver[f.field.TypeName]; ok { + res = reflect.NewAt(v.Elem().Type(), unsafe.Pointer(res.Elem().UnsafeAddr())) } callOut = res.Method(f.field.MethodIndex).Call(in) diff --git a/internal/exec/resolvable/resolvable.go b/internal/exec/resolvable/resolvable.go index 6278d2c7..2e6f6e45 100644 --- a/internal/exec/resolvable/resolvable.go +++ b/internal/exec/resolvable/resolvable.go @@ -18,7 +18,7 @@ type Schema struct { Mutation Resolvable Subscription Resolvable Resolver reflect.Value - ExtResolver reflect.Value + ExtResolver map[string]reflect.Value } type Resolvable interface { @@ -62,7 +62,7 @@ func (*Object) isResolvable() {} func (*List) isResolvable() {} func (*Scalar) isResolvable() {} -func ApplyResolver(s *schema.Schema, resolver interface{}, ext interface{}) (*Schema, error) { +func ApplyResolver(s *schema.Schema, resolver interface{}, ext map[string]interface{}) (*Schema, error) { if resolver == nil { return &Schema{Meta: newMeta(s), Schema: *s}, nil } @@ -72,19 +72,19 @@ func ApplyResolver(s *schema.Schema, resolver interface{}, ext interface{}) (*Sc var query, mutation, subscription Resolvable if t, ok := s.EntryPoints["query"]; ok { - if err := b.assignExec(&query, t, reflect.TypeOf(resolver), reflect.TypeOf(ext)); err != nil { + if err := b.assignExec(&query, t, reflect.TypeOf(resolver), ext); err != nil { return nil, err } } if t, ok := s.EntryPoints["mutation"]; ok { - if err := b.assignExec(&mutation, t, reflect.TypeOf(resolver), reflect.TypeOf(ext)); err != nil { + if err := b.assignExec(&mutation, t, reflect.TypeOf(resolver), ext); err != nil { return nil, err } } if t, ok := s.EntryPoints["subscription"]; ok { - if err := b.assignExec(&subscription, t, reflect.TypeOf(resolver), reflect.TypeOf(ext)); err != nil { + if err := b.assignExec(&subscription, t, reflect.TypeOf(resolver), ext); err != nil { return nil, err } } @@ -93,11 +93,16 @@ func ApplyResolver(s *schema.Schema, resolver interface{}, ext interface{}) (*Sc return nil, err } + extResolvers := make(map[string]reflect.Value) + for i := range ext { + extResolvers[i] = reflect.ValueOf(ext[i]) + } + return &Schema{ Meta: newMeta(s), Schema: *s, Resolver: reflect.ValueOf(resolver), - ExtResolver: reflect.ValueOf(ext), + ExtResolver: extResolvers, Query: query, Mutation: mutation, Subscription: subscription, @@ -138,7 +143,7 @@ func (b *execBuilder) finish() error { return b.packerBuilder.Finish() } -func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType reflect.Type, ext reflect.Type) error { +func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType reflect.Type, ext map[string]interface{}) error { k := typePair{t, resolverType} ref, ok := b.resMap[k] if !ok { @@ -155,14 +160,16 @@ func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType return nil } -func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type, ext reflect.Type) (Resolvable, error) { +func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type, ext map[string]interface{}) (Resolvable, error) { var nonNull bool t, nonNull = unwrapNonNull(t) switch t := t.(type) { case *schema.Object: - if t.Name == "statTeam" && ext != nil { - resolverType = ext + if ext != nil { + if v, ok := ext[t.Name]; ok { + resolverType = reflect.TypeOf(v) + } } return b.makeObjectExec(t.Name, t.Fields, nil, nonNull, resolverType, ext) @@ -224,7 +231,7 @@ func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, er } func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, possibleTypes []*schema.Object, - nonNull bool, resolverType reflect.Type, ext reflect.Type) (*Object, error) { + nonNull bool, resolverType reflect.Type, ext map[string]interface{}) (*Object, error) { if !nonNull { if resolverType.Kind() != reflect.Ptr && resolverType.Kind() != reflect.Interface { return nil, fmt.Errorf("%s is not a pointer or interface", resolverType) @@ -301,7 +308,7 @@ var contextType = reflect.TypeOf((*context.Context)(nil)).Elem() var errorType = reflect.TypeOf((*error)(nil)).Elem() func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, sf reflect.StructField, - methodIndex int, fieldIndex []int, methodHasReceiver bool, ext reflect.Type) (*Field, error) { + methodIndex int, fieldIndex []int, methodHasReceiver bool, ext map[string]interface{}) (*Field, error) { var argsPacker *packer.StructPacker var hasError bool