Skip to content

Commit eda106e

Browse files
committed
don't close stream of a response API observer
1 parent c570038 commit eda106e

File tree

4 files changed

+29
-16
lines changed

4 files changed

+29
-16
lines changed

shai-http/src/apis/openai/completion/handler.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async fn handle_chat_completion_stream(
6262
let formatter = ChatCompletionFormatter::new(model);
6363

6464
// Create SSE stream
65-
let stream = session_to_sse_stream(request_session, formatter, session_id);
65+
let stream = session_to_sse_stream(request_session, formatter, session_id, true);
6666

6767
Ok(Sse::new(stream).into_response())
6868
}

shai-http/src/apis/openai/response/handler.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ async fn handle_response_stream(
6969
let formatter = ResponseFormatter::new(model, payload);
7070

7171
// Create SSE stream
72-
let stream = session_to_sse_stream(request_session, formatter, session_id);
72+
let stream = session_to_sse_stream(request_session, formatter, session_id, true);
7373

7474
Ok(Sse::new(stream).into_response())
7575
}
@@ -115,7 +115,8 @@ pub async fn handle_get_response(
115115
let formatter = ResponseFormatter::new(agent_session.agent_name.clone(), placeholder_payload);
116116

117117
// Create SSE stream using the simple sse_stream (no lifecycle needed for read-only)
118-
let stream = event_to_sse_stream(event_rx, formatter, response_id);
118+
// stop_on_pause = false means stream stops on Completed OR Paused
119+
let stream = event_to_sse_stream(event_rx, formatter, response_id, false);
119120

120121
Ok(Sse::new(stream).into_response())
121122
}

shai-http/src/apis/simple/handler.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ pub async fn handle_multimodal_query_stream(
6363
let formatter = SimpleFormatter::new(payload.model.clone());
6464

6565
// Create SSE stream
66-
let stream = session_to_sse_stream(request_session, formatter, session_id);
66+
let stream = session_to_sse_stream(request_session, formatter, session_id, true);
6767

6868
Ok(Sse::new(stream).into_response())
6969
}

shai-http/src/streaming.rs

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ fn sse_stream_internal<F, L>(
3636
formatter: F,
3737
session_id: String,
3838
lifecycle: Option<L>,
39+
stop_on_pause: bool,
3940
) -> impl Stream<Item = Result<Event, Infallible>>
4041
where
4142
F: EventFormatter + 'static,
@@ -55,7 +56,7 @@ where
5556
loop {
5657
match rx.next().await {
5758
Some(Ok(event)) => {
58-
let is_terminal = is_terminal_event(&event);
59+
let is_terminal = is_terminal_event(&event, stop_on_pause);
5960
let formatted = fmt.format_event(event, &session_id).await;
6061
let new_done = if is_terminal { true } else { done };
6162

@@ -93,23 +94,31 @@ where
9394

9495
/// Core SSE stream creation from event receiver
9596
/// Watches events, formats them, and stops on completion or client disconnect
97+
///
98+
/// # Parameters
99+
/// * `stop_on_pause` - If true, only stops on Completed. If false, stops on Completed or StatusChanged to Paused.
96100
pub fn event_to_sse_stream<F>(
97101
event_rx: Receiver<AgentEvent>,
98102
formatter: F,
99103
session_id: String,
104+
stop_on_pause: bool,
100105
) -> impl Stream<Item = Result<Event, Infallible>>
101106
where
102107
F: EventFormatter + 'static,
103108
{
104-
sse_stream_internal(event_rx, formatter, session_id, None::<()>)
109+
sse_stream_internal(event_rx, formatter, session_id, None::<()>, stop_on_pause)
105110
}
106111

107112
/// Create an SSE stream from a RequestSession
108113
/// Same as sse_stream but keeps lifecycle in scope for session cleanup
114+
///
115+
/// # Parameters
116+
/// * `stop_on_pause` - If true, only stops on Completed. If false, stops on Completed or StatusChanged to Paused.
109117
pub fn session_to_sse_stream<F>(
110118
request_session: RequestSession,
111119
formatter: F,
112120
session_id: String,
121+
stop_on_pause: bool,
113122
) -> impl Stream<Item = Result<Event, Infallible>>
114123
where
115124
F: EventFormatter + 'static,
@@ -118,17 +127,20 @@ where
118127
let _controller = request_session.controller;
119128
let lifecycle = request_session.lifecycle;
120129

121-
sse_stream_internal(event_rx, formatter, session_id, Some(lifecycle))
130+
sse_stream_internal(event_rx, formatter, session_id, Some(lifecycle), stop_on_pause)
122131
}
123132

124133
/// Check if an event signals the end of the stream
125-
fn is_terminal_event(event: &AgentEvent) -> bool {
126-
matches!(
127-
event,
128-
AgentEvent::Completed { .. }
129-
| AgentEvent::StatusChanged {
130-
new_status: PublicAgentState::Paused,
131-
..
132-
}
133-
)
134+
///
135+
/// # Parameters
136+
/// * `stop_on_pause` - If true, only Completed is terminal. If false, both Completed and Paused are terminal.
137+
fn is_terminal_event(event: &AgentEvent, stop_on_pause: bool) -> bool {
138+
match event {
139+
AgentEvent::Completed { .. } => true,
140+
AgentEvent::StatusChanged {
141+
new_status: PublicAgentState::Paused,
142+
..
143+
} => stop_on_pause,
144+
_ => false,
145+
}
134146
}

0 commit comments

Comments
 (0)