Skip to content

Commit 5a7244b

Browse files
committed
use new type of hidden context item; add stream options
1 parent e29b09e commit 5a7244b

File tree

4 files changed

+142
-74
lines changed

4 files changed

+142
-74
lines changed

chatkit/agents.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
EndOfTurnItem,
5757
FileSource,
5858
HiddenContextItem,
59+
SDKHiddenContextItem,
5960
Task,
6061
TaskItem,
6162
ThoughtTask,
@@ -691,9 +692,6 @@ async def hidden_context_to_input(
691692
"""
692693
Convert a HiddenContextItem into input item(s) to send to the model.
693694
Required to override when HiddenContextItems with non-string content are used.
694-
695-
ChatKitServer may save HiddenContextItems with string content; make sure your override
696-
can also handle HiddenContextItems with string content.
697695
"""
698696
if not isinstance(item.content, str):
699697
raise NotImplementedError(
@@ -715,6 +713,29 @@ async def hidden_context_to_input(
715713
role="user",
716714
)
717715

716+
async def sdk_hidden_context_to_input(
717+
self, item: SDKHiddenContextItem
718+
) -> TResponseInputItem | list[TResponseInputItem] | None:
719+
"""
720+
Convert a SDKHiddenContextItem into input item to send to the model.
721+
This is used by the ChatKit Python SDK for storing additional context
722+
for internal operations; you shouldn't need to override this.
723+
"""
724+
text = (
725+
"Hidden ChatKit SDK context for the agent (not shown to the user):\n"
726+
f"<SDKHiddenContext>\n{item.content}\n</SDKHiddenContext>"
727+
)
728+
return Message(
729+
type="message",
730+
content=[
731+
ResponseInputTextParam(
732+
type="input_text",
733+
text=text,
734+
)
735+
],
736+
role="user",
737+
)
738+
718739
async def task_to_input(
719740
self, item: TaskItem
720741
) -> TResponseInputItem | list[TResponseInputItem] | None:
@@ -951,6 +972,9 @@ async def _thread_item_to_input_item(
951972
case HiddenContextItem():
952973
out = await self.hidden_context_to_input(item) or []
953974
return out if isinstance(out, list) else [out]
975+
case SDKHiddenContextItem():
976+
out = await self.sdk_hidden_context_to_input(item) or []
977+
return out if isinstance(out, list) else [out]
954978
case _:
955979
assert_never(item)
956980

chatkit/server.py

Lines changed: 77 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@
4545
ItemsListReq,
4646
NonStreamingReq,
4747
Page,
48+
SDKHiddenContextItem,
4849
StreamingReq,
50+
StreamOptions,
51+
StreamOptionsEvent,
4952
Thread,
5053
ThreadCreatedEvent,
5154
ThreadItem,
@@ -73,6 +76,8 @@
7376
WidgetRootUpdated,
7477
WidgetStreamingTextValueDelta,
7578
WorkflowItem,
79+
WorkflowTaskAdded,
80+
WorkflowTaskUpdated,
7681
is_streaming_req,
7782
)
7883
from .version import __version__
@@ -315,24 +320,55 @@ def action(
315320
"See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
316321
)
317322

323+
def get_stream_options(
324+
self, thread: ThreadMetadata, context: TContext
325+
) -> StreamOptions:
326+
"""
327+
Return stream-level runtime options. Allows the user to cancel the stream by default.
328+
Override this method to customize behavior.
329+
"""
330+
return StreamOptions(allow_cancel=True)
331+
318332
async def handle_stream_cancelled(
319333
self,
320334
thread: ThreadMetadata,
321335
pending_items: list[ThreadItem],
322336
context: TContext,
323337
):
324338
"""Perform custom cleanup / stop inference when a stream is cancelled.
339+
Updates you make here will not be reflected in the UI until a reload.
340+
341+
The default implementation persists any non-empty pending assistant messages
342+
to the thread but does not auto-save pending widget items or workflow items.
325343
326344
Args:
327345
thread: The thread that was being processed.
328346
pending_items: Items that were not done streaming at cancellation time.
329-
By default, already-streamed assistant messages, widgets, and workflows are
330-
saved to the store during error handling prior to this method being called.
331-
If you want to remove them from the thread, you can do so here.
332-
(Updates you make here will not be reflected in the UI until a reload.)
333347
context: Arbitrary per-request context provided by the caller.
334348
"""
335-
pass
349+
pending_assistant_message_items: list[AssistantMessageItem] = [
350+
item for item in pending_items if isinstance(item, AssistantMessageItem)
351+
]
352+
for item in pending_assistant_message_items:
353+
is_empty = len(item.content) == 0 or all(
354+
(not content.text.strip()) for content in item.content
355+
)
356+
if not is_empty:
357+
await self.store.add_thread_item(thread.id, item, context=context)
358+
359+
# Add a hidden context item to the thread to indicate that the stream was cancelled.
360+
# Otherwise, depending on the timing of the cancellation, subsequent responses may
361+
# attempt to continue the cancelled response.
362+
await self.store.add_thread_item(
363+
thread.id,
364+
SDKHiddenContextItem(
365+
thread_id=thread.id,
366+
created_at=datetime.now(),
367+
id=self.store.generate_item_id("sdk_hidden_context", thread, context),
368+
content="The user cancelled the stream.",
369+
),
370+
context=context,
371+
)
336372

337373
async def process(
338374
self, request: str | bytes | bytearray, context: TContext
@@ -633,6 +669,11 @@ async def _process_events(
633669
) -> AsyncIterator[ThreadStreamEvent]:
634670
await asyncio.sleep(0) # allow the response to start streaming
635671

672+
# Send initial stream options
673+
yield StreamOptionsEvent(
674+
stream_options=self.get_stream_options(thread, context)
675+
)
676+
636677
last_thread = thread.model_copy(deep=True)
637678

638679
# Keep track of items that were streamed but not yet saved
@@ -662,12 +703,10 @@ async def _process_events(
662703
)
663704
pending_items.pop(event.item.id, None)
664705
case ThreadItemUpdatedEvent():
665-
# Keep the pending assistant message item up to date
666-
# so that we can persist already-streamed partial content
667-
# if the stream is cancelled.
668-
self._update_pending_assistant_message_items(
669-
pending_items, event
670-
)
706+
# Keep pending assistant message and workflow items up to date
707+
# so that we have a reference to the latest version of these pending items
708+
# when the stream is cancelled.
709+
self._update_pending_items(pending_items, event)
671710

672711
# special case - don't send hidden context items back to the client
673712
should_swallow_event = isinstance(
@@ -690,10 +729,6 @@ async def _process_events(
690729
await self.store.save_thread(thread, context=context)
691730
yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
692731
except asyncio.CancelledError:
693-
# When a stream is cancelled, whether it's a deliberate stop request or due to a network issue,
694-
# save already-streamed items to the thread.
695-
await self._persist_cancelled_stream_state(thread, pending_items, context)
696-
# Allow custom cleanup.
697732
await self.handle_stream_cancelled(
698733
thread, list(pending_items.values()), context
699734
)
@@ -721,43 +756,6 @@ async def _process_events(
721756
await self.store.save_thread(thread, context=context)
722757
yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
723758

724-
async def _persist_cancelled_stream_state(
725-
self,
726-
thread: ThreadMetadata,
727-
pending_items: dict[str, ThreadItem],
728-
context: TContext,
729-
):
730-
# Persist any streamed items that the UI should keep when cancellation happens mid-stream.
731-
for item in pending_items.values():
732-
if isinstance(
733-
item, (AssistantMessageItem, WidgetItem, WorkflowItem)
734-
) and not self._is_streamed_item_empty(item):
735-
await self.store.add_thread_item(thread.id, item, context=context)
736-
737-
await self.store.add_thread_item(
738-
thread.id,
739-
HiddenContextItem(
740-
thread_id=thread.id,
741-
created_at=datetime.now(),
742-
id=self.store.generate_item_id("hidden_context", thread, context),
743-
content="SYSTEM: The user cancelled the stream.",
744-
),
745-
context=context,
746-
)
747-
748-
def _is_streamed_item_empty(
749-
self, item: AssistantMessageItem | WorkflowItem | WidgetItem
750-
) -> bool:
751-
if isinstance(item, AssistantMessageItem):
752-
return len(item.content) == 0 or all(
753-
(not content.text.strip()) for content in item.content
754-
)
755-
if isinstance(item, WorkflowItem):
756-
return len(item.workflow.tasks) == 0 and item.workflow.summary is None
757-
758-
# Assume all WidgetItems are not empty
759-
return False
760-
761759
def _apply_assistant_message_update(
762760
self,
763761
item: AssistantMessageItem,
@@ -788,27 +786,38 @@ def _apply_assistant_message_update(
788786
updated.content[update.content_index] = update.content
789787
return updated
790788

791-
def _update_pending_assistant_message_items(
789+
def _update_pending_items(
792790
self,
793791
pending_items: dict[str, ThreadItem],
794792
event: ThreadItemUpdatedEvent,
795793
):
796-
if not isinstance(
797-
event.update,
798-
(
799-
AssistantMessageContentPartAdded,
800-
AssistantMessageContentPartTextDelta,
801-
AssistantMessageContentPartAnnotationAdded,
802-
AssistantMessageContentPartDone,
803-
),
804-
):
805-
return
806-
807794
updated_item = pending_items.get(event.item_id)
808-
if updated_item and isinstance(updated_item, AssistantMessageItem):
809-
pending_items[updated_item.id] = self._apply_assistant_message_update(
810-
updated_item, event.update
811-
)
795+
update = event.update
796+
match updated_item:
797+
case AssistantMessageItem():
798+
if isinstance(
799+
update,
800+
(
801+
AssistantMessageContentPartAdded,
802+
AssistantMessageContentPartTextDelta,
803+
AssistantMessageContentPartAnnotationAdded,
804+
AssistantMessageContentPartDone,
805+
),
806+
):
807+
pending_items[updated_item.id] = (
808+
self._apply_assistant_message_update(updated_item, update)
809+
)
810+
case WorkflowItem():
811+
if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
812+
match update:
813+
case WorkflowTaskUpdated():
814+
updated_item.workflow.tasks[update.task_index] = update.task
815+
case WorkflowTaskAdded():
816+
updated_item.workflow.tasks.append(update.task)
817+
818+
pending_items[updated_item.id] = updated_item
819+
case _:
820+
pass
812821

813822
async def _build_user_message_item(
814823
self, input: UserMessageInput, thread: ThreadMetadata, context: TContext

chatkit/store.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
TContext = TypeVar("TContext", default=Any)
1616

1717
StoreItemType = Literal[
18-
"thread", "message", "tool_call", "task", "workflow", "attachment", "hidden_context"
18+
"thread",
19+
"message",
20+
"tool_call",
21+
"task",
22+
"workflow",
23+
"attachment",
24+
"sdk_hidden_context",
1925
]
2026

2127

@@ -26,7 +32,7 @@
2632
"workflow": "wf",
2733
"task": "tsk",
2834
"attachment": "atc",
29-
"hidden_context": "hcx",
35+
"sdk_hidden_context": "shcx",
3036
}
3137

3238

chatkit/types.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,20 @@ class ThreadItemReplacedEvent(BaseModel):
317317
item: ThreadItem
318318

319319

320+
class StreamOptions(BaseModel):
321+
"""Settings that control runtime stream behavior."""
322+
323+
allow_cancel: bool
324+
"""Allow the client to request cancellation mid-stream."""
325+
326+
327+
class StreamOptionsEvent(BaseModel):
328+
"""Event emitted to set stream options at runtime."""
329+
330+
type: Literal["stream_options"] = "stream_options"
331+
stream_options: StreamOptions
332+
333+
320334
class ProgressUpdateEvent(BaseModel):
321335
"""Event providing incremental progress from the assistant."""
322336

@@ -354,6 +368,7 @@ class NoticeEvent(BaseModel):
354368
| ThreadItemUpdated
355369
| ThreadItemRemovedEvent
356370
| ThreadItemReplacedEvent
371+
| StreamOptionsEvent
357372
| ProgressUpdateEvent
358373
| ErrorEvent
359374
| NoticeEvent,
@@ -576,12 +591,25 @@ class EndOfTurnItem(ThreadItemBase):
576591

577592

578593
class HiddenContextItem(ThreadItemBase):
579-
"""HiddenContext is never sent to the client. It's not officially part of ChatKit. It is only used internally to store additional context in a specific place in the thread."""
594+
"""
595+
HiddenContext is never sent to the client. It's not officially part of ChatKit.js.
596+
It is only used internally to store additional context in a specific place in the thread.
597+
"""
580598

581599
type: Literal["hidden_context_item"] = "hidden_context_item"
582600
content: Any
583601

584602

603+
class SDKHiddenContextItem(ThreadItemBase):
604+
"""
605+
Hidden context that is used by the ChatKit Python SDK for storing additional context
606+
for internal operations.
607+
"""
608+
609+
type: Literal["sdk_hidden_context"] = "sdk_hidden_context"
610+
content: str
611+
612+
585613
ThreadItem = Annotated[
586614
UserMessageItem
587615
| AssistantMessageItem
@@ -590,6 +618,7 @@ class HiddenContextItem(ThreadItemBase):
590618
| WorkflowItem
591619
| TaskItem
592620
| HiddenContextItem
621+
| SDKHiddenContextItem
593622
| EndOfTurnItem,
594623
Field(discriminator="type"),
595624
]

0 commit comments

Comments
 (0)