Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,50 @@ on:

permissions:
contents: read
id-token: write

jobs:
test:
name: Test - ${{ matrix.os }} - py${{ matrix.python }}
runs-on: ${{ matrix.os }}
defaults:
run:
working-directory: ${{ env.WORKING_DIRECTORY }}
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest, macos-latest ]
python: [ '3.9', '3.13' ]
steps:
- uses: actions/checkout@v3

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}

- name: Install Rust
uses: dtolnay/rust-toolchain@stable

- name: Rust cache + sccache
uses: ./.github/actions/setup-rust
with:
shared-key: python-test
sccache-bucket: ${{ secrets.SCCACHE_S3_BUCKET }}
sccache-region: ${{ secrets.SCCACHE_AWS_REGION }}
sccache-role-arn: ${{ secrets.SCCACHE_AWS_ROLE_ARN }}

- name: Build and install wheel
run: |
python -m pip install --upgrade pip maturin
maturin build --release --out dist -m Cargo.toml
pip install --force-reinstall dist/*.whl
shell: bash

- name: Run tests
run: python -m unittest test_sync test_async
shell: bash

# The build jobs skip on release events for other components (release-please fires
# one `release` event per component; only `python-v*` should build/publish here).
wheels:
Expand Down Expand Up @@ -94,7 +136,7 @@ jobs:
environment: release
# Publishes only on the release-please GitHub Release for the python component.
if: "github.ref_type == 'tag'"
needs: [ wheels, sdist ]
needs: [ test, wheels, sdist ]
steps:
- uses: actions/download-artifact@v4
with:
Expand Down
27 changes: 27 additions & 0 deletions bindings/python/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,30 @@ pub fn response_to_py(py: Python<'_>, response: DecisionGraphResponse) -> PyResu

Ok(dict.into_any().unbind())
}

pub fn batch_results_to_py(
py: Python<'_>,
outcomes: Vec<Result<PortableResponse, Value>>,
) -> PyResult<Py<PyAny>> {
let list = PyList::empty(py);

for outcome in outcomes {
let item = PyDict::new(py);
match outcome {
Ok(response) => {
item.set_item("success", true)?;
item.set_item("data", response.into_py(py)?)?;
item.set_item("error", py.None())?;
}
Err(error) => {
item.set_item("success", false)?;
item.set_item("data", py.None())?;
item.set_item("error", value_to_object(py, &error)?)?;
}
}

list.append(item)?;
}

Ok(list.into_any().unbind())
}
82 changes: 82 additions & 0 deletions bindings/python/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,88 @@ impl PyZenEngine {
Ok(result.unbind())
}

#[pyo3(signature = (requests, opts=None))]
pub fn evaluate_batch(
&self,
py: Python,
requests: Vec<(String, PyValue)>,
opts: Option<PyZenEvaluateOptions>,
) -> PyResult<Py<PyAny>> {
let options: EvaluationOptions = opts.unwrap_or_default().into();
let trace = options.trace;
let max_depth = options.max_depth;
let engine = self.engine.clone();

let outcomes = py.allow_threads(|| {
block_on(async move {
let mut handles = Vec::with_capacity(requests.len());
for (key, ctx) in requests {
let engine = engine.clone();
handles.push(worker_pool().spawn_pinned(move || async move {
let options = EvaluationOptions { trace, max_depth };
engine
.evaluate_with_opts(key, ctx.0.into(), options)
.await
.map(crate::convert::PortableResponse::build)
.map_err(|e| serde_json::to_value(e.as_ref()).unwrap_or_default())
}));
}

let mut outcomes = Vec::with_capacity(handles.len());
for handle in handles {
outcomes.push(match handle.await {
Ok(outcome) => outcome,
Err(_) => Err(Value::String("evaluation worker panicked".to_string())),
});
}

outcomes
})
});

crate::convert::batch_results_to_py(py, outcomes)
}

#[pyo3(signature = (requests, opts=None))]
pub fn async_evaluate_batch<'py>(
&'py self,
py: Python<'py>,
requests: Vec<(String, PyValue)>,
opts: Option<PyZenEvaluateOptions>,
) -> PyResult<Py<PyAny>> {
let options: EvaluationOptions = opts.unwrap_or_default().into();
let trace = options.trace;
let max_depth = options.max_depth;
let engine = self.engine.clone();

let result = tokio::future_into_py_with_locals(py, get_current_locals(py)?, async move {
let mut handles = Vec::with_capacity(requests.len());
for (key, ctx) in requests {
let engine = engine.clone();
handles.push(worker_pool().spawn_pinned(move || async move {
let options = EvaluationOptions { trace, max_depth };
engine
.evaluate_with_opts(key, ctx.0.into(), options)
.await
.map(crate::convert::PortableResponse::build)
.map_err(|e| serde_json::to_value(e.as_ref()).unwrap_or_default())
}));
}

let mut outcomes = Vec::with_capacity(handles.len());
for handle in handles {
outcomes.push(match handle.await {
Ok(outcome) => outcome,
Err(_) => Err(Value::String("evaluation worker panicked".to_string())),
});
}

Python::with_gil(|py| crate::convert::batch_results_to_py(py, outcomes))
})?;

Ok(result.unbind())
}

pub fn create_decision(&self, content: PyZenDecisionContentJson) -> PyResult<PyZenDecision> {
let decision = self
.engine
Expand Down
16 changes: 16 additions & 0 deletions bindings/python/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ async def test_create_decisions_from_content(self):
r = await functionDecision.async_evaluate({"input": 15})
self.assertEqual(r["result"]["output"], 30)

async def test_async_evaluate_batch(self):
engine = zen.ZenEngine({"loader": loader})
results = await engine.async_evaluate_batch([
("table.json", {"input": 12}),
("table.json", {"input": 2}),
("does-not-exist.json", {}),
])

self.assertEqual(len(results), 3)
self.assertTrue(results[0]["success"])
self.assertEqual(results[0]["data"]["result"]["output"], 10)
self.assertTrue(results[1]["success"])
self.assertEqual(results[1]["data"]["result"]["output"], 0)
self.assertFalse(results[2]["success"])
self.assertIsNotNone(results[2]["error"])

async def test_evaluate_graphs(self):
engine = zen.ZenEngine({"loader": graph_loader})
json_files = glob.glob("../../test-data/graphs/*.json")
Expand Down
16 changes: 16 additions & 0 deletions bindings/python/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ def test_create_decisions_from_content(self):
r = functionDecision.evaluate({"input": 15})
self.assertEqual(r["result"]["output"], 30)

def test_evaluate_batch(self):
engine = zen.ZenEngine({"loader": loader})
results = engine.evaluate_batch([
("table.json", {"input": 12}),
("table.json", {"input": 2}),
("does-not-exist.json", {}),
])

self.assertEqual(len(results), 3)
self.assertTrue(results[0]["success"])
self.assertEqual(results[0]["data"]["result"]["output"], 10)
self.assertTrue(results[1]["success"])
self.assertEqual(results[1]["data"]["result"]["output"], 0)
self.assertFalse(results[2]["success"])
self.assertIsNotNone(results[2]["error"])

def test_engine_custom_handler(self):
engine = zen.ZenEngine({"loader": loader, "customHandler": custom_handler})
r1 = engine.evaluate("custom.json", {"a": 10})
Expand Down
14 changes: 14 additions & 0 deletions bindings/python/zen.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@ class EvaluateResponse(TypedDict):
trace: dict


class BatchEvaluateResult(TypedDict):
success: bool
data: Optional[EvaluateResponse]
error: Optional[Any]


ZenContext: TypeAlias = Union[str, bytes, dict]
ZenDecisionContentInput: TypeAlias = Union[str, ZenDecisionContent]
ZenBatchRequest: TypeAlias = "tuple[str, ZenContext]"


class ZenEngine:
Expand All @@ -26,6 +33,13 @@ class ZenEngine:
def async_evaluate(self, key: str, context: ZenContext, options: Optional[DecisionEvaluateOptions] = None) -> \
Awaitable[EvaluateResponse]: ...

def evaluate_batch(self, requests: "list[ZenBatchRequest]",
options: Optional[DecisionEvaluateOptions] = None) -> "list[BatchEvaluateResult]": ...

def async_evaluate_batch(self, requests: "list[ZenBatchRequest]",
options: Optional[DecisionEvaluateOptions] = None) -> \
Awaitable["list[BatchEvaluateResult]"]: ...

def create_decision(self, content: ZenDecisionContentInput) -> ZenDecision: ...

def get_decision(self, key: str) -> ZenDecision: ...
Expand Down
Loading