Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ classifiers = [
"Programming Language :: Python :: 3.13",
"Typing :: Typed",
]
dependencies = ["witty>=v0.2.1", "CT3>=3.3.3", "numpy", "setuptools>=75.8.0"]
dependencies = [
"witty>=0.3.1",
"CT3>=3.3.3",
"numpy",
"setuptools>=75.8.0",
]

[dependency-groups]
test = ["pytest>=8.3.5", "pytest-cov>=6.1.1"]
Expand Down
2 changes: 1 addition & 1 deletion src/spatial_graph/_graph/graph_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _compile_graph(
edge_attr_dtypes=edge_attr_dtypes,
directed=directed,
)
wrapper = witty.compile_module(
wrapper = witty.compile_cython(
wrapper_template,
source_files=[str(SRC_DIR / "src" / "graph_lite.h")],
extra_compile_args=EXTRA_COMPILE_ARGS,
Expand Down
56 changes: 30 additions & 26 deletions src/spatial_graph/_rtree/line_rtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,39 @@


class LineRTree(RTree):
pyx_item_t_declaration = """
cdef struct item_t:
item_base_t u
item_base_t v
bool corner_mask[DIMS]
"""

c_item_t_declaration = """
typedef struct item_t {
item_base_t u;
item_base_t v;
item_data_base_t u;
item_data_base_t v;
bool corner_mask[DIMS];
} item_t;
"""

c_converter_functions = """
inline item_t convert_pyx_to_c_item(pyx_item_t *pyx_item,
coord_t *start, coord_t *end) {
inline void item_to_item_data(
const item_t& item,
item_data_t *item_data) {

(*item_data)[0] = item.u;
(*item_data)[1] = item.v;
}
inline item_t item_data_to_item(
item_data_base_t *item_data,
coord_t *start,
coord_t *end) {

item_t item;
coord_t tmp;
item.u = (*pyx_item)[0];
item.v = (*pyx_item)[1];
for (int d = 0; d < DIMS; d++) {
item.u = item_data[0];
item.v = item_data[1];
for (unsigned int d = 0; d < DIMS; d++) {
item.corner_mask[d] = (start[d] < end[d]);
if (!item.corner_mask[d]) {
// swap coordinates to create bounding box
tmp = start[d];
start[d] = end[d];
end[d] = tmp;
std::swap(start[d], end[d]);
}
}
return item;
}
inline void copy_c_to_pyx_item(const item_t c_item, pyx_item_t *pyx_item) {
(*pyx_item)[0] = c_item.u;
(*pyx_item)[1] = c_item.v;
}
"""

c_equal_function = """
Expand All @@ -51,15 +47,18 @@ class LineRTree(RTree):

c_distance_function = """
inline coord_t length2(const coord_t x[]) {

coord_t length2 = 0;
for (int d = 0; d < DIMS; d++) {
length2 += pow(x[d], 2);
}
return length2;
}

inline coord_t point_segment_dist2(const coord_t point[], const coord_t start[],
const coord_t end[]) {
inline coord_t point_segment_dist2(
const coord_t point[],
const coord_t start[],
const coord_t end[]) {

coord_t a[DIMS];
coord_t b[DIMS];
Expand All @@ -79,7 +78,7 @@ class LineRTree(RTree):
alpha /= length2(a);

// clip at 0 and 1 (beginning and end of line segment)
alpha = min0(1, max0(0, alpha));
alpha = std::min((coord_t)1, std::max((coord_t)0, alpha));

for (int d = 0; d < DIMS; d++) {

Expand All @@ -95,7 +94,10 @@ class LineRTree(RTree):
}

extern inline coord_t distance(
const coord_t point[], const struct rect *rect, const struct item_t item) {
const coord_t point[],
const struct rect *rect,
const struct item_t item) {

coord_t start[DIMS];
coord_t end[DIMS];
for (int d = 0; d < DIMS; d++) {
Expand All @@ -107,6 +109,8 @@ class LineRTree(RTree):
end[d] = rect->min[d];
}
}


return point_segment_dist2(point, start, end);
}
"""
Expand Down
15 changes: 10 additions & 5 deletions src/spatial_graph/_rtree/rtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _build_wrapper(
############################################

wrapper_template = Template(
file=str(SRC_DIR / "wrapper_template.pyx"),
file=str(SRC_DIR / "wrapper_template.cpp"),
compilerSettings={"directiveStartToken": "%"},
)
wrapper_template.item_dtype = DType(item_dtype)
Expand All @@ -46,7 +46,7 @@ def _compile_tree(
cls: type[RTree], item_dtype: str, coord_dtype: str, dims: int
) -> type:
wrapper = _build_wrapper(cls, item_dtype, coord_dtype, dims)
module = witty.compile_module(
module = witty.compile_nanobind(
wrapper,
source_files=[
SRC_DIR / "src" / "rtree.h",
Expand All @@ -55,7 +55,7 @@ def _compile_tree(
],
extra_compile_args=EXTRA_COMPILE_ARGS,
include_dirs=[str(SRC_DIR)],
language="c",
language="c++",
quiet=True,
define_macros=DEFINE_MACROS,
)
Expand Down Expand Up @@ -161,7 +161,7 @@ def delete_item(self, item, bb_min, bb_max=None):
"""
items = np.array([item], dtype=self.item_dtype.base)
bb_mins = bb_min[np.newaxis, :]
bb_maxs = None if bb_max is None else bb_max[np.newaxis, :]
bb_maxs = bb_mins if bb_max is None else bb_max[np.newaxis, :]
return self._ctree.delete_items(items, bb_mins, bb_maxs)

def delete_items(self, items, bb_mins, bb_maxs=None):
Expand All @@ -183,6 +183,8 @@ def delete_items(self, items, bb_mins, bb_maxs=None):
Array of shape `(n, dims)`, the minimum/maximum points of the
bounding boxes per item to delete.
"""
if bb_maxs is None:
bb_maxs = bb_mins
return self._ctree.delete_items(items, bb_mins, bb_maxs)

def count(self, bb_min, bb_max):
Expand Down Expand Up @@ -222,7 +224,10 @@ def nearest(self, point, k=1, return_distances=False):
`distances` contains the distance of each found item to the
query point.
"""
return self._ctree.nearest(point, k, return_distances)
if return_distances:
return self._ctree.nearest_with_distances(point, k)
else:
return self._ctree.nearest(point, k)

def insert_bb_items(self, items, bb_mins, bb_maxs):
"""Insert items with bounding boxes.
Expand Down
10 changes: 5 additions & 5 deletions src/spatial_graph/_rtree/src/rtree.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ void heapify_down(struct priority_queue* queue, size_t index) {
bool enqueue(struct priority_queue* queue, struct element element) {
if (queue->size == queue->capacity) {
queue->capacity *= 2;
queue->elements = realloc(queue->elements, sizeof(struct element) * queue->capacity);
queue->elements = (struct element*)realloc(queue->elements, sizeof(struct element) * queue->capacity);
if (!queue->elements)
return false;
}
Expand All @@ -175,7 +175,7 @@ struct element dequeue(struct priority_queue* queue) {
// reclaim some memory when the queue is shrinking
if (queue->size < queue->capacity/4) {
queue->capacity /= 2;
struct element *elements = realloc(queue->elements, sizeof(struct element) * queue->capacity);
struct element *elements = (struct element*)realloc(queue->elements, sizeof(struct element) * queue->capacity);
if (!elements) {
queue->capacity *= 2;
} else {
Expand Down Expand Up @@ -908,7 +908,6 @@ static bool node_delete(struct rtree *tr, struct rect *nr, struct node *node,
if (!rect_contains(&node->rects[h], ir)) {
continue;
}
struct rect crect = node->rects[h];
cow_node_or(node->nodes[h], return false);
if (!node_delete(tr, &node->rects[h], node->nodes[h], ir, item, depth+1,
removed, shrunk, compare, udata))
Expand All @@ -919,6 +918,7 @@ static bool node_delete(struct rtree *tr, struct rect *nr, struct node *node,
continue;
}
removed:
struct rect crect = node->rects[h];
if (node->nodes[h]->count == 0) {
// underflow
node_free(tr, node->nodes[h]);
Expand Down Expand Up @@ -995,7 +995,7 @@ int rtree_delete(struct rtree *tr, const coord_t *min, const coord_t *max,
return rtree_delete0(tr, min, max, item, NULL, NULL);
}

int rtree_delete_with_comparator(struct rtree *tr, const coord_t *min,
bool rtree_delete_with_comparator(struct rtree *tr, const coord_t *min,
const coord_t *max, const item_t item,
int (*compare)(const item_t a, const item_t b, void *udata),
void *udata)
Expand All @@ -1005,7 +1005,7 @@ int rtree_delete_with_comparator(struct rtree *tr, const coord_t *min,

struct rtree *rtree_clone(struct rtree *tr) {
if (!tr) return NULL;
struct rtree *tr2 = tr->malloc(sizeof(struct rtree));
struct rtree *tr2 = (struct rtree*)tr->malloc(sizeof(struct rtree));
if (!tr2) return NULL;
memcpy(tr2, tr, sizeof(struct rtree));
if (tr2->root) rc_fetch_add(&tr2->root->rc, 1);
Expand Down
Loading