Skip to content
This repository was archived by the owner on Oct 22, 2025. It is now read-only.

Commit 8cbfd48

Browse files
committed
feat: add onAuth lifecycle hooks
1 parent 0a8dd91 commit 8cbfd48

File tree

29 files changed

+871
-207
lines changed

29 files changed

+871
-207
lines changed

CLAUDE.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,7 @@ This ensures imports resolve correctly across different build environments and p
108108
- Run `yarn check-types` regularly during development to catch type errors early. Prefer `yarn check-types` instead of `yarn build`.
109109
- Use `tsx` CLI to execute TypeScript scripts directly (e.g., `tsx script.ts` instead of `node script.js`).
110110
- Do not auto-commit changes
111+
112+
## Test Guidelines
113+
114+
- Do not check if errors are an instanceOf WorkerError in tests. Many error types do not have the same prototype chain when sent over the network, but still have the same properties so you can safely cast with `as`.

docs/workers/quickstart.mdx

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ description: Start building awesome documentation in under 5 minutes
77
```ts registry.ts
88
import { setup } from "rivetkit";
99
import { worker } from "rivetkit/worker";
10-
import { workflow } from "rivetkit/workflow";
11-
import { realtime } from "rivetkit/realtime";
1210

1311
const counter = worker({
12+
onAuth: async () => {
13+
// Allow public access
14+
},
1415
state: {
1516
count: 0,
1617
},
@@ -36,18 +37,24 @@ export type Registry = typeof registry;
3637
```
3738

3839
```ts server.ts
40+
// With router
3941
import { registry } from "./registry";
4042

41-
const registry = new Hono();
42-
app.route("/registry", registry.handler);
43+
const app = new Hono();
44+
app.route("/registry", c => registry.handler(c.req.raw));
45+
serve(app);
46+
47+
// Without router
48+
import { serve } from "@rivetkit/node";
49+
4350
serve(registry);
4451
```
4552

4653
```ts client.ts
4754
import { createClient } from "rivetkit/client";
4855
import type { Registry } from "./registry";
4956

50-
const client = createClient<Registry>("http://localhost:8080/registry");
57+
const client = createClient<Registry>("http://localhost:8080");
5158
```
5259

5360
<Steps>
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import { worker, UserError } from "rivetkit";
2+
3+
// Basic auth worker - requires API key
4+
export const authWorker = worker({
5+
state: { requests: 0 },
6+
onAuth: (opts) => {
7+
const { req, intents, params } = opts;
8+
const apiKey = (params as any)?.apiKey;
9+
if (!apiKey) {
10+
throw new UserError("API key required", { code: "missing_auth" });
11+
}
12+
13+
if (apiKey !== "valid-api-key") {
14+
throw new UserError("Invalid API key", { code: "invalid_auth" });
15+
}
16+
17+
return { userId: "user123", token: apiKey };
18+
},
19+
actions: {
20+
getRequests: (c) => {
21+
c.state.requests++;
22+
return c.state.requests;
23+
},
24+
getUserAuth: (c) => c.conn.auth,
25+
},
26+
});
27+
28+
// Intent-specific auth worker - checks different permissions for different intents
29+
export const intentAuthWorker = worker({
30+
state: { value: 0 },
31+
onAuth: (opts) => {
32+
const { req, intents, params } = opts;
33+
console.log('intents', intents, params);
34+
const role = (params as any)?.role;
35+
36+
if (intents.has("create") && role !== "admin") {
37+
throw new UserError("Admin role required for create operations", { code: "insufficient_permissions" });
38+
}
39+
40+
if (intents.has("action") && !["admin", "user"].includes(role || "")) {
41+
throw new UserError("User or admin role required for actions", { code: "insufficient_permissions" });
42+
}
43+
44+
return { role, timestamp: Date.now() };
45+
},
46+
actions: {
47+
getValue: (c) => c.state.value,
48+
setValue: (c, value: number) => {
49+
c.state.value = value;
50+
return value;
51+
},
52+
getAuth: (c) => c.conn.auth,
53+
},
54+
});
55+
56+
// Public worker - empty onAuth to allow public access
57+
export const publicWorker = worker({
58+
state: { visitors: 0 },
59+
onAuth: () => {
60+
return null; // Allow public access
61+
},
62+
actions: {
63+
visit: (c) => {
64+
c.state.visitors++;
65+
return c.state.visitors;
66+
},
67+
},
68+
});
69+
70+
// No auth worker - should fail when accessed publicly (no onAuth defined)
71+
export const noAuthWorker = worker({
72+
state: { value: 42 },
73+
actions: {
74+
getValue: (c) => c.state.value,
75+
},
76+
});
77+
78+
// Async auth worker - tests promise-based authentication
79+
export const asyncAuthWorker = worker({
80+
state: { count: 0 },
81+
onAuth: async (opts) => {
82+
const { req, intents, params } = opts;
83+
// Simulate async auth check (e.g., database lookup)
84+
await new Promise(resolve => setTimeout(resolve, 10));
85+
86+
const token = (params as any)?.token;
87+
if (!token) {
88+
throw new UserError("Token required", { code: "missing_token" });
89+
}
90+
91+
// Simulate token validation
92+
if (token === "invalid") {
93+
throw new UserError("Token is invalid", { code: "invalid_token" });
94+
}
95+
96+
return { userId: `user-${token}`, validated: true };
97+
},
98+
actions: {
99+
increment: (c) => {
100+
c.state.count++;
101+
return c.state.count;
102+
},
103+
getAuthData: (c) => c.conn.auth,
104+
},
105+
});

