Skip to content

Commit 4b676e2

Browse files
author
Vasilis Tsiolkas
committed
Adds support for gRPC aio server that allows for async handlers
Signed-off-by: Vasilis Tsiolkas <[email protected]>
1 parent 5882d52 commit 4b676e2

File tree

9 files changed

+911
-0
lines changed

9 files changed

+911
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Copyright 2023 The Dapr Authors
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
from dapr.clients.grpc._request import InvokeMethodRequest, BindingRequest, JobEvent
17+
from dapr.clients.grpc._response import InvokeMethodResponse, TopicEventResponse
18+
from dapr.clients.grpc._jobs import Job, FailurePolicy, DropFailurePolicy, ConstantFailurePolicy
19+
20+
from dapr.ext.grpc.aio.app import App, Rule # type:ignore
21+
22+
23+
__all__ = [
24+
'App',
25+
'Rule',
26+
'InvokeMethodRequest',
27+
'InvokeMethodResponse',
28+
'BindingRequest',
29+
'TopicEventResponse',
30+
'Job',
31+
'JobEvent',
32+
'FailurePolicy',
33+
'DropFailurePolicy',
34+
'ConstantFailurePolicy',
35+
]
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import grpc
2+
from typing import Callable, Optional
3+
4+
from dapr.proto import appcallback_service_v1
5+
from dapr.proto.runtime.v1.appcallback_pb2 import HealthCheckResponse
6+
7+
HealthCheckCallable = Optional[Callable[[], None]]
8+
9+
10+
class _AioHealthCheckServicer(appcallback_service_v1.AppCallbackHealthCheckServicer):
11+
"""The implementation of HealthCheck Server.
12+
13+
:class:`App` provides useful decorators to register method, topic, input bindings.
14+
"""
15+
16+
def __init__(self):
17+
self._health_check_cb: Optional[HealthCheckCallable] = None
18+
19+
def register_health_check(self, cb: HealthCheckCallable) -> None:
20+
if not cb:
21+
raise ValueError('health check callback must be defined')
22+
self._health_check_cb = cb
23+
24+
async def HealthCheck(self, request, context):
25+
"""Health check."""
26+
27+
if not self._health_check_cb:
28+
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
29+
context.set_details('Method not implemented!')
30+
raise NotImplementedError('Method not implemented!')
31+
await self._health_check_cb()
32+
return HealthCheckResponse()
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Copyright 2023 The Dapr Authors
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
import grpc.aio
16+
17+
from cloudevents.sdk.event import v1 # type: ignore
18+
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union
19+
20+
from google.protobuf import empty_pb2
21+
from google.protobuf.message import Message as GrpcMessage
22+
from google.protobuf.struct_pb2 import Struct
23+
24+
from dapr.proto import appcallback_service_v1, common_v1, appcallback_v1
25+
from dapr.proto.runtime.v1.appcallback_pb2 import (
26+
TopicEventRequest,
27+
BindingEventRequest,
28+
JobEventRequest,
29+
)
30+
from dapr.proto.common.v1.common_pb2 import InvokeRequest
31+
from dapr.clients.base import DEFAULT_JSON_CONTENT_TYPE
32+
from dapr.clients.grpc._request import InvokeMethodRequest, BindingRequest, JobEvent
33+
from dapr.clients.grpc._response import InvokeMethodResponse, TopicEventResponse
34+
35+
InvokeMethodCallable = Callable[
36+
[InvokeMethodRequest], Awaitable[Union[str, bytes, InvokeMethodResponse]]
37+
]
38+
TopicSubscribeCallable = Callable[[v1.Event], Awaitable[Optional[TopicEventResponse]]]
39+
BindingCallable = Callable[[BindingRequest], Awaitable[None]]
40+
JobEventCallable = Callable[[JobEvent], Awaitable[None]]
41+
42+
DELIMITER = ':'
43+
44+
45+
class Rule:
46+
def __init__(self, match: str, priority: int) -> None:
47+
self.match = match
48+
self.priority = priority
49+
50+
51+
class _RegisteredSubscription:
52+
def __init__(
53+
self,
54+
subscription: appcallback_v1.TopicSubscription,
55+
rules: List[Tuple[int, appcallback_v1.TopicRule]],
56+
):
57+
self.subscription = subscription
58+
self.rules = rules
59+
60+
61+
class _AioCallbackServicer(
62+
appcallback_service_v1.AppCallbackServicer, appcallback_service_v1.AppCallbackAlphaServicer
63+
):
64+
"""The asyncio-native implementation of the AppCallback Server.
65+
66+
This internal class implements application server and provides helpers to register
67+
method, topic, and input bindings. It implements the routing handling logic to route
68+
mulitple methods, topics, and bindings.
69+
70+
:class:`App` provides useful decorators to register method, topic, input bindings.
71+
"""
72+
73+
def __init__(self):
74+
self._invoke_method_map: Dict[str, InvokeMethodCallable] = {}
75+
self._topic_map: Dict[str, TopicSubscribeCallable] = {}
76+
self._binding_map: Dict[str, BindingCallable] = {}
77+
self._job_event_map: Dict[str, JobEventCallable] = {}
78+
79+
self._registered_topics_map: Dict[str, _RegisteredSubscription] = {}
80+
self._registered_topics: List[appcallback_v1.TopicSubscription] = []
81+
self._registered_bindings: List[str] = []
82+
83+
def register_method(self, method: str, cb: InvokeMethodCallable) -> None:
84+
"""Registers method for service invocation."""
85+
if method in self._invoke_method_map:
86+
raise ValueError(f'{method} is already registered')
87+
self._invoke_method_map[method] = cb
88+
89+
def register_topic(
90+
self,
91+
pubsub_name: str,
92+
topic: str,
93+
cb: TopicSubscribeCallable,
94+
metadata: Optional[Dict[str, str]],
95+
dead_letter_topic: Optional[str] = None,
96+
rule: Optional[Rule] = None,
97+
disable_topic_validation: Optional[bool] = False,
98+
) -> None:
99+
"""Registers topic subscription for pubsub."""
100+
if not disable_topic_validation:
101+
topic_key = pubsub_name + DELIMITER + topic
102+
else:
103+
topic_key = pubsub_name
104+
pubsub_topic = topic_key + DELIMITER
105+
if rule is not None:
106+
path = getattr(cb, '__name__', rule.match)
107+
pubsub_topic = pubsub_topic + path
108+
if pubsub_topic in self._topic_map:
109+
raise ValueError(f'{topic} is already registered with {pubsub_name}')
110+
self._topic_map[pubsub_topic] = cb
111+
112+
registered_topic = self._registered_topics_map.get(topic_key)
113+
sub: appcallback_v1.TopicSubscription = appcallback_v1.TopicSubscription()
114+
rules: List[Tuple[int, appcallback_v1.TopicRule]] = []
115+
if not registered_topic:
116+
sub = appcallback_v1.TopicSubscription(
117+
pubsub_name=pubsub_name,
118+
topic=topic,
119+
metadata=metadata,
120+
routes=appcallback_v1.TopicRoutes(),
121+
)
122+
if dead_letter_topic:
123+
sub.dead_letter_topic = dead_letter_topic
124+
registered_topic = _RegisteredSubscription(sub, rules)
125+
self._registered_topics_map[topic_key] = registered_topic
126+
self._registered_topics.append(sub)
127+
128+
sub = registered_topic.subscription
129+
rules = registered_topic.rules
130+
131+
if rule:
132+
path = getattr(cb, '__name__', rule.match)
133+
rules.append((rule.priority, appcallback_v1.TopicRule(match=rule.match, path=path)))
134+
rules.sort(key=lambda x: x[0])
135+
rs = [rule for id, rule in rules]
136+
del sub.routes.rules[:]
137+
sub.routes.rules.extend(rs)
138+
139+
def register_binding(self, name: str, cb: BindingCallable) -> None:
140+
"""Registers input bindings."""
141+
if name in self._binding_map:
142+
raise ValueError(f'{name} is already registered')
143+
self._binding_map[name] = cb
144+
self._registered_bindings.append(name)
145+
146+
def register_job_event(self, name: str, cb: JobEventCallable) -> None:
147+
"""Registers job event handler.
148+
149+
Args:
150+
name (str): The name of the job to handle events for.
151+
cb (JobEventCallable): The callback function to handle job events.
152+
"""
153+
if name in self._job_event_map:
154+
raise ValueError(f'Job event handler for {name} is already registered')
155+
self._job_event_map[name] = cb
156+
157+
async def OnInvoke(self, request: InvokeRequest, context):
158+
"""Invokes service method with InvokeRequest."""
159+
if request.method not in self._invoke_method_map:
160+
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
161+
raise NotImplementedError(f'{request.method} method not implemented!')
162+
163+
req = InvokeMethodRequest(request.data, request.content_type)
164+
req.metadata = context.invocation_metadata()
165+
resp = await self._invoke_method_map[request.method](req)
166+
167+
if not resp:
168+
return common_v1.InvokeResponse()
169+
170+
resp_data = InvokeMethodResponse()
171+
if isinstance(resp, (bytes, str)):
172+
resp_data.set_data(resp)
173+
resp_data.content_type = DEFAULT_JSON_CONTENT_TYPE
174+
elif isinstance(resp, GrpcMessage):
175+
resp_data.set_data(resp)
176+
elif isinstance(resp, InvokeMethodResponse):
177+
resp_data = resp
178+
else:
179+
context.set_code(grpc.StatusCode.OUT_OF_RANGE)
180+
context.set_details(f'{type(resp)} is the invalid return type.')
181+
raise NotImplementedError(f'{request.method} method not implemented!')
182+
183+
if len(resp_data.get_headers()) > 0:
184+
context.send_initial_metadata(resp_data.get_headers())
185+
186+
content_type = ''
187+
if resp_data.content_type:
188+
content_type = resp_data.content_type
189+
190+
return common_v1.InvokeResponse(data=resp_data.proto, content_type=content_type)
191+
192+
async def ListTopicSubscriptions(self, request, context):
193+
"""Lists all topics subscribed by this app."""
194+
return appcallback_v1.ListTopicSubscriptionsResponse(subscriptions=self._registered_topics)
195+
196+
async def OnTopicEvent(self, request: TopicEventRequest, context):
197+
"""Subscribes events from Pubsub."""
198+
pubsub_topic = request.pubsub_name + DELIMITER + request.topic + DELIMITER + request.path
199+
no_validation_key = request.pubsub_name + DELIMITER + request.path
200+
201+
if pubsub_topic not in self._topic_map:
202+
if no_validation_key in self._topic_map:
203+
pubsub_topic = no_validation_key
204+
else:
205+
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
206+
raise NotImplementedError(f'topic {request.topic} is not implemented!')
207+
208+
customdata: Struct = request.extensions
209+
extensions = dict()
210+
for k, v in customdata.items():
211+
extensions[k] = v
212+
for k, v in context.invocation_metadata():
213+
extensions['_metadata_' + k] = v
214+
215+
event = v1.Event()
216+
event.SetEventType(request.type)
217+
event.SetEventID(request.id)
218+
event.SetSource(request.source)
219+
event.SetData(request.data)
220+
event.SetContentType(request.data_content_type)
221+
event.SetSubject(request.topic)
222+
event.SetExtensions(extensions)
223+
224+
response = await self._topic_map[pubsub_topic](event)
225+
if isinstance(response, TopicEventResponse):
226+
return appcallback_v1.TopicEventResponse(status=response.status.value)
227+
return empty_pb2.Empty()
228+
229+
async def ListInputBindings(self, request, context):
230+
"""Lists all input bindings subscribed by this app."""
231+
return appcallback_v1.ListInputBindingsResponse(bindings=self._registered_bindings)
232+
233+
async def OnBindingEvent(self, request: BindingEventRequest, context):
234+
"""Listens events from the input bindings
235+
User application can save the states or send the events to the output
236+
bindings optionally by returning BindingEventResponse.
237+
"""
238+
if request.name not in self._binding_map:
239+
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
240+
raise NotImplementedError(f'{request.name} binding not implemented!')
241+
242+
req = BindingRequest(request.data, dict(request.metadata))
243+
req.metadata = context.invocation_metadata()
244+
await self._binding_map[request.name](req)
245+
246+
# TODO: support output bindings options
247+
return appcallback_v1.BindingEventResponse()
248+
249+
async def OnJobEventAlpha1(self, request: JobEventRequest, context):
250+
"""Handles job events from Dapr runtime.
251+
252+
This method is called by Dapr when a scheduled job is triggered.
253+
It routes the job event to the appropriate registered handler based on the job name.
254+
255+
Args:
256+
request (JobEventRequest): The job event request from Dapr.
257+
context: The gRPC context.
258+
259+
Returns:
260+
appcallback_v1.JobEventResponse: Empty response indicating successful handling.
261+
"""
262+
job_name = request.name
263+
264+
if job_name not in self._job_event_map:
265+
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
266+
raise NotImplementedError(f'Job event handler for {job_name} not implemented!')
267+
268+
# Create a JobEvent object matching Go SDK's common.JobEvent
269+
# Extract raw data bytes from the Any proto (matching Go implementation)
270+
data_bytes = b''
271+
if request.HasField('data') and request.data.value:
272+
data_bytes = request.data.value
273+
274+
job_event = JobEvent(name=request.name, data=data_bytes)
275+
276+
# Call the registered handler with the JobEvent object
277+
await self._job_event_map[job_name](job_event)
278+
279+
# Return empty response
280+
return appcallback_v1.JobEventResponse()

0 commit comments

Comments
 (0)