Skip to content

Commit

Permalink
Merge pull request #56 from tartiflette/ISSUE-55
Browse files Browse the repository at this point in the history
ISSUE-55 - Provides a context factory parameter
  • Loading branch information
Maximilien-R authored Oct 3, 2019
2 parents ddaf215 + ed6e739 commit c76fda1
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 37 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Released]

- [1.x.x]
- [1.1.x]
- [1.1.0](./changelogs/1.1.0.md) - 2019-10-02
- [1.0.x]
- [1.0.0](./changelogs/1.0.0.md) - 2019-09-12
- [0.x.x]
Expand Down
26 changes: 26 additions & 0 deletions changelogs/1.1.0.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# [1.1.0] -- 2019-10-02

## Added

- [ISSUE-55](https://github.com/tartiflette/tartiflette-aiohttp/issues/55) - Add
a new optional `context_factory` parameter to the `register_graphql_handlers`
function. This parameter can take a coroutine function which will be called on
each request with the following signature:
```python
async def context_factory(
context: Dict[str, Any], req: "aiohttp.web.Request"
) -> Dict[str, Any]:
"""
Generates a new context.
:param context: the value filled in through the `executor_context`
parameter
:param req: the incoming aiohttp request instance
:type context: Dict[str, Any]
:type req: aiohttp.web.Request
:return: the context for the incoming request
:rtype: Dict[str, Any]
"""
```

The aim of this function will be to returns the context which will be forwarded
to the Tartiflette engine on the `execute` or `subscribe` method.
4 changes: 2 additions & 2 deletions changelogs/next.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# [next]
# [Next]

## Added

## Changed

## Fixed
## Fixed
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"isort==4.3.21",
]

_VERSION = "1.0.0"
_VERSION = "1.1.0"

_PACKAGES = find_packages(exclude=["tests*"])

Expand Down
26 changes: 20 additions & 6 deletions tartiflette_aiohttp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json

from functools import partial
from inspect import iscoroutine
from typing import Any, Dict, List, Optional, Union
from inspect import iscoroutine, iscoroutinefunction
from typing import Any, Callable, Dict, List, Optional, Union

