Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions dspy/utils/asyncify.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Any, Awaitable, Callable
from typing import TYPE_CHECKING, Any, Awaitable, Callable, ParamSpec, TypeVar, Union, overload

import asyncer
from anyio import CapacityLimiter

if TYPE_CHECKING:
from dspy.primitives.module import Module

P = ParamSpec("P")
T = TypeVar("T")

_limiter = None


Expand All @@ -27,22 +30,30 @@ def get_limiter():
return _limiter


def asyncify(program: "Module") -> Callable[[Any, Any], Awaitable[Any]]:
@overload
def asyncify(program: Callable[P, T]) -> Callable[P, Awaitable[T]]: ...


@overload
def asyncify(program: "Module") -> Callable[..., Awaitable[Any]]: ...


def asyncify(program: Union[Callable[P, T], "Module"]) -> Callable[P, Awaitable[T]] | Callable[..., Awaitable[Any]]:
"""
Wraps a DSPy program so that it can be called asynchronously. This is useful for running a
Wraps a DSPy program or callable so that it can be called asynchronously. This is useful for running a
program in parallel with another task (e.g., another DSPy program).

This implementation propagates the current thread's configuration context to the worker thread.

Args:
program: The DSPy program to be wrapped for asynchronous execution.
program: The DSPy program or callable to be wrapped for asynchronous execution.

Returns:
An async function: An async function that, when awaited, runs the program in a worker thread.
The current thread's configuration context is inherited for each call.
"""

async def async_program(*args, **kwargs) -> Any:
async def async_program(*args: P.args, **kwargs: P.kwargs) -> T:
# Capture the current overrides at call-time.
from dspy.dsp.utils.settings import thread_local_overrides

Expand All @@ -62,4 +73,4 @@ def wrapped_program(*a, **kw):
call_async = asyncer.asyncify(wrapped_program, abandon_on_cancel=True, limiter=get_limiter())
return await call_async(*args, **kwargs)

return async_program
return async_program # type: ignore[return-value]
21 changes: 21 additions & 0 deletions tests/utils/test_asyncify.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,24 @@ async def verify_asyncify(capacity: int, number_of_tasks: int, wait: float = 0.5
await verify_asyncify(4, 10)
await verify_asyncify(8, 15)
await verify_asyncify(8, 30)


@pytest.mark.anyio
async def test_asyncify_with_dspy_module():
"""Test that asyncify works with DSPy modules and can be type-checked."""

class SimpleModule(dspy.Module):
def forward(self, x: int) -> int:
return x * 2

module = SimpleModule()
async_module = dspy.asyncify(module)

# Test with positional argument
result = await async_module(5)
assert result == 10, "Asyncified module should return correct result"

# Test with keyword argument
result = await async_module(x=7)
assert result == 14, "Asyncified module should work with keyword arguments"