4545 ItemsListReq ,
4646 NonStreamingReq ,
4747 Page ,
48+ SDKHiddenContextItem ,
4849 StreamingReq ,
50+ StreamOptions ,
51+ StreamOptionsEvent ,
4952 Thread ,
5053 ThreadCreatedEvent ,
5154 ThreadItem ,
7376 WidgetRootUpdated ,
7477 WidgetStreamingTextValueDelta ,
7578 WorkflowItem ,
79+ WorkflowTaskAdded ,
80+ WorkflowTaskUpdated ,
7681 is_streaming_req ,
7782)
7883from .version import __version__
@@ -315,24 +320,55 @@ def action(
315320 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
316321 )
317322
323+ def get_stream_options (
324+ self , thread : ThreadMetadata , context : TContext
325+ ) -> StreamOptions :
326+ """
327+ Return stream-level runtime options. Allows the user to cancel the stream by default.
328+ Override this method to customize behavior.
329+ """
330+ return StreamOptions (allow_cancel = True )
331+
318332 async def handle_stream_cancelled (
319333 self ,
320334 thread : ThreadMetadata ,
321335 pending_items : list [ThreadItem ],
322336 context : TContext ,
323337 ):
324338 """Perform custom cleanup / stop inference when a stream is cancelled.
339+ Updates you make here will not be reflected in the UI until a reload.
340+
341+ The default implementation persists any non-empty pending assistant messages
342+ to the thread but does not auto-save pending widget items or workflow items.
325343
326344 Args:
327345 thread: The thread that was being processed.
328346 pending_items: Items that were not done streaming at cancellation time.
329- By default, already-streamed assistant messages, widgets, and workflows are
330- saved to the store during error handling prior to this method being called.
331- If you want to remove them from the thread, you can do so here.
332- (Updates you make here will not be reflected in the UI until a reload.)
333347 context: Arbitrary per-request context provided by the caller.
334348 """
335- pass
349+ pending_assistant_message_items : list [AssistantMessageItem ] = [
350+ item for item in pending_items if isinstance (item , AssistantMessageItem )
351+ ]
352+ for item in pending_assistant_message_items :
353+ is_empty = len (item .content ) == 0 or all (
354+ (not content .text .strip ()) for content in item .content
355+ )
356+ if not is_empty :
357+ await self .store .add_thread_item (thread .id , item , context = context )
358+
359+ # Add a hidden context item to the thread to indicate that the stream was cancelled.
360+ # Otherwise, depending on the timing of the cancellation, subsequent responses may
361+ # attempt to continue the cancelled response.
362+ await self .store .add_thread_item (
363+ thread .id ,
364+ SDKHiddenContextItem (
365+ thread_id = thread .id ,
366+ created_at = datetime .now (),
367+ id = self .store .generate_item_id ("sdk_hidden_context" , thread , context ),
368+ content = "The user cancelled the stream." ,
369+ ),
370+ context = context ,
371+ )
336372
337373 async def process (
338374 self , request : str | bytes | bytearray , context : TContext
@@ -633,6 +669,11 @@ async def _process_events(
633669 ) -> AsyncIterator [ThreadStreamEvent ]:
634670 await asyncio .sleep (0 ) # allow the response to start streaming
635671
672+ # Send initial stream options
673+ yield StreamOptionsEvent (
674+ stream_options = self .get_stream_options (thread , context )
675+ )
676+
636677 last_thread = thread .model_copy (deep = True )
637678
638679 # Keep track of items that were streamed but not yet saved
@@ -662,12 +703,10 @@ async def _process_events(
662703 )
663704 pending_items .pop (event .item .id , None )
664705 case ThreadItemUpdatedEvent ():
665- # Keep the pending assistant message item up to date
666- # so that we can persist already-streamed partial content
667- # if the stream is cancelled.
668- self ._update_pending_assistant_message_items (
669- pending_items , event
670- )
706+ # Keep pending assistant message and workflow items up to date
707+ # so that we have a reference to the latest version of these pending items
708+ # when the stream is cancelled.
709+ self ._update_pending_items (pending_items , event )
671710
672711 # special case - don't send hidden context items back to the client
673712 should_swallow_event = isinstance (
@@ -690,10 +729,6 @@ async def _process_events(
690729 await self .store .save_thread (thread , context = context )
691730 yield ThreadUpdatedEvent (thread = self ._to_thread_response (thread ))
692731 except asyncio .CancelledError :
693- # When a stream is cancelled, whether it's a deliberate stop request or due to a network issue,
694- # save already-streamed items to the thread.
695- await self ._persist_cancelled_stream_state (thread , pending_items , context )
696- # Allow custom cleanup.
697732 await self .handle_stream_cancelled (
698733 thread , list (pending_items .values ()), context
699734 )
@@ -721,43 +756,6 @@ async def _process_events(
721756 await self .store .save_thread (thread , context = context )
722757 yield ThreadUpdatedEvent (thread = self ._to_thread_response (thread ))
723758
724- async def _persist_cancelled_stream_state (
725- self ,
726- thread : ThreadMetadata ,
727- pending_items : dict [str , ThreadItem ],
728- context : TContext ,
729- ):
730- # Persist any streamed items that the UI should keep when cancellation happens mid-stream.
731- for item in pending_items .values ():
732- if isinstance (
733- item , (AssistantMessageItem , WidgetItem , WorkflowItem )
734- ) and not self ._is_streamed_item_empty (item ):
735- await self .store .add_thread_item (thread .id , item , context = context )
736-
737- await self .store .add_thread_item (
738- thread .id ,
739- HiddenContextItem (
740- thread_id = thread .id ,
741- created_at = datetime .now (),
742- id = self .store .generate_item_id ("hidden_context" , thread , context ),
743- content = "SYSTEM: The user cancelled the stream." ,
744- ),
745- context = context ,
746- )
747-
748- def _is_streamed_item_empty (
749- self , item : AssistantMessageItem | WorkflowItem | WidgetItem
750- ) -> bool :
751- if isinstance (item , AssistantMessageItem ):
752- return len (item .content ) == 0 or all (
753- (not content .text .strip ()) for content in item .content
754- )
755- if isinstance (item , WorkflowItem ):
756- return len (item .workflow .tasks ) == 0 and item .workflow .summary is None
757-
758- # Assume all WidgetItems are not empty
759- return False
760-
761759 def _apply_assistant_message_update (
762760 self ,
763761 item : AssistantMessageItem ,
@@ -788,27 +786,38 @@ def _apply_assistant_message_update(
788786 updated .content [update .content_index ] = update .content
789787 return updated
790788
791- def _update_pending_assistant_message_items (
789+ def _update_pending_items (
792790 self ,
793791 pending_items : dict [str , ThreadItem ],
794792 event : ThreadItemUpdatedEvent ,
795793 ):
796- if not isinstance (
797- event .update ,
798- (
799- AssistantMessageContentPartAdded ,
800- AssistantMessageContentPartTextDelta ,
801- AssistantMessageContentPartAnnotationAdded ,
802- AssistantMessageContentPartDone ,
803- ),
804- ):
805- return
806-
807794 updated_item = pending_items .get (event .item_id )
808- if updated_item and isinstance (updated_item , AssistantMessageItem ):
809- pending_items [updated_item .id ] = self ._apply_assistant_message_update (
810- updated_item , event .update
811- )
795+ update = event .update
796+ match updated_item :
797+ case AssistantMessageItem ():
798+ if isinstance (
799+ update ,
800+ (
801+ AssistantMessageContentPartAdded ,
802+ AssistantMessageContentPartTextDelta ,
803+ AssistantMessageContentPartAnnotationAdded ,
804+ AssistantMessageContentPartDone ,
805+ ),
806+ ):
807+ pending_items [updated_item .id ] = (
808+ self ._apply_assistant_message_update (updated_item , update )
809+ )
810+ case WorkflowItem ():
811+ if isinstance (update , (WorkflowTaskUpdated , WorkflowTaskAdded )):
812+ match update :
813+ case WorkflowTaskUpdated ():
814+ updated_item .workflow .tasks [update .task_index ] = update .task
815+ case WorkflowTaskAdded ():
816+ updated_item .workflow .tasks .append (update .task )
817+
818+ pending_items [updated_item .id ] = updated_item
819+ case _:
820+ pass
812821
813822 async def _build_user_message_item (
814823 self , input : UserMessageInput , thread : ThreadMetadata , context : TContext
0 commit comments