Skip to content
This repository was archived by the owner on Oct 22, 2025. It is now read-only.

Commit beffb43

Browse files
committed
feat: add onAuth lifecycle hooks
1 parent 0a8dd91 commit beffb43

File tree

11 files changed

+269
-11
lines changed

11 files changed

+269
-11
lines changed

docs/workers/quickstart.mdx

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ description: Start building awesome documentation in under 5 minutes
77
```ts registry.ts
88
import { setup } from "rivetkit";
99
import { worker } from "rivetkit/worker";
10-
import { workflow } from "rivetkit/workflow";
11-
import { realtime } from "rivetkit/realtime";
1210

1311
const counter = worker({
12+
onAuth: async () => {
13+
// Allow public access
14+
},
1415
state: {
1516
count: 0,
1617
},
@@ -36,18 +37,24 @@ export type Registry = typeof registry;
3637
```
3738

3839
```ts server.ts
40+
// With router
3941
import { registry } from "./registry";
4042

41-
const registry = new Hono();
42-
app.route("/registry", registry.handler);
43+
const app = new Hono();
44+
app.route("/registry", c => registry.handler(c.req.raw));
45+
serve(app);
46+
47+
// Without router
48+
import { serve } from "@rivetkit/node";
49+
4350
serve(registry);
4451
```
4552

4653
```ts client.ts
4754
import { createClient } from "rivetkit/client";
4855
import type { Registry } from "./registry";
4956

50-
const client = createClient<Registry>("http://localhost:8080/registry");
57+
const client = createClient<Registry>("http://localhost:8080");
5158
```
5259

5360
<Steps>

packages/core/src/manager/auth.ts

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import * as errors from "@/worker/errors";
2+
import { Hono, type Context as HonoContext, type Next } from "hono";
3+
import type { WorkerQuery } from "./protocol/query";
4+
import type { AuthIntent } from "@/worker/config";
5+
import type { AnyWorkerDefinition } from "@/worker/definition";
6+
import type { RegistryConfig } from "@/registry/config";
7+
8+
/**
9+
* Get authentication intents from a worker query
10+
*/
11+
export function getIntentsFromQuery(query: WorkerQuery): Set<AuthIntent> {
12+
const intents = new Set<AuthIntent>();
13+
14+
if ("getForId" in query) {
15+
intents.add("get");
16+
} else if ("getForKey" in query) {
17+
intents.add("get");
18+
} else if ("getOrCreateForKey" in query) {
19+
intents.add("get");
20+
intents.add("create");
21+
} else if ("create" in query) {
22+
intents.add("create");
23+
}
24+
25+
return intents;
26+
}
27+
28+
/**
29+
* Get worker name from a worker query
30+
*/
31+
export function getWorkerNameFromQuery(query: WorkerQuery): string {
32+
if ("getForId" in query) {
33+
throw new errors.InvalidRequest(
34+
"Cannot determine worker name from getForId query",
35+
);
36+
} else if ("getForKey" in query) {
37+
return query.getForKey.name;
38+
} else if ("getOrCreateForKey" in query) {
39+
return query.getOrCreateForKey.name;
40+
} else if ("create" in query) {
41+
return query.create.name;
42+
} else {
43+
throw new errors.InvalidRequest("Invalid query format");
44+
}
45+
}
46+
47+
/**
48+
* Authenticate a request using the worker's onAuth function
49+
*/
50+
export async function authenticateRequest(
51+
c: HonoContext,
52+
workerDefinition: AnyWorkerDefinition,
53+
intents: Set<AuthIntent>,
54+
params: unknown,
55+
): Promise<unknown> {
56+
const { onAuth } = workerDefinition.config;
57+
58+
if (!onAuth) {
59+
throw new errors.Forbidden(
60+
"Worker requires authentication but no onAuth handler is defined",
61+
);
62+
}
63+
64+
try {
65+
return await onAuth({
66+
req: c.req.raw,
67+
intents,
68+
params,
69+
});
70+
} catch (error) {
71+
if (errors.WorkerError.isWorkerError(error)) {
72+
throw error;
73+
}
74+
throw new errors.Forbidden("Authentication failed");
75+
}
76+
}
77+
78+
/**
79+
* Simplified authentication for endpoints that combines all auth steps
80+
*/
81+
export async function authenticateEndpoint(
82+
c: HonoContext,
83+
registryConfig: RegistryConfig,
84+
query: WorkerQuery,
85+
additionalIntents: AuthIntent[],
86+
params?: unknown,
87+
): Promise<unknown> {
88+
// Get base intents from query
89+
const intents = getIntentsFromQuery(query);
90+
91+
// Add endpoint-specific intents
92+
for (const intent of additionalIntents) {
93+
intents.add(intent);
94+
}
95+
96+
// Get worker definition
97+
const workerName = getWorkerNameFromQuery(query);
98+
const workerDefinition = registryConfig.workers[workerName];
99+
if (!workerDefinition) {
100+
throw new errors.WorkerNotFound(workerName);
101+
}
102+
103+
// Authenticate
104+
return await authenticateRequest(c, workerDefinition, intents, params);
105+
}
106+

