Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix coercion of typedef primitives and their pointers #489

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 99 additions & 8 deletions scalars.go
Original file line number Diff line number Diff line change
@@ -3,12 +3,79 @@ package graphql
import (
"fmt"
"math"
"reflect"
"strconv"
"time"

"github.com/graphql-go/graphql/language/ast"
)

func unwrapInt(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.Int:
return int(r.Int()), true
case reflect.Int8:
return int8(r.Int()), true
case reflect.Int16:
return int16(r.Int()), true
case reflect.Int32:
return int32(r.Int()), true
case reflect.Int64:
return r.Int(), true
default:
return nil, false
}
}

func unwrapFloat(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.Float32:
return float32(r.Float()), true
case reflect.Float64:
return r.Float(), true
default:
return nil, false
}
}

func unwrapBool(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.Bool:
return r.Bool(), true
default:
return nil, false
}
}

func unwrapString(value interface{}) (interface{}, bool) {
r := reflect.Indirect(reflect.ValueOf(value))
if !r.IsValid() || (r.Kind() == reflect.Ptr && r.IsNil()) {
return nil, false
}

switch r.Kind() {
case reflect.String:
return r.String(), true
default:
return nil, false
}
}

// As per the GraphQL Spec, Integers are only treated as valid when a valid
// 32-bit signed integer, providing the broadest support across platforms.
//
@@ -142,11 +209,14 @@ func coerceInt(value interface{}) interface{} {
return nil
}
return coerceInt(*value)
default:
if v, ok := unwrapInt(value); ok {
return coerceInt(v)
}
// If the value cannot be transformed into an int, return nil instead of '0'
// to denote 'no integer found'
return nil
}

// If the value cannot be transformed into an int, return nil instead of '0'
// to denote 'no integer found'
return nil
}

// Int is the GraphQL Integer type definition.
@@ -276,6 +346,10 @@ func coerceFloat(value interface{}) interface{} {
return coerceFloat(*value)
}

if v, ok := unwrapFloat(value); ok {
return coerceFloat(v)
}

// If the value cannot be transformed into an float, return nil instead of '0.0'
// to denote 'no float found'
return nil
@@ -305,13 +379,23 @@ var Float = NewScalar(ScalarConfig{
})

func coerceString(value interface{}) interface{} {
if v, ok := value.(*string); ok {
if v == nil {
switch t := value.(type) {
case *string:
if t == nil {
return nil
}
return *v
return *t
case string:
return t
default:
if v, ok := unwrapString(value); ok {
return coerceString(v)
}
if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() {
return nil
}
return fmt.Sprintf("%v", value)
}
return fmt.Sprintf("%v", value)
}

// String is the GraphQL string type definition
@@ -472,6 +556,13 @@ func coerceBool(value interface{}) interface{} {
}
return coerceBool(*value)
}

if v, ok := unwrapBool(value); ok {
return coerceBool(v)
}
if r := reflect.ValueOf(value); r.Kind() == reflect.Ptr && r.IsNil() {
return nil
}
return false
}

95 changes: 92 additions & 3 deletions scalars_test.go
Original file line number Diff line number Diff line change
@@ -5,6 +5,50 @@ import (
"testing"
)

type (
myInt int
myString string
myBool bool
myFloat32 float32
)

func TestCoerceString(t *testing.T) {
tests := []struct {
in interface{}
want interface{}
}{
{
in: "hello",
want: "hello",
},
{
in: func() interface{} { s := "hello"; return &s }(),
want: "hello",
},
// Typedef
{
in: myString("hello"),
want: "hello",
},
// Typedef with pointer
{
in: func() interface{} { v := myString("hello"); return &v }(),
want: "hello",
},
// Typedef with nil pointer
{
in: (*myString)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceString(tt.in), tt.want; got != want {
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}

func TestCoerceInt(t *testing.T) {
tests := []struct {
in interface{}
@@ -240,11 +284,26 @@ func TestCoerceInt(t *testing.T) {
in: make(map[string]interface{}),
want: nil,
},
// Typedef
{
in: myInt(42),
want: int(42),
},
// Typedef with pointer
{
in: func() interface{} { v := myInt(42); return &v }(),
want: int(42),
},
// Typedef with nil pointer
{
in: (*myInt)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceInt(tt.in), tt.want; got != want {
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}
@@ -438,11 +497,26 @@ func TestCoerceFloat(t *testing.T) {
in: make(map[string]interface{}),
want: nil,
},
// Typedef
{
in: myFloat32(3.14),
want: float32(3.14),
},
// Typedef with pointer
{
in: func() interface{} { v := myFloat32(3.14); return &v }(),
want: float32(3.14),
},
// Typedef with nil pointer
{
in: (*myFloat32)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceFloat(tt.in), tt.want; got != want {
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}
@@ -740,11 +814,26 @@ func TestCoerceBool(t *testing.T) {
in: make(map[string]interface{}),
want: false,
},
// Typedef
{
in: myBool(true),
want: true,
},
// Typedef with pointer
{
in: func() interface{} { v := myBool(true); return &v }(),
want: true,
},
// Typedef with nil pointer
{
in: (*myBool)(nil),
want: nil,
},
}

for i, tt := range tests {
if got, want := coerceBool(tt.in), tt.want; got != want {
t.Errorf("%d: in=%v, got=%v, want=%v", i, tt.in, got, want)
t.Errorf("%d: in=%#v, got=%#v, want=%#v", i, tt.in, got, want)
}
}
}