Skip to content

Commit 3a98fed

Browse files
committed
pyo3: support module prefix + naming
1 parent a3bb997 commit 3a98fed

File tree

4 files changed

+71
-6
lines changed

4 files changed

+71
-6
lines changed

extensions/pyo3/private/pyo3.bzl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,19 @@ def _py_pyo3_library_impl(ctx):
8787
is_windows = extension.basename.endswith(".dll")
8888

8989
# https://pyo3.rs/v0.26.0/building-and-distribution#manual-builds
90-
ext = ctx.actions.declare_file("{}{}".format(
91-
ctx.label.name,
92-
".pyd" if is_windows else ".so",
93-
))
90+
# Determine the on-disk and logical Python module layout.
91+
module_name = ctx.attr.module if ctx.attr.module else ctx.label.name
92+
93+
# Convert a dotted prefix (e.g. "foo.bar") into a path ("foo/bar").
94+
if ctx.attr.module_prefix:
95+
module_prefix_path = ctx.attr.module_prefix.replace(".", "/")
96+
module_relpath = "{}/{}.{}".format(module_prefix_path, module_name, "pyd" if is_windows else "so")
97+
stub_relpath = "{}/{}.pyi".format(module_prefix_path, module_name)
98+
else:
99+
module_relpath = "{}.{}".format(module_name, "pyd" if is_windows else "so")
100+
stub_relpath = "{}.pyi".format(module_name)
101+
102+
ext = ctx.actions.declare_file(module_relpath)
94103
ctx.actions.symlink(
95104
output = ext,
96105
target_file = extension,
@@ -99,10 +108,10 @@ def _py_pyo3_library_impl(ctx):
99108

100109
stub = None
101110
if _stubs_enabled(ctx.attr.stubs, toolchain):
102-
stub = ctx.actions.declare_file("{}.pyi".format(ctx.label.name))
111+
stub = ctx.actions.declare_file(stub_relpath)
103112

104113
args = ctx.actions.args()
105-
args.add(ctx.label.name, format = "--module_name=%s")
114+
args.add(module_name, format = "--module_name=%s")
106115
args.add(ext, format = "--module_path=%s")
107116
args.add(stub, format = "--output=%s")
108117
ctx.actions.run(
@@ -180,6 +189,12 @@ py_pyo3_library = rule(
180189
"imports": attr.string_list(
181190
doc = "List of import directories to be added to the `PYTHONPATH`.",
182191
),
192+
"module": attr.string(
193+
doc = "The Python module name implemented by this extension.",
194+
),
195+
"module_prefix": attr.string(
196+
doc = "A dotted Python package prefix for the module.",
197+
),
183198
"stubs": attr.int(
184199
doc = "Whether or not to generate stubs. `-1` will default to the global config, `0` will never generate, and `1` will always generate stubs.",
185200
default = -1,
@@ -218,6 +233,8 @@ def pyo3_extension(
218233
stubs = None,
219234
version = None,
220235
compilation_mode = "opt",
236+
module = None,
237+
module_prefix = None,
221238
**kwargs):
222239
"""Define a PyO3 python extension module.
223240
@@ -259,6 +276,8 @@ def pyo3_extension(
259276
For more details see [rust_shared_library][rsl].
260277
compilation_mode (str, optional): The [compilation_mode](https://bazel.build/reference/command-line-reference#flag--compilation_mode)
261278
value to build the extension for. If set to `"current"`, the current configuration will be used.
279+
module (str, optional): The Python module name implemented by this extension.
280+
module_prefix (str, optional): A dotted Python package prefix for the module.
262281
**kwargs (dict): Additional keyword arguments.
263282
"""
264283
tags = kwargs.pop("tags", [])
@@ -318,6 +337,8 @@ def pyo3_extension(
318337
compilation_mode = compilation_mode,
319338
stubs = stubs_int,
320339
imports = imports,
340+
module = module,
341+
module_prefix = module_prefix,
321342
tags = tags,
322343
visibility = visibility,
323344
**kwargs
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
load("@rules_python//python:defs.bzl", "py_test")
2+
load("//:defs.bzl", "pyo3_extension")
3+
4+
pyo3_extension(
5+
name = "module_prefix",
6+
srcs = ["bar.rs"],
7+
edition = "2021",
8+
imports = ["."],
9+
module = "bar",
10+
module_prefix = "foo",
11+
)
12+
13+
py_test(
14+
name = "module_prefix_import_test",
15+
srcs = ["module_prefix_import_test.py"],
16+
deps = [":module_prefix"],
17+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use pyo3::prelude::*;
2+
3+
#[pyfunction]
4+
fn thing() -> PyResult<&'static str> {
5+
Ok("hello from rust")
6+
}
7+
8+
#[pymodule]
9+
fn bar(m: &Bound<'_, PyModule>) -> PyResult<()> {
10+
m.add_function(wrap_pyfunction!(thing, m)?)?;
11+
Ok(())
12+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Tests that a pyo3 extension can be imported via a module prefix."""
2+
3+
import unittest
4+
5+
import foo.bar
6+
7+
8+
class ModulePrefixImportTest(unittest.TestCase):
9+
def test_import_and_call(self) -> None:
10+
result = foo.bar.thing()
11+
self.assertEqual("hello from rust", result)
12+
13+
14+
if __name__ == "__main__":
15+
unittest.main()

0 commit comments

Comments
 (0)