packages/core/src/manager/router.ts

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ import {
5555
import type { WorkerQuery } from "./protocol/query";
5656
import { VERSION } from "@/utils";
5757
import { ConnRoutingHandler } from "@/worker/conn-routing-handler";
58-
import { ClientDriver, createClientWithDriver } from "@/client/client";
59-
import { Transport, TransportSchema } from "@/worker/protocol/message/mod";
58+
import { ClientDriver } from "@/client/client";
59+
import { Transport } from "@/worker/protocol/message/mod";
60+
import { authenticateEndpoint, authenticateMessage } from "./auth";
6061

6162
type ManagerRouterHandler = {
6263
onConnectInspector?: ManagerInspectorConnHandler;
@@ -141,7 +142,10 @@ export function createManagerRouter(
141142

142143
return cors({
143144
...corsConfig,
144-
allowHeaders: [...(registryConfig.cors?.allowHeaders ?? []), ...ALL_HEADERS],
145+
allowHeaders: [
146+
...(registryConfig.cors?.allowHeaders ?? []),
147+
...ALL_HEADERS,
148+
],
145149
})(c, next);
146150
});
147151
}
@@ -194,7 +198,9 @@ export function createManagerRouter(
194198
responses: buildOpenApiResponses(ResolveResponseSchema),
195199
});
196200

197-
router.openapi(resolveRoute, (c) => handleResolveRequest(c, driver));
201+
router.openapi(resolveRoute, (c) =>
202+
handleResolveRequest(c, registryConfig, driver),
203+
);
198204
}
199205