packages/core/fixtures/driver-test-suite/registry.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ import {
2727
uniqueVarWorker,
2828
driverCtxWorker,
2929
} from "./vars";
30+
import {
31+
authWorker,
32+
intentAuthWorker,
33+
publicWorker,
34+
noAuthWorker,
35+
asyncAuthWorker,
36+
} from "./auth";
3037

3138
// Consolidated setup with all workers
3239
export const registry = setup({
@@ -63,6 +70,12 @@ export const registry = setup({
6370
dynamicVarWorker,
6471
uniqueVarWorker,
6572
driverCtxWorker,
73+
// From auth.ts
74+
authWorker,
75+
intentAuthWorker,
76+
publicWorker,
77+
noAuthWorker,
78+
asyncAuthWorker,
6679
},
6780
});
6881

packages/core/src/client/client.ts

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ export interface ClientDriver {
172172
c: HonoContext | undefined,
173173
workerQuery: WorkerQuery,
174174
encodingKind: Encoding,
175+
params: unknown,
175176
): Promise<string>;
176177
connectWebSocket(
177178
c: HonoContext | undefined,
@@ -364,6 +365,7 @@ export class ClientRaw {
364365
undefined,
365366
createQuery,
366367
this.#encodingKind,
368+
opts?.params,
367369
);
368370
logger().debug("created worker with ID", {
369371
name,
@@ -481,11 +483,9 @@ export function createClientWithDriver<A extends Registry<any>>(
481483
key?: string | string[],
482484
opts?: GetOptions,
483485
): WorkerHandle<ExtractWorkersFromRegistry<A>[typeof prop]> => {
484-
return target.getOrCreate<ExtractWorkersFromRegistry<A>[typeof prop]>(
485-
prop,
486-
key,
487-
opts,
488-
);
486+
return target.getOrCreate<
487+
ExtractWorkersFromRegistry<A>[typeof prop]
488+
>(prop, key, opts);
489489
},
490490
getForId: (
491491
workerId: string,
@@ -500,12 +500,12 @@ export function createClientWithDriver<A extends Registry<any>>(
500500
create: async (
501501
key: string | string[],
502502
opts: CreateOptions = {},
503-
): Promise<WorkerHandle<ExtractWorkersFromRegistry<A>[typeof prop]>> => {
504-
return await target.create<ExtractWorkersFromRegistry<A>[typeof prop]>(
505-
prop,
506-
key,
507-
opts,
508-
);
503+
): Promise<
504+
WorkerHandle<ExtractWorkersFromRegistry<A>[typeof prop]>
505+
> => {
506+
return await target.create<
507+
ExtractWorkersFromRegistry<A>[typeof prop]
508+
>(prop, key, opts);
509509
},
510510
} as WorkerAccessor<ExtractWorkersFromRegistry<A>[typeof prop]>;
511511
}

packages/core/src/client/http-client-driver.ts

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver {
8282
_c: HonoContext | undefined,
8383
workerQuery: WorkerQuery,
8484
encodingKind: Encoding,
85+
params: unknown,
8586
): Promise<string> => {
8687
logger().debug("resolving worker ID", { query: workerQuery });
8788

@@ -95,6 +96,9 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver {
9596
headers: {
9697
[HEADER_ENCODING]: encodingKind,
9798
[HEADER_WORKER_QUERY]: JSON.stringify(workerQuery),
99+
...(params !== undefined
100+
? { [HEADER_CONN_PARAMS]: JSON.stringify(params) }
101+
: {}),
98102
},
99103
body: {},
100104
encoding: encodingKind,
@@ -122,14 +126,20 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver {
122126
): Promise<WebSocket> => {
123127
const { WebSocket } = await dynamicImports;
124128

125-
const workerQueryStr = encodeURIComponent(JSON.stringify(workerQuery));
126129
const endpoint = managerEndpoint
127130
.replace(/^http:/, "ws:")
128131
.replace(/^https:/, "wss:");
129-
const url = `${endpoint}/workers/connect/websocket?encoding=${encodingKind}&query=${workerQueryStr}`;
132+
const url = `${endpoint}/workers/connect/websocket`;
133+
134+
// Pass sensitive data via protocol
135+
const protocol = [
136+
`query.${btoa(JSON.stringify(workerQuery))}`,
137+
"encoding.encodingKind",
138+
];
139+
if (params) protocol.push(`conn_params.${btoa(JSON.stringify(params))}`);
130140

131141
logger().debug("connecting to websocket", { url });
132-
const ws = new WebSocket(url);
142+
const ws = new WebSocket(url, protocol);
133143
if (encodingKind === "cbor") {
134144
ws.binaryType = "arraybuffer";
135145
} else if (encodingKind === "json") {

packages/core/src/client/worker-handle.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ export class WorkerHandleRaw {
109109
undefined,
110110
this.#workerQuery,
111111
this.#encodingKind,
112+
this.#params,
112113
);
113114
this.#workerQuery = { getForId: { workerId } };
114115
return workerId;

packages/core/src/driver-test-suite/mod.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { runWorkerVarsTests } from "./tests/worker-vars";
2121
import { runWorkerConnStateTests } from "./tests/worker-conn-state";
2222
import { runWorkerMetadataTests } from "./tests/worker-metadata";
2323
import { runWorkerErrorHandlingTests } from "./tests/worker-error-handling";
24+
import { runWorkerAuthTests } from "./tests/worker-auth";
2425

2526
export interface DriverTestConfig {
2627
/** Deploys an registry and returns the connection endpoint. */
@@ -90,6 +91,8 @@ export function runDriverTests(
9091
runWorkerMetadataTests(driverTestConfig);
9192

9293
runWorkerErrorHandlingTests(driverTestConfig);
94+
95+
runWorkerAuthTests(driverTestConfig);
9396
});
9497
}
9598
}

packages/core/src/driver-test-suite/test-inline-client-driver.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,14 @@ export function createTestInlineClientDriver(
4343
c: HonoContext | undefined,
4444
workerQuery: WorkerQuery,
4545
encodingKind: Encoding,
46+
params: unknown,
4647
): Promise<string> => {
4748
return makeInlineRequest<string>(
4849
endpoint,
4950
encodingKind,
5051
transport,
5152
"resolveWorkerId",
52-
[undefined, workerQuery, encodingKind],
53+
[undefined, workerQuery, encodingKind, params],
5354
);
5455
},
5556

0 commit comments

Comments
 (0)