diff --git a/pycbsdk/examples/repro_disable_others.py b/pycbsdk/examples/repro_disable_others.py new file mode 100644 index 0000000..a8d5d10 --- /dev/null +++ b/pycbsdk/examples/repro_disable_others.py @@ -0,0 +1,66 @@ +"""Repro for the bulk-setter `disable_others=True` hang report. + +Mirrors the call sequence reported as hanging indefinitely in +ezmsg-blackrock: + + session.set_sample_group( + 2, ChannelType.FRONTEND, SampleRate.SR_30kHz, disable_others=True + ) + +Each phase is timed and logged so a hang shows up as the last printed +line. Run with: + + python repro_disable_others.py NPLAY + python repro_disable_others.py HUB2 + python repro_disable_others.py HUB1 +""" + +import asyncio +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from pycbsdk import ChannelType, DeviceType, SampleRate, Session + + +def _stamp(t0: float, label: str) -> None: + print(f" +{time.monotonic() - t0:6.3f}s {label}", flush=True) + + +async def main(device_type: str) -> None: + print(f"Connecting to {device_type}...", flush=True) + t0 = time.monotonic() + + with Session(device_type=device_type) as session: + _stamp(t0, "session created") + + await session.wait_until_running(timeout=10.0) + _stamp(t0, f"runlevel reached RUNNING ({session.runlevel})") + + n_fe = session.num_fe_chans() + _stamp(t0, f"num_fe_chans={n_fe}") + + # The exact call reported as hanging. + session.set_sample_group( + 2, ChannelType.FRONTEND, SampleRate.SR_30kHz, disable_others=True + ) + _stamp(t0, "set_sample_group returned") + + session.sync() + _stamp(t0, "sync returned") + + ch_30k = session.get_group_channels(int(SampleRate.SR_30kHz)) + _stamp(t0, f"30kHz group: {sorted(ch_30k)}") + + # Sanity: only chans 1, 2 should be at 30kHz. + assert sorted(ch_30k) == [1, 2], ( + f"Expected [1, 2] at 30kHz, got {sorted(ch_30k)}" + ) + print("OK", flush=True) + + +if __name__ == "__main__": + device_type = sys.argv[1] if len(sys.argv) > 1 else "NPLAY" + asyncio.run(main(device_type)) diff --git a/pycbsdk/src/pycbsdk/_cdef.py b/pycbsdk/src/pycbsdk/_cdef.py index c13bcfd..ae7be8c 100644 --- a/pycbsdk/src/pycbsdk/_cdef.py +++ b/pycbsdk/src/pycbsdk/_cdef.py @@ -135,6 +135,7 @@ size_t n_channels, const uint64_t* timestamps, void* user_data); typedef void (*cbsdk_config_callback_fn)(const cbPKT_GENERIC* pkt, void* user_data); +typedef void (*cbsdk_runlevel_callback_fn)(uint32_t runlevel, void* user_data); typedef void (*cbsdk_error_callback_fn)(const char* error_message, void* user_data); /////////////////////////////////////////////////////////////////////////// @@ -178,6 +179,8 @@ cbsdk_callback_handle_t cbsdk_session_register_config_callback( cbsdk_session_t session, uint16_t packet_type, cbsdk_config_callback_fn callback, void* user_data); +cbsdk_callback_handle_t cbsdk_session_register_runlevel_callback( + cbsdk_session_t session, cbsdk_runlevel_callback_fn callback, void* user_data); void cbsdk_session_unregister_callback(cbsdk_session_t session, cbsdk_callback_handle_t handle); @@ -204,13 +207,17 @@ cbsdk_result_t cbsdk_session_get_group_list(cbsdk_session_t session, uint32_t group_id, uint16_t* list, uint32_t* count); -// Channel configuration +// Channel configuration. +// `chans` is an optional explicit list of 1-based channel ids. Pass NULL +// for the legacy "first n_chans matching" / "all matching" (n_chans=UINT32_MAX) +// behavior; pass a non-NULL pointer for the explicit-list mode. cbsdk_result_t cbsdk_session_set_sample_group( - cbsdk_session_t session, size_t n_chans, cbproto_channel_type_t chan_type, - cbproto_group_rate_t rate, _Bool disable_others); + cbsdk_session_t session, uint32_t n_chans, const uint32_t* chans, + cbproto_channel_type_t chan_type, cbproto_group_rate_t rate, + _Bool disable_others); cbsdk_result_t cbsdk_session_set_ac_input_coupling( - cbsdk_session_t session, size_t n_chans, cbproto_channel_type_t chan_type, - _Bool enabled); + cbsdk_session_t session, uint32_t n_chans, const uint32_t* chans, + cbproto_channel_type_t chan_type, _Bool enabled); // Per-channel getters cbproto_channel_type_t cbsdk_session_get_channel_type(cbsdk_session_t session, uint32_t chan_id); @@ -226,25 +233,29 @@ cbsdk_result_t cbsdk_session_get_channel_scaling( cbsdk_session_t session, uint32_t chan_id, cbsdk_channel_scaling_t* scaling); -// Per-channel setters +// Per-channel setters. +// `auto_sync` is non-zero to run an internal sync() before the read-modify-write +// (so any prior in-flight config from this process has landed in the cache); +// set_channel_spkthrlevel is "narrow" (CHANSETSPKTHR overwrites the only field +// the firmware reads) and so does not need an auto_sync flag. cbsdk_result_t cbsdk_session_set_channel_label(cbsdk_session_t session, - uint32_t chan_id, const char* label); + uint32_t chan_id, const char* label, int auto_sync); cbsdk_result_t cbsdk_session_set_channel_smpgroup(cbsdk_session_t session, - uint32_t chan_id, cbproto_group_rate_t rate); + uint32_t chan_id, cbproto_group_rate_t rate, int auto_sync); cbsdk_result_t cbsdk_session_set_channel_smpfilter(cbsdk_session_t session, - uint32_t chan_id, uint32_t filter_id); + uint32_t chan_id, uint32_t filter_id, int auto_sync); cbsdk_result_t cbsdk_session_set_channel_spkfilter(cbsdk_session_t session, - uint32_t chan_id, uint32_t filter_id); + uint32_t chan_id, uint32_t filter_id, int auto_sync); cbsdk_result_t cbsdk_session_set_channel_ainpopts(cbsdk_session_t session, - uint32_t chan_id, uint32_t ainpopts); + uint32_t chan_id, uint32_t ainpopts, int auto_sync); cbsdk_result_t cbsdk_session_set_channel_lncrate(cbsdk_session_t session, - uint32_t chan_id, uint32_t lncrate); + uint32_t chan_id, uint32_t lncrate, int auto_sync); cbsdk_result_t cbsdk_session_set_channel_spkopts(cbsdk_session_t session, - uint32_t chan_id, uint32_t spkopts); + uint32_t chan_id, uint32_t spkopts, int auto_sync); cbsdk_result_t cbsdk_session_set_channel_spkthrlevel(cbsdk_session_t session, uint32_t chan_id, int32_t level); cbsdk_result_t cbsdk_session_set_channel_autothreshold(cbsdk_session_t session, - uint32_t chan_id, _Bool enabled); + uint32_t chan_id, _Bool enabled, int auto_sync); // Channel info field selector typedef enum { @@ -317,6 +328,7 @@ // CCF configuration files cbsdk_result_t cbsdk_session_load_channel_map(cbsdk_session_t session, const char* filepath, uint32_t start_chan, uint32_t hs_id); +cbsdk_result_t cbsdk_session_clear_channel_map(cbsdk_session_t session); cbsdk_result_t cbsdk_session_save_ccf(cbsdk_session_t session, const char* filename); cbsdk_result_t cbsdk_session_load_ccf(cbsdk_session_t session, const char* filename); cbsdk_result_t cbsdk_session_load_ccf_sync(cbsdk_session_t session, const char* filename, uint32_t timeout_ms); @@ -331,15 +343,15 @@ // Spike sorting cbsdk_result_t cbsdk_session_set_spike_sorting( - cbsdk_session_t session, size_t n_chans, cbproto_channel_type_t chan_type, - uint32_t sort_options); + cbsdk_session_t session, uint32_t n_chans, const uint32_t* chans, + cbproto_channel_type_t chan_type, uint32_t sort_options); cbsdk_result_t cbsdk_session_set_channel_spike_sorting( - cbsdk_session_t session, uint32_t chan_id, uint32_t sort_options); + cbsdk_session_t session, uint32_t chan_id, uint32_t sort_options, int auto_sync); // Spike extraction (enable/disable cbAINPSPK_EXTRACT via CHANSETSPK) cbsdk_result_t cbsdk_session_set_spike_extraction( - cbsdk_session_t session, size_t n_chans, cbproto_channel_type_t chan_type, - bool enabled); + cbsdk_session_t session, uint32_t n_chans, const uint32_t* chans, + cbproto_channel_type_t chan_type, bool enabled); // Clock synchronization cbsdk_result_t cbsdk_session_get_clock_offset(cbsdk_session_t session, int64_t* offset_ns); diff --git a/pycbsdk/src/pycbsdk/session.py b/pycbsdk/src/pycbsdk/session.py index 011bf16..2c2c39a 100644 --- a/pycbsdk/src/pycbsdk/session.py +++ b/pycbsdk/src/pycbsdk/session.py @@ -4,6 +4,7 @@ from __future__ import annotations +import asyncio import enum import time as _time import threading @@ -12,6 +13,20 @@ from ._lib import ffi, load_library +# Device runlevels (cbRUNLEVEL_*). Mirrors values in cbproto types.h. +RUNLEVEL_STARTUP = 10 +RUNLEVEL_HARDRESET = 20 +RUNLEVEL_STANDBY = 30 +RUNLEVEL_RESET = 40 +RUNLEVEL_RUNNING = 50 +RUNLEVEL_STRESSED = 60 +RUNLEVEL_ERROR = 70 +RUNLEVEL_UPDATE = 78 +RUNLEVEL_SHUTDOWN = 80 + +# "All matching" sentinel for bulk channel-config setters. +_ALL_CHANS = 0xFFFFFFFF + lib = None # Lazy-loaded @@ -500,6 +515,31 @@ def c_config_cb(pkt, user_data): self._handles.append(handle) self._callback_refs.append(c_config_cb) + def _register_runlevel_callback(self, fn: Callable[[int], None]) -> int: + """Register a runlevel-change callback. Returns the callback handle. + + The callback is invoked on the SDK's receive thread (or shmem-receive + thread in CLIENT mode), so ``fn`` must be fast and thread-safe. For + async use, see :meth:`wait_for_runlevel`. + """ + _lib = _get_lib() + + @ffi.callback("void(uint32_t, void*)") + def c_runlevel_cb(runlevel, user_data): + try: + fn(int(runlevel)) + except Exception: + pass # Never let exceptions propagate into C + + handle = _lib.cbsdk_session_register_runlevel_callback( + self._session, c_runlevel_cb, ffi.NULL + ) + if handle == 0: + raise RuntimeError("Failed to register runlevel callback") + self._handles.append(handle) + self._callback_refs.append(c_runlevel_cb) + return handle + def _register_packet_callback(self, fn): _lib = _get_lib() @@ -554,6 +594,52 @@ def runlevel(self) -> int: """Get the current device run level.""" return _get_lib().cbsdk_session_get_runlevel(self._session) + async def wait_for_runlevel(self, level: int, timeout: float = 10.0) -> None: + """Resolve when the device runlevel reaches ``level``. + + Returns immediately if ``self.runlevel >= level`` at call time. + Otherwise registers a one-shot listener for the next runlevel + transition matching the predicate and resolves on its arrival. + + Args: + level: Target runlevel (use one of the ``RUNLEVEL_*`` module + constants, e.g. :data:`RUNLEVEL_RUNNING`). + timeout: Seconds to wait before raising + :class:`asyncio.TimeoutError`. + """ + if self.runlevel >= level: + return + + loop = asyncio.get_running_loop() + fut: asyncio.Future = loop.create_future() + + def _on_change(rl: int) -> None: + if rl >= level and not fut.done(): + loop.call_soon_threadsafe(_resolve, rl) + + def _resolve(_rl: int) -> None: + if not fut.done(): + fut.set_result(None) + + handle = self._register_runlevel_callback(_on_change) + try: + await asyncio.wait_for(fut, timeout) + finally: + _get_lib().cbsdk_session_unregister_callback(self._session, handle) + # Drop our local handle reference so close() doesn't double-free. + try: + self._handles.remove(handle) + except ValueError: + pass + + async def wait_until_running(self, timeout: float = 10.0) -> None: + """Resolve when the device runlevel reaches ``RUNLEVEL_RUNNING``. + + Thin wrapper over :meth:`wait_for_runlevel` for the common case of + gating boot-time configuration on device readiness. + """ + await self.wait_for_runlevel(RUNLEVEL_RUNNING, timeout) + @property def is_standalone(self) -> bool: """Whether this session owns the device connection (STANDALONE mode). @@ -901,43 +987,65 @@ def get_channels_positions( # --- Channel Configuration --- + @staticmethod + def _normalize_chans(chans): + """Resolve ``chans`` (``int | Iterable[int] | None``) for the C-API. + + Returns ``(n_chans, c_arr)`` where ``c_arr`` is either ``ffi.NULL`` + (count-based mode) or an ffi-owned ``uint32_t[]`` (explicit-list + mode). Caller must keep ``c_arr`` alive across the C call. + """ + if chans is None: + return _ALL_CHANS, ffi.NULL + if isinstance(chans, (bool,)): + raise TypeError("chans must be an int, an iterable of ints, or None") + if isinstance(chans, int): + return int(chans), ffi.NULL + chan_list = [int(c) for c in chans] + return len(chan_list), ffi.new("uint32_t[]", chan_list) + def set_sample_group( self, - n_chans: int, + chans: "int | list[int] | None", channel_type: ChannelType, rate: SampleRate, disable_others: bool = False, - ): + ) -> None: """Set sampling rate for channels of a specific type (fire-and-forget). - Configures the first *n_chans* channels matching *channel_type*. - To configure a specific channel by ID, use :meth:`set_channel_smpgroup`. + ``chans`` selects the channels to configure: - The device will not have applied the new configuration when this - call returns. Call :meth:`sync` before reading back state (e.g., - :meth:`get_group_channels`) or registering callbacks that depend - on the new configuration. - - .. note:: + - ``None``: every channel matching *channel_type*. + - ``int N``: the first N channels matching *channel_type* in + channel-id order. + - ``list[int]`` (or any iterable of ints): the explicit list of + 1-based channel IDs. Caller is trusted; no type filter is + applied to listed channels, but *channel_type* still defines + the "others" set for *disable_others*. - The device does not update the ``smpgroup`` field for raw - channels. After setting ``SampleRate.SR_RAW``, - :meth:`get_channel_config` will show ``smpgroup=0``. - Use ``get_group_channels(int(SampleRate.SR_RAW))`` to check - raw group membership. + Runs :meth:`sync` before reading the local cache so any prior + in-flight config from this process has landed before we seed + CHANSET* packets from it. Does NOT sync after sending — call + :meth:`sync` before reading back state that depends on the new + configuration. Always sends a CHANSET* packet for every in-scope + channel — never skips channels that look already configured + (the local cache may be stale due to a dropped CHANREP). Args: - n_chans: Number of channels to configure. - channel_type: Channel type filter (e.g., ``ChannelType.FRONTEND``). + chans: Channel selection (see above). + channel_type: Channel type filter. rate: Sample rate (e.g., ``SampleRate.SR_30kHz``, ``SampleRate.NONE`` to disable). - disable_others: Disable sampling on unselected channels. + disable_others: Disable sampling on channels of *channel_type* + that are not in the configured set. """ _lib = _get_lib() + n_chans, c_arr = self._normalize_chans(chans) _check( _lib.cbsdk_session_set_sample_group( self._session, n_chans, + c_arr, int(_coerce_enum(ChannelType, channel_type)), int(_coerce_enum(SampleRate, rate, _RATE_ALIASES)), disable_others, @@ -947,48 +1055,62 @@ def set_sample_group( def set_ac_input_coupling( self, - n_chans: int, + chans: "int | list[int] | None", channel_type: ChannelType, enabled: bool, - ): - """Set AC/DC input coupling for channels of a specific type (fire-and-forget). + ) -> None: + """Set AC/DC input coupling for channels of a specific type + (fire-and-forget). - Call :meth:`sync` before reading back state that depends on this - configuration. + See :meth:`set_sample_group` for the channel selection and + pre-sync contract. Args: - n_chans: Number of channels to configure. - channel_type: Channel type filter (e.g., ``ChannelType.FRONTEND``). + chans: Channel selection (see :meth:`set_sample_group`). + channel_type: Channel type filter. enabled: ``True`` for AC coupling (offset correction on), ``False`` for DC coupling. """ _lib = _get_lib() + n_chans, c_arr = self._normalize_chans(chans) _check( _lib.cbsdk_session_set_ac_input_coupling( self._session, n_chans, + c_arr, int(_coerce_enum(ChannelType, channel_type)), enabled, ), "Failed to set AC input coupling", ) - def set_channel_label(self, chan_id: int, label: str): - """Set a channel's label.""" + def set_channel_label(self, chan_id: int, label: str, auto_sync: bool = False): + """Set a channel's label. + + Per-channel CHANSET setters seed the outgoing CHANINFO from the + local cache, so a prior in-flight setter from this process can + leave the seed stale. Pass ``auto_sync=True`` to run an internal + :meth:`sync` first (one round-trip latency cost). + """ _check( _get_lib().cbsdk_session_set_channel_label( - self._session, chan_id, label.encode() + self._session, chan_id, label.encode(), 1 if auto_sync else 0 ), "Failed to set channel label", ) - def set_channel_smpgroup(self, chan_id: int, rate: SampleRate | int): + def set_channel_smpgroup( + self, chan_id: int, rate: SampleRate | int, auto_sync: bool = False + ): """Set a single channel's sample group (fire-and-forget). Handles group-specific logic: RAWSTREAM flag for group 6 (raw), filter mapping for groups 1-4, clearing conflicting flags. Groups 5 (30 kHz filtered) and 6 (raw) are mutually exclusive. - Call :meth:`sync` before reading back state that depends on this. + + Pass ``auto_sync=True`` to internally sync before the read-modify- + write so the local cache reflects any in-flight config from this + process. .. note:: @@ -1002,61 +1124,76 @@ def set_channel_smpgroup(self, chan_id: int, rate: SampleRate | int): chan_id: 1-based channel ID. rate: Sample group (``SampleRate.NONE`` to disable, 1-5 for filtered groups, ``SampleRate.SR_RAW`` for raw). + auto_sync: If True, sync before the read-modify-write. """ _check( _get_lib().cbsdk_session_set_channel_smpgroup( self._session, chan_id, int(_coerce_enum(SampleRate, rate, _RATE_ALIASES)), + 1 if auto_sync else 0, ), "Failed to set channel smpgroup", ) - def set_channel_smpfilter(self, chan_id: int, filter_id: int): + def set_channel_smpfilter( + self, chan_id: int, filter_id: int, auto_sync: bool = False + ): """Set a channel's continuous-time pathway filter.""" _check( _get_lib().cbsdk_session_set_channel_smpfilter( - self._session, chan_id, filter_id + self._session, chan_id, filter_id, 1 if auto_sync else 0 ), "Failed to set channel smpfilter", ) - def set_channel_spkfilter(self, chan_id: int, filter_id: int): + def set_channel_spkfilter( + self, chan_id: int, filter_id: int, auto_sync: bool = False + ): """Set a channel's spike pathway filter.""" _check( _get_lib().cbsdk_session_set_channel_spkfilter( - self._session, chan_id, filter_id + self._session, chan_id, filter_id, 1 if auto_sync else 0 ), "Failed to set channel spkfilter", ) - def set_channel_ainpopts(self, chan_id: int, ainpopts: int): + def set_channel_ainpopts( + self, chan_id: int, ainpopts: int, auto_sync: bool = False + ): """Set a channel's analog input options (cbAINP_* flags).""" _check( _get_lib().cbsdk_session_set_channel_ainpopts( - self._session, chan_id, ainpopts + self._session, chan_id, ainpopts, 1 if auto_sync else 0 ), "Failed to set channel ainpopts", ) - def set_channel_lncrate(self, chan_id: int, rate: int): + def set_channel_lncrate(self, chan_id: int, rate: int, auto_sync: bool = False): """Set a channel's line noise cancellation adaptation rate.""" _check( - _get_lib().cbsdk_session_set_channel_lncrate(self._session, chan_id, rate), + _get_lib().cbsdk_session_set_channel_lncrate( + self._session, chan_id, rate, 1 if auto_sync else 0 + ), "Failed to set channel lncrate", ) - def set_channel_spkopts(self, chan_id: int, spkopts: int): + def set_channel_spkopts(self, chan_id: int, spkopts: int, auto_sync: bool = False): """Set a channel's spike processing options (cbAINPSPK_* flags).""" _check( _get_lib().cbsdk_session_set_channel_spkopts( - self._session, chan_id, spkopts + self._session, chan_id, spkopts, 1 if auto_sync else 0 ), "Failed to set channel spkopts", ) def set_channel_spkthrlevel(self, chan_id: int, level: int): - """Set a channel's spike threshold level.""" + """Set a channel's spike threshold level. + + Uses CHANSETSPKTHR which is "narrow" — the firmware only reads + ``spkthrlevel`` and the setter overwrites it fully, so no + ``auto_sync`` flag is needed. + """ _check( _get_lib().cbsdk_session_set_channel_spkthrlevel( self._session, chan_id, level @@ -1064,23 +1201,30 @@ def set_channel_spkthrlevel(self, chan_id: int, level: int): "Failed to set channel spkthrlevel", ) - def set_channel_autothreshold(self, chan_id: int, enabled: bool): + def set_channel_autothreshold( + self, chan_id: int, enabled: bool, auto_sync: bool = False + ): """Enable or disable auto-thresholding for a channel.""" _check( _get_lib().cbsdk_session_set_channel_autothreshold( - self._session, chan_id, enabled + self._session, chan_id, enabled, 1 if auto_sync else 0 ), "Failed to set channel autothreshold", ) - def configure_channel(self, chan_id: int, **kwargs): + def configure_channel(self, chan_id: int, *, auto_sync: bool = False, **kwargs): """Configure one or more attributes of a single channel. This is a convenience method that dispatches to the individual setters. Each keyword argument maps to a channel attribute. + When ``auto_sync=True``, the wrapper applies it to the **first** + dispatched setter only. Subsequent setters skip the sync because the + local cache is already fresh after the first one. + Args: chan_id: 1-based channel ID. + auto_sync: Sync once before the first dispatched setter. Keyword Args: label (str): Channel label (max 15 chars). @@ -1101,6 +1245,17 @@ def configure_channel(self, chan_id: int, **kwargs): autothreshold=True, ) """ + # All setters except set_channel_spkthrlevel accept an auto_sync kwarg. + _accepts_auto_sync = { + "label", + "smpgroup", + "smpfilter", + "spkfilter", + "ainpopts", + "lncrate", + "spkopts", + "autothreshold", + } _dispatch = { "label": self.set_channel_label, "smpfilter": self.set_channel_smpfilter, @@ -1111,13 +1266,18 @@ def configure_channel(self, chan_id: int, **kwargs): "spkthrlevel": self.set_channel_spkthrlevel, "autothreshold": self.set_channel_autothreshold, } + first = True for key, value in kwargs.items(): + sync_arg = auto_sync and first and key in _accepts_auto_sync if key == "smpgroup": - self.set_channel_smpgroup(chan_id, value) - elif key in _dispatch: + self.set_channel_smpgroup(chan_id, value, auto_sync=sync_arg) + elif key == "spkthrlevel": _dispatch[key](chan_id, value) + elif key in _dispatch: + _dispatch[key](chan_id, value, auto_sync=sync_arg) else: raise ValueError(f"Unknown channel attribute: {key!r}") + first = False # --- Channel Mapping (CMP) Files --- @@ -1148,6 +1308,20 @@ def load_channel_map(self, filepath: str, start_chan: int = 1, hs_id: int = 0): "Failed to load channel map", ) + def clear_channel_map(self): + """Remove all channel maps loaded via :meth:`load_channel_map`. + + Drops the local position+label overlay and pushes the device's + default labels (``"chan{N}"``) back to the device for every + previously-mapped channel so the device-side state matches. + Fire-and-forget; call :meth:`sync` if you need to read back state + before issuing further config calls. + """ + _check( + _get_lib().cbsdk_session_clear_channel_map(self._session), + "Failed to clear channel map", + ) + # --- CCF Configuration Files --- def save_ccf(self, filename: str): @@ -1352,32 +1526,37 @@ def close_central_file_dialog(self): def set_spike_sorting( self, - n_chans: int, + chans: "int | list[int] | None", channel_type: ChannelType, sort_options: int, - ): - """Set spike sorting options for channels of a specific type (fire-and-forget). + ) -> None: + """Set spike sorting options for channels of a specific type + (fire-and-forget). - Call :meth:`sync` before reading back state that depends on this - configuration. + See :meth:`set_sample_group` for the channel selection and + pre-sync contract. Args: - n_chans: Number of channels to configure. + chans: Channel selection (see :meth:`set_sample_group`). channel_type: Channel type filter (e.g., ``ChannelType.FRONTEND``). sort_options: Spike sorting option flags (cbAINPSPK_*). """ _lib = _get_lib() + n_chans, c_arr = self._normalize_chans(chans) _check( _lib.cbsdk_session_set_spike_sorting( self._session, n_chans, + c_arr, int(_coerce_enum(ChannelType, channel_type)), sort_options, ), "Failed to set spike sorting", ) - def set_channel_spike_sorting(self, chan_id: int, sort_options: int): + def set_channel_spike_sorting( + self, chan_id: int, sort_options: int, auto_sync: bool = False + ): """Set spike sorting options for a single channel (fire-and-forget). Clears ``cbAINPSPK_ALLSORT`` bits then sets *sort_options*. @@ -1385,44 +1564,46 @@ def set_channel_spike_sorting(self, chan_id: int, sort_options: int): ``spkopts`` field), this preserves non-sorting bits like ``cbAINPSPK_EXTRACT``. - Call :meth:`sync` before reading back state that depends on this - configuration. - Args: chan_id: 1-based channel ID. sort_options: Spike sorting option flags (cbAINPSPK_*). + auto_sync: If True, sync before the read-modify-write so the + local cache reflects any in-flight config from this process. """ _check( _get_lib().cbsdk_session_set_channel_spike_sorting( - self._session, chan_id, sort_options + self._session, chan_id, sort_options, 1 if auto_sync else 0 ), "Failed to set channel spike sorting", ) def set_spike_extraction( self, - n_chans: int, + chans: "int | list[int] | None", channel_type: ChannelType, enabled: bool, - ): - """Enable or disable spike extraction for channels of a specific type (fire-and-forget). + ) -> None: + """Enable or disable spike extraction for channels of a specific type + (fire-and-forget). Controls the ``cbAINPSPK_EXTRACT`` bit which determines whether the device emits spike event packets. Uses ``cbPKTTYPE_CHANSETSPK``. - Call :meth:`sync` before reading back state that depends on this - configuration. + See :meth:`set_sample_group` for the channel selection and + pre-sync contract. Args: - n_chans: Number of channels to configure. + chans: Channel selection (see :meth:`set_sample_group`). channel_type: Channel type filter (e.g., ``ChannelType.FRONTEND``). enabled: ``True`` to enable spike extraction, ``False`` to disable. """ _lib = _get_lib() + n_chans, c_arr = self._normalize_chans(chans) _check( _lib.cbsdk_session_set_spike_extraction( self._session, n_chans, + c_arr, int(_coerce_enum(ChannelType, channel_type)), enabled, ), diff --git a/pycbsdk/tests/test_configuration.py b/pycbsdk/tests/test_configuration.py index 279fd0f..1edbf37 100644 --- a/pycbsdk/tests/test_configuration.py +++ b/pycbsdk/tests/test_configuration.py @@ -149,6 +149,21 @@ def test_disable_others(self, nplay_session): channels_1k = nplay_session.get_group_channels(int(SampleRate.SR_1kHz)) assert len(channels_1k) == 2 + def test_disable_others_30k_no_preamble(self, nplay_session): + """Repro for ezmsg-blackrock hang: disable_others=True at SR_30kHz + called immediately after session creation, no preamble. + + Must complete within the per-test timeout (60 s). An indefinite + hang here points to a sync-vs-bulk-send race in the bulk setter. + """ + nplay_session.set_sample_group( + 2, ChannelType.FRONTEND, SampleRate.SR_30kHz, + disable_others=True, + ) + nplay_session.sync() + ch_30k = nplay_session.get_group_channels(int(SampleRate.SR_30kHz)) + assert sorted(ch_30k) == [1, 2] + # --- Per-channel setter (set_channel_smpgroup) --- # # These tests use only per-channel setters for setup (no batch clear) @@ -417,6 +432,140 @@ def test_set_spike_sorting(self, nplay_session): nplay_session.sync() +# --------------------------------------------------------------------------- +# Explicit channel-list mode for bulk setters +# --------------------------------------------------------------------------- + + +class TestBulkSettersChanList: + """``chans`` accepts an explicit list of 1-based channel ids.""" + + def test_disjoint_ranges_via_two_calls(self, nplay_session): + """Configure chans 1-2 to one rate, chans 3-4 to another.""" + nplay_session.set_sample_group( + [1, 2], ChannelType.FRONTEND, SampleRate.SR_1kHz, + disable_others=False, + ) + nplay_session.set_sample_group( + [3, 4], ChannelType.FRONTEND, SampleRate.SR_2kHz, + disable_others=False, + ) + nplay_session.sync() + assert 1 in nplay_session.get_group_channels(int(SampleRate.SR_1kHz)) + assert 2 in nplay_session.get_group_channels(int(SampleRate.SR_1kHz)) + assert 3 in nplay_session.get_group_channels(int(SampleRate.SR_2kHz)) + assert 4 in nplay_session.get_group_channels(int(SampleRate.SR_2kHz)) + + def test_disjoint_set(self, nplay_session): + """Non-contiguous list of channels is configured correctly.""" + # First clear any prior state on these chans. + nplay_session.set_sample_group( + None, ChannelType.FRONTEND, SampleRate.NONE, disable_others=True, + ) + nplay_session.set_sample_group( + [1, 3], ChannelType.FRONTEND, SampleRate.SR_30kHz, + disable_others=False, + ) + nplay_session.sync() + ch_30k = nplay_session.get_group_channels(int(SampleRate.SR_30kHz)) + assert 1 in ch_30k + assert 3 in ch_30k + assert 2 not in ch_30k + + def test_disable_others_with_list(self, nplay_session): + """disable_others=True with an explicit list disables every FE chan + that isn't in the list.""" + n_fe = nplay_session.num_fe_chans() + # Enable everything at 30kHz first. + nplay_session.set_sample_group( + None, ChannelType.FRONTEND, SampleRate.SR_30kHz, disable_others=True, + ) + # Now keep only chans 1, 2 at 1kHz; disable everything else. + nplay_session.set_sample_group( + [1, 2], ChannelType.FRONTEND, SampleRate.SR_1kHz, disable_others=True, + ) + nplay_session.sync() + ch_1k = nplay_session.get_group_channels(int(SampleRate.SR_1kHz)) + ch_30k = nplay_session.get_group_channels(int(SampleRate.SR_30kHz)) + assert sorted(ch_1k) == [1, 2] + if n_fe > 2: + # The not-listed FE chans should be disabled (smpgroup=0). + for chan_id in range(3, n_fe + 1): + assert chan_id not in ch_30k + + def test_list_with_set_ac_input_coupling(self, nplay_session): + """Explicit list works for set_ac_input_coupling — smoke test.""" + nplay_session.set_ac_input_coupling( + [1, 3], ChannelType.FRONTEND, True, + ) + + def test_list_with_set_spike_extraction(self, nplay_session): + """Explicit list works for set_spike_extraction — smoke test.""" + nplay_session.set_spike_extraction( + None, ChannelType.FRONTEND, True, + ) + nplay_session.set_spike_extraction( + [1, 3], ChannelType.FRONTEND, False, + ) + + def test_empty_list_with_disable_others(self, nplay_session): + """Empty list + disable_others disables all FE chans.""" + # Enable everything first + nplay_session.set_sample_group( + None, ChannelType.FRONTEND, SampleRate.SR_30kHz, disable_others=True, + ) + # Empty list with disable_others: nothing configured, all disabled. + nplay_session.set_sample_group( + [], ChannelType.FRONTEND, SampleRate.SR_1kHz, disable_others=True, + ) + nplay_session.sync() + assert nplay_session.get_group_channels(int(SampleRate.SR_30kHz)) == [] + assert nplay_session.get_group_channels(int(SampleRate.SR_1kHz)) == [] + + def test_tuple_and_range_accepted(self, nplay_session): + """chans accepts any iterable of ints, not just list.""" + nplay_session.set_sample_group( + (1, 2), ChannelType.FRONTEND, SampleRate.SR_30kHz, disable_others=False, + ) + nplay_session.set_sample_group( + range(1, 3), ChannelType.FRONTEND, SampleRate.SR_30kHz, disable_others=False, + ) + + +# --------------------------------------------------------------------------- +# Async wait_until_running +# --------------------------------------------------------------------------- + + +class TestWaitUntilRunning: + """Async runlevel-awaitable. + + The fixture brings the device up to RUNNING before yielding the session, + so wait_until_running() returns immediately without registering a + listener; covering the registration path requires fault injection that + isn't worth the test complexity. + """ + + def test_returns_immediately_when_running(self, nplay_session): + import asyncio + from pycbsdk.session import RUNLEVEL_RUNNING + + async def run() -> None: + assert nplay_session.runlevel >= RUNLEVEL_RUNNING + await nplay_session.wait_until_running(timeout=1.0) + + asyncio.run(run()) + + def test_wait_for_runlevel_immediate(self, nplay_session): + import asyncio + from pycbsdk.session import RUNLEVEL_STANDBY + + async def run() -> None: + await nplay_session.wait_for_runlevel(RUNLEVEL_STANDBY, timeout=1.0) + + asyncio.run(run()) + + # --------------------------------------------------------------------------- # Per-channel configuration (configure_channel) # --------------------------------------------------------------------------- @@ -873,6 +1022,54 @@ def test_overlay_survives_chanrep_refresh(self, nplay_session, cmp_path): assert pos == expected[chan_id].position assert label == expected[chan_id].label + def test_clear_channel_map_resets_labels(self, nplay_session, cmp_path): + """clear_channel_map() reverts labels to the device default ("chanN").""" + from pycbsdk.cmp import parse_cmp + + nplay_session.load_channel_map(str(cmp_path), hs_id=11) + nplay_session.sync() + + expected = parse_cmp(str(cmp_path), hs_id=11) + view_before = self._frontend_view(nplay_session) + loaded = sorted(set(expected) & set(view_before)) + assert loaded + for chan_id in loaded: + _, label = view_before[chan_id] + assert label.startswith("hs11-") + + nplay_session.clear_channel_map() + nplay_session.sync() + + view_after = self._frontend_view(nplay_session) + for chan_id in loaded: + _, label = view_after[chan_id] + assert label == f"chan{chan_id}", ( + f"chan {chan_id}: expected default label, got {label!r}" + ) + + def test_clear_channel_map_resets_positions(self, nplay_session, cmp_path): + """clear_channel_map() drops the position overlay; positions revert to zero.""" + nplay_session.load_channel_map(str(cmp_path)) + nplay_session.sync() + + view_before = self._frontend_view(nplay_session) + loaded = [c for c, (pos, _) in view_before.items() if any(p != 0 for p in pos)] + assert loaded, "expected at least one frontend chan with non-zero position" + + nplay_session.clear_channel_map() + nplay_session.sync() + + view_after = self._frontend_view(nplay_session) + for chan_id in loaded: + pos, _ = view_after[chan_id] + assert all(p == 0 for p in pos), ( + f"chan {chan_id}: expected zeroed position after clear, got {pos}" + ) + + def test_clear_channel_map_empty_is_noop(self, nplay_session): + """Calling clear_channel_map() with no map loaded is a no-op.""" + nplay_session.clear_channel_map() # must not raise + # --------------------------------------------------------------------------- # Spike length configuration diff --git a/src/cbdev/src/device_session.cpp b/src/cbdev/src/device_session.cpp index e5423e4..a28c090 100644 --- a/src/cbdev/src/device_session.cpp +++ b/src/cbdev/src/device_session.cpp @@ -1220,9 +1220,13 @@ Result DeviceSession::setChannelsSpikeSortingByType(const size_t nChans, c continue; } - // Create channel config packet + // Create channel config packet. + // + // Use CHANSET so the firmware applies the spkopts modification. + // CHANSETSPKTHR (used by older revisions) only reads spkthrlevel + // server-side, so the spkopts changes were silently dropped. cbPKT_CHANINFO pkt = chaninfo; // Start with current config - pkt.cbpkt_header.type = cbPKTTYPE_CHANSETSPKTHR; // Use spike threshold set command + pkt.cbpkt_header.type = cbPKTTYPE_CHANSET; pkt.chan = chan; // Clear all spike sorting flags and set new ones @@ -1249,8 +1253,9 @@ Result DeviceSession::setChannelsSpikeSortingSync(const size_t nChans, con return setChannelsSpikeSortingByType(nChans, chanType, sortOptions); }, [](const cbPKT_HEADER* hdr) { + // CHANSET broadcasts a CHANREP echo (not CHANREPSPKTHR). return (hdr->chid & cbPKTCHAN_CONFIGURATION) == cbPKTCHAN_CONFIGURATION && - hdr->type == cbPKTTYPE_CHANREPSPKTHR; + hdr->type == cbPKTTYPE_CHANREP; }, timeout, total_matching diff --git a/src/cbsdk/include/cbsdk/cbsdk.h b/src/cbsdk/include/cbsdk/cbsdk.h index e3e24cc..0ca0f21 100644 --- a/src/cbsdk/include/cbsdk/cbsdk.h +++ b/src/cbsdk/include/cbsdk/cbsdk.h @@ -340,6 +340,24 @@ CBSDK_API cbsdk_callback_handle_t cbsdk_session_register_config_callback( cbsdk_config_callback_fn callback, void* user_data); +/// Callback type for device-runlevel transitions. +typedef void (*cbsdk_runlevel_callback_fn)(uint32_t runlevel, void* user_data); + +/// Register a callback that fires whenever the device-reported runlevel +/// transitions to a new value (driven by SYSREP packets). Useful for +/// awaiting cbRUNLEVEL_RUNNING after a reset. The callback runs on the +/// receive thread (STANDALONE) or shmem-receive thread (CLIENT) — keep it +/// fast or post into a queue. +/// @param session Session handle (must not be NULL) +/// @param callback Function to call on each transition (must not be NULL) +/// @param user_data Opaque user pointer passed to @p callback +/// @return Callback handle for later cbsdk_session_unregister_callback, +/// or 0 on error +CBSDK_API cbsdk_callback_handle_t cbsdk_session_register_runlevel_callback( + cbsdk_session_t session, + cbsdk_runlevel_callback_fn callback, + void* user_data); + /// Unregister a previously registered callback /// @param session Session handle (must not be NULL) /// @param handle Handle returned by a register_*_callback function @@ -543,29 +561,55 @@ CBSDK_API cbsdk_result_t cbsdk_session_get_group_list( // Channel Configuration /////////////////////////////////////////////////////////////////////////////////////////////////// -/// Set sampling rate for channels of a specific type +/// Set sampling rate for channels of a specific type (fire-and-forget). +/// +/// Runs sync() before reading the local cache so any prior in-flight config +/// from this process has landed before we seed CHANSET* packets from it. +/// Does NOT sync after sending — the caller should call cbsdk_session_sync() +/// before reading back state that depends on the new configuration. Always +/// sends a CHANSET* packet for every in-scope channel — no +/// skip-if-already-correct optimization, since the local cache may be stale +/// due to a dropped CHANREP. +/// +/// Channel selection has two modes: +/// - @p chans is NULL: configure the first @p n_chans channels of +/// @p chan_type in channel-id order. Use UINT32_MAX for all matching. +/// - @p chans is non-NULL: configure the explicit list of @p n_chans +/// 1-based channel ids. No type filtering on listed channels (caller +/// is trusted), but @p chan_type still selects the "others" set for +/// @p disable_others. +/// /// @param session Session handle (must not be NULL) -/// @param n_chans Number of channels to configure (use cbMAXCHANS for all) -/// @param chan_type Channel type filter +/// @param n_chans When @p chans is NULL: count to configure (UINT32_MAX +/// for all matching). When @p chans is non-NULL: list length. +/// @param chans Optional explicit list of 1-based channel ids; pass NULL +/// for the count/UINT32_MAX-based mode. +/// @param chan_type Channel type filter (matching when @p chans is NULL, +/// "others" set when @p disable_others). /// @param rate Sample rate (CBPROTO_GROUP_RATE_NONE to disable, _500Hz through _RAW) -/// @param disable_others If true, disable sampling on unselected channels of this type +/// @param disable_others If true, disable sampling on unselected channels of @p chan_type. /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_sample_group( cbsdk_session_t session, - size_t n_chans, + uint32_t n_chans, + const uint32_t* chans, cbproto_channel_type_t chan_type, cbproto_group_rate_t rate, bool disable_others); /// Set AC input coupling (offset correction) for channels of a specific type +/// (fire-and-forget). See cbsdk_session_set_sample_group() for the +/// channel-selection and pre-sync contract. /// @param session Session handle (must not be NULL) -/// @param n_chans Number of channels to configure (use cbMAXCHANS for all) +/// @param n_chans Count or list length (see cbsdk_session_set_sample_group) +/// @param chans Optional explicit list of 1-based channel ids /// @param chan_type Channel type filter /// @param enabled true = AC coupling, false = DC coupling /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_ac_input_coupling( cbsdk_session_t session, - size_t n_chans, + uint32_t n_chans, + const uint32_t* chans, cbproto_channel_type_t chan_type, bool enabled); @@ -577,15 +621,24 @@ CBSDK_API cbsdk_result_t cbsdk_session_set_channel_config( cbsdk_session_t session, const cbPKT_CHANINFO* chaninfo); -/// Set a channel's label +/// Set a channel's label. +/// +/// Per-channel CHANSET setters are read-modify-write: they seed the outgoing +/// CHANINFO from the locally cached state. If a prior config command sent by +/// this process is still in flight, the seed may be stale. Pass +/// @p auto_sync != 0 to run an internal sync() barrier before reading the +/// cache, at the cost of one round-trip to the device. +/// /// @param session Session handle (must not be NULL) /// @param chan_id 1-based channel ID (1 to cbMAXCHANS) /// @param label New label string (max 15 chars, null-terminated) +/// @param auto_sync Non-zero to sync() before the read-modify-write /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_channel_label( cbsdk_session_t session, uint32_t chan_id, - const char* label); + const char* label, + int auto_sync); /// Set a single channel's sample group (fire-and-forget). /// Handles group-specific logic: RAWSTREAM flag for group 6, filter mapping @@ -593,63 +646,79 @@ CBSDK_API cbsdk_result_t cbsdk_session_set_channel_label( /// @param session Session handle (must not be NULL) /// @param chan_id 1-based channel ID (1 to cbMAXCHANS) /// @param rate Sample group (0=NONE, 1-5=filtered, 6=RAW) +/// @param auto_sync Non-zero to sync() before the read-modify-write /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_channel_smpgroup( cbsdk_session_t session, uint32_t chan_id, - cbproto_group_rate_t rate); + cbproto_group_rate_t rate, + int auto_sync); /// Set a channel's continuous-time pathway filter /// @param session Session handle (must not be NULL) /// @param chan_id 1-based channel ID (1 to cbMAXCHANS) /// @param filter_id Filter ID (0 to cbMAXFILTS-1) +/// @param auto_sync Non-zero to sync() before the read-modify-write /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_channel_smpfilter( cbsdk_session_t session, uint32_t chan_id, - uint32_t filter_id); + uint32_t filter_id, + int auto_sync); /// Set a channel's spike pathway filter /// @param session Session handle (must not be NULL) /// @param chan_id 1-based channel ID (1 to cbMAXCHANS) /// @param filter_id Filter ID (0 to cbMAXFILTS-1) +/// @param auto_sync Non-zero to sync() before the read-modify-write /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_channel_spkfilter( cbsdk_session_t session, uint32_t chan_id, - uint32_t filter_id); + uint32_t filter_id, + int auto_sync); /// Set a channel's analog input options (LNC mode, reference electrode, etc.) /// @param session Session handle (must not be NULL) /// @param chan_id 1-based channel ID (1 to cbMAXCHANS) /// @param ainpopts Analog input option flags (cbAINP_* flags) +/// @param auto_sync Non-zero to sync() before the read-modify-write /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_channel_ainpopts( cbsdk_session_t session, uint32_t chan_id, - uint32_t ainpopts); + uint32_t ainpopts, + int auto_sync); /// Set a channel's line noise cancellation adaptation rate /// @param session Session handle (must not be NULL) /// @param chan_id 1-based channel ID (1 to cbMAXCHANS) /// @param lncrate LNC rate +/// @param auto_sync Non-zero to sync() before the read-modify-write /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_channel_lncrate( cbsdk_session_t session, uint32_t chan_id, - uint32_t lncrate); + uint32_t lncrate, + int auto_sync); /// Set a channel's spike processing options /// @param session Session handle (must not be NULL) /// @param chan_id 1-based channel ID (1 to cbMAXCHANS) /// @param spkopts Spike option flags (cbAINPSPK_* flags) +/// @param auto_sync Non-zero to sync() before the read-modify-write /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_channel_spkopts( cbsdk_session_t session, uint32_t chan_id, - uint32_t spkopts); + uint32_t spkopts, + int auto_sync); -/// Set a channel's spike threshold level +/// Set a channel's spike threshold level. +/// +/// CHANSETSPKTHR is narrow (firmware only reads spkthrlevel and the setter +/// fully overwrites it), so no auto_sync flag is required. +/// /// @param session Session handle (must not be NULL) /// @param chan_id 1-based channel ID (1 to cbMAXCHANS) /// @param level Threshold level @@ -663,11 +732,13 @@ CBSDK_API cbsdk_result_t cbsdk_session_set_channel_spkthrlevel( /// @param session Session handle (must not be NULL) /// @param chan_id 1-based channel ID (1 to cbMAXCHANS) /// @param enabled true to enable, false to disable +/// @param auto_sync Non-zero to sync() before the read-modify-write /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_channel_autothreshold( cbsdk_session_t session, uint32_t chan_id, - bool enabled); + bool enabled, + int auto_sync); /////////////////////////////////////////////////////////////////////////////////////////////////// // Bulk Channel Queries @@ -846,6 +917,16 @@ CBSDK_API cbsdk_result_t cbsdk_session_load_channel_map( uint32_t start_chan, uint32_t hs_id); +/// Clear all channel maps loaded via cbsdk_session_load_channel_map(). +/// +/// Drops the local position+label overlay and pushes default labels +/// ("chan{N}") to the device for every previously-mapped channel so the +/// device-side label state reverts. Fire-and-forget. +/// +/// @param session Session handle (must not be NULL) +/// @return CBSDK_RESULT_SUCCESS on success, error code on failure +CBSDK_API cbsdk_result_t cbsdk_session_clear_channel_map(cbsdk_session_t session); + /////////////////////////////////////////////////////////////////////////////////////////////////// // CCF Configuration Files /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -970,14 +1051,18 @@ CBSDK_API cbsdk_result_t cbsdk_session_close_central_file_dialog(cbsdk_session_t /////////////////////////////////////////////////////////////////////////////////////////////////// /// Set spike sorting options for channels of a specific type +/// (fire-and-forget). See cbsdk_session_set_sample_group() for the +/// channel-selection and pre-sync contract. /// @param session Session handle (must not be NULL) -/// @param n_chans Number of channels to configure +/// @param n_chans Count or list length (see cbsdk_session_set_sample_group) +/// @param chans Optional explicit list of 1-based channel ids /// @param chan_type Channel type filter /// @param sort_options Spike sorting option flags (cbAINPSPK_*) /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_spike_sorting( cbsdk_session_t session, - size_t n_chans, + uint32_t n_chans, + const uint32_t* chans, cbproto_channel_type_t chan_type, uint32_t sort_options); @@ -986,23 +1071,29 @@ CBSDK_API cbsdk_result_t cbsdk_session_set_spike_sorting( /// @param session Session handle (must not be NULL) /// @param chan_id 1-based channel ID (1 to cbMAXCHANS) /// @param sort_options Spike sorting option flags (cbAINPSPK_*) +/// @param auto_sync Non-zero to sync() before the read-modify-write /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_channel_spike_sorting( cbsdk_session_t session, uint32_t chan_id, - uint32_t sort_options); + uint32_t sort_options, + int auto_sync); -/// Enable or disable spike extraction for channels of a specific type. -/// Controls the cbAINPSPK_EXTRACT bit via cbPKTTYPE_CHANSETSPK. -/// When enabled, the device emits spike event packets for matching channels. +/// Enable or disable spike extraction for channels of a specific type +/// (fire-and-forget). Controls the cbAINPSPK_EXTRACT bit via +/// cbPKTTYPE_CHANSETSPK. When enabled, the device emits spike event +/// packets for matching channels. See cbsdk_session_set_sample_group() +/// for the channel-selection and pre-sync contract. /// @param session Session handle (must not be NULL) -/// @param n_chans Number of channels to configure +/// @param n_chans Count or list length (see cbsdk_session_set_sample_group) +/// @param chans Optional explicit list of 1-based channel ids /// @param chan_type Channel type filter /// @param enabled true = enable spike extraction, false = disable /// @return CBSDK_RESULT_SUCCESS on success, error code on failure CBSDK_API cbsdk_result_t cbsdk_session_set_spike_extraction( cbsdk_session_t session, - size_t n_chans, + uint32_t n_chans, + const uint32_t* chans, cbproto_channel_type_t chan_type, bool enabled); diff --git a/src/cbsdk/include/cbsdk/sdk_session.h b/src/cbsdk/include/cbsdk/sdk_session.h index 55d538a..3f518a0 100644 --- a/src/cbsdk/include/cbsdk/sdk_session.h +++ b/src/cbsdk/include/cbsdk/sdk_session.h @@ -269,6 +269,11 @@ using GroupBatchCallback = std::function; +/// Runlevel-change callback. Fired when the device runlevel reported in a +/// SYSREP packet differs from the previously recorded value. +/// @param runlevel The new runlevel (cbRUNLEVEL_*) +using RunlevelCallback = std::function; + /// Error callback for queue overflow and other errors /// @param error_message Description of the error using ErrorCallback = std::function; @@ -377,6 +382,16 @@ class SdkSession { /// @return Handle for unregistration CallbackHandle registerConfigCallback(uint16_t packet_type, ConfigCallback callback) const; + /// Register callback for device-runlevel transitions. + /// Fires when the runlevel reported in a SYSREP packet differs from the + /// previously recorded value. Useful for waiting on the device to reach + /// cbRUNLEVEL_RUNNING after a reset. Callbacks run on the receive + /// thread (STANDALONE) or shmem-receive thread (CLIENT) — keep them fast, + /// or post to your own queue. + /// @param callback Function to call on each transition (receives new runlevel) + /// @return Handle for unregistration + CallbackHandle registerRunlevelChangeCallback(RunlevelCallback callback) const; + /// Unregister a previously registered callback /// @param handle Handle returned by any register*Callback method void unregisterCallback(CallbackHandle handle) const; @@ -510,45 +525,73 @@ class SdkSession { ///-------------------------------------------------------------------------------------------- /// Set sampling rate for channels of a specific type (fire-and-forget). - /// The device will not have applied the new configuration when this call - /// returns. Call sync() before reading back state (e.g., getGroupChannelList) - /// or registering callbacks that depend on the new configuration. - /// @param nChans Number of channels to configure (cbMAXCHANS for all) - /// @param chanType Channel type filter + /// + /// Runs sync() before reading the local cache so any prior in-flight + /// config from this process has landed before we seed CHANSET* packets + /// from it (the #177 contract). Does NOT sync after sending — the + /// caller should call sync() before reading back state that depends on + /// the new configuration. Always sends a CHANSET* packet for every + /// in-scope channel — never skips channels that look already configured, + /// since the local cache may be stale due to dropped CHANREP packets + /// from a concurrent client. + /// + /// Channel selection has two modes: + /// - @p chans is nullptr: configure the first @p nChans channels of + /// @p chanType in channel-id order. Use UINT32_MAX for all matching. + /// - @p chans is non-null: configure the explicit list of @p nChans + /// 1-based channel ids. No type filtering is performed on the listed + /// channels (caller is trusted), but @p chanType still selects the + /// "others" set for @p disableOthers. + /// + /// @param nChans When @p chans is nullptr: count to configure (UINT32_MAX + /// for all matching). When @p chans is non-null: list length. + /// @param chanType Channel type filter (used for matching when @p chans + /// is nullptr, and for the "others" set when @p disableOthers). /// @param rate Desired sample rate (NONE to disable, SR_500 through SR_RAW) - /// @param disableOthers Disable sampling on channels not in the first nChans of type - /// @return Result indicating success or error - Result setSampleGroup(size_t nChans, ChannelType chanType, - SampleRate rate, bool disableOthers = false); - - /// Set spike sorting options for channels of a specific type (fire-and-forget). - /// Call sync() before reading back state that depends on this configuration. - /// @param nChans Number of channels to configure + /// @param disableOthers Disable sampling on channels of @p chanType not + /// in the configured set. + /// @param chans Optional explicit list of 1-based channel ids; must be + /// non-null only when caller wants the explicit-list mode. + /// @return Result indicating success or error. + Result setSampleGroup(uint32_t nChans, ChannelType chanType, + SampleRate rate, bool disableOthers = false, + const uint32_t* chans = nullptr); + + /// Set spike sorting options for channels of a specific type + /// (fire-and-forget). See setSampleGroup() for the channel-selection + /// and pre-sync contract. + /// @param nChans Count or list length (see setSampleGroup) /// @param chanType Channel type filter /// @param sortOptions Spike sorting option flags (cbAINPSPK_*) - /// @return Result indicating success or error - Result setSpikeSorting(size_t nChans, ChannelType chanType, - uint32_t sortOptions); + /// @param chans Optional explicit list of 1-based channel ids + /// @return Result indicating success or error. + Result setSpikeSorting(uint32_t nChans, ChannelType chanType, + uint32_t sortOptions, + const uint32_t* chans = nullptr); /// Enable or disable spike extraction (cbAINPSPK_EXTRACT) for channels - /// of a type (fire-and-forget). - /// This controls whether the device emits spike event packets for these channels. - /// Uses cbPKTTYPE_CHANSETSPK (not the threshold command). - /// Call sync() before reading back state that depends on this configuration. - /// @param nChans Number of channels to configure + /// of a type (fire-and-forget). See setSampleGroup() for the + /// channel-selection and pre-sync contract. + /// @param nChans Count or list length (see setSampleGroup) /// @param chanType Channel type filter /// @param enabled true = enable spike extraction, false = disable - /// @return Result indicating success or error - Result setSpikeExtraction(size_t nChans, ChannelType chanType, bool enabled); + /// @param chans Optional explicit list of 1-based channel ids + /// @return Result indicating success or error. + Result setSpikeExtraction(uint32_t nChans, ChannelType chanType, + bool enabled, + const uint32_t* chans = nullptr); /// Set AC input coupling (offset correction) for channels of a specific - /// type (fire-and-forget). - /// Call sync() before reading back state that depends on this configuration. - /// @param nChans Number of channels to configure (cbMAXCHANS for all) + /// type (fire-and-forget). See setSampleGroup() for the + /// channel-selection and pre-sync contract. + /// @param nChans Count or list length (see setSampleGroup) /// @param chanType Channel type filter /// @param enabled true = AC coupling (offset correction on), false = DC coupling - /// @return Result indicating success or error - Result setACInputCoupling(size_t nChans, ChannelType chanType, bool enabled); + /// @param chans Optional explicit list of 1-based channel ids + /// @return Result indicating success or error. + Result setACInputCoupling(uint32_t nChans, ChannelType chanType, + bool enabled, + const uint32_t* chans = nullptr); /// Set full channel configuration by packet (fire-and-forget). /// Call sync() before reading back state that depends on this configuration. @@ -659,6 +702,17 @@ class SdkSession { uint32_t start_chan = 1, uint32_t hs_id = 0); + /// Clear all channel maps loaded via loadChannelMap(). + /// + /// Wipes the local position+label overlay (so positions revert) and pushes + /// default labels ("chan1", "chan2", ...) to the device for every channel + /// that was previously mapped, so the device-side label state matches. + /// Fire-and-forget: returns once labels are queued; the caller can call + /// sync() if it needs to read back state. + /// + /// @return Result indicating success or error + Result clearChannelMap(); + ///-------------------------------------------------------------------------------------------- /// CCF Configuration Files ///-------------------------------------------------------------------------------------------- @@ -740,6 +794,15 @@ class SdkSession { /// @return Result indicating success or error Result sendPacket(const cbPKT_GENERIC& pkt); + /// Bulk-send a vector of packets using the most efficient available path + /// (direct UDP via device_session in STANDALONE — coalesces with built-in + /// pacing — and per-packet shmem enqueue in CLIENT). Used by the bulk + /// by-type setters; useful directly when the caller has already built + /// the packet vector. + /// @param packets Packets to send (may be empty, in which case this is a no-op) + /// @return Result indicating success or error + Result sendBulkPackets(const std::vector& packets); + /// Send a runlevel command packet to the device /// @param runlevel Desired runlevel (cbRUNLEVEL_*) /// @param resetque Channel for reset to queue on (default: 0) diff --git a/src/cbsdk/src/cbsdk.cpp b/src/cbsdk/src/cbsdk.cpp index 3ef7892..99533dd 100644 --- a/src/cbsdk/src/cbsdk.cpp +++ b/src/cbsdk/src/cbsdk.cpp @@ -578,6 +578,24 @@ cbsdk_callback_handle_t cbsdk_session_register_config_callback( } } +cbsdk_callback_handle_t cbsdk_session_register_runlevel_callback( + cbsdk_session_t session, + cbsdk_runlevel_callback_fn callback, + void* user_data) { + if (!session || !session->cpp_session || !callback) { + return 0; + } + try { + return session->cpp_session->registerRunlevelChangeCallback( + [callback, user_data](uint32_t runlevel) { + callback(runlevel, user_data); + } + ); + } catch (...) { + return 0; + } +} + void cbsdk_session_unregister_callback(cbsdk_session_t session, cbsdk_callback_handle_t handle) { if (!session || !session->cpp_session || handle == 0) { @@ -890,7 +908,8 @@ cbsdk_result_t cbsdk_session_get_group_list( cbsdk_result_t cbsdk_session_set_sample_group( cbsdk_session_t session, - size_t n_chans, + uint32_t n_chans, + const uint32_t* chans, cbproto_channel_type_t chan_type, cbproto_group_rate_t rate, bool disable_others) { @@ -899,7 +918,8 @@ cbsdk_result_t cbsdk_session_set_sample_group( } try { auto result = session->cpp_session->setSampleGroup( - n_chans, to_cpp_channel_type(chan_type), static_cast(rate), disable_others); + n_chans, to_cpp_channel_type(chan_type), + static_cast(rate), disable_others, chans); return result.isOk() ? CBSDK_RESULT_SUCCESS : CBSDK_RESULT_INTERNAL_ERROR; } catch (...) { return CBSDK_RESULT_INTERNAL_ERROR; @@ -922,13 +942,20 @@ cbsdk_result_t cbsdk_session_set_channel_config( /// Helper: read-modify-write a channel config field /// Gets current chaninfo from shared memory, calls `modify` on a copy, sends the packet. +/// If @p auto_sync is true, runs sync() first so the local cache reflects any +/// in-flight CHANREP packets before this read-modify-write seeds from it. static cbsdk_result_t modify_and_send_chaninfo( cbsdk_session_t session, uint32_t chan_id, uint16_t pkt_type, - std::function modify) { + std::function modify, + bool auto_sync = false) { if (!session || !session->cpp_session) return CBSDK_RESULT_INVALID_PARAMETER; try { + if (auto_sync) { + auto sync_result = session->cpp_session->sync(5000); + if (sync_result.isError()) return CBSDK_RESULT_INTERNAL_ERROR; + } const cbPKT_CHANINFO* info = session->cpp_session->getChanInfo(chan_id); if (!info) return CBSDK_RESULT_INVALID_PARAMETER; cbPKT_CHANINFO ci = *info; @@ -945,22 +972,28 @@ static cbsdk_result_t modify_and_send_chaninfo( } cbsdk_result_t cbsdk_session_set_channel_label( - cbsdk_session_t session, uint32_t chan_id, const char* label) { + cbsdk_session_t session, uint32_t chan_id, const char* label, int auto_sync) { if (!label) return CBSDK_RESULT_INVALID_PARAMETER; return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETLABEL, [label](cbPKT_CHANINFO& ci) { std::strncpy(ci.label, label, sizeof(ci.label) - 1); ci.label[sizeof(ci.label) - 1] = '\0'; - }); + }, + auto_sync != 0); } cbsdk_result_t cbsdk_session_set_channel_smpgroup( - const cbsdk_session_t session, const uint32_t chan_id, const cbproto_group_rate_t rate) { + const cbsdk_session_t session, const uint32_t chan_id, const cbproto_group_rate_t rate, + int auto_sync) { if (!session || !session->cpp_session) return CBSDK_RESULT_INVALID_PARAMETER; // Mirror the per-group logic from DeviceSession::setChannelsGroupByType. // The packet type varies by group because the device firmware only reads // specific fields depending on the type. try { + if (auto_sync) { + auto sync_result = session->cpp_session->sync(5000); + if (sync_result.isError()) return CBSDK_RESULT_INTERNAL_ERROR; + } const cbPKT_CHANINFO* info = session->cpp_session->getChanInfo(chan_id); if (!info) return CBSDK_RESULT_INVALID_PARAMETER; cbPKT_CHANINFO ci = *info; @@ -1002,47 +1035,57 @@ cbsdk_result_t cbsdk_session_set_channel_smpgroup( } cbsdk_result_t cbsdk_session_set_channel_smpfilter( - cbsdk_session_t session, uint32_t chan_id, uint32_t filter_id) { + cbsdk_session_t session, uint32_t chan_id, uint32_t filter_id, int auto_sync) { return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETSMP, [filter_id](cbPKT_CHANINFO& ci) { ci.smpfilter = filter_id; - }); + }, + auto_sync != 0); } cbsdk_result_t cbsdk_session_set_channel_spkfilter( - cbsdk_session_t session, uint32_t chan_id, uint32_t filter_id) { + cbsdk_session_t session, uint32_t chan_id, uint32_t filter_id, int auto_sync) { return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETSPK, [filter_id](cbPKT_CHANINFO& ci) { ci.spkfilter = filter_id; - }); + }, + auto_sync != 0); } cbsdk_result_t cbsdk_session_set_channel_ainpopts( - cbsdk_session_t session, uint32_t chan_id, uint32_t ainpopts) { + cbsdk_session_t session, uint32_t chan_id, uint32_t ainpopts, int auto_sync) { return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETAINP, [ainpopts](cbPKT_CHANINFO& ci) { ci.ainpopts = ainpopts; - }); + }, + auto_sync != 0); } cbsdk_result_t cbsdk_session_set_channel_lncrate( - cbsdk_session_t session, uint32_t chan_id, uint32_t lncrate) { + cbsdk_session_t session, uint32_t chan_id, uint32_t lncrate, int auto_sync) { return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETAINP, [lncrate](cbPKT_CHANINFO& ci) { ci.lncrate = lncrate; - }); + }, + auto_sync != 0); } cbsdk_result_t cbsdk_session_set_channel_spkopts( - cbsdk_session_t session, uint32_t chan_id, uint32_t spkopts) { - return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETSPKTHR, + cbsdk_session_t session, uint32_t chan_id, uint32_t spkopts, int auto_sync) { + // Use CHANSETSPK (firmware reads spkopts+spkfilter for this type). + // CHANSETSPKTHR — used by older revisions of this code — only reads + // spkthrlevel and would silently ignore the spkopts update. + return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETSPK, [spkopts](cbPKT_CHANINFO& ci) { ci.spkopts = spkopts; - }); + }, + auto_sync != 0); } cbsdk_result_t cbsdk_session_set_channel_spkthrlevel( cbsdk_session_t session, uint32_t chan_id, int32_t level) { + // CHANSETSPKTHR is "narrow": firmware reads only spkthrlevel and we + // overwrite it fully here, so no auto_sync flag is needed. return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETSPKTHR, [level](cbPKT_CHANINFO& ci) { ci.spkthrlevel = level; @@ -1050,14 +1093,15 @@ cbsdk_result_t cbsdk_session_set_channel_spkthrlevel( } cbsdk_result_t cbsdk_session_set_channel_autothreshold( - cbsdk_session_t session, uint32_t chan_id, bool enabled) { + cbsdk_session_t session, uint32_t chan_id, bool enabled, int auto_sync) { return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETAUTOTHRESHOLD, [enabled](cbPKT_CHANINFO& ci) { if (enabled) ci.spkopts |= cbAINPSPK_THRAUTO; else ci.spkopts &= ~cbAINPSPK_THRAUTO; - }); + }, + auto_sync != 0); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1310,6 +1354,18 @@ cbsdk_result_t cbsdk_session_load_channel_map( } } +cbsdk_result_t cbsdk_session_clear_channel_map(cbsdk_session_t session) { + if (!session || !session->cpp_session) { + return CBSDK_RESULT_INVALID_PARAMETER; + } + try { + auto result = session->cpp_session->clearChannelMap(); + return result.isOk() ? CBSDK_RESULT_SUCCESS : CBSDK_RESULT_INTERNAL_ERROR; + } catch (...) { + return CBSDK_RESULT_INTERNAL_ERROR; + } +} + /////////////////////////////////////////////////////////////////////////////////////////////////// // CCF Configuration Files /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1490,7 +1546,8 @@ cbsdk_result_t cbsdk_session_close_central_file_dialog(cbsdk_session_t session) cbsdk_result_t cbsdk_session_set_ac_input_coupling( cbsdk_session_t session, - size_t n_chans, + uint32_t n_chans, + const uint32_t* chans, cbproto_channel_type_t chan_type, bool enabled) { if (!session || !session->cpp_session) { @@ -1498,7 +1555,7 @@ cbsdk_result_t cbsdk_session_set_ac_input_coupling( } try { auto result = session->cpp_session->setACInputCoupling( - n_chans, to_cpp_channel_type(chan_type), enabled); + n_chans, to_cpp_channel_type(chan_type), enabled, chans); return result.isOk() ? CBSDK_RESULT_SUCCESS : CBSDK_RESULT_INTERNAL_ERROR; } catch (...) { return CBSDK_RESULT_INTERNAL_ERROR; @@ -1511,7 +1568,8 @@ cbsdk_result_t cbsdk_session_set_ac_input_coupling( cbsdk_result_t cbsdk_session_set_spike_sorting( cbsdk_session_t session, - size_t n_chans, + uint32_t n_chans, + const uint32_t* chans, cbproto_channel_type_t chan_type, uint32_t sort_options) { if (!session || !session->cpp_session) { @@ -1519,7 +1577,7 @@ cbsdk_result_t cbsdk_session_set_spike_sorting( } try { auto result = session->cpp_session->setSpikeSorting( - n_chans, to_cpp_channel_type(chan_type), sort_options); + n_chans, to_cpp_channel_type(chan_type), sort_options, chans); return result.isOk() ? CBSDK_RESULT_SUCCESS : CBSDK_RESULT_INTERNAL_ERROR; } catch (...) { return CBSDK_RESULT_INTERNAL_ERROR; @@ -1527,17 +1585,22 @@ cbsdk_result_t cbsdk_session_set_spike_sorting( } cbsdk_result_t cbsdk_session_set_channel_spike_sorting( - cbsdk_session_t session, uint32_t chan_id, uint32_t sort_options) { - return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETSPKTHR, + cbsdk_session_t session, uint32_t chan_id, uint32_t sort_options, int auto_sync) { + // Use CHANSETSPK (firmware reads spkopts+spkfilter). Earlier revisions + // used CHANSETSPKTHR, which only reads spkthrlevel and would silently + // ignore the spkopts update we want here. + return modify_and_send_chaninfo(session, chan_id, cbPKTTYPE_CHANSETSPK, [sort_options](cbPKT_CHANINFO& ci) { ci.spkopts &= ~cbAINPSPK_ALLSORT; ci.spkopts |= sort_options; - }); + }, + auto_sync != 0); } cbsdk_result_t cbsdk_session_set_spike_extraction( cbsdk_session_t session, - size_t n_chans, + uint32_t n_chans, + const uint32_t* chans, cbproto_channel_type_t chan_type, bool enabled) { if (!session || !session->cpp_session) { @@ -1545,7 +1608,7 @@ cbsdk_result_t cbsdk_session_set_spike_extraction( } try { auto result = session->cpp_session->setSpikeExtraction( - n_chans, to_cpp_channel_type(chan_type), enabled); + n_chans, to_cpp_channel_type(chan_type), enabled, chans); return result.isOk() ? CBSDK_RESULT_SUCCESS : CBSDK_RESULT_INTERNAL_ERROR; } catch (...) { return CBSDK_RESULT_INTERNAL_ERROR; diff --git a/src/cbsdk/src/sdk_session.cpp b/src/cbsdk/src/sdk_session.cpp index 5938ef7..2a4d2cf 100644 --- a/src/cbsdk/src/sdk_session.cpp +++ b/src/cbsdk/src/sdk_session.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include "cbdev/clock_sync.h" #ifndef _WIN32 @@ -242,12 +243,30 @@ struct SdkSession::Impl { struct GroupCB { CallbackHandle handle; uint8_t group_id; GroupCallback cb; }; struct GroupBatchCB { CallbackHandle handle; uint8_t group_id; GroupBatchCallback cb; }; struct ConfigCB { CallbackHandle handle; uint16_t packet_type; ConfigCallback cb; }; + struct RunlevelCB { CallbackHandle handle; RunlevelCallback cb; }; std::vector packet_callbacks; std::vector event_callbacks; std::vector group_callbacks; std::vector group_batch_callbacks; std::vector config_callbacks; + std::vector runlevel_callbacks; + + /// Atomically update device_runlevel; fire registered callbacks if the + /// value changed. Called from the receive thread (STANDALONE) or the + /// shmem-receive thread (CLIENT) — both paths converge here. + void updateRunlevel(uint32_t new_runlevel) { + const uint32_t prev = device_runlevel.exchange(new_runlevel, std::memory_order_acq_rel); + if (prev == new_runlevel) return; + std::vector snap; + { + std::lock_guard lock(user_callback_mutex); + snap = runlevel_callbacks; + } + for (const auto& cb : snap) { + if (cb.cb) cb.cb(new_runlevel); + } + } CallbackHandle next_callback_handle = 1; ErrorCallback error_callback; @@ -872,7 +891,7 @@ Result SdkSession::start() { // Check for SYSREP packets (handshake responses) if ((pkt.cbpkt_header.type & 0xF0) == cbPKTTYPE_SYSREP) { const auto* sysinfo = reinterpret_cast(&pkt); - impl->device_runlevel.store(sysinfo->runlevel, std::memory_order_release); + impl->updateRunlevel(sysinfo->runlevel); impl->received_sysrep.store(true, std::memory_order_release); impl->handshake_cv.notify_all(); } @@ -1183,7 +1202,7 @@ Result SdkSession::start() { // Check for SYSREP packets (handshake responses) if ((packets[i].cbpkt_header.type & 0xF0) == cbPKTTYPE_SYSREP) { const auto* sysinfo = reinterpret_cast(&packets[i]); - impl->device_runlevel.store(sysinfo->runlevel, std::memory_order_release); + impl->updateRunlevel(sysinfo->runlevel); impl->received_sysrep.store(true, std::memory_order_release); impl->handshake_cv.notify_all(); } @@ -1320,6 +1339,13 @@ CallbackHandle SdkSession::registerConfigCallback(const uint16_t packet_type, Co return handle; } +CallbackHandle SdkSession::registerRunlevelChangeCallback(RunlevelCallback callback) const { + std::lock_guard lock(m_impl->user_callback_mutex); + const auto handle = m_impl->next_callback_handle++; + m_impl->runlevel_callbacks.push_back({handle, std::move(callback)}); + return handle; +} + void SdkSession::unregisterCallback(CallbackHandle handle) const { std::lock_guard lock(m_impl->user_callback_mutex); auto erase_by_handle = [handle](auto& vec) { @@ -1332,6 +1358,7 @@ void SdkSession::unregisterCallback(CallbackHandle handle) const { erase_by_handle(m_impl->group_callbacks); erase_by_handle(m_impl->group_batch_callbacks); erase_by_handle(m_impl->config_callbacks); + erase_by_handle(m_impl->runlevel_callbacks); } void SdkSession::setErrorCallback(ErrorCallback callback) { @@ -1663,38 +1690,49 @@ static cbdev::ChannelType toDevChannelType(const ChannelType chanType) { } -Result SdkSession::setSampleGroup(const size_t nChans, const ChannelType chanType, - const SampleRate rate, const bool disableOthers) { - const uint32_t group_id = static_cast(rate); - // STANDALONE mode: delegate to device session (has full config + direct send) - if (m_impl->device_session) { - const auto r = m_impl->device_session->setChannelsGroupByType( - nChans, toDevChannelType(chanType), static_cast(group_id), disableOthers); - if (r.isError()) - return Result::error(r.error()); - return Result::ok(); +// Resolve (nChans, chans) into the explicit set of 1-based channel ids to +// configure. When @p chans is non-null, it's the verbatim list (length +// nChans). When @p chans is null, walk channel ids in ascending order and +// pick the first @p nChans that match @p chanType; nChans == UINT32_MAX +// selects all matching. +static std::vector resolveTargetChans( + const SdkSession& session, uint32_t nChans, ChannelType chanType, + const uint32_t* chans) { + if (chans != nullptr) { + return std::vector(chans, chans + nChans); } + std::vector result; + if (nChans != UINT32_MAX) { + result.reserve(std::min(nChans, cbMAXCHANS)); + } + for (uint32_t chan = 1; chan <= cbMAXCHANS && result.size() < nChans; ++chan) { + const cbPKT_CHANINFO* ci = session.getChanInfo(chan); + if (ci && classifyChannelByCaps(*ci) == chanType) { + result.push_back(chan); + } + } + return result; +} - // CLIENT mode: build packets from shmem chaninfo and send through shmem transmit queue - if (!m_impl->shmem_session) - return Result::error("No session available"); - - size_t count = 0; - for (uint32_t chan = 1; chan <= cbMAXCHANS; ++chan) { - if (!disableOthers && count >= nChans) - break; - - auto ci_result = m_impl->shmem_session->getChanInfo(chan - 1); - if (ci_result.isError()) - continue; - auto chaninfo = ci_result.value(); - - if (classifyChannelByCaps(chaninfo) != chanType) - continue; +Result SdkSession::setSampleGroup( + const uint32_t nChans, const ChannelType chanType, const SampleRate rate, + const bool disableOthers, const uint32_t* chans) { + // Pre-sync: ensure local chaninfo cache is up to date before we seed + // outgoing CHANSET* packets from it. Stale cache would re-send obsolete + // values for fields we don't explicitly modify here. + if (auto r = sync(5000); r.isError()) return r; - const auto grp = count < nChans ? group_id : 0u; + const uint32_t group_id = static_cast(rate); + const std::vector targets = resolveTargetChans(*this, nChans, chanType, chans); + + // Build a packet for `chan` with `grp` as its target group. + // Always sends the full chaninfo (seeded from local cache), so a stale + // CHANREP can't leave us stuck against a concurrent change. + auto build_packet = [this](uint32_t chan, uint32_t grp) -> std::optional { + const cbPKT_CHANINFO* base = getChanInfo(chan); + if (!base || base->chan == 0) return std::nullopt; + cbPKT_CHANINFO chaninfo = *base; chaninfo.chan = chan; - if (grp > 0 && grp < 6) { chaninfo.cbpkt_header.type = cbPKTTYPE_CHANSETSMP; chaninfo.smpgroup = grp; @@ -1712,142 +1750,113 @@ Result SdkSession::setSampleGroup(const size_t nChans, const ChannelType c chaninfo.smpgroup = 0; chaninfo.ainpopts &= ~cbAINP_RAWSTREAM; } + return chaninfo; + }; - auto r = sendPacket(reinterpret_cast(chaninfo)); - if (r.isError()) - return r; - count++; + // Build the packet vector. Targets first (configured to `rate`), then + // — if disableOthers — every other matching channel set to disabled. + std::vector packets; + packets.reserve(targets.size() + (disableOthers ? cbMAXCHANS : 0)); + for (uint32_t chan : targets) { + if (auto pkt = build_packet(chan, group_id); pkt) { + packets.push_back(reinterpret_cast(*pkt)); + } + } + if (disableOthers) { + std::set target_set(targets.begin(), targets.end()); + for (uint32_t chan = 1; chan <= cbMAXCHANS; ++chan) { + if (target_set.count(chan)) continue; + const cbPKT_CHANINFO* ci = getChanInfo(chan); + if (!ci || classifyChannelByCaps(*ci) != chanType) continue; + if (auto pkt = build_packet(chan, 0u); pkt) { + packets.push_back(reinterpret_cast(*pkt)); + } + } } - if (count == 0) - return Result::error("No channels found matching type"); - return Result::ok(); + return sendBulkPackets(packets); } -Result SdkSession::setSpikeSorting(const size_t nChans, const ChannelType chanType, - const uint32_t sortOptions) { - // STANDALONE mode: delegate to device session +// Bulk-send a vector of packets using the most efficient path: +// - STANDALONE: device_session->sendPackets — direct UDP with built-in +// pacing. +// - CLIENT: per-packet shmem enqueue, drained in FIFO order by the peer +// STANDALONE process. +Result SdkSession::sendBulkPackets(const std::vector& packets) { + if (packets.empty()) return Result::ok(); if (m_impl->device_session) { - const auto r = m_impl->device_session->setChannelsSpikeSortingByType( - nChans, toDevChannelType(chanType), sortOptions); - if (r.isError()) - return Result::error(r.error()); - return Result::ok(); + return m_impl->device_session->sendPackets(packets); } - - // CLIENT mode: build packets from shmem chaninfo - if (!m_impl->shmem_session) + if (!m_impl->shmem_session) { return Result::error("No session available"); - - size_t count = 0; - for (uint32_t chan = 1; chan <= cbMAXCHANS && count < nChans; ++chan) { - auto ci_result = m_impl->shmem_session->getChanInfo(chan - 1); - if (ci_result.isError()) - continue; - auto chaninfo = ci_result.value(); - - if (classifyChannelByCaps(chaninfo) != chanType) - continue; - - chaninfo.cbpkt_header.type = cbPKTTYPE_CHANSETSPKTHR; - chaninfo.chan = chan; - chaninfo.spkopts &= ~cbAINPSPK_ALLSORT; - chaninfo.spkopts |= sortOptions; - - auto r = sendPacket(reinterpret_cast(chaninfo)); - if (r.isError()) - return r; - count++; } - - if (count == 0) - return Result::error("No channels found matching type"); + for (const auto& pkt : packets) { + if (auto r = sendPacket(pkt); r.isError()) return r; + } return Result::ok(); } -Result SdkSession::setSpikeExtraction(const size_t nChans, const ChannelType chanType, - const bool enabled) { - // STANDALONE mode: delegate to device session - if (m_impl->device_session) { - const auto r = m_impl->device_session->setSpikeExtraction( - nChans, toDevChannelType(chanType), enabled); - if (r.isError()) - return Result::error(r.error()); - return Result::ok(); - } - - // CLIENT mode: build packets from shmem chaninfo - if (!m_impl->shmem_session) - return Result::error("No session available"); - - size_t count = 0; - for (uint32_t chan = 1; chan <= cbMAXCHANS && count < nChans; ++chan) { - auto ci_result = m_impl->shmem_session->getChanInfo(chan - 1); - if (ci_result.isError()) - continue; - auto chaninfo = ci_result.value(); - - if (classifyChannelByCaps(chaninfo) != chanType) - continue; - - chaninfo.cbpkt_header.type = cbPKTTYPE_CHANSETSPK; +// Helper for the three "configure-each" bulk setters that don't have +// disable_others semantics: setSpikeSorting, setSpikeExtraction, +// setACInputCoupling. Pre-syncs, iterates the resolved target list, +// builds a packet per chan via @p mutate, and sends. Fire-and-forget +// on the response side — caller can call sync() if it needs to read +// back state. +template +static Result applyBulkSetter( + SdkSession& session, uint32_t nChans, ChannelType chanType, + const uint32_t* chans, Mutate&& mutate) { + if (auto r = session.sync(5000); r.isError()) return r; + + const std::vector targets = resolveTargetChans(session, nChans, chanType, chans); + std::vector packets; + packets.reserve(targets.size()); + for (uint32_t chan : targets) { + const cbPKT_CHANINFO* base = session.getChanInfo(chan); + if (!base || base->chan == 0) continue; + cbPKT_CHANINFO chaninfo = *base; chaninfo.chan = chan; - chaninfo.spkopts &= ~cbAINPSPK_EXTRACT; - if (enabled) - chaninfo.spkopts |= cbAINPSPK_EXTRACT; - - auto r = sendPacket(reinterpret_cast(chaninfo)); - if (r.isError()) - return r; - count++; + mutate(chaninfo); + packets.push_back(reinterpret_cast(chaninfo)); } - if (count == 0) - return Result::error("No channels found matching type"); - return Result::ok(); + return session.sendBulkPackets(packets); } -Result SdkSession::setACInputCoupling(const size_t nChans, const ChannelType chanType, - const bool enabled) { - // STANDALONE mode: delegate to device session - if (m_impl->device_session) { - const auto r = m_impl->device_session->setChannelsACInputCouplingByType( - nChans, toDevChannelType(chanType), enabled); - if (r.isError()) - return Result::error(r.error()); - return Result::ok(); - } - - // CLIENT mode: build packets from shmem chaninfo - if (!m_impl->shmem_session) - return Result::error("No session available"); - - size_t count = 0; - for (uint32_t chan = 1; chan <= cbMAXCHANS && count < nChans; ++chan) { - auto ci_result = m_impl->shmem_session->getChanInfo(chan - 1); - if (ci_result.isError()) - continue; - auto chaninfo = ci_result.value(); - - if (classifyChannelByCaps(chaninfo) != chanType) - continue; - - chaninfo.cbpkt_header.type = cbPKTTYPE_CHANSETAINP; - chaninfo.chan = chan; - if (enabled) - chaninfo.ainpopts |= cbAINP_OFFSET_CORRECT; - else - chaninfo.ainpopts &= ~cbAINP_OFFSET_CORRECT; +Result SdkSession::setSpikeSorting( + const uint32_t nChans, const ChannelType chanType, const uint32_t sortOptions, + const uint32_t* chans) { + return applyBulkSetter(*this, nChans, chanType, chans, + [sortOptions](cbPKT_CHANINFO& ci) { + // Use CHANSET so the firmware applies spkopts. CHANSETSPKTHR + // (used by older revisions) only reads spkthrlevel and would + // silently ignore the spkopts modification. + ci.cbpkt_header.type = cbPKTTYPE_CHANSET; + ci.spkopts &= ~cbAINPSPK_ALLSORT; + ci.spkopts |= sortOptions; + }); +} - auto r = sendPacket(reinterpret_cast(chaninfo)); - if (r.isError()) - return r; - count++; - } +Result SdkSession::setSpikeExtraction( + const uint32_t nChans, const ChannelType chanType, const bool enabled, + const uint32_t* chans) { + return applyBulkSetter(*this, nChans, chanType, chans, + [enabled](cbPKT_CHANINFO& ci) { + ci.cbpkt_header.type = cbPKTTYPE_CHANSETSPK; + ci.spkopts &= ~cbAINPSPK_EXTRACT; + if (enabled) ci.spkopts |= cbAINPSPK_EXTRACT; + }); +} - if (count == 0) - return Result::error("No channels found matching type"); - return Result::ok(); +Result SdkSession::setACInputCoupling( + const uint32_t nChans, const ChannelType chanType, const bool enabled, + const uint32_t* chans) { + return applyBulkSetter(*this, nChans, chanType, chans, + [enabled](cbPKT_CHANINFO& ci) { + ci.cbpkt_header.type = cbPKTTYPE_CHANSETAINP; + if (enabled) ci.ainpopts |= cbAINP_OFFSET_CORRECT; + else ci.ainpopts &= ~cbAINP_OFFSET_CORRECT; + }); } Result SdkSession::setChannelConfig(const cbPKT_CHANINFO& chaninfo) { @@ -2111,6 +2120,55 @@ Result SdkSession::loadChannelMap( return Result::ok(); } +Result SdkSession::clearChannelMap() { + // Snapshot the previously-mapped channel ids and drop the overlay so + // future CHANREPs land in chaninfo as the device sends them. Any further + // applyCmpToAllChannels() call is now a no-op. + std::vector mapped_chans; + { + std::lock_guard lock(m_impl->cmp_mutex); + mapped_chans.reserve(m_impl->cmp_entries.size()); + for (const auto& [chan_id, _] : m_impl->cmp_entries) { + mapped_chans.push_back(chan_id); + } + m_impl->cmp_entries.clear(); + } + + // Reset shmem positions to zero for previously-mapped channels. The + // device doesn't persist positions, so there's no on-device action — just + // wipe the local overlay we applied in applyCmpToAllChannels(). + if (m_impl->shmem_session) { + for (uint32_t chan_id : mapped_chans) { + if (chan_id < 1 || chan_id > cbMAXCHANS) continue; + auto r = m_impl->shmem_session->getChanInfo(chan_id - 1); + if (r.isError()) continue; + cbPKT_CHANINFO ci = r.value(); + std::memset(ci.position, 0, sizeof(ci.position)); + m_impl->shmem_session->setChanInfo(chan_id - 1, ci); + } + } + + // Push default labels ("chan{N}") to the device so the device-side label + // state matches. Labels were modified by loadChannelMap; positions were + // local-only. + if (m_impl->device_session || m_impl->shmem_session) { + for (uint32_t chan_id : mapped_chans) { + const cbPKT_CHANINFO* info = getChanInfo(chan_id); + if (!info) continue; + cbPKT_CHANINFO ci = *info; + ci.chan = chan_id; + ci.cbpkt_header.type = cbPKTTYPE_CHANSETLABEL; + char default_label[16]; + std::snprintf(default_label, sizeof(default_label), "chan%u", chan_id); + std::strncpy(ci.label, default_label, sizeof(ci.label) - 1); + ci.label[sizeof(ci.label) - 1] = '\0'; + (void)setChannelConfig(ci); // best-effort + } + } + + return Result::ok(); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// // CCF Configuration Files /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tests/integration/test_capi_configuration.cpp b/tests/integration/test_capi_configuration.cpp index 6201b40..d835ef6 100644 --- a/tests/integration/test_capi_configuration.cpp +++ b/tests/integration/test_capi_configuration.cpp @@ -132,10 +132,9 @@ TEST_F(CApiSampleGroupTest, SetFrontend30kHz) { ASSERT_TRUE(sg.create()); EXPECT_EQ(cbsdk_session_set_sample_group( - sg.session, 4, CBPROTO_CHANNEL_TYPE_FRONTEND, + sg.session, 4, /*chans=*/nullptr, CBPROTO_CHANNEL_TYPE_FRONTEND, CBPROTO_GROUP_RATE_30000Hz, true), CBSDK_RESULT_SUCCESS); - - std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(cbsdk_session_sync(sg.session, 5000), CBSDK_RESULT_SUCCESS); // Verify via group list uint16_t list[256]; @@ -150,10 +149,9 @@ TEST_F(CApiSampleGroupTest, SetAndVerifyField) { ASSERT_TRUE(sg.create()); EXPECT_EQ(cbsdk_session_set_sample_group( - sg.session, 4, CBPROTO_CHANNEL_TYPE_FRONTEND, + sg.session, 4, /*chans=*/nullptr, CBPROTO_CHANNEL_TYPE_FRONTEND, CBPROTO_GROUP_RATE_10000Hz, true), CBSDK_RESULT_SUCCESS); - - std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_EQ(cbsdk_session_sync(sg.session, 5000), CBSDK_RESULT_SUCCESS); // Query via bulk field getter int64_t values[512]; @@ -268,7 +266,8 @@ TEST_F(CApiACCouplingTest, SetACCoupling) { ASSERT_TRUE(sg.create()); EXPECT_EQ(cbsdk_session_set_ac_input_coupling( - sg.session, 4, CBPROTO_CHANNEL_TYPE_FRONTEND, true), CBSDK_RESULT_SUCCESS); + sg.session, 4, /*chans=*/nullptr, CBPROTO_CHANNEL_TYPE_FRONTEND, true), + CBSDK_RESULT_SUCCESS); } TEST_F(CApiACCouplingTest, SetDCCoupling) { @@ -276,7 +275,8 @@ TEST_F(CApiACCouplingTest, SetDCCoupling) { ASSERT_TRUE(sg.create()); EXPECT_EQ(cbsdk_session_set_ac_input_coupling( - sg.session, 4, CBPROTO_CHANNEL_TYPE_FRONTEND, false), CBSDK_RESULT_SUCCESS); + sg.session, 4, /*chans=*/nullptr, CBPROTO_CHANNEL_TYPE_FRONTEND, false), + CBSDK_RESULT_SUCCESS); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -290,7 +290,8 @@ TEST_F(CApiSpikeSortingTest, SetSpikeSorting) { ASSERT_TRUE(sg.create()); EXPECT_EQ(cbsdk_session_set_spike_sorting( - sg.session, 4, CBPROTO_CHANNEL_TYPE_FRONTEND, 0), CBSDK_RESULT_SUCCESS); + sg.session, 4, /*chans=*/nullptr, CBPROTO_CHANNEL_TYPE_FRONTEND, 0), + CBSDK_RESULT_SUCCESS); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -327,9 +328,9 @@ TEST_F(CApiPerChannelTest, SetChannelLabel) { SessionGuard sg; ASSERT_TRUE(sg.create()); - EXPECT_EQ(cbsdk_session_set_channel_label(sg.session, 1, "TestCh"), + EXPECT_EQ(cbsdk_session_set_channel_label(sg.session, 1, "TestCh", /*auto_sync=*/0), CBSDK_RESULT_SUCCESS); - std::this_thread::sleep_for(std::chrono::milliseconds(300)); + EXPECT_EQ(cbsdk_session_sync(sg.session, 5000), CBSDK_RESULT_SUCCESS); const char* label = cbsdk_session_get_channel_label(sg.session, 1); ASSERT_NE(label, nullptr); @@ -340,9 +341,9 @@ TEST_F(CApiPerChannelTest, SetChannelSmpfilter) { SessionGuard sg; ASSERT_TRUE(sg.create()); - EXPECT_EQ(cbsdk_session_set_channel_smpfilter(sg.session, 1, 2), + EXPECT_EQ(cbsdk_session_set_channel_smpfilter(sg.session, 1, 2, /*auto_sync=*/0), CBSDK_RESULT_SUCCESS); - std::this_thread::sleep_for(std::chrono::milliseconds(300)); + EXPECT_EQ(cbsdk_session_sync(sg.session, 5000), CBSDK_RESULT_SUCCESS); EXPECT_EQ(cbsdk_session_get_channel_smpfilter(sg.session, 1), 2u); } @@ -351,9 +352,9 @@ TEST_F(CApiPerChannelTest, SetChannelSpkfilter) { SessionGuard sg; ASSERT_TRUE(sg.create()); - EXPECT_EQ(cbsdk_session_set_channel_spkfilter(sg.session, 1, 3), + EXPECT_EQ(cbsdk_session_set_channel_spkfilter(sg.session, 1, 3, /*auto_sync=*/0), CBSDK_RESULT_SUCCESS); - std::this_thread::sleep_for(std::chrono::milliseconds(300)); + EXPECT_EQ(cbsdk_session_sync(sg.session, 5000), CBSDK_RESULT_SUCCESS); EXPECT_EQ(cbsdk_session_get_channel_spkfilter(sg.session, 1), 3u); } diff --git a/tests/unit/test_cbsdk_c_api.cpp b/tests/unit/test_cbsdk_c_api.cpp index f1ebab2..729e46a 100644 --- a/tests/unit/test_cbsdk_c_api.cpp +++ b/tests/unit/test_cbsdk_c_api.cpp @@ -385,8 +385,9 @@ TEST_F(CbsdkCApiTest, ConfigAccess_WithSession) { /////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(CbsdkCApiTest, SetChannelSampleGroup_NullSession) { - EXPECT_EQ(cbsdk_session_set_sample_group(nullptr, 256, - CBPROTO_CHANNEL_TYPE_FRONTEND, CBPROTO_GROUP_RATE_30000Hz, false), CBSDK_RESULT_INVALID_PARAMETER); + EXPECT_EQ(cbsdk_session_set_sample_group(nullptr, 256u, /*chans=*/nullptr, + CBPROTO_CHANNEL_TYPE_FRONTEND, CBPROTO_GROUP_RATE_30000Hz, false), + CBSDK_RESULT_INVALID_PARAMETER); } TEST_F(CbsdkCApiTest, SetChannelConfig_NullSession) {