diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py b/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py index 8de632f97..eddd8d417 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py @@ -13,6 +13,7 @@ limitations under the License. """ +import warnings from typing import Callable, Dict, List, Optional, Tuple, Union from cloudevents.sdk.event import v1 # type: ignore @@ -29,6 +30,8 @@ from dapr.proto.runtime.v1.appcallback_pb2 import ( BindingEventRequest, JobEventRequest, + TopicEventBulkRequest, + TopicEventBulkResponse, TopicEventRequest, ) @@ -276,3 +279,81 @@ def OnJobEventAlpha1(self, request: JobEventRequest, context): # Return empty response return appcallback_v1.JobEventResponse() + + def _handle_bulk_topic_event( + self, request: TopicEventBulkRequest, context + ) -> TopicEventBulkResponse: + """Process bulk topic event request - routes each entry to the appropriate topic handler.""" + topic_key = request.pubsub_name + DELIMITER + request.topic + DELIMITER + request.path + no_validation_key = request.pubsub_name + DELIMITER + request.path + + if topic_key not in self._topic_map and no_validation_key not in self._topic_map: + return None # we don't have a handler + + handler_key = topic_key if topic_key in self._topic_map else no_validation_key + cb = self._topic_map[handler_key] # callback + + statuses = [] + for entry in request.entries: + entry_id = entry.entry_id + try: + # Build event from entry & send req with many entries + event = v1.Event() + extensions = dict() + if entry.HasField('cloud_event') and entry.cloud_event: + ce = entry.cloud_event + event.SetEventType(ce.type) + event.SetEventID(ce.id) + event.SetSource(ce.source) + event.SetData(ce.data) + event.SetContentType(ce.data_content_type) + if ce.extensions: + for k, v in ce.extensions.items(): + extensions[k] = v + else: + event.SetEventID(entry_id) + event.SetData(entry.bytes if entry.HasField('bytes') else b'') + event.SetContentType(entry.content_type or '') + event.SetSubject(request.topic) + if entry.metadata: + for k, v in entry.metadata.items(): + extensions[k] = v + for k, v in context.invocation_metadata(): + extensions['_metadata_' + k] = v + if extensions: + event.SetExtensions(extensions) + + response = cb(event) # invoke app registered handler and send event + if isinstance(response, TopicEventResponse): + status = response.status.value + else: + status = appcallback_v1.TopicEventResponse.TopicEventResponseStatus.SUCCESS + except Exception: + status = appcallback_v1.TopicEventResponse.TopicEventResponseStatus.RETRY + statuses.append( + appcallback_v1.TopicEventBulkResponseEntry(entry_id=entry_id, status=status) + ) + return appcallback_v1.TopicEventBulkResponse(statuses=statuses) + + def OnBulkTopicEvent(self, request: TopicEventBulkRequest, context): + """Subscribes bulk events from Pubsub""" + response = self._handle_bulk_topic_event(request, context) + if response is None: + context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore + raise NotImplementedError(f'bulk topic {request.topic} is not implemented!') + return response + + def OnBulkTopicEventAlpha1(self, request: TopicEventBulkRequest, context): + """Subscribes bulk events from Pubsub. + Deprecated: Use OnBulkTopicEvent instead. + """ + warnings.warn( + 'OnBulkTopicEventAlpha1 is deprecated. Use OnBulkTopicEvent instead.', + DeprecationWarning, + stacklevel=2, + ) + response = self._handle_bulk_topic_event(request, context) + if response is None: + context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore + raise NotImplementedError(f'bulk topic {request.topic} is not implemented!') + return response diff --git a/ext/dapr-ext-grpc/tests/test_servicier.py b/ext/dapr-ext-grpc/tests/test_servicier.py index 325d9b6d6..ca6e2f9bc 100644 --- a/ext/dapr-ext-grpc/tests/test_servicier.py +++ b/ext/dapr-ext-grpc/tests/test_servicier.py @@ -183,6 +183,70 @@ def test_non_registered_topic(self): ) +class BulkTopicEventTests(unittest.TestCase): + def setUp(self): + self._servicer = _CallbackServicer() + self._topic_method = Mock() + self._topic_method.return_value = TopicEventResponse('success') + self._servicer.register_topic('pubsub1', 'topic1', self._topic_method, {'session': 'key'}) + + self.fake_context = MagicMock() + self.fake_context.invocation_metadata.return_value = ( + ('key1', 'value1'), + ('key2', 'value1'), + ) + + def test_on_bulk_topic_event(self): + from dapr.proto.runtime.v1.appcallback_pb2 import ( + TopicEventBulkRequest, + TopicEventBulkRequestEntry, + ) + + entry1 = TopicEventBulkRequestEntry( + entry_id='entry1', + bytes=b'hello', + content_type='text/plain', + ) + entry2 = TopicEventBulkRequestEntry( + entry_id='entry2', + bytes=b'{"a": 1}', + content_type='application/json', + ) + request = TopicEventBulkRequest( + id='bulk1', + pubsub_name='pubsub1', + topic='topic1', + path='', + entries=[entry1, entry2], + ) + resp = self._servicer.OnBulkTopicEvent(request, self.fake_context) + self.assertEqual(2, len(resp.statuses)) + self.assertEqual('entry1', resp.statuses[0].entry_id) + self.assertEqual('entry2', resp.statuses[1].entry_id) + self.assertEqual( + appcallback_v1.TopicEventResponse.TopicEventResponseStatus.SUCCESS, + resp.statuses[0].status, + ) + self.assertEqual(2, self._topic_method.call_count) + + def test_on_bulk_topic_event_non_registered(self): + from dapr.proto.runtime.v1.appcallback_pb2 import ( + TopicEventBulkRequest, + TopicEventBulkRequestEntry, + ) + + entry = TopicEventBulkRequestEntry(entry_id='entry1', bytes=b'hello') + request = TopicEventBulkRequest( + id='bulk1', + pubsub_name='pubsub1', + topic='unknown_topic', + path='', + entries=[entry], + ) + with self.assertRaises(NotImplementedError): + self._servicer.OnBulkTopicEvent(request, self.fake_context) + + class BindingTests(unittest.TestCase): def setUp(self): self._servicer = _CallbackServicer() diff --git a/tests/clients/fake_dapr_server.py b/tests/clients/fake_dapr_server.py index 73cd22fd9..76c26cd99 100644 --- a/tests/clients/fake_dapr_server.py +++ b/tests/clients/fake_dapr_server.py @@ -155,6 +155,14 @@ def PublishEvent(self, request, context): context.set_trailing_metadata(trailers) return empty_pb2.Empty() + def BulkPublishEvent(self, request, context): + self.check_for_exception(context) + return api_v1.BulkPublishResponse() + + def BulkPublishEventAlpha1(self, request, context): + self.check_for_exception(context) + return api_v1.BulkPublishResponse() + def SubscribeTopicEventsAlpha1(self, request_iterator, context): for request in request_iterator: if request.HasField('initial_request'): diff --git a/tests/clients/test_dapr_grpc_client.py b/tests/clients/test_dapr_grpc_client.py index a52bbeb0d..6902c7773 100644 --- a/tests/clients/test_dapr_grpc_client.py +++ b/tests/clients/test_dapr_grpc_client.py @@ -271,6 +271,27 @@ def test_publish_error(self): data=111, ) + def test_publish_bulk_event(self): + dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.grpc_port}') + resp = dapr.publish_bulk_event( + pubsub_name='pubsub', + topic_name='example', + events=[ + {'entry_id': '1', 'event': b'{"key": "value1"}'}, + {'entry_id': '2', 'event': b'{"key": "value2"}'}, + ], + ) + self.assertEqual(0, len(resp.failed_entries)) + + def test_publish_bulk_event_invalid_event_type(self): + dapr = DaprGrpcClient(f'{self.scheme}localhost:{self.grpc_port}') + with self.assertRaisesRegex(ValueError, 'invalid type for event data'): + dapr.publish_bulk_event( + pubsub_name='pubsub', + topic_name='example', + events=[{'entry_id': '1', 'event': 123}], + ) + def test_subscribe_topic(self): # The fake server we're using sends two messages and then closes the stream # The client should be able to read both messages, handle the stream closure and reconnect diff --git a/tests/clients/test_dapr_grpc_client_async.py b/tests/clients/test_dapr_grpc_client_async.py index 245c384dd..8406e3813 100644 --- a/tests/clients/test_dapr_grpc_client_async.py +++ b/tests/clients/test_dapr_grpc_client_async.py @@ -266,6 +266,27 @@ async def test_publish_error(self): data=111, ) + async def test_publish_bulk_event(self): + dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.grpc_port}') + resp = await dapr.publish_bulk_event( + pubsub_name='pubsub', + topic_name='example', + events=[ + {'entry_id': '1', 'event': b'{"key": "value1"}'}, + {'entry_id': '2', 'event': b'{"key": "value2"}'}, + ], + ) + self.assertEqual(0, len(resp.failed_entries)) + + async def test_publish_bulk_event_invalid_event_type(self): + dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.grpc_port}') + with self.assertRaisesRegex(ValueError, 'invalid type for event data'): + await dapr.publish_bulk_event( + pubsub_name='pubsub', + topic_name='example', + events=[{'entry_id': '1', 'event': 123}], + ) + async def test_subscribe_topic(self): # The fake server we're using sends two messages and then closes the stream # The client should be able to read both messages, handle the stream closure and reconnect