diff --git a/src/internal_modules/roc_pipeline/state_tracker.cpp b/src/internal_modules/roc_pipeline/state_tracker.cpp index bc2b6ef6f..9981d3a1b 100644 --- a/src/internal_modules/roc_pipeline/state_tracker.cpp +++ b/src/internal_modules/roc_pipeline/state_tracker.cpp @@ -8,14 +8,51 @@ #include "roc_pipeline/state_tracker.h" #include "roc_core/panic.h" - namespace roc { namespace pipeline { StateTracker::StateTracker() - : halt_state_(-1) + : sem_(0) + , halt_state_(-1) , active_sessions_(0) - , pending_packets_(0) { + , pending_packets_(0) + , waiting_mask_(0) { +} + +// This method should block until the state becomes any of the states specified by the +// mask, or deadline expires. E.g. if mask is ACTIVE | PAUSED, it should block until state +// becomes either ACTIVE or PAUSED. (Currently only two states are used, but later more +// states will be needed). Deadline should be an absolute timestamp. + +// Questions: +// - When should the function return true vs false +bool StateTracker::wait_state(unsigned state_mask, core::nanoseconds_t deadline) { + waiting_mask_ = state_mask; + for (;;) { + // If no state is specified in state_mask, return immediately + if (state_mask == 0) { + return true; + } + + if (static_cast(get_state()) & state_mask) { + waiting_mask_ = 0; + return true; + } + + if (deadline >= 0 && deadline <= core::timestamp(core::ClockMonotonic)) { + waiting_mask_ = 0; + return false; + } + + if (deadline >= 0) { + if (!sem_.timed_wait(deadline)) { + waiting_mask_ = 0; + return false; + } + } else { + sem_.wait(); + } + } } sndio::DeviceState StateTracker::get_state() const { @@ -65,22 +102,46 @@ size_t StateTracker::num_sessions() const { } void StateTracker::register_session() { - active_sessions_++; + if (active_sessions_++ == 0) { + signal_state_change(); + } } void StateTracker::unregister_session() { - if (--active_sessions_ < 0) { + int prev_sessions = active_sessions_--; + if (prev_sessions == 0) { roc_panic("state tracker: unpaired register/unregister session"); + } else if (prev_sessions == 1 && pending_packets_ == 0) { + signal_state_change(); } + + // if (--active_sessions_ < 0) { + // roc_panic("state tracker: unpaired register/unregister session"); + // } } void StateTracker::register_packet() { - pending_packets_++; + if (pending_packets_++ == 0 && active_sessions_ == 0) { + signal_state_change(); + } } void StateTracker::unregister_packet() { - if (--pending_packets_ < 0) { + int prev_packets = pending_packets_--; + if (prev_packets == 0) { roc_panic("state tracker: unpaired register/unregister packet"); + } else if (prev_packets == 1 && active_sessions_ == 0) { + signal_state_change(); + } + + // if (--pending_packets_ < 0) { + // roc_panic("state tracker: unpaired register/unregister packet"); + // } +} + +void StateTracker::signal_state_change() { + if (waiting_mask_ != 0 && (static_cast(get_state()) & waiting_mask_)) { + sem_.post(); } } diff --git a/src/internal_modules/roc_pipeline/state_tracker.h b/src/internal_modules/roc_pipeline/state_tracker.h index 1f5bb1499..5d3e7f498 100644 --- a/src/internal_modules/roc_pipeline/state_tracker.h +++ b/src/internal_modules/roc_pipeline/state_tracker.h @@ -14,7 +14,9 @@ #include "roc_core/atomic.h" #include "roc_core/noncopyable.h" +#include "roc_core/semaphore.h" #include "roc_core/stddefs.h" +#include "roc_core/time.h" #include "roc_sndio/device_state.h" namespace roc { @@ -32,6 +34,9 @@ class StateTracker : public core::NonCopyable<> { //! Initialize all counters to zero. StateTracker(); + //! Block until state becomes any of the ones specified by state_mask. + bool wait_state(unsigned state_mask, core::nanoseconds_t deadline); + //! Compute current state. sndio::DeviceState get_state() const; @@ -63,9 +68,12 @@ class StateTracker : public core::NonCopyable<> { void unregister_packet(); private: + core::Semaphore sem_; core::Atomic halt_state_; core::Atomic active_sessions_; core::Atomic pending_packets_; + core::Atomic waiting_mask_; + void signal_state_change(); }; } // namespace pipeline