@@ -134,6 +134,100 @@ def module_start_status_message(self, instance, inputs):
134134 assert status_messages [2 ].message == "Predict starting!"
135135
136136
137+ @pytest .mark .anyio
138+ async def test_default_then_custom_status_message_provider ():
139+ class MyProgram (dspy .Module ):
140+ def __init__ (self ):
141+ self .generate_question = dspy .Tool (lambda x : f"What color is the { x } ?" , name = "generate_question" )
142+ self .predict = dspy .Predict ("question->answer" )
143+
144+ def __call__ (self , x : str ):
145+ question = self .generate_question (x = x )
146+ return self .predict (question = question )
147+
148+ class MyStatusMessageProvider (StatusMessageProvider ):
149+ def tool_start_status_message (self , instance , inputs ):
150+ return "Tool starting!"
151+
152+ def tool_end_status_message (self , outputs ):
153+ return "Tool finished!"
154+
155+ def module_start_status_message (self , instance , inputs ):
156+ if isinstance (instance , dspy .Predict ):
157+ return "Predict starting!"
158+
159+ lm = dspy .utils .DummyLM ([{"answer" : "red" }, {"answer" : "blue" }])
160+ with dspy .context (lm = lm ):
161+ program = dspy .streamify (MyProgram ())
162+ output = program ("sky" )
163+
164+ status_messages = []
165+ async for value in output :
166+ if isinstance (value , StatusMessage ):
167+ status_messages .append (value )
168+
169+ assert len (status_messages ) == 2
170+ assert status_messages [0 ].message == "Calling tool generate_question..."
171+ assert status_messages [1 ].message == "Tool calling finished! Querying the LLM with tool calling results..."
172+
173+ program = dspy .streamify (MyProgram (), status_message_provider = MyStatusMessageProvider ())
174+ output = program ("sky" )
175+ status_messages = []
176+ async for value in output :
177+ if isinstance (value , StatusMessage ):
178+ status_messages .append (value )
179+ assert len (status_messages ) == 3
180+ assert status_messages [0 ].message == "Tool starting!"
181+ assert status_messages [1 ].message == "Tool finished!"
182+ assert status_messages [2 ].message == "Predict starting!"
183+
184+
185+ @pytest .mark .anyio
186+ async def test_custom_then_default_status_message_provider ():
187+ class MyProgram (dspy .Module ):
188+ def __init__ (self ):
189+ self .generate_question = dspy .Tool (lambda x : f"What color is the { x } ?" , name = "generate_question" )
190+ self .predict = dspy .Predict ("question->answer" )
191+
192+ def __call__ (self , x : str ):
193+ question = self .generate_question (x = x )
194+ return self .predict (question = question )
195+
196+ class MyStatusMessageProvider (StatusMessageProvider ):
197+ def tool_start_status_message (self , instance , inputs ):
198+ return "Tool starting!"
199+
200+ def tool_end_status_message (self , outputs ):
201+ return "Tool finished!"
202+
203+ def module_start_status_message (self , instance , inputs ):
204+ if isinstance (instance , dspy .Predict ):
205+ return "Predict starting!"
206+
207+ lm = dspy .utils .DummyLM ([{"answer" : "red" }, {"answer" : "blue" }])
208+ with dspy .context (lm = lm ):
209+ program = dspy .streamify (MyProgram (), status_message_provider = MyStatusMessageProvider ())
210+ output = program ("sky" )
211+ status_messages = []
212+ async for value in output :
213+ if isinstance (value , StatusMessage ):
214+ status_messages .append (value )
215+ assert len (status_messages ) == 3
216+ assert status_messages [0 ].message == "Tool starting!"
217+ assert status_messages [1 ].message == "Tool finished!"
218+ assert status_messages [2 ].message == "Predict starting!"
219+
220+ program = dspy .streamify (MyProgram ())
221+ output = program ("sky" )
222+ status_messages = []
223+ async for value in output :
224+ if isinstance (value , StatusMessage ):
225+ status_messages .append (value )
226+ assert len (status_messages ) == 2
227+ assert status_messages [0 ].message == "Calling tool generate_question..."
228+ assert status_messages [1 ].message == "Tool calling finished! Querying the LLM with tool calling results..."
229+
230+
137231@pytest .mark .llm_call
138232@pytest .mark .anyio
139233async def test_stream_listener_chat_adapter (lm_for_test ):
0 commit comments