Skip to content

Commit 4a4b81b

Browse files
committed
update tests
1 parent 5a7244b commit 4a4b81b

File tree

1 file changed

+40
-33
lines changed

1 file changed

+40
-33
lines changed

tests/test_chatkit_server.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ async def responder(
240240
def generate_item_id(
241241
item_type: StoreItemType, thread: ThreadMetadata, context: Any
242242
):
243-
if item_type == "hidden_context":
244-
return default_generate_id("hidden_context")
243+
if item_type == "sdk_hidden_context":
244+
return default_generate_id("sdk_hidden_context")
245245
return original_generate_item_id(item_type, thread, context)
246246

247247
server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
@@ -270,8 +270,8 @@ def generate_item_id(
270270
thread.id, None, 1, "desc", DEFAULT_CONTEXT
271271
)
272272
hidden_context_item = items.data[-1]
273-
assert hidden_context_item.type == "hidden_context_item"
274-
assert hidden_context_item.content == "SYSTEM: The user cancelled the stream."
273+
assert hidden_context_item.type == "sdk_hidden_context"
274+
assert hidden_context_item.content == "The user cancelled the stream."
275275

276276
assistant_message_item = await server.store.load_item(
277277
thread.id, "assistant-message-pending", DEFAULT_CONTEXT
@@ -300,8 +300,8 @@ async def responder(
300300
def generate_item_id(
301301
item_type: StoreItemType, thread: ThreadMetadata, context: Any
302302
):
303-
if item_type == "hidden_context":
304-
return default_generate_id("hidden_context")
303+
if item_type == "sdk_hidden_context":
304+
return default_generate_id("sdk_hidden_context")
305305
return original_generate_item_id(item_type, thread, context)
306306

307307
server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
@@ -330,8 +330,8 @@ def generate_item_id(
330330
thread.id, None, 1, "desc", DEFAULT_CONTEXT
331331
)
332332
hidden_context_item = items.data[-1]
333-
assert hidden_context_item.type == "hidden_context_item"
334-
assert hidden_context_item.content == "SYSTEM: The user cancelled the stream."
333+
assert hidden_context_item.type == "sdk_hidden_context"
334+
assert hidden_context_item.content == "The user cancelled the stream."
335335

336336
with pytest.raises(NotFoundError):
337337
await server.store.load_item(
@@ -643,19 +643,21 @@ async def responder(
643643
)
644644
)
645645

646-
assert len(events) == 3
646+
assert len(events) == 4
647647
assert events[0].type == "thread.created"
648648
thread = events[0].thread
649649

650650
assert events[1].type == "thread.item.done"
651651
assert events[1].item.type == "user_message"
652652

653-
assert events[2].type == "thread.item.done"
654-
assert events[2].item.type == "client_tool_call"
655-
assert events[2].item.id == "msg_1"
656-
assert events[2].item.name == "tool_call_1"
657-
assert events[2].item.arguments == {"arg1": "val1", "arg2": False}
658-
assert events[2].item.call_id == "tool_call_1"
653+
assert events[2].type == "stream_options"
654+
655+
assert events[3].type == "thread.item.done"
656+
assert events[3].item.type == "client_tool_call"
657+
assert events[3].item.id == "msg_1"
658+
assert events[3].item.name == "tool_call_1"
659+
assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
660+
assert events[3].item.call_id == "tool_call_1"
659661

660662
events = await server.process_streaming(
661663
ThreadsAddClientToolOutputReq(
@@ -666,9 +668,10 @@ async def responder(
666668
)
667669
)
668670

669-
assert len(events) == 1
670-
assert events[0].type == "thread.item.done"
671-
assert events[0].item.type == "assistant_message"
671+
assert len(events) == 2
672+
assert events[0].type == "stream_options"
673+
assert events[1].type == "thread.item.done"
674+
assert events[1].item.type == "assistant_message"
672675

673676

674677
async def test_removes_tool_call_if_no_output_provided():
@@ -764,11 +767,12 @@ async def responder(
764767
)
765768
)
766769

767-
assert len(events) == 4
770+
assert len(events) == 5
768771
assert events[0].type == "thread.created"
769772
assert events[1].type == "thread.item.done"
770-
assert events[2].type == "progress_update"
771-
assert events[3].type == "thread.item.done"
773+
assert events[2].type == "stream_options"
774+
assert events[3].type == "progress_update"
775+
assert events[4].type == "thread.item.done"
772776

773777

774778
async def test_list_threads_response():
@@ -939,11 +943,12 @@ async def action(
939943
widget_item,
940944
)
941945

942-
assert len(events) == 1
943-
assert events[0].type == "thread.item.updated"
944-
assert isinstance(events[0], ThreadItemUpdatedEvent)
945-
assert events[0].update.type == "widget.root.updated"
946-
assert events[0].update.widget == Card(children=[Text(value="Email sent!")])
946+
assert len(events) == 2
947+
assert events[0].type == "stream_options"
948+
assert events[1].type == "thread.item.updated"
949+
assert isinstance(events[1], ThreadItemUpdatedEvent)
950+
assert events[1].update.type == "widget.root.updated"
951+
assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
947952

948953

949954
async def test_add_feedback():
@@ -1603,11 +1608,12 @@ async def responder(
16031608
)
16041609

16051610
# Verify the retry generated new response
1606-
assert len(retry_events) == 1
1607-
assert retry_events[0].type == "thread.item.done"
1608-
assert retry_events[0].item.type == "assistant_message"
1609-
assert retry_events[0].item.content[0].type == "output_text"
1610-
assert retry_events[0].item.content[0].text == "Retried response"
1611+
assert len(retry_events) == 2
1612+
assert retry_events[0].type == "stream_options"
1613+
assert retry_events[1].type == "thread.item.done"
1614+
assert retry_events[1].item.type == "assistant_message"
1615+
assert retry_events[1].item.content[0].type == "output_text"
1616+
assert retry_events[1].item.content[0].text == "Retried response"
16111617

16121618
# Verify the responder was called twice with the same user message
16131619
assert len(responder_calls) == 2
@@ -1750,8 +1756,9 @@ async def responder(
17501756
)
17511757
)
17521758

1753-
assert len(retry_events) == 1
1754-
assert retry_events[0].type == "thread.item.done"
1759+
assert len(retry_events) == 2
1760+
assert retry_events[0].type == "stream_options"
1761+
assert retry_events[1].type == "thread.item.done"
17551762

17561763
# Verify retry used the second user message
17571764
assert len(responder_calls) == 3 # Original 2 + 1 retry

0 commit comments

Comments
 (0)