From f4cb457c403e4aac1d76107b4d03f0aee3a57e10 Mon Sep 17 00:00:00 2001 From: Alexey Michurin Date: Fri, 5 Oct 2018 23:23:39 +0300 Subject: [PATCH] Add RequestCallbackFn --- handler.go | 46 +++++++++++++++++++++++++++------------------- handler_test.go | 22 ++++++++++++++++++---- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/handler.go b/handler.go index 57a37b8..fd494aa 100644 --- a/handler.go +++ b/handler.go @@ -19,14 +19,16 @@ const ( ) type ResultCallbackFn func(ctx context.Context, params *graphql.Params, result *graphql.Result, responseBody []byte) +type RequestCallbackFn func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context type Handler struct { - Schema *graphql.Schema - pretty bool - graphiql bool - playground bool - rootObjectFn RootObjectFn - resultCallbackFn ResultCallbackFn + Schema *graphql.Schema + pretty bool + graphiql bool + playground bool + requestCallbackFn RequestCallbackFn + rootObjectFn RootObjectFn + resultCallbackFn ResultCallbackFn } type RequestOptions struct { Query string `json:"query" url:"query" schema:"query"` @@ -178,19 +180,24 @@ func (h *Handler) ContextHandler(ctx context.Context, w http.ResponseWriter, r * // ServeHTTP provides an entrypoint into executing graphQL queries. func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - h.ContextHandler(r.Context(), w, r) + ctx := r.Context() + if h.requestCallbackFn != nil { + ctx = h.requestCallbackFn(ctx, w, r) + } + h.ContextHandler(ctx, w, r) } // RootObjectFn allows a user to generate a RootObject per request type RootObjectFn func(ctx context.Context, r *http.Request) map[string]interface{} type Config struct { - Schema *graphql.Schema - Pretty bool - GraphiQL bool - Playground bool - RootObjectFn RootObjectFn - ResultCallbackFn ResultCallbackFn + Schema *graphql.Schema + Pretty bool + GraphiQL bool + Playground bool + RequestCallbackFn RequestCallbackFn + RootObjectFn RootObjectFn + ResultCallbackFn ResultCallbackFn } func NewConfig() *Config { @@ -211,11 +218,12 @@ func New(p *Config) *Handler { } return &Handler{ - Schema: p.Schema, - pretty: p.Pretty, - graphiql: p.GraphiQL, - playground: p.Playground, - rootObjectFn: p.RootObjectFn, - resultCallbackFn: p.ResultCallbackFn, + Schema: p.Schema, + pretty: p.Pretty, + graphiql: p.GraphiQL, + playground: p.Playground, + requestCallbackFn: p.RequestCallbackFn, + rootObjectFn: p.RootObjectFn, + resultCallbackFn: p.ResultCallbackFn, } } diff --git a/handler_test.go b/handler_test.go index b8f68b2..dd787cd 100644 --- a/handler_test.go +++ b/handler_test.go @@ -96,16 +96,24 @@ func TestHandler_BasicQuery_Pretty(t *testing.T) { queryString := `query=query HeroNameQuery { hero { name } }&operationName=HeroNameQuery` req, _ := http.NewRequest("GET", fmt.Sprintf("/graphql?%v", queryString), nil) - callbackCalled := false + requestCallbackCalled := false + resultCallbackCalled := false h := handler.New(&handler.Config{ Schema: &testutil.StarWarsSchema, Pretty: true, + RequestCallbackFn: func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { + requestCallbackCalled = true + w.Header().Add("X-Custom-Header", "header-value") + return context.WithValue(ctx, "custom", "value") + }, ResultCallbackFn: func(ctx context.Context, params *graphql.Params, result *graphql.Result, responseBody []byte) { - callbackCalled = true + resultCallbackCalled = true if params.OperationName != "HeroNameQuery" { t.Fatalf("OperationName passed to callback was not HeroNameQuery: %v", params.OperationName) } - + if ctx.Value("custom").(string) != "value" { + t.Fatalf("context was not feeled in RequestCallbackFn") + } if result.HasErrors() { t.Fatalf("unexpected graphql result errors") } @@ -115,10 +123,16 @@ func TestHandler_BasicQuery_Pretty(t *testing.T) { if resp.Code != http.StatusOK { t.Fatalf("unexpected server response %v", resp.Code) } + if resp.Header()["X-Custom-Header"][0] != "header-value" { + t.Fatalf("HTTP headers was not feeled in RequestCallbackFn") + } if !reflect.DeepEqual(result, expected) { t.Fatalf("wrong result, graphql result diff: %v", testutil.Diff(expected, result)) } - if !callbackCalled { + if !requestCallbackCalled { + t.Fatalf("RequestCallbackFn was not called when it should have been") + } + if !resultCallbackCalled { t.Fatalf("ResultCallbackFn was not called when it should have been") } }