from tartiflette import Engine
from tartiflette_aiohttp._context_factory import default_context_factory
from tartiflette_aiohttp._graphiql import graphiql_handler
from tartiflette_aiohttp._handler import Handlers
from tartiflette_aiohttp._subscription_ws_handler import (
Expand Down Expand Up @@ -35,15 +36,15 @@ def validate_and_compute_graphiql_option(
def _set_subscription_ws_handler(
app: "Application",
subscription_ws_endpoint: Optional[str],
context: Dict[str, Any],
context_factory: Callable,
) -> None:
if not subscription_ws_endpoint:
return

app.router.add_route(
"GET",
subscription_ws_endpoint,
AIOHTTPSubscriptionHandler(app, context),
AIOHTTPSubscriptionHandler(app, context_factory),
)


Expand Down Expand Up @@ -116,6 +117,7 @@ def register_graphql_handlers(
engine_modules: Optional[
List[Union[str, Dict[str, Union[str, Dict[str, str]]]]]
] = None,
context_factory: Optional[Callable] = None,
) -> "Application":
"""Register a Tartiflette Engine to an app
Expand All @@ -133,10 +135,12 @@ def register_graphql_handlers(
graphiql_enabled {bool} -- Determines whether or not we should handle a GraphiQL endpoint (default: {False})
graphiql_options {dict} -- Customization options for the GraphiQL instance (default: {None})
engine_modules: {Optional[List[Union[str, Dict[str, Union[str, Dict[str, str]]]]]]} -- Module to import (default:{None})
context_factory: {Optional[Callable]} -- coroutine function in charge of generating the context for each request (default: {None})
Raises:
Exception -- On bad sdl/engine parameter combinaison.
Exception -- On unsupported HTTP Method.
Exception -- if `context_factory` is filled in without a coroutine function.
Return:
The app object.
Expand All @@ -150,6 +154,16 @@ def register_graphql_handlers(
if not executor_http_methods:
executor_http_methods = ["GET", "POST"]

if context_factory is None:
context_factory = default_context_factory

if not iscoroutinefunction(context_factory):
raise Exception(
"`context_factory` parameter should be a coroutine function."
)

context_factory = partial(context_factory, executor_context)

if not engine:
engine = Engine()

Expand All @@ -174,14 +188,14 @@ def register_graphql_handlers(
executor_http_endpoint,
partial(
getattr(Handlers, "handle_%s" % method.lower()),
executor_context,
context_factory=context_factory,
),
)
except AttributeError:
raise Exception("Unsupported < %s > http method" % method)

_set_subscription_ws_handler(
app, subscription_ws_endpoint, executor_context
app, subscription_ws_endpoint, context_factory
)

_set_graphiql_handler(
Expand Down
19 changes: 19 additions & 0 deletions tartiflette_aiohttp/_context_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Any, Dict

__all__ = ("default_context_factory",)


async def default_context_factory(
context: Dict[str, Any], req: "aiohttp.web.Request"
) -> Dict[str, Any]:
"""
Generates a new context.
:param context: the value filled in through the `executor_context`
parameter
:param req: the incoming aiohttp request instance
:type context: Dict[str, Any]
:type req: aiohttp.web.Request
:return: the context for the incoming request
:rtype: Dict[str, Any]
"""
return {**context, "req": req}
25 changes: 13 additions & 12 deletions tartiflette_aiohttp/_handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
import logging

from copy import copy

from aiohttp import web

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -34,8 +32,11 @@ def prepare_response(data):
return web.json_response(data, headers=headers, dumps=json.dumps)


async def _handle_query(req, query, query_vars, operation_name, context):
context = copy(context)
async def _handle_query(
req, query, query_vars, operation_name, context_factory
):
context = await context_factory(req)

try:
if not operation_name:
operation_name = None
Expand Down Expand Up @@ -94,23 +95,23 @@ async def _post_params(req):

class Handlers:
@staticmethod
async def _handle(param_func, user_c, req):
user_c["req"] = req

async def _handle(param_func, req, context_factory):
try:
qry, qry_vars, oprn_name = await param_func(req)
return prepare_response(
await _handle_query(req, qry, qry_vars, oprn_name, user_c)
await _handle_query(
req, qry, qry_vars, oprn_name, context_factory
)
)
except BadRequestError as e:
return prepare_response(
{"data": None, "errors": _format_errors([e])}
)

@staticmethod
async def handle_get(user_context, req):
return await Handlers._handle(_get_params, user_context, req)
async def handle_get(req, context_factory):
return await Handlers._handle(_get_params, req, context_factory)

@staticmethod
async def handle_post(user_context, req):
return await Handlers._handle(_post_params, user_context, req)
async def handle_post(req, context_factory):
return await Handlers._handle(_post_params, req, context_factory)
13 changes: 7 additions & 6 deletions tartiflette_aiohttp/_subscription_ws_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

from asyncio import ensure_future, shield, wait
from typing import Any, AsyncIterator, Dict, Optional, Set
from typing import Any, AsyncIterator, Callable, Dict, Optional, Set

from aiohttp import WSMsgType, web

Expand Down Expand Up @@ -79,9 +79,11 @@ async def close(self, code: int) -> None:


class AIOHTTPSubscriptionHandler:
def __init__(self, app: "Application", context: Dict[str, Any]) -> None:
def __init__(self, app: "Application", context_factory: Callable) -> None:
self._app: "Application" = app
self._context = context
self._context_factory = context_factory
self._socket: Optional["web.WebSocketResponse"] = None
self._context: Optional[Dict[str, Any]] = None

async def _send_message(
self,
Expand Down Expand Up @@ -255,9 +257,8 @@ async def _handle_request(self) -> None:
await self._on_close(connection_context, tasks)

async def __call__(self, request: "Request") -> "WebSocketResponse":
self._socket = web.WebSocketResponse( # pylint: disable=attribute-defined-outside-init
protocols=(WS_PROTOCOL,)
)
self._socket = web.WebSocketResponse(protocols=(WS_PROTOCOL,))
self._context = await self._context_factory(request)
await self._socket.prepare(request)
await shield(self._handle_request())
return self._socket
15 changes: 9 additions & 6 deletions tests/integration/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from functools import partial
from unittest.mock import Mock

import pytest

from tartiflette_aiohttp import default_context_factory


@pytest.mark.asyncio
async def test_handler__handle_query__context_unicity():
Expand Down Expand Up @@ -31,18 +34,18 @@ async def resolver_hello(parent, args, ctx, info):
a_req = Mock()
a_req.app = {"ttftt_engine": tftt_engine}

ctx = {}
context_factory = partial(default_context_factory, {})

await _handle_query(
a_req, 'query { hello(name: "Chuck") }', None, None, ctx
a_req, 'query { hello(name: "Chuck") }', None, None, context_factory
)

await _handle_query(
a_req, 'query { hello(name: "Chuck") }', None, None, ctx
a_req, 'query { hello(name: "Chuck") }', None, None, context_factory
)

b_response = await _handle_query(
a_req, 'query { hello(name: "Chuck") }', None, None, ctx
a_req, 'query { hello(name: "Chuck") }', None, None, context_factory
)

assert b_response == {"data": {"hello": "hello 1"}}
Expand Down Expand Up @@ -71,7 +74,7 @@ async def resolver_hello(parent, args, ctx, info):
a_req = Mock()
a_req.app = {"ttftt_engine": tftt_engine}

ctx = {}
context_factory = partial(default_context_factory, {})

result = await _handle_query(
a_req,
Expand All @@ -82,7 +85,7 @@ async def resolver_hello(parent, args, ctx, info):
""",
None,
"B",
ctx,
context_factory,
)

assert result == {"data": {"hello": "hello Bar"}}
21 changes: 17 additions & 4 deletions tests/unit/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from functools import partial
from unittest.mock import Mock

import pytest

from asynctest import CoroutineMock

from tartiflette_aiohttp import default_context_factory


@pytest.mark.parametrize(
"value,expected",
Expand Down Expand Up @@ -32,7 +35,11 @@ async def test_handler__handle_query():
a_req.app = {"ttftt_engine": an_engine}

a_response = await _handle_query(
a_req, "query a {}", {"B": "C"}, "a", {"D": "E"}
a_req,
"query a {}",
{"B": "C"},
"a",
partial(default_context_factory, {"D": "E"}),
)

assert a_response == "T"
Expand All @@ -42,7 +49,7 @@ async def test_handler__handle_query():
{
"query": "query a {}",
"variables": {"B": "C"},
"context": {"D": "E"},
"context": {"D": "E", "req": a_req},
"operation_name": "a",
},
)
Expand All @@ -60,7 +67,11 @@ async def test_handler__handle_query_nok():
a_req.app = {}

a_response = await _handle_query(
a_req, "query a {}", {"B": "C"}, "a", {"D": "E"}
a_req,
"query a {}",
{"B": "C"},
"a",
partial(default_context_factory, {"D": "E"}),
)

assert a_response == {
Expand Down Expand Up @@ -179,7 +190,9 @@ async def test_handler__handle():

a_method = CoroutineMock(return_value=("a", "b", "c"))

await Handlers._handle(a_method, {}, a_req)
await Handlers._handle(
a_method, a_req, partial(default_context_factory, {})
)

assert a_method.call_args_list == [((a_req,),)]

Expand Down

0 comments on commit c76fda1

Please sign in to comment.