diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index f448b413..8e7822f3 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -15,8 +15,50 @@ on: permissions: contents: read + id-token: write jobs: + test: + name: Test - ${{ matrix.os }} - py${{ matrix.python }} + runs-on: ${{ matrix.os }} + defaults: + run: + working-directory: ${{ env.WORKING_DIRECTORY }} + strategy: + fail-fast: false + matrix: + os: [ ubuntu-latest, macos-latest ] + python: [ '3.9', '3.13' ] + steps: + - uses: actions/checkout@v3 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python }} + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Rust cache + sccache + uses: ./.github/actions/setup-rust + with: + shared-key: python-test + sccache-bucket: ${{ secrets.SCCACHE_S3_BUCKET }} + sccache-region: ${{ secrets.SCCACHE_AWS_REGION }} + sccache-role-arn: ${{ secrets.SCCACHE_AWS_ROLE_ARN }} + + - name: Build and install wheel + run: | + python -m pip install --upgrade pip maturin + maturin build --release --out dist -m Cargo.toml + pip install --force-reinstall dist/*.whl + shell: bash + + - name: Run tests + run: python -m unittest test_sync test_async + shell: bash + # The build jobs skip on release events for other components (release-please fires # one `release` event per component; only `python-v*` should build/publish here). wheels: @@ -94,7 +136,7 @@ jobs: environment: release # Publishes only on the release-please GitHub Release for the python component. if: "github.ref_type == 'tag'" - needs: [ wheels, sdist ] + needs: [ test, wheels, sdist ] steps: - uses: actions/download-artifact@v4 with: diff --git a/bindings/python/src/convert.rs b/bindings/python/src/convert.rs index 1c345f1d..4b4cd21d 100644 --- a/bindings/python/src/convert.rs +++ b/bindings/python/src/convert.rs @@ -293,3 +293,30 @@ pub fn response_to_py(py: Python<'_>, response: DecisionGraphResponse) -> PyResu Ok(dict.into_any().unbind()) } + +pub fn batch_results_to_py( + py: Python<'_>, + outcomes: Vec>, +) -> PyResult> { + let list = PyList::empty(py); + + for outcome in outcomes { + let item = PyDict::new(py); + match outcome { + Ok(response) => { + item.set_item("success", true)?; + item.set_item("data", response.into_py(py)?)?; + item.set_item("error", py.None())?; + } + Err(error) => { + item.set_item("success", false)?; + item.set_item("data", py.None())?; + item.set_item("error", value_to_object(py, &error)?)?; + } + } + + list.append(item)?; + } + + Ok(list.into_any().unbind()) +} diff --git a/bindings/python/src/engine.rs b/bindings/python/src/engine.rs index 07f116ba..80ebfb3a 100644 --- a/bindings/python/src/engine.rs +++ b/bindings/python/src/engine.rs @@ -165,6 +165,88 @@ impl PyZenEngine { Ok(result.unbind()) } + #[pyo3(signature = (requests, opts=None))] + pub fn evaluate_batch( + &self, + py: Python, + requests: Vec<(String, PyValue)>, + opts: Option, + ) -> PyResult> { + let options: EvaluationOptions = opts.unwrap_or_default().into(); + let trace = options.trace; + let max_depth = options.max_depth; + let engine = self.engine.clone(); + + let outcomes = py.allow_threads(|| { + block_on(async move { + let mut handles = Vec::with_capacity(requests.len()); + for (key, ctx) in requests { + let engine = engine.clone(); + handles.push(worker_pool().spawn_pinned(move || async move { + let options = EvaluationOptions { trace, max_depth }; + engine + .evaluate_with_opts(key, ctx.0.into(), options) + .await + .map(crate::convert::PortableResponse::build) + .map_err(|e| serde_json::to_value(e.as_ref()).unwrap_or_default()) + })); + } + + let mut outcomes = Vec::with_capacity(handles.len()); + for handle in handles { + outcomes.push(match handle.await { + Ok(outcome) => outcome, + Err(_) => Err(Value::String("evaluation worker panicked".to_string())), + }); + } + + outcomes + }) + }); + + crate::convert::batch_results_to_py(py, outcomes) + } + + #[pyo3(signature = (requests, opts=None))] + pub fn async_evaluate_batch<'py>( + &'py self, + py: Python<'py>, + requests: Vec<(String, PyValue)>, + opts: Option, + ) -> PyResult> { + let options: EvaluationOptions = opts.unwrap_or_default().into(); + let trace = options.trace; + let max_depth = options.max_depth; + let engine = self.engine.clone(); + + let result = tokio::future_into_py_with_locals(py, get_current_locals(py)?, async move { + let mut handles = Vec::with_capacity(requests.len()); + for (key, ctx) in requests { + let engine = engine.clone(); + handles.push(worker_pool().spawn_pinned(move || async move { + let options = EvaluationOptions { trace, max_depth }; + engine + .evaluate_with_opts(key, ctx.0.into(), options) + .await + .map(crate::convert::PortableResponse::build) + .map_err(|e| serde_json::to_value(e.as_ref()).unwrap_or_default()) + })); + } + + let mut outcomes = Vec::with_capacity(handles.len()); + for handle in handles { + outcomes.push(match handle.await { + Ok(outcome) => outcome, + Err(_) => Err(Value::String("evaluation worker panicked".to_string())), + }); + } + + Python::with_gil(|py| crate::convert::batch_results_to_py(py, outcomes)) + })?; + + Ok(result.unbind()) + } + pub fn create_decision(&self, content: PyZenDecisionContentJson) -> PyResult { let decision = self .engine diff --git a/bindings/python/test_async.py b/bindings/python/test_async.py index 0efad8e7..0a2106c4 100644 --- a/bindings/python/test_async.py +++ b/bindings/python/test_async.py @@ -79,6 +79,22 @@ async def test_create_decisions_from_content(self): r = await functionDecision.async_evaluate({"input": 15}) self.assertEqual(r["result"]["output"], 30) + async def test_async_evaluate_batch(self): + engine = zen.ZenEngine({"loader": loader}) + results = await engine.async_evaluate_batch([ + ("table.json", {"input": 12}), + ("table.json", {"input": 2}), + ("does-not-exist.json", {}), + ]) + + self.assertEqual(len(results), 3) + self.assertTrue(results[0]["success"]) + self.assertEqual(results[0]["data"]["result"]["output"], 10) + self.assertTrue(results[1]["success"]) + self.assertEqual(results[1]["data"]["result"]["output"], 0) + self.assertFalse(results[2]["success"]) + self.assertIsNotNone(results[2]["error"]) + async def test_evaluate_graphs(self): engine = zen.ZenEngine({"loader": graph_loader}) json_files = glob.glob("../../test-data/graphs/*.json") diff --git a/bindings/python/test_sync.py b/bindings/python/test_sync.py index 8472d6df..922f22c0 100644 --- a/bindings/python/test_sync.py +++ b/bindings/python/test_sync.py @@ -58,6 +58,22 @@ def test_create_decisions_from_content(self): r = functionDecision.evaluate({"input": 15}) self.assertEqual(r["result"]["output"], 30) + def test_evaluate_batch(self): + engine = zen.ZenEngine({"loader": loader}) + results = engine.evaluate_batch([ + ("table.json", {"input": 12}), + ("table.json", {"input": 2}), + ("does-not-exist.json", {}), + ]) + + self.assertEqual(len(results), 3) + self.assertTrue(results[0]["success"]) + self.assertEqual(results[0]["data"]["result"]["output"], 10) + self.assertTrue(results[1]["success"]) + self.assertEqual(results[1]["data"]["result"]["output"], 0) + self.assertFalse(results[2]["success"]) + self.assertIsNotNone(results[2]["error"]) + def test_engine_custom_handler(self): engine = zen.ZenEngine({"loader": loader, "customHandler": custom_handler}) r1 = engine.evaluate("custom.json", {"a": 10}) diff --git a/bindings/python/zen.pyi b/bindings/python/zen.pyi index ecf448a2..c5a307be 100644 --- a/bindings/python/zen.pyi +++ b/bindings/python/zen.pyi @@ -13,8 +13,15 @@ class EvaluateResponse(TypedDict): trace: dict +class BatchEvaluateResult(TypedDict): + success: bool + data: Optional[EvaluateResponse] + error: Optional[Any] + + ZenContext: TypeAlias = Union[str, bytes, dict] ZenDecisionContentInput: TypeAlias = Union[str, ZenDecisionContent] +ZenBatchRequest: TypeAlias = "tuple[str, ZenContext]" class ZenEngine: @@ -26,6 +33,13 @@ class ZenEngine: def async_evaluate(self, key: str, context: ZenContext, options: Optional[DecisionEvaluateOptions] = None) -> \ Awaitable[EvaluateResponse]: ... + def evaluate_batch(self, requests: "list[ZenBatchRequest]", + options: Optional[DecisionEvaluateOptions] = None) -> "list[BatchEvaluateResult]": ... + + def async_evaluate_batch(self, requests: "list[ZenBatchRequest]", + options: Optional[DecisionEvaluateOptions] = None) -> \ + Awaitable["list[BatchEvaluateResult]"]: ... + def create_decision(self, content: ZenDecisionContentInput) -> ZenDecision: ... def get_decision(self, key: str) -> ZenDecision: ...