-
Notifications
You must be signed in to change notification settings - Fork 1
Detection jax (WIP) #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
beckermr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great except for the noise thing we discussed!
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #35 +/- ##
==========================================
- Coverage 99.05% 98.61% -0.45%
==========================================
Files 20 22 +2
Lines 1689 1945 +256
==========================================
+ Hits 1673 1918 +245
- Misses 16 27 +11 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
beckermr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor things, but looks great!
| jnp.ndarray | ||
| Binary mask indicating local maxima positions | ||
| """ | ||
| noise_array = jnp.broadcast_to(noise, image.shape) if jnp.isscalar(noise) else noise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this if statement will be buggy under JIT but maybe JAX re-JITs the function if noise is a scalar and so it works fine?
| def is_local_max(i, j): | ||
| center_val = padded_image[i + pad_size, j + pad_size] | ||
| threshold = 3 * noise_array[i, j] # noise is not padded | ||
|
|
||
| neighborhood = jax.lax.dynamic_slice( | ||
| padded_image, (i, j), (window_size, window_size) | ||
| ) | ||
|
|
||
| return (jnp.all(center_val >= neighborhood)) & (threshold < center_val) | ||
|
|
||
| height, width = image.shape | ||
| i_indices, j_indices = jnp.meshgrid( | ||
| jnp.arange(height), jnp.arange(width), indexing="ij" | ||
| ) | ||
|
|
||
| local_max_mask = jax.vmap( | ||
| jax.vmap( | ||
| is_local_max, | ||
| in_axes=(0, 0), | ||
| ), | ||
| in_axes=(0, 0), | ||
| )(i_indices, j_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may be able to write this whole op as a convolution with a specific kernel. I mention it because the convolutions get lowered into lax directly AFAIK and so may have faster bits of code underneath depending on the backend.
| window_size=window_size, | ||
| ) | ||
|
|
||
| positions = jnp.argwhere(local_max_mask, size=max_objects, fill_value=(-999, -999)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should make the fill value a -999 a global constant here. Also I suspect -1 would be sufficient in which case a global constant might not be needed.
| refined_positions, border_flags = refine_centroid_in_cell( | ||
| image, peak_positions, window_size=5 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess one iteration is enough or maybe it runs away in some cases?
|
|
||
| Parameters: | ||
| ----------- | ||
| nverted_image : jnp.ndarray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| nverted_image : jnp.ndarray | |
| inverted_image : jnp.ndarray |
| for i in range(peaks.shape[0]): | ||
| markers = place_marker(i, peaks[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be scan operation. Otherwise it will unroll into a big loop and slow compile times.
jax based detection and segmentation algorithm