Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dspy/streaming/streamify.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def use_streaming():
elif not iscoroutinefunction(program):
program = asyncify(program)

callbacks = settings.callbacks
callbacks = list(settings.callbacks)
status_streaming_callback = StatusStreamingCallback(status_message_provider)
if not any(isinstance(c, StatusStreamingCallback) for c in callbacks):
callbacks.append(status_streaming_callback)
Expand Down
94 changes: 94 additions & 0 deletions tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,100 @@ def module_start_status_message(self, instance, inputs):
assert status_messages[2].message == "Predict starting!"


@pytest.mark.anyio
async def test_default_then_custom_status_message_provider():
class MyProgram(dspy.Module):
def __init__(self):
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
self.predict = dspy.Predict("question->answer")

def __call__(self, x: str):
question = self.generate_question(x=x)
return self.predict(question=question)

class MyStatusMessageProvider(StatusMessageProvider):
def tool_start_status_message(self, instance, inputs):
return "Tool starting!"

def tool_end_status_message(self, outputs):
return "Tool finished!"

def module_start_status_message(self, instance, inputs):
if isinstance(instance, dspy.Predict):
return "Predict starting!"

lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
with dspy.context(lm=lm):
program = dspy.streamify(MyProgram())
output = program("sky")

status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)

assert len(status_messages) == 2
assert status_messages[0].message == "Calling tool generate_question..."
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."

program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider())
output = program("sky")
status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)
assert len(status_messages) == 3
assert status_messages[0].message == "Tool starting!"
assert status_messages[1].message == "Tool finished!"
assert status_messages[2].message == "Predict starting!"


@pytest.mark.anyio
async def test_custom_then_default_status_message_provider():
class MyProgram(dspy.Module):
def __init__(self):
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
self.predict = dspy.Predict("question->answer")

def __call__(self, x: str):
question = self.generate_question(x=x)
return self.predict(question=question)

class MyStatusMessageProvider(StatusMessageProvider):
def tool_start_status_message(self, instance, inputs):
return "Tool starting!"

def tool_end_status_message(self, outputs):
return "Tool finished!"

def module_start_status_message(self, instance, inputs):
if isinstance(instance, dspy.Predict):
return "Predict starting!"

lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
with dspy.context(lm=lm):
program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider())
output = program("sky")
status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)
assert len(status_messages) == 3
assert status_messages[0].message == "Tool starting!"
assert status_messages[1].message == "Tool finished!"
assert status_messages[2].message == "Predict starting!"

program = dspy.streamify(MyProgram())
output = program("sky")
status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value)
assert len(status_messages) == 2
assert status_messages[0].message == "Calling tool generate_question..."
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."


@pytest.mark.llm_call
@pytest.mark.anyio
async def test_stream_listener_chat_adapter(lm_for_test):
Expand Down