Skip to content

Conversation

@b-biswas
Copy link
Collaborator

jax based detection and segmentation algorithm

Copy link
Owner

@beckermr beckermr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great except for the noise thing we discussed!

@codecov
Copy link

codecov bot commented Nov 18, 2025

Codecov Report

❌ Patch coverage is 95.70312% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 98.61%. Comparing base (56c5176) to head (3634a39).

Files with missing lines Patch % Lines
...ield_metadetect/jaxify/tests/test_jax_detection.py 93.61% 9 Missing ⚠️
deep_field_metadetect/jaxify/jax_detection.py 98.26% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #35      +/-   ##
==========================================
- Coverage   99.05%   98.61%   -0.45%     
==========================================
  Files          20       22       +2     
  Lines        1689     1945     +256     
==========================================
+ Hits         1673     1918     +245     
- Misses         16       27      +11     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Owner

@beckermr beckermr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor things, but looks great!

jnp.ndarray
Binary mask indicating local maxima positions
"""
noise_array = jnp.broadcast_to(noise, image.shape) if jnp.isscalar(noise) else noise
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this if statement will be buggy under JIT but maybe JAX re-JITs the function if noise is a scalar and so it works fine?

Comment on lines +37 to +58
def is_local_max(i, j):
center_val = padded_image[i + pad_size, j + pad_size]
threshold = 3 * noise_array[i, j] # noise is not padded

neighborhood = jax.lax.dynamic_slice(
padded_image, (i, j), (window_size, window_size)
)

return (jnp.all(center_val >= neighborhood)) & (threshold < center_val)

height, width = image.shape
i_indices, j_indices = jnp.meshgrid(
jnp.arange(height), jnp.arange(width), indexing="ij"
)

local_max_mask = jax.vmap(
jax.vmap(
is_local_max,
in_axes=(0, 0),
),
in_axes=(0, 0),
)(i_indices, j_indices)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may be able to write this whole op as a convolution with a specific kernel. I mention it because the convolutions get lowered into lax directly AFAIK and so may have faster bits of code underneath depending on the backend.

window_size=window_size,
)

positions = jnp.argwhere(local_max_mask, size=max_objects, fill_value=(-999, -999))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make the fill value a -999 a global constant here. Also I suspect -1 would be sufficient in which case a global constant might not be needed.

Comment on lines +233 to +235
refined_positions, border_flags = refine_centroid_in_cell(
image, peak_positions, window_size=5
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one iteration is enough or maybe it runs away in some cases?


Parameters:
-----------
nverted_image : jnp.ndarray
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nverted_image : jnp.ndarray
inverted_image : jnp.ndarray

Comment on lines +434 to +435
for i in range(peaks.shape[0]):
markers = place_marker(i, peaks[i])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be scan operation. Otherwise it will unroll into a big loop and slow compile times.

@b-biswas b-biswas changed the title Detection jax Detection jax (WIP) Nov 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants