Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
limitations under the License.
"""

import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union

from cloudevents.sdk.event import v1 # type: ignore
Expand All @@ -29,6 +30,8 @@
from dapr.proto.runtime.v1.appcallback_pb2 import (
BindingEventRequest,
JobEventRequest,
TopicEventBulkRequest,
TopicEventBulkResponse,
TopicEventRequest,
)

Expand Down Expand Up @@ -276,3 +279,81 @@ def OnJobEventAlpha1(self, request: JobEventRequest, context):

# Return empty response
return appcallback_v1.JobEventResponse()

def _handle_bulk_topic_event(
self, request: TopicEventBulkRequest, context
) -> TopicEventBulkResponse:
"""Process bulk topic event request - routes each entry to the appropriate topic handler."""
topic_key = request.pubsub_name + DELIMITER + request.topic + DELIMITER + request.path
no_validation_key = request.pubsub_name + DELIMITER + request.path

if topic_key not in self._topic_map and no_validation_key not in self._topic_map:
return None # we don't have a handler

handler_key = topic_key if topic_key in self._topic_map else no_validation_key
cb = self._topic_map[handler_key] # callback

statuses = []
for entry in request.entries:
entry_id = entry.entry_id
try:
# Build event from entry & send req with many entries
event = v1.Event()
extensions = dict()
if entry.HasField('cloud_event') and entry.cloud_event:
ce = entry.cloud_event
event.SetEventType(ce.type)
event.SetEventID(ce.id)
event.SetSource(ce.source)
event.SetData(ce.data)
event.SetContentType(ce.data_content_type)
if ce.extensions:
for k, v in ce.extensions.items():
extensions[k] = v
else:
event.SetEventID(entry_id)
event.SetData(entry.bytes if entry.HasField('bytes') else b'')
event.SetContentType(entry.content_type or '')
event.SetSubject(request.topic)
if entry.metadata:
for k, v in entry.metadata.items():
extensions[k] = v
for k, v in context.invocation_metadata():
extensions['_metadata_' + k] = v
if extensions:
event.SetExtensions(extensions)

response = cb(event) # invoke app registered handler and send event
if isinstance(response, TopicEventResponse):
status = response.status.value
else:
status = appcallback_v1.TopicEventResponse.TopicEventResponseStatus.SUCCESS
except Exception:
status = appcallback_v1.TopicEventResponse.TopicEventResponseStatus.RETRY
statuses.append(
appcallback_v1.TopicEventBulkResponseEntry(entry_id=entry_id, status=status)
)
return appcallback_v1.TopicEventBulkResponse(statuses=statuses)

def OnBulkTopicEvent(self, request: TopicEventBulkRequest, context):
"""Subscribes bulk events from Pubsub"""
response = self._handle_bulk_topic_event(request, context)
if response is None:
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
raise NotImplementedError(f'bulk topic {request.topic} is not implemented!')
return response

def OnBulkTopicEventAlpha1(self, request: TopicEventBulkRequest, context):
"""Subscribes bulk events from Pubsub.
Deprecated: Use OnBulkTopicEvent instead.
"""
warnings.warn(
'OnBulkTopicEventAlpha1 is deprecated. Use OnBulkTopicEvent instead.',
DeprecationWarning,
stacklevel=2,
)
response = self._handle_bulk_topic_event(request, context)
if response is None:
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
raise NotImplementedError(f'bulk topic {request.topic} is not implemented!')
return response
64 changes: 64 additions & 0 deletions ext/dapr-ext-grpc/tests/test_servicier.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,70 @@ def test_non_registered_topic(self):
)


class BulkTopicEventTests(unittest.TestCase):
def setUp(self):
self._servicer = _CallbackServicer()
self._topic_method = Mock()
self._topic_method.return_value = TopicEventResponse('success')
self._servicer.register_topic('pubsub1', 'topic1', self._topic_method, {'session': 'key'})

self.fake_context = MagicMock()
self.fake_context.invocation_metadata.return_value = (
('key1', 'value1'),
('key2', 'value1'),
)

def test_on_bulk_topic_event(self):
from dapr.proto.runtime.v1.appcallback_pb2 import (
TopicEventBulkRequest,
TopicEventBulkRequestEntry,
)

entry1 = TopicEventBulkRequestEntry(
entry_id='entry1',
bytes=b'hello',
content_type='text/plain',
)
entry2 = TopicEventBulkRequestEntry(
entry_id='entry2',
bytes=b'{"a": 1}',
content_type='application/json',
)
request = TopicEventBulkRequest(
id='bulk1',
pubsub_name='pubsub1',
topic='topic1',
path='',
entries=[entry1, entry2],
)
resp = self._servicer.OnBulkTopicEvent(request, self.fake_context)
self.assertEqual(2, len(resp.statuses))
self.assertEqual('entry1', resp.statuses[0].entry_id)
self.assertEqual('entry2', resp.statuses[1].entry_id)
self.assertEqual(
appcallback_v1.TopicEventResponse.TopicEventResponseStatus.SUCCESS,
resp.statuses[0].status,
)
self.assertEqual(2, self._topic_method.call_count)

def test_on_bulk_topic_event_non_registered(self):
from dapr.proto.runtime.v1.appcallback_pb2 import (
TopicEventBulkRequest,
TopicEventBulkRequestEntry,
)

entry = TopicEventBulkRequestEntry(entry_id='entry1', bytes=b'hello')
request = TopicEventBulkRequest(
id='bulk1',
pubsub_name='pubsub1',
topic='unknown_topic',
path='',
entries=[entry],
)
with self.assertRaises(NotImplementedError):
self._servicer.OnBulkTopicEvent(request, self.fake_context)


class BindingTests(unittest.TestCase):
def setUp(self):
self._servicer = _CallbackServicer()
Expand Down
8 changes: 8 additions & 0 deletions tests/clients/fake_dapr_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ def PublishEvent(self, request, context):
context.set_trailing_metadata(trailers)
return empty_pb2.Empty()

def BulkPublishEvent(self, request, context):
self.check_for_exception(context)
return api_v1.BulkPublishResponse()

def BulkPublishEventAlpha1(self, request, context):
self.check_for_exception(context)
return api_v1.BulkPublishResponse()

def SubscribeTopicEventsAlpha1(self, request_iterator, context):
for request in request_iterator:
if request.HasField('initial_request'):
Expand Down
21 changes: 21 additions & 0 deletions tests/clients/test_dapr_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,27 @@ def test_publish_error(self):
data=111,
)

def test_publish_bulk_event(self):
dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.grpc_port}')
resp = dapr.publish_bulk_event(
pubsub_name='pubsub',
topic_name='example',
events=[
{'entry_id': '1', 'event': b'{"key": "value1"}'},
{'entry_id': '2', 'event': b'{"key": "value2"}'},
],
)
self.assertEqual(0, len(resp.failed_entries))

def test_publish_bulk_event_invalid_event_type(self):
dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.grpc_port}')
with self.assertRaisesRegex(ValueError, 'invalid type for event data'):
dapr.publish_bulk_event(
pubsub_name='pubsub',
topic_name='example',
events=[{'entry_id': '1', 'event': 123}],
)

def test_subscribe_topic(self):
# The fake server we're using sends two messages and then closes the stream
# The client should be able to read both messages, handle the stream closure and reconnect
Expand Down
21 changes: 21 additions & 0 deletions tests/clients/test_dapr_grpc_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,27 @@ async def test_publish_error(self):
data=111,
)

async def test_publish_bulk_event(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.grpc_port}')
resp = await dapr.publish_bulk_event(
pubsub_name='pubsub',
topic_name='example',
events=[
{'entry_id': '1', 'event': b'{"key": "value1"}'},
{'entry_id': '2', 'event': b'{"key": "value2"}'},
],
)
self.assertEqual(0, len(resp.failed_entries))

async def test_publish_bulk_event_invalid_event_type(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.grpc_port}')
with self.assertRaisesRegex(ValueError, 'invalid type for event data'):
await dapr.publish_bulk_event(
pubsub_name='pubsub',
topic_name='example',
events=[{'entry_id': '1', 'event': 123}],
)

async def test_subscribe_topic(self):
# The fake server we're using sends two messages and then closes the stream
# The client should be able to read both messages, handle the stream closure and reconnect
Expand Down
Loading