diff --git a/dspy/utils/asyncify.py b/dspy/utils/asyncify.py index 842746e49f..3ff5ae4741 100644 --- a/dspy/utils/asyncify.py +++ b/dspy/utils/asyncify.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Awaitable, Callable +from typing import TYPE_CHECKING, Any, Awaitable, Callable, ParamSpec, TypeVar, Union, overload import asyncer from anyio import CapacityLimiter @@ -6,6 +6,9 @@ if TYPE_CHECKING: from dspy.primitives.module import Module +P = ParamSpec("P") +T = TypeVar("T") + _limiter = None @@ -27,22 +30,30 @@ def get_limiter(): return _limiter -def asyncify(program: "Module") -> Callable[[Any, Any], Awaitable[Any]]: +@overload +def asyncify(program: Callable[P, T]) -> Callable[P, Awaitable[T]]: ... + + +@overload +def asyncify(program: "Module") -> Callable[..., Awaitable[Any]]: ... + + +def asyncify(program: Union[Callable[P, T], "Module"]) -> Callable[P, Awaitable[T]] | Callable[..., Awaitable[Any]]: """ - Wraps a DSPy program so that it can be called asynchronously. This is useful for running a + Wraps a DSPy program or callable so that it can be called asynchronously. This is useful for running a program in parallel with another task (e.g., another DSPy program). This implementation propagates the current thread's configuration context to the worker thread. Args: - program: The DSPy program to be wrapped for asynchronous execution. + program: The DSPy program or callable to be wrapped for asynchronous execution. Returns: An async function: An async function that, when awaited, runs the program in a worker thread. The current thread's configuration context is inherited for each call. """ - async def async_program(*args, **kwargs) -> Any: + async def async_program(*args: P.args, **kwargs: P.kwargs) -> T: # Capture the current overrides at call-time. from dspy.dsp.utils.settings import thread_local_overrides @@ -62,4 +73,4 @@ def wrapped_program(*a, **kw): call_async = asyncer.asyncify(wrapped_program, abandon_on_cancel=True, limiter=get_limiter()) return await call_async(*args, **kwargs) - return async_program + return async_program # type: ignore[return-value] diff --git a/tests/utils/test_asyncify.py b/tests/utils/test_asyncify.py index d34209ad6e..4049e8c996 100644 --- a/tests/utils/test_asyncify.py +++ b/tests/utils/test_asyncify.py @@ -50,3 +50,24 @@ async def verify_asyncify(capacity: int, number_of_tasks: int, wait: float = 0.5 await verify_asyncify(4, 10) await verify_asyncify(8, 15) await verify_asyncify(8, 30) + + +@pytest.mark.anyio +async def test_asyncify_with_dspy_module(): + """Test that asyncify works with DSPy modules and can be type-checked.""" + + class SimpleModule(dspy.Module): + def forward(self, x: int) -> int: + return x * 2 + + module = SimpleModule() + async_module = dspy.asyncify(module) + + # Test with positional argument + result = await async_module(5) + assert result == 10, "Asyncified module should return correct result" + + # Test with keyword argument + result = await async_module(x=7) + assert result == 14, "Asyncified module should work with keyword arguments" +