Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions server/claude-sdk.js
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ function addSession(sessionId, queryInstance, tempImagePaths = [], tempDir = nul
status: 'active',
tempImagePaths,
tempDir,
writer
ws: writer && writer.isWebSocketWriter ? writer.ws : (writer && typeof writer.send === 'function' ? writer : null)
});
}

Expand Down Expand Up @@ -476,7 +476,7 @@ async function loadMcpConfig(cwd) {
* @param {Object} ws - WebSocket connection
* @returns {Promise<void>}
*/
async function queryClaudeSDK(command, options = {}, ws) {
async function queryClaudeSDK(command, options = {}, writer) {
const { sessionId, sessionSummary } = options;
let capturedSessionId = sessionId;
let sessionCreatedSent = false;
Expand All @@ -485,8 +485,8 @@ async function queryClaudeSDK(command, options = {}, ws) {

const emitNotification = (event) => {
notifyUserIfEnabled({
userId: ws?.userId || null,
writer: ws,
userId: writer?.userId || null,
writer: writer,
event
});
};
Expand Down Expand Up @@ -551,7 +551,7 @@ async function queryClaudeSDK(command, options = {}, ws) {
}

const requestId = createRequestId();
ws.send(createNormalizedMessage({ kind: 'permission_request', requestId, toolName, input, sessionId: capturedSessionId || sessionId || null, provider: 'claude' }));
writer.send(createNormalizedMessage({ kind: 'permission_request', requestId, toolName, input, sessionId: capturedSessionId || sessionId || null, provider: 'claude' }));
emitNotification(createNotificationEvent({
provider: 'claude',
sessionId: capturedSessionId || sessionId || null,
Expand All @@ -573,7 +573,7 @@ async function queryClaudeSDK(command, options = {}, ws) {
_receivedAt: new Date(),
},
onCancel: (reason) => {
ws.send(createNormalizedMessage({ kind: 'permission_cancelled', requestId, reason, sessionId: capturedSessionId || sessionId || null, provider: 'claude' }));
writer.send(createNormalizedMessage({ kind: 'permission_cancelled', requestId, reason, sessionId: capturedSessionId || sessionId || null, provider: 'claude' }));
}
});
if (!decision) {
Expand Down Expand Up @@ -629,7 +629,7 @@ async function queryClaudeSDK(command, options = {}, ws) {

// Track the query instance for abort capability
if (capturedSessionId) {
addSession(capturedSessionId, queryInstance, tempImagePaths, tempDir, ws);
addSession(capturedSessionId, queryInstance, tempImagePaths, tempDir, writer);
}

// Process streaming messages
Expand All @@ -639,17 +639,17 @@ async function queryClaudeSDK(command, options = {}, ws) {
if (message.session_id && !capturedSessionId) {

capturedSessionId = message.session_id;
addSession(capturedSessionId, queryInstance, tempImagePaths, tempDir, ws);
addSession(capturedSessionId, queryInstance, tempImagePaths, tempDir, writer);

// Set session ID on writer
if (ws.setSessionId && typeof ws.setSessionId === 'function') {
ws.setSessionId(capturedSessionId);
if (writer && typeof writer.setSessionId === 'function') {
writer.setSessionId(capturedSessionId);
}

// Send session-created event only once for new sessions
if (!sessionId && !sessionCreatedSent) {
sessionCreatedSent = true;
ws.send(createNormalizedMessage({ kind: 'session_created', newSessionId: capturedSessionId, sessionId: capturedSessionId, provider: 'claude' }));
writer.send(createNormalizedMessage({ kind: 'session_created', newSessionId: capturedSessionId, sessionId: capturedSessionId, provider: 'claude' }));
}
} else {
// session_id already captured
Expand All @@ -666,7 +666,7 @@ async function queryClaudeSDK(command, options = {}, ws) {
if (transformedMessage.parentToolUseId && !msg.parentToolUseId) {
msg.parentToolUseId = transformedMessage.parentToolUseId;
}
ws.send(msg);
writer.send(msg);
}

// Extract and send token budget updates from result messages
Expand All @@ -677,7 +677,7 @@ async function queryClaudeSDK(command, options = {}, ws) {
}
const tokenBudgetData = extractTokenBudget(message);
if (tokenBudgetData) {
ws.send(createNormalizedMessage({ kind: 'status', text: 'token_budget', tokenBudget: tokenBudgetData, sessionId: capturedSessionId || sessionId || null, provider: 'claude' }));
writer.send(createNormalizedMessage({ kind: 'status', text: 'token_budget', tokenBudget: tokenBudgetData, sessionId: capturedSessionId || sessionId || null, provider: 'claude' }));
}
}
}
Expand All @@ -691,9 +691,9 @@ async function queryClaudeSDK(command, options = {}, ws) {
await cleanupTempFiles(tempImagePaths, tempDir);

// Send completion event
ws.send(createNormalizedMessage({ kind: 'complete', exitCode: 0, isNewSession: !sessionId && !!command, sessionId: capturedSessionId, provider: 'claude' }));
writer.send(createNormalizedMessage({ kind: 'complete', exitCode: 0, isNewSession: !sessionId && !!command, sessionId: capturedSessionId, provider: 'claude' }));
notifyRunStopped({
userId: ws?.userId || null,
userId: writer?.userId || null,
provider: 'claude',
sessionId: capturedSessionId || sessionId || null,
sessionName: sessionSummary,
Expand All @@ -719,9 +719,9 @@ async function queryClaudeSDK(command, options = {}, ws) {
: error.message;

// Send error to WebSocket
ws.send(createNormalizedMessage({ kind: 'error', content: errorContent, sessionId: capturedSessionId || sessionId || null, provider: 'claude' }));
writer.send(createNormalizedMessage({ kind: 'error', content: errorContent, sessionId: capturedSessionId || sessionId || null, provider: 'claude' }));
notifyRunFailed({
userId: ws?.userId || null,
userId: writer?.userId || null,
provider: 'claude',
sessionId: capturedSessionId || sessionId || null,
sessionName: sessionSummary,
Expand Down Expand Up @@ -765,6 +765,23 @@ async function abortClaudeSDKSession(sessionId) {
}
}

/**
* Aborts all sessions associated with a specific WebSocket
* @param {WebSocket} ws - The WebSocket to match
* @returns {number} - Number of sessions aborted
*/
function abortClaudeSDKSessionsForWebSocket(ws) {
let count = 0;
for (const [id, session] of activeSessions.entries()) {
if (session.ws === ws && session.status === 'active') {
console.log(`[Claude] Aborting orphaned session ${id} due to WebSocket disconnect`);
abortClaudeSDKSession(id);
count++;
}
}
return count;
}

/**
* Checks if an SDK session is currently active
* @param {string} sessionId - Session identifier
Expand All @@ -780,7 +797,13 @@ function isClaudeSDKSessionActive(sessionId) {
* @returns {Array<string>} Array of active session IDs
*/
function getActiveClaudeSDKSessions() {
return getAllSessions();
const activeIds = [];
for (const [id, session] of activeSessions.entries()) {
if (session.status === 'active') {
activeIds.push(id);
}
}
return activeIds;
}

/**
Expand Down Expand Up @@ -820,10 +843,11 @@ function reconnectSessionWriter(sessionId, newRawWs) {
return true;
}

// Export public API
// Export public API - Fixed duplicate export issue
export {
queryClaudeSDK,
abortClaudeSDKSession,
abortClaudeSDKSessionsForWebSocket,
isClaudeSDKSessionActive,
getActiveClaudeSDKSessions,
resolveToolApproval,
Expand Down
50 changes: 35 additions & 15 deletions server/cursor-cli.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { spawn } from 'child_process';
import crossSpawn from 'cross-spawn';
import { StringDecoder } from 'string_decoder';
import { notifyRunFailed, notifyRunStopped } from './services/notification-orchestrator.js';
import { sessionsService } from './modules/providers/services/sessions.service.js';
import { providerAuthService } from './modules/providers/services/provider-auth.service.js';
Expand All @@ -25,7 +26,7 @@ function isWorkspaceTrustPrompt(text = '') {
return WORKSPACE_TRUST_PATTERNS.some((pattern) => pattern.test(text));
}

async function spawnCursor(command, options = {}, ws) {
async function spawnCursor(command, options = {}, writer) {
return new Promise(async (resolve, reject) => {
const { sessionId, projectPath, cwd, resume, toolsSettings, skipPermissions, model, sessionSummary } = options;
let capturedSessionId = sessionId; // Track session ID throughout the process
Expand Down Expand Up @@ -97,7 +98,7 @@ async function spawnCursor(command, options = {}, ws) {
const finalSessionId = capturedSessionId || sessionId || processKey;
if (code === 0 && !error) {
notifyRunStopped({
userId: ws?.userId || null,
userId: writer?.userId || null,
provider: 'cursor',
sessionId: finalSessionId,
sessionName: sessionSummary,
Expand All @@ -107,7 +108,7 @@ async function spawnCursor(command, options = {}, ws) {
}

notifyRunFailed({
userId: ws?.userId || null,
userId: writer?.userId || null,
provider: 'cursor',
sessionId: finalSessionId,
sessionName: sessionSummary,
Expand All @@ -129,6 +130,11 @@ async function spawnCursor(command, options = {}, ws) {
env: { ...process.env } // Inherit all environment variables
});

// Store WebSocket reference for cleanup on disconnect
if (writer && writer.ws) {
cursorProcess.ws = writer.ws;
}

activeCursorProcesses.set(processKey, cursorProcess);

const shouldSuppressForTrustRetry = (text) => {
Expand Down Expand Up @@ -168,14 +174,12 @@ async function spawnCursor(command, options = {}, ws) {
}

// Set session ID on writer (for API endpoint compatibility)
if (ws.setSessionId && typeof ws.setSessionId === 'function') {
ws.setSessionId(capturedSessionId);
}
writer && typeof writer.setSessionId === 'function' && writer.setSessionId(capturedSessionId);

// Send session-created event only once for new sessions
if (!sessionId && !sessionCreatedSent) {
sessionCreatedSent = true;
ws.send(createNormalizedMessage({ kind: 'session_created', newSessionId: capturedSessionId, model: response.model, cwd: response.cwd, sessionId: capturedSessionId, provider: 'cursor' }));
writer.send(createNormalizedMessage({ kind: 'session_created', newSessionId: capturedSessionId, model: response.model, cwd: response.cwd, sessionId: capturedSessionId, provider: 'cursor' }));
}
}

Expand All @@ -191,15 +195,15 @@ async function spawnCursor(command, options = {}, ws) {
// Accumulate assistant message chunks
if (response.message && response.message.content && response.message.content.length > 0) {
const normalized = sessionsService.normalizeMessage('cursor', response, capturedSessionId || sessionId || null);
for (const msg of normalized) ws.send(msg);
for (const msg of normalized) writer.send(msg);
}
break;

case 'result': {
// Session complete — send stream end + lifecycle complete with result payload
console.log('Cursor session result:', response);
const resultText = typeof response.result === 'string' ? response.result : '';
ws.send(createNormalizedMessage({
writer.send(createNormalizedMessage({
kind: 'complete',
exitCode: response.subtype === 'success' ? 0 : 1,
resultText,
Expand All @@ -221,13 +225,16 @@ async function spawnCursor(command, options = {}, ws) {

// If not JSON, send as stream delta via adapter
const normalized = sessionsService.normalizeMessage('cursor', line, capturedSessionId || sessionId || null);
for (const msg of normalized) ws.send(msg);
for (const msg of normalized) writer.send(msg);
}
};

const stdoutDecoder = new StringDecoder('utf8');
const stderrDecoder = new StringDecoder('utf8');

// Handle stdout (streaming JSON responses)
cursorProcess.stdout.on('data', (data) => {
const rawOutput = data.toString();
const rawOutput = stdoutDecoder.write(data);
console.log('Cursor CLI stdout:', rawOutput);

// Stream chunks can split JSON objects across packets; keep trailing partial line.
Expand All @@ -242,14 +249,14 @@ async function spawnCursor(command, options = {}, ws) {

// Handle stderr
cursorProcess.stderr.on('data', (data) => {
const stderrText = data.toString();
const stderrText = stderrDecoder.write(data);
console.error('Cursor CLI stderr:', stderrText);

if (shouldSuppressForTrustRetry(stderrText)) {
return;
}

ws.send(createNormalizedMessage({ kind: 'error', content: stderrText, sessionId: capturedSessionId || sessionId || null, provider: 'cursor' }));
writer.send(createNormalizedMessage({ kind: 'error', content: stderrText, sessionId: capturedSessionId || sessionId || null, provider: 'cursor' }));
});

// Handle process completion
Expand All @@ -276,7 +283,7 @@ async function spawnCursor(command, options = {}, ws) {
return;
}

ws.send(createNormalizedMessage({ kind: 'complete', exitCode: code, isNewSession: !sessionId && !!command, sessionId: finalSessionId, provider: 'cursor' }));
writer.send(createNormalizedMessage({ kind: 'complete', exitCode: code, isNewSession: !sessionId && !!command, sessionId: finalSessionId, provider: 'cursor' }));

if (code === 0) {
notifyTerminalState({ code });
Expand All @@ -301,7 +308,7 @@ async function spawnCursor(command, options = {}, ws) {
? 'Cursor CLI is not installed. Please install it from https://cursor.com'
: error.message;

ws.send(createNormalizedMessage({ kind: 'error', content: errorContent, sessionId: capturedSessionId || sessionId || null, provider: 'cursor' }));
writer.send(createNormalizedMessage({ kind: 'error', content: errorContent, sessionId: capturedSessionId || sessionId || null, provider: 'cursor' }));
notifyTerminalState({ error });

settleOnce(() => reject(error));
Expand All @@ -326,6 +333,18 @@ function abortCursorSession(sessionId) {
return false;
}

function abortCursorSessionsForWebSocket(ws) {
let count = 0;
for (const [sessionId, proc] of activeCursorProcesses.entries()) {
if (proc.ws === ws) {
console.log(`[Cursor] Aborting orphaned session ${sessionId} due to WebSocket disconnect`);
abortCursorSession(sessionId);
count++;
}
}
return count;
}

function isCursorSessionActive(sessionId) {
return activeCursorProcesses.has(sessionId);
}
Expand All @@ -337,6 +356,7 @@ function getActiveCursorSessions() {
export {
spawnCursor,
abortCursorSession,
abortCursorSessionsForWebSocket,
isCursorSessionActive,
getActiveCursorSessions
};
Loading