Skip to content

Commit 901ddcf

Browse files
committed
Tests for streamify status message settings leak
1 parent eb29c47 commit 901ddcf

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

tests/streaming/test_streaming.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,100 @@ def module_start_status_message(self, instance, inputs):
134134
assert status_messages[2].message == "Predict starting!"
135135

136136

137+
@pytest.mark.anyio
138+
async def test_default_then_custom_status_message_provider():
139+
class MyProgram(dspy.Module):
140+
def __init__(self):
141+
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
142+
self.predict = dspy.Predict("question->answer")
143+
144+
def __call__(self, x: str):
145+
question = self.generate_question(x=x)
146+
return self.predict(question=question)
147+
148+
class MyStatusMessageProvider(StatusMessageProvider):
149+
def tool_start_status_message(self, instance, inputs):
150+
return "Tool starting!"
151+
152+
def tool_end_status_message(self, outputs):
153+
return "Tool finished!"
154+
155+
def module_start_status_message(self, instance, inputs):
156+
if isinstance(instance, dspy.Predict):
157+
return "Predict starting!"
158+
159+
lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
160+
with dspy.context(lm=lm):
161+
program = dspy.streamify(MyProgram())
162+
output = program("sky")
163+
164+
status_messages = []
165+
async for value in output:
166+
if isinstance(value, StatusMessage):
167+
status_messages.append(value)
168+
169+
assert len(status_messages) == 2
170+
assert status_messages[0].message == "Calling tool generate_question..."
171+
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."
172+
173+
program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider())
174+
output = program("sky")
175+
status_messages = []
176+
async for value in output:
177+
if isinstance(value, StatusMessage):
178+
status_messages.append(value)
179+
assert len(status_messages) == 3
180+
assert status_messages[0].message == "Tool starting!"
181+
assert status_messages[1].message == "Tool finished!"
182+
assert status_messages[2].message == "Predict starting!"
183+
184+
185+
@pytest.mark.anyio
186+
async def test_custom_then_default_status_message_provider():
187+
class MyProgram(dspy.Module):
188+
def __init__(self):
189+
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
190+
self.predict = dspy.Predict("question->answer")
191+
192+
def __call__(self, x: str):
193+
question = self.generate_question(x=x)
194+
return self.predict(question=question)
195+
196+
class MyStatusMessageProvider(StatusMessageProvider):
197+
def tool_start_status_message(self, instance, inputs):
198+
return "Tool starting!"
199+
200+
def tool_end_status_message(self, outputs):
201+
return "Tool finished!"
202+
203+
def module_start_status_message(self, instance, inputs):
204+
if isinstance(instance, dspy.Predict):
205+
return "Predict starting!"
206+
207+
lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}])
208+
with dspy.context(lm=lm):
209+
program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider())
210+
output = program("sky")
211+
status_messages = []
212+
async for value in output:
213+
if isinstance(value, StatusMessage):
214+
status_messages.append(value)
215+
assert len(status_messages) == 3
216+
assert status_messages[0].message == "Tool starting!"
217+
assert status_messages[1].message == "Tool finished!"
218+
assert status_messages[2].message == "Predict starting!"
219+
220+
program = dspy.streamify(MyProgram())
221+
output = program("sky")
222+
status_messages = []
223+
async for value in output:
224+
if isinstance(value, StatusMessage):
225+
status_messages.append(value)
226+
assert len(status_messages) == 2
227+
assert status_messages[0].message == "Calling tool generate_question..."
228+
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."
229+
230+
137231
@pytest.mark.llm_call
138232
@pytest.mark.anyio
139233
async def test_stream_listener_chat_adapter(lm_for_test):

0 commit comments

Comments
 (0)