Skip to content

Commit e29b09e

Browse files
committed
Don't persist empty pending items
1 parent 942620a commit e29b09e

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

chatkit/server.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,9 @@ async def _persist_cancelled_stream_state(
729729
):
730730
# Persist any streamed items that the UI should keep when cancellation happens mid-stream.
731731
for item in pending_items.values():
732-
if isinstance(item, (AssistantMessageItem, WidgetItem, WorkflowItem)):
732+
if isinstance(
733+
item, (AssistantMessageItem, WidgetItem, WorkflowItem)
734+
) and not self._is_streamed_item_empty(item):
733735
await self.store.add_thread_item(thread.id, item, context=context)
734736

735737
await self.store.add_thread_item(
@@ -743,6 +745,19 @@ async def _persist_cancelled_stream_state(
743745
context=context,
744746
)
745747

748+
def _is_streamed_item_empty(
749+
self, item: AssistantMessageItem | WorkflowItem | WidgetItem
750+
) -> bool:
751+
if isinstance(item, AssistantMessageItem):
752+
return len(item.content) == 0 or all(
753+
(not content.text.strip()) for content in item.content
754+
)
755+
if isinstance(item, WorkflowItem):
756+
return len(item.workflow.tasks) == 0 and item.workflow.summary is None
757+
758+
# Assume all WidgetItems are not empty
759+
return False
760+
746761
def _apply_assistant_message_update(
747762
self,
748763
item: AssistantMessageItem,

tests/test_chatkit_server.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,65 @@ def generate_item_id(
280280
assert assistant_message_item.content[0].text == "Hello, World!"
281281

282282

283+
async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
284+
async def responder(
285+
thread: ThreadMetadata, input: UserMessageItem | None, context: Any
286+
) -> AsyncIterator[ThreadStreamEvent]:
287+
yield ThreadItemAddedEvent(
288+
item=AssistantMessageItem(
289+
id="assistant-message-pending",
290+
created_at=datetime.now(),
291+
content=[],
292+
thread_id=thread.id,
293+
)
294+
)
295+
raise asyncio.CancelledError()
296+
297+
with make_server(responder) as server:
298+
original_generate_item_id = server.store.generate_item_id
299+
300+
def generate_item_id(
301+
item_type: StoreItemType, thread: ThreadMetadata, context: Any
302+
):
303+
if item_type == "hidden_context":
304+
return default_generate_id("hidden_context")
305+
return original_generate_item_id(item_type, thread, context)
306+
307+
server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
308+
309+
stream = await server.process(
310+
ThreadsCreateReq(
311+
params=ThreadCreateParams(
312+
input=UserMessageInput(
313+
content=[UserMessageTextContent(text="Hello")],
314+
attachments=[],
315+
inference_options=InferenceOptions(),
316+
)
317+
)
318+
).model_dump_json(),
319+
DEFAULT_CONTEXT,
320+
)
321+
assert isinstance(stream, StreamingResult)
322+
323+
events: list[ThreadStreamEvent] = []
324+
with pytest.raises(asyncio.CancelledError): # noqa: PT012
325+
async for raw in stream.json_events:
326+
events.append(decode_event(raw))
327+
328+
thread = next(e.thread for e in events if e.type == "thread.created")
329+
items = await server.store.load_thread_items(
330+
thread.id, None, 1, "desc", DEFAULT_CONTEXT
331+
)
332+
hidden_context_item = items.data[-1]
333+
assert hidden_context_item.type == "hidden_context_item"
334+
assert hidden_context_item.content == "SYSTEM: The user cancelled the stream."
335+
336+
with pytest.raises(NotFoundError):
337+
await server.store.load_item(
338+
thread.id, "assistant-message-pending", DEFAULT_CONTEXT
339+
)
340+
341+
283342
async def test_flows_context_to_responder():
284343
responder_context = None
285344
add_feedback_context = None

0 commit comments

Comments
 (0)