Skip to content

Commit 4f6df14

Browse files
authored
Merge pull request #145 from microsoftgraph/bugfix/set-context-and-feature-usage
Bugfix/set context and feature usage
2 parents 0487b2c + 2a814c6 commit 4f6df14

7 files changed

+75
-44
lines changed

src/msgraph_core/graph_client_factory.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
import httpx
1010
from kiota_http.kiota_client_factory import KiotaClientFactory
11-
from kiota_http.middleware import AsyncKiotaTransport
1211
from kiota_http.middleware.middleware import BaseMiddleware
1312

1413
from ._enums import APIVersion, NationalClouds
15-
from .middleware import GraphTelemetryHandler
14+
from .middleware import AsyncGraphTransport, GraphTelemetryHandler
1615

1716

1817
class GraphClientFactory(KiotaClientFactory):
@@ -40,9 +39,10 @@ def create_with_default_middleware(
4039
middleware, current_transport
4140
)
4241

43-
client._transport = AsyncKiotaTransport(
42+
client._transport = AsyncGraphTransport(
4443
transport=current_transport, pipeline=middleware_pipeline
4544
)
45+
client._transport.pipeline
4646
return client
4747

4848
@staticmethod
@@ -66,7 +66,7 @@ def create_with_custom_middleware(
6666
middleware, current_transport
6767
)
6868

69-
client._transport = AsyncKiotaTransport(
69+
client._transport = AsyncGraphTransport(
7070
transport=current_transport, pipeline=middleware_pipeline
7171
)
7272
return client

src/msgraph_core/middleware/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5+
from .async_graph_transport import AsyncGraphTransport
56
from .request_context import GraphRequestContext
67
from .telemetry import GraphTelemetryHandler
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import json
2+
3+
import httpx
4+
from kiota_http.middleware import MiddlewarePipeline, RedirectHandler, RetryHandler
5+
6+
from .._enums import FeatureUsageFlag
7+
from .request_context import GraphRequestContext
8+
9+
10+
class AsyncGraphTransport(httpx.AsyncBaseTransport):
11+
"""A custom transport for requests to the Microsoft Graph API
12+
"""
13+
14+
def __init__(self, transport: httpx.AsyncBaseTransport, pipeline: MiddlewarePipeline) -> None:
15+
self.transport = transport
16+
self.pipeline = pipeline
17+
18+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
19+
if self.pipeline:
20+
self.set_request_context_and_feature_usage(request)
21+
response = await self.pipeline.send(request)
22+
return response
23+
24+
response = await self.transport.handle_async_request(request)
25+
return response
26+
27+
def set_request_context_and_feature_usage(self, request: httpx.Request) -> httpx.Request:
28+
29+
request_options = {}
30+
options = request.headers.get('request_options', None)
31+
if options:
32+
request_options = json.loads(options)
33+
34+
context = GraphRequestContext(request_options, request.headers)
35+
middleware = self.pipeline._first_middleware
36+
while middleware:
37+
if isinstance(middleware, RedirectHandler):
38+
context.feature_usage = FeatureUsageFlag.REDIRECT_HANDLER_ENABLED
39+
if isinstance(middleware, RetryHandler):
40+
context.feature_usage = FeatureUsageFlag.RETRY_HANDLER_ENABLED
41+
42+
middleware = middleware.next
43+
request.context = context #type: ignore
44+
return request

src/msgraph_core/middleware/telemetry.py

+4-25
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import platform
44

55
import httpx
6-
from kiota_http.middleware import AsyncKiotaTransport, BaseMiddleware, RedirectHandler, RetryHandler
6+
from kiota_http.middleware import BaseMiddleware
77
from urllib3.util import parse_url
88

99
from .._constants import SDK_VERSION
10-
from .._enums import FeatureUsageFlag, NationalClouds
10+
from .._enums import NationalClouds
11+
from .async_graph_transport import AsyncGraphTransport
1112
from .request_context import GraphRequestContext
1213

1314

@@ -20,10 +21,9 @@ class GraphTelemetryHandler(BaseMiddleware):
2021
the SDK team improve the developer experience.
2122
"""
2223

23-
async def send(self, request: GraphRequest, transport: AsyncKiotaTransport):
24+
async def send(self, request: GraphRequest, transport: AsyncGraphTransport):
2425
"""Adds telemetry headers and sends the http request.
2526
"""
26-
self.set_request_context_and_feature_usage(request, transport)
2727

2828
if self.is_graph_url(request.url):
2929
self._add_client_request_id_header(request)
@@ -34,27 +34,6 @@ async def send(self, request: GraphRequest, transport: AsyncKiotaTransport):
3434
response = await super().send(request, transport)
3535
return response
3636

37-
def set_request_context_and_feature_usage(
38-
self, request: GraphRequest, transport: AsyncKiotaTransport
39-
) -> GraphRequest:
40-
41-
request_options = {}
42-
options = request.headers.pop('request_options', None)
43-
if options:
44-
request_options = json.loads(options)
45-
46-
request.context = GraphRequestContext(request_options, request.headers)
47-
middleware = transport.pipeline._first_middleware
48-
while middleware:
49-
if isinstance(middleware, RedirectHandler):
50-
request.context.feature_usage = FeatureUsageFlag.REDIRECT_HANDLER_ENABLED
51-
if isinstance(middleware, RetryHandler):
52-
request.context.feature_usage = FeatureUsageFlag.RETRY_HANDLER_ENABLED
53-
54-
middleware = middleware.next
55-
56-
return request
57-
5837
def is_graph_url(self, url):
5938
"""Check if the request is made to a graph endpoint. We do not add telemetry headers to
6039
non-graph endpoints"""
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
from kiota_http.kiota_client_factory import KiotaClientFactory
3+
4+
from msgraph_core._enums import FeatureUsageFlag
5+
from msgraph_core.middleware import AsyncGraphTransport, GraphRequestContext
6+
7+
8+
def test_set_request_context_and_feature_usage(mock_request, mock_transport):
9+
middleware = KiotaClientFactory.get_default_middleware()
10+
pipeline = KiotaClientFactory.create_middleware_pipeline(middleware, mock_transport)
11+
transport = AsyncGraphTransport(mock_transport, pipeline)
12+
transport.set_request_context_and_feature_usage(mock_request)
13+
14+
assert hasattr(mock_request, 'context')
15+
assert isinstance(mock_request.context, GraphRequestContext)
16+
assert mock_request.context.feature_usage == hex(
17+
FeatureUsageFlag.RETRY_HANDLER_ENABLED | FeatureUsageFlag.REDIRECT_HANDLER_ENABLED
18+
)

tests/unit/test_graph_client_factory.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
# ------------------------------------
55
import httpx
66
import pytest
7-
from kiota_http.middleware import AsyncKiotaTransport, MiddlewarePipeline, RedirectHandler
7+
from kiota_http.middleware import MiddlewarePipeline, RedirectHandler
88

99
from msgraph_core import APIVersion, GraphClientFactory, NationalClouds
10-
from msgraph_core.middleware.telemetry import GraphTelemetryHandler
10+
from msgraph_core.middleware import AsyncGraphTransport, GraphTelemetryHandler
1111

1212

1313
def test_create_with_default_middleware():
1414
"""Test creation of GraphClient using default middleware"""
1515
client = GraphClientFactory.create_with_default_middleware()
1616

1717
assert isinstance(client, httpx.AsyncClient)
18-
assert isinstance(client._transport, AsyncKiotaTransport)
18+
assert isinstance(client._transport, AsyncGraphTransport)
1919
pipeline = client._transport.pipeline
2020
assert isinstance(pipeline, MiddlewarePipeline)
2121
assert isinstance(pipeline._first_middleware, RedirectHandler)
@@ -30,7 +30,7 @@ def test_create_with_custom_middleware():
3030
client = GraphClientFactory.create_with_custom_middleware(middleware=middleware)
3131

3232
assert isinstance(client, httpx.AsyncClient)
33-
assert isinstance(client._transport, AsyncKiotaTransport)
33+
assert isinstance(client._transport, AsyncGraphTransport)
3434
pipeline = client._transport.pipeline
3535
assert isinstance(pipeline._first_middleware, GraphTelemetryHandler)
3636

tests/unit/test_graph_telemetry_handler.py

-11
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,11 @@
1010
import pytest
1111

1212
from msgraph_core import SDK_VERSION, APIVersion, NationalClouds
13-
from msgraph_core._enums import FeatureUsageFlag
1413
from msgraph_core.middleware import GraphRequestContext, GraphTelemetryHandler
1514

1615
BASE_URL = NationalClouds.Global + '/' + APIVersion.v1
1716

1817

19-
def test_set_request_context_and_feature_usage(mock_request, mock_transport):
20-
telemetry_handler = GraphTelemetryHandler()
21-
telemetry_handler.set_request_context_and_feature_usage(mock_request, mock_transport)
22-
23-
assert hasattr(mock_request, 'context')
24-
assert mock_request.context.feature_usage == hex(
25-
FeatureUsageFlag.RETRY_HANDLER_ENABLED | FeatureUsageFlag.REDIRECT_HANDLER_ENABLED
26-
)
27-
28-
2918
def test_is_graph_url(mock_graph_request):
3019
"""
3120
Test method that checks whether a request url is a graph endpoint

0 commit comments

Comments
 (0)