Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/browser/components/AppLoader.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ function AppLoaderInner() {
getPRStatusStoreInstance().setClient(api ?? null);

if (!workspaceContext.loading) {
// Tell the store which workspace is selected before syncing, so it only
// subscribes to onChat for that workspace (not all of them).
workspaceStoreInstance.setSelectedWorkspaceId(
workspaceContext.selectedWorkspace?.workspaceId ?? null
);
workspaceStoreInstance.syncWorkspaces(workspaceContext.workspaceMetadata);
gitStatusStore.syncWorkspaces(workspaceContext.workspaceMetadata);

Expand All @@ -107,6 +112,7 @@ function AppLoaderInner() {
}, [
workspaceContext.loading,
workspaceContext.workspaceMetadata,
workspaceContext.selectedWorkspace?.workspaceId,
workspaceStoreInstance,
gitStatusStore,
backgroundBashStore,
Expand Down
12 changes: 3 additions & 9 deletions src/browser/components/ConnectionStatusToast.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export const ConnectionStatusToast: React.FC<ConnectionStatusToastProps> = ({ wr
return null;
}

if (apiState.status === "degraded" || apiState.status === "reconnecting") {
if (apiState.status === "reconnecting") {
const content = (
<div
role="status"
Expand All @@ -41,14 +41,8 @@ export const ConnectionStatusToast: React.FC<ConnectionStatusToastProps> = ({ wr
>
<span className="bg-warning inline-block h-2 w-2 animate-pulse rounded-full" />
<span>
{apiState.status === "degraded" ? (
"Connection unstable β€” messages may be delayed"
) : (
<>
Reconnecting to server
{apiState.attempt > 1 && ` (attempt ${apiState.attempt})`}…
</>
)}
Reconnecting to server
{apiState.attempt > 1 && ` (attempt ${apiState.attempt})`}…
</span>
</div>
);
Expand Down
223 changes: 59 additions & 164 deletions src/browser/contexts/API.test.tsx
Original file line number Diff line number Diff line change
@@ -1,63 +1,20 @@
import { act, cleanup, render, waitFor } from "@testing-library/react";
import type { APIClient } from "@/browser/contexts/API";
import { afterEach, beforeEach, describe, expect, mock, test } from "bun:test";
import { GlobalWindow } from "happy-dom";

// Mock WebSocket that we can control
class MockWebSocket {
static instances: MockWebSocket[] = [];
url: string;
readyState = 0; // CONNECTING
eventListeners = new Map<string, Array<(event?: unknown) => void>>();

constructor(url: string) {
this.url = url;
MockWebSocket.instances.push(this);
}

addEventListener(event: string, handler: (event?: unknown) => void) {
const handlers = this.eventListeners.get(event) ?? [];
handlers.push(handler);
this.eventListeners.set(event, handlers);
}

close() {
this.readyState = 3; // CLOSED
}

// Test helpers
simulateOpen() {
this.readyState = 1; // OPEN
this.eventListeners.get("open")?.forEach((h) => h());
}

simulateClose(code: number) {
this.readyState = 3;
this.eventListeners.get("close")?.forEach((h) => h({ code }));
}

simulateError() {
this.eventListeners.get("error")?.forEach((h) => h());
}

static reset() {
MockWebSocket.instances = [];
}

static lastInstance(): MockWebSocket | undefined {
return MockWebSocket.instances[MockWebSocket.instances.length - 1];
}
}
// Control what ping() returns across tests.
let mockPing: (input: string) => Promise<string> = () => Promise.resolve("pong");

// Mock orpc client
void mock.module("@/common/orpc/client", () => ({
createClient: () => ({
general: {
ping: () => Promise.resolve("pong"),
ping: (input: string) => mockPing(input),
},
}),
}));

void mock.module("@orpc/client/websocket", () => ({
void mock.module("@orpc/client/fetch", () => ({
RPCLink: class {},
}));

Expand Down Expand Up @@ -98,10 +55,7 @@ function APIStateObserver(props: { onState: (state: UseAPIResult) => void }) {
return null;
}

// Factory that creates MockWebSocket instances (injected via prop)
const createMockWebSocket = (url: string) => new MockWebSocket(url) as unknown as WebSocket;

describe("API reconnection", () => {
describe("API connection (fetch transport)", () => {
beforeEach(() => {
// Minimal DOM setup required by @testing-library/react.
//
Expand All @@ -110,185 +64,126 @@ describe("API reconnection", () => {
const happyWindow = new GlobalWindow({ url: "https://mux.example.com/" });
globalThis.window = happyWindow as unknown as Window & typeof globalThis;
globalThis.document = happyWindow.document as unknown as Document;
MockWebSocket.reset();

// Default: ping succeeds.
mockPing = () => Promise.resolve("pong");
});

afterEach(() => {
cleanup();
MockWebSocket.reset();
globalThis.window = undefined as unknown as Window & typeof globalThis;
globalThis.document = undefined as unknown as Document;
});

test("constructs WebSocket URL with app proxy prefix", () => {
window.location.href = "https://coder.example.com/@u/ws/apps/mux/?token=abc";

render(
<APIProvider createWebSocket={createMockWebSocket}>
<APIStateObserver onState={() => undefined} />
</APIProvider>
);

const ws1 = MockWebSocket.lastInstance();
expect(ws1).toBeDefined();
expect(ws1!.url).toBe("wss://coder.example.com/@u/ws/apps/mux/orpc/ws?token=abc");
});

test("reconnects on close without showing auth_required when previously connected", async () => {
test("transitions to connected when ping succeeds", async () => {
const states: string[] = [];

render(
<APIProvider createWebSocket={createMockWebSocket}>
<APIProvider>
<APIStateObserver onState={(s) => states.push(s.status)} />
</APIProvider>
);

const ws1 = MockWebSocket.lastInstance();
expect(ws1).toBeDefined();

// Simulate successful connection (open + ping success)
await act(async () => {
ws1!.simulateOpen();
// Wait for ping promise to resolve
await new Promise((r) => setTimeout(r, 10));
});

expect(states).toContain("connected");

// Simulate server restart (close code 1006 = abnormal closure)
act(() => {
ws1!.simulateClose(1006);
});

// Should be "reconnecting", NOT "auth_required"
await waitFor(() => {
expect(states).toContain("reconnecting");
});

expect(states.filter((s) => s === "auth_required")).toHaveLength(0);

// New WebSocket should be created for reconnect attempt (after delay)
await waitFor(() => {
expect(MockWebSocket.instances.length).toBeGreaterThan(1);
expect(states).toContain("connected");
});
});

test("shows auth_required on close with auth error codes (4401)", async () => {
test("shows auth_required when ping returns an auth error", async () => {
mockPing = () => Promise.reject(new Error("401 Unauthorized"));

const states: string[] = [];

render(
<APIProvider createWebSocket={createMockWebSocket}>
<APIProvider>
<APIStateObserver onState={(s) => states.push(s.status)} />
</APIProvider>
);

const ws1 = MockWebSocket.lastInstance();
expect(ws1).toBeDefined();

await act(async () => {
ws1!.simulateOpen();
await new Promise((r) => setTimeout(r, 10));
});

expect(states).toContain("connected");

act(() => {
ws1!.simulateClose(4401);
});

await waitFor(() => {
expect(states).toContain("auth_required");
});
});

test("shows auth_required on close with auth error codes (1008)", async () => {
test("retries on first connection failure without showing auth_required", async () => {
// First call fails with a network error, subsequent calls succeed.
let callCount = 0;
mockPing = () => {
callCount++;
if (callCount === 1) {
return Promise.reject(new Error("Failed to fetch"));
}
return Promise.resolve("pong");
};

const states: string[] = [];

render(
<APIProvider createWebSocket={createMockWebSocket}>
<APIProvider>
<APIStateObserver onState={(s) => states.push(s.status)} />
</APIProvider>
);

const ws1 = MockWebSocket.lastInstance();
expect(ws1).toBeDefined();

await act(async () => {
ws1!.simulateOpen();
await new Promise((r) => setTimeout(r, 10));
// Should retry via reconnection, not show auth_required.
await waitFor(() => {
expect(states).toContain("reconnecting");
});

expect(states).toContain("connected");

act(() => {
ws1!.simulateClose(1008);
});
expect(states.filter((s) => s === "auth_required")).toHaveLength(0);

// Eventually connects on retry.
await waitFor(() => {
expect(states).toContain("auth_required");
expect(states).toContain("connected");
});
});

test("retries on first connection failure without showing auth_required", async () => {
test("uses pre-created client and skips connection flow", async () => {
const mockClient = {
general: { ping: () => Promise.resolve("pong") },
} as unknown as APIClient;

const states: string[] = [];

render(
<APIProvider createWebSocket={createMockWebSocket}>
<APIProvider client={mockClient}>
<APIStateObserver onState={(s) => states.push(s.status)} />
</APIProvider>
);

const ws1 = MockWebSocket.lastInstance();
expect(ws1).toBeDefined();

// First connection fails - browser fires error then close.
act(() => {
ws1!.simulateError();
ws1!.simulateClose(1006);
});

// Should immediately be connected without going through connecting.
await waitFor(() => {
expect(states).toContain("reconnecting");
});

expect(states.filter((s) => s === "auth_required")).toHaveLength(0);

// Should create a new WebSocket for the reconnect attempt.
await waitFor(() => {
expect(MockWebSocket.instances.length).toBeGreaterThan(1);
expect(states[0]).toBe("connected");
});
});

test("reconnects on connection loss when previously connected", async () => {
const states: string[] = [];
test("authenticate() triggers reconnection with new token", async () => {
mockPing = () => Promise.reject(new Error("401 Unauthorized"));

const capturedStates: UseAPIResult[] = [];

render(
<APIProvider createWebSocket={createMockWebSocket}>
<APIStateObserver onState={(s) => states.push(s.status)} />
<APIProvider>
<APIStateObserver onState={(s) => capturedStates.push(s)} />
</APIProvider>
);

const ws1 = MockWebSocket.lastInstance();
expect(ws1).toBeDefined();

await act(async () => {
ws1!.simulateOpen();
await new Promise((r) => setTimeout(r, 10));
// Wait for auth_required state.
await waitFor(() => {
expect(capturedStates.some((s) => s.status === "auth_required")).toBe(true);
});

expect(states).toContain("connected");
// Now make ping succeed and call authenticate.
mockPing = () => Promise.resolve("pong");

// Connection lost after being connected
act(() => {
ws1!.simulateError();
ws1!.simulateClose(1006);
await act(async () => {
const lastState = capturedStates[capturedStates.length - 1];
expect(lastState.status).toBe("auth_required");
lastState.authenticate("new-token");
await new Promise((r) => setTimeout(r, 50));
});

await waitFor(() => {
expect(states).toContain("reconnecting");
expect(capturedStates.some((s) => s.status === "connected")).toBe(true);
});

const authRequiredAfterConnected = states.slice(states.indexOf("connected") + 1);
expect(authRequiredAfterConnected.filter((s) => s === "auth_required")).toHaveLength(0);
});
});
Loading
Loading