Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dspy/streaming/streamify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import threading
from asyncio import iscoroutinefunction
from copy import deepcopy
from queue import Queue
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator

Expand Down Expand Up @@ -161,7 +162,7 @@ async def use_streaming():
elif not iscoroutinefunction(program):
program = asyncify(program)

callbacks = settings.callbacks
callbacks = deepcopy(settings.callbacks)
status_streaming_callback = StatusStreamingCallback(status_message_provider)
if not any(isinstance(c, StatusStreamingCallback) for c in callbacks):
callbacks.append(status_streaming_callback)
Expand Down
94 changes: 94 additions & 0 deletions tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,100 @@ def module_start_status_message(self, instance, inputs):
assert status_messages[2].message == "Predict starting!"


@pytest.mark.anyio
async def test_default_then_custom_status_message_provider():
class MyProgram(dspy.Module):
def __init__(self):
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
self.predict = dspy.Predict("question->answer")

def __call__(self, x: str):
question = self.generate_question(x=x)
return self.predict(question=question)

class MyStatusMessageProvider(StatusMessageProvider):
def tool_start_status_message(self, instance, inputs):
return "Tool starting!"

def tool_end_status_message(self, outputs):
return "Tool finished!"

def module_start_status_message(self, instance, inputs):
if isinstance(instance, dspy.Predict):
return "Predict starting!"

lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
with dspy.context(lm=lm):
program = dspy.streamify(MyProgram())
output = program("sky")

status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)

assert len(status_messages) == 2
assert status_messages[0].message == "Calling tool generate_question..."
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."

program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider())
output = program("sky")
status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)
assert len(status_messages) == 3
assert status_messages[0].message == "Tool starting!"
assert status_messages[1].message == "Tool finished!"
assert status_messages[2].message == "Predict starting!"


@pytest.mark.anyio
async def test_custom_then_default_status_message_provider():
class MyProgram(dspy.Module):
def __init__(self):
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
self.predict = dspy.Predict("question->answer")

def __call__(self, x: str):
question = self.generate_question(x=x)
return self.predict(question=question)

class MyStatusMessageProvider(StatusMessageProvider):
def tool_start_status_message(self, instance, inputs):
return "Tool starting!"

def tool_end_status_message(self, outputs):
return "Tool finished!"

def module_start_status_message(self, instance, inputs):
if isinstance(instance, dspy.Predict):
return "Predict starting!"

lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
with dspy.context(lm=lm):
program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider())
output = program("sky")
status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)
assert len(status_messages) == 3
assert status_messages[0].message == "Tool starting!"
assert status_messages[1].message == "Tool finished!"
assert status_messages[2].message == "Predict starting!"

program = dspy.streamify(MyProgram())
output = program("sky")
status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)
assert len(status_messages) == 2
assert status_messages[0].message == "Calling tool generate_question..."
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."


@pytest.mark.llm_call
@pytest.mark.anyio
async def test_stream_listener_chat_adapter(lm_for_test):
Expand Down