200206
// GET /workers/connect/websocket
@@ -667,6 +673,20 @@ async function handleSseConnectRequest(
667673

668674
const query = params.data.query;
669675

676+
// Parse connection parameters for authentication
677+
const connParams = params.data.connParams
678+
? JSON.parse(params.data.connParams)
679+
: undefined;
680+
681+
// Authenticate the request
682+
const authData = await authenticateEndpoint(
683+
c,
684+
registryConfig,
685+
query,
686+
["connect"],
687+
connParams,
688+
);
689+
670690
// Get the worker ID and meta
671691
const { workerId, meta } = await queryWorker(c, query, driver);
672692
invariant(workerId, "Missing worker ID");
@@ -682,6 +702,7 @@ async function handleSseConnectRequest(
682702
driverConfig,
683703
handler.routingHandler.inline.handlers.onConnectSse,
684704
workerId,
705+
authData,
685706
);
686707
} else if ("custom" in handler.routingHandler) {
687708
logger().debug("using custom proxy mode for sse connection");
@@ -790,6 +811,14 @@ async function handleWebSocketConnectRequest(
790811
throw new errors.InvalidRequest(params.error);
791812
}
792813

814+
// For WebSocket, params come later over the socket, so we auth without params for now
815+
const authData = await authenticateEndpoint(
816+
c,
817+
registryConfig,
818+
params.data.query,
819+
["connect"],
820+
);
821+
793822
// Get the worker ID and meta
794823
const { workerId, meta } = await queryWorker(c, params.data.query, driver);
795824
logger().debug("found worker for websocket connection", { workerId, meta });
@@ -811,6 +840,7 @@ async function handleWebSocketConnectRequest(
811840
driverConfig,
812841
onConnectWebSocket,
813842
workerId,
843+
authData,
814844
)();
815845
})(c, noopNext());
816846
} else if ("custom" in handler.routingHandler) {
@@ -878,6 +908,9 @@ async function handleWebSocketConnectRequest(
878908

879909
/**
880910
* Handle a connection message request to a worker
911+
*
912+
* There is no authentication handler on this request since the connection
913+
* token is used to authenticate the message.
881914
*/
882915
async function handleMessageRequest(
883916
c: HonoContext,
@@ -900,6 +933,22 @@ async function handleMessageRequest(
900933
}
901934
const { workerId, connId, encoding, connToken } = params.data;
902935

936+
// TODO: This endpoint can be used to exhause resources (DoS attack) on an worker if you know the worker ID:
937+
// 1. Get the worker ID (usually this is reasonably secure, but we don't assume worker ID is sensitive)
938+
// 2. Spam messages to the worker (the conn token can be invalid)
939+
// 3. The worker will be exhausted processing messages — even if the token is invalid
940+
//
941+
// The solution is we need to move the authorization of the connection token to this request handler
942+
// AND include the worker ID in the connection token so we can verify that it has permission to send
943+
// a message to that worker. This would require changing the token to a JWT so we can include a secure
944+
// payload, but this requires managing a private key & managing key rotations.
945+
//
946+
// All other solutions (e.g. include the worker name as a header or include the worker name in the worker ID)
947+
// have exploits that allow the caller to send messages to arbitrary workers.
948+
//
949+
// Currently, we assume this is not a critical problem because requests will likely get rate
950+
// limited before enough messages are passed to the worker to exhaust resources.
951+
903952
// Handle based on mode
904953
if ("inline" in handler.routingHandler) {
905954
logger().debug("using inline proxy mode for connection message");
@@ -972,6 +1021,20 @@ async function handleActionRequest(
9721021
throw new errors.InvalidRequest(params.error);
9731022
}
9741023

1024+
// Parse connection parameters for authentication
1025+
const connParams = params.data.connParams
1026+
? JSON.parse(params.data.connParams)
1027+
: undefined;
1028+
1029+
// Authenticate the request
1030+
const authData = await authenticateEndpoint(
1031+
c,
1032+
registryConfig,
1033+
params.data.query,
1034+
["action"],
1035+
connParams,
1036+
);
1037+
9751038
// Get the worker ID and meta
9761039
const { workerId, meta } = await queryWorker(c, params.data.query, driver);
9771040
logger().debug("found worker for action", { workerId, meta });
@@ -988,6 +1051,7 @@ async function handleActionRequest(
9881051
handler.routingHandler.inline.handlers.onAction,
9891052
actionName,
9901053
workerId,
1054+
authData,
9911055
);
9921056
} else if ("custom" in handler.routingHandler) {
9931057
logger().debug("using custom proxy mode for action call");
@@ -1031,6 +1095,7 @@ async function handleActionRequest(
10311095
*/
10321096
async function handleResolveRequest(
10331097
c: HonoContext,
1098+
registryConfig: RegistryConfig,
10341099
driver: ManagerDriver,
10351100
): Promise<Response> {
10361101
const encoding = getRequestEncoding(c.req, false);
@@ -1046,8 +1111,13 @@ async function handleResolveRequest(
10461111
throw new errors.InvalidRequest(params.error);
10471112
}
10481113

1114+
const query = params.data.query;
1115+
1116+
// Authenticate the request
1117+
await authenticateEndpoint(c, registryConfig, query, []);
1118+
10491119
// Get the worker ID and meta
1050-
const { workerId, meta } = await queryWorker(c, params.data.query, driver);
1120+
const { workerId, meta } = await queryWorker(c, query, driver);
10511121
logger().debug("resolved worker", { workerId, meta });
10521122
invariant(workerId, "Missing worker ID");
10531123

packages/core/src/topologies/standalone/topology.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ export class StandaloneTopology {
157157
connState,
158158
CONN_DRIVER_GENERIC_WEBSOCKET,
159159
{ encoding: opts.encoding } satisfies GenericWebSocketDriverState,
160+
opts.authData,
160161
);
161162
},
162163
onMessage: async (message) => {
@@ -199,6 +200,7 @@ export class StandaloneTopology {
199200
connState,
200201
CONN_DRIVER_GENERIC_SSE,
201202
{ encoding: opts.encoding } satisfies GenericSseDriverState,
203+
opts.authData,
202204
);
203205
},
204206
onClose: async () => {
@@ -224,6 +226,7 @@ export class StandaloneTopology {
224226
connState,
225227
CONN_DRIVER_GENERIC_HTTP,
226228
{} satisfies GenericHttpDriverState,
229+
opts.authData,
227230
);
228231

229232
// Call action

0 commit comments

Comments
 (0)