diff --git a/packages/kernel-agents/package.json b/packages/kernel-agents/package.json index bab12f4cf..0ecb436e4 100644 --- a/packages/kernel-agents/package.json +++ b/packages/kernel-agents/package.json @@ -23,6 +23,26 @@ "default": "./dist/index.cjs" } }, + "./json": { + "import": { + "types": "./dist/strategies/json-agent.d.mts", + "default": "./dist/strategies/json-agent.mjs" + }, + "require": { + "types": "./dist/strategies/json-agent.d.cts", + "default": "./dist/strategies/json-agent.cjs" + } + }, + "./repl": { + "import": { + "types": "./dist/strategies/repl-agent.d.mts", + "default": "./dist/strategies/repl-agent.mjs" + }, + "require": { + "types": "./dist/strategies/repl-agent.d.cts", + "default": "./dist/strategies/repl-agent.cjs" + } + }, "./package.json": "./package.json" }, "files": [ @@ -82,8 +102,12 @@ "node": "^20.6 || >=22" }, "dependencies": { + "@metamask/kernel-errors": "workspace:^", "@metamask/kernel-utils": "workspace:^", "@metamask/logger": "workspace:^", - "@ocap/kernel-language-model-service": "workspace:^" + "@ocap/kernel-language-model-service": "workspace:^", + "ses": "^1.14.0", + "tree-sitter": "^0.25.0", + "tree-sitter-javascript": "^0.25.0" } } diff --git a/packages/kernel-agents/src/README.md b/packages/kernel-agents/src/README.md new file mode 100644 index 000000000..c0b679b07 --- /dev/null +++ b/packages/kernel-agents/src/README.md @@ -0,0 +1,66 @@ +# Kernel Agents + +A kernel agent, [tasked](./task.ts) with an objective, attempts to fulfill the objective within the context of its available capabilities and knowledge. + +So doing, the agent collects experiences, which, in addition to the objective and context of a given task, include the history of the agent's actions and observations throughout the attempt, and, if relevant, the error state or final result achieved by the attempt. + +Learning from experience requires some value signal associated to said experience. None is implemented nor provided for within this package. + +Although every task is defined by an objective and the context within which to attempt that objective, the attempts themselves may follow various strategies. A [strategy](./strategies/README.md) presents the task specification, together with the history of an agent's observations and actions, in a textual form that elicits useful responses from a language model. + +## Implementation Sketches + +An abstract agent formulation looks as follows. + +```js +// A highly abstract sketch of an agent +const agent = (params) => { + let state = initState(), + done = false; + const { act } = makeModel(params); + const { observe, step, render } = makeEnvironment(params); + for (let i = 0; i < params.maxSteps; i++) { + const observation = observe(state); + const action = await act(observation); + [state, done] = await step(state, action); + if (done) { + return state.result; + } + render(state); + } +} +``` + +In practice, agents are constructible from a language model by a slightly more detailed implementation. Although not an exact factorization of the generic structure given above, the broad sketch of `observe->act->step->render->repeat` remains. + +```js +// A more detailed abstract sketch of an agent +const agent = ({ task, llm }, { maxSteps = 10 }) => { + const state = initState(); + const prompter = makePrompter(state, task), + reader = makeReader(), + evaluator = makeEvaluator(state), + printer = makePrinter(state); + for (let i = 0; i < maxSteps; i++) { + // Observe + const { prompt, readerOptions } = prompter(state); + // Act + const { sample, abort } = await llm.sample(prompt); + const action = await reader({ sample, abort, ...readerOptions }); + // Step + const dState = await evaluator(state, action); + state.update(action, dState); + if (task.isDone(action, state)) { + return result; + } + + // Render + printer(action, observation); + } +} +``` + +For concrete implementations, see below. + +- [json-agent](./strategies/json-agent.ts) +- [repl-agent](./strategies/repl-agent.ts) diff --git a/packages/kernel-agents/src/agent.test.ts b/packages/kernel-agents/src/agent.test.ts index 4bb8a1389..8541db208 100644 --- a/packages/kernel-agents/src/agent.test.ts +++ b/packages/kernel-agents/src/agent.test.ts @@ -1,25 +1,29 @@ import '@ocap/repo-tools/test-utils/mock-endoify'; -import type { Logger } from '@metamask/logger'; +import { Logger } from '@metamask/logger'; import { vi, describe, it, expect } from 'vitest'; -import { makeAgent } from './agent.ts'; -import { capability } from './capability.ts'; -import { end } from './default-capabilities.ts'; -import { AssistantMessage, CapabilityResultMessage } from './messages.ts'; -import { makeChat } from './prompt.ts'; +import { makeJsonAgent } from './strategies/json-agent.ts'; +import { makeReplAgent } from './strategies/repl-agent.ts'; const prompt = 'test prompt'; const prefix = '{"messageType":"assistant","'; +const stop = ''; -vi.mock('./prompt.ts', () => ({ - makeChat: vi.fn(() => ({ - getPromptAndPrefix: vi.fn(() => ({ prompt, prefix })), - pushMessages: vi.fn(), - })), +vi.mock('./strategies/repl/prompter.ts', () => ({ + makePrompter: vi.fn(() => () => ({ prompt, readerArgs: { stop } })), })); -describe('makeAgent', () => { +vi.mock('./strategies/json/prompter.ts', () => ({ + makePrompter: vi.fn(() => () => ({ prompt, readerArgs: { prefix } })), +})); + +const logger = new Logger('test'); + +describe.each([ + ['Json', makeJsonAgent, [`invoke":[{"name":"end","args":{"final":"x"}}]}`]], + ['Repl', makeReplAgent, ["await end({ final: 'x' });", stop]], +])('make%sAgent', (strategy, makeAgent, endStatement) => { const mockLlm = (...chunks: string[]) => ({ getInfo: vi.fn(), load: vi.fn(), @@ -37,88 +41,54 @@ describe('makeAgent', () => { }); it('makes an agent', () => { - const llm = mockLlm(); - const agent = makeAgent({ llm, capabilities: {} }); + const languageModel = mockLlm(); + const agent = makeAgent({ languageModel, capabilities: {} }); expect(agent).toBeDefined(); expect(agent).toHaveProperty('task'); }); - it('endows the "end" capability by default', async () => { - const llm = mockLlm(); - const mockMergeDisjointRecordsSpy = vi.spyOn( - await import('@metamask/kernel-utils'), - 'mergeDisjointRecords', - ); - const capabilities = {}; - makeAgent({ llm, capabilities }); - expect(mockMergeDisjointRecordsSpy).toHaveBeenCalledWith( - { end }, - capabilities, - ); - }); - describe('task', () => { it('invokes the LLM', async () => { - const llm = mockLlm(`invoke":[{"name":"end","args":{"final":"x"}}]}`); - const agent = makeAgent({ llm, capabilities: {} }); + const languageModel = mockLlm(...endStatement); + const agent = makeAgent({ languageModel, capabilities: {}, logger }); const result = await agent.task(''); expect(result).toBe('x'); // This is a massive understatement, but we don't want to test the prompt - expect(llm.sample).toHaveBeenCalledWith(prompt); + expect(languageModel.sample).toHaveBeenCalledWith(prompt); }); - it('throws if the LLM did not invoke a capability', async () => { - // LLM finishes valid JSON, but no invoke property - const llm = mockLlm(`content":""}`); - const agent = makeAgent({ llm, capabilities: {} }); - const task = agent.task(''); - await expect(task).rejects.toThrow('No invoke in result'); - }); + it.skipIf(strategy !== 'Json')( + 'throws if the LLM did not invoke a capability', + async () => { + // LLM finishes valid JSON, but no invoke property + const languageModel = mockLlm(`content":""}`); + const agent = makeAgent({ languageModel, capabilities: {} }); + const task = agent.task(''); + await expect(task).rejects.toThrow('No invoke in message'); + }, + ); it('throws if invocation budget is exceeded', async () => { - const llm = mockLlm(`invoke":[{"name":"end","args":{"final":"x"}}]}`); - const agent = makeAgent({ llm, capabilities: {} }); - const task = agent.task('', { invocationBudget: 0 }); - await expect(task).rejects.toThrow('Invocation budget exceeded'); - }); - - // XXX This test reflects a poor factorization of the agent. - it('pushes messages to the transcript', async () => { - const llm = mockLlm(`invoke":[{"name":"test","args":{}}]}`); - const pushMessages = vi.fn(); - vi.mocked(makeChat).mockReturnValue({ - getPromptAndPrefix: vi.fn(() => ({ prompt, prefix })), - pushMessages, - }); - const { makeAgent: makeAgent2 } = await import('./agent.ts'); - const agent = makeAgent2({ - llm, - capabilities: { - test: capability(async () => 'test', { - description: 'test', - args: {}, - returns: { type: 'string' }, - }), - }, - }); - const task = agent.task('test', { invocationBudget: 1 }); + const languageModel = mockLlm(...endStatement); + const agent = makeAgent({ languageModel, capabilities: {} }); + const task = agent.task('', undefined, { invocationBudget: 0 }); await expect(task).rejects.toThrow('Invocation budget exceeded'); - expect(pushMessages).toHaveBeenCalledWith( - expect.any(AssistantMessage), - expect.any(CapabilityResultMessage), - ); }); it('logs to the provided logger', async () => { - const llm = mockLlm(`invoke":[{"name":"end","args":{"final":"x"}}]}`); - const logger = { + const languageModel = mockLlm(...endStatement); + const testLogger = { info: vi.fn(), - subLogger: vi.fn(() => logger), + subLogger: vi.fn(() => testLogger), } as unknown as Logger; - const agent = makeAgent({ llm, capabilities: {}, logger }); - await agent.task('test', { invocationBudget: 1 }); - expect(logger.info).toHaveBeenCalledWith('query:', 'test'); - expect(logger.subLogger).toHaveBeenCalledWith({ tags: ['t001'] }); + const agent = makeAgent({ + languageModel, + capabilities: {}, + logger: testLogger, + }); + await agent.task('test', undefined, { invocationBudget: 1 }); + expect(testLogger.info).toHaveBeenCalledWith('intent:', 'test'); + expect(testLogger.subLogger).toHaveBeenCalledWith({ tags: ['t001'] }); }); }); }); diff --git a/packages/kernel-agents/src/agent.ts b/packages/kernel-agents/src/agent.ts index 3f34f783b..74ffe74b6 100644 --- a/packages/kernel-agents/src/agent.ts +++ b/packages/kernel-agents/src/agent.ts @@ -1,98 +1,143 @@ -import { makeCounter, mergeDisjointRecords } from '@metamask/kernel-utils'; +import { mergeDisjointRecords } from '@metamask/kernel-utils'; import type { Logger } from '@metamask/logger'; import type { LanguageModel } from '@ocap/kernel-language-model-service'; -import { invokeCapabilities } from './capability.ts'; -import { end } from './default-capabilities.ts'; -import { AssistantMessage, CapabilityResultMessage } from './messages.ts'; -import type { AssistantMessageJson } from './messages.ts'; -import { gatherStreamingResponse, makeIncrementalParser } from './parser.ts'; -import { makeChat } from './prompt.ts'; -import type { Agent, CapabilityRecord } from './types.ts'; +import { doAttempt } from './attempt.ts'; +import { TaskManager } from './task.ts'; +import type { Message, MessageTypeBase } from './types/messages.ts'; +import type { + Agent, + CapabilityRecord, + PrepareAttempt, + TaskArgs, +} from './types.ts'; +import { ifDefined } from './utils.ts'; + +export type MakeAgentArgs = { + languageModel: LanguageModel; + capabilities: CapabilityRecord; + logger?: Logger; +}; /** * Make a capability-augmented agent * * @param args - The arguments to make the agent. - * @param args.llm - The language model to use for the agent + * @param args.languageModel - The language model to use for the agent * @param args.capabilities - The agent's capabilities * @param args.logger - The logger to use for the agent + * @param prepareAttempt - A strategy function to prepare the attempt. * @returns A kernel agent */ -export const makeAgent = ({ - llm, - capabilities, - logger, -}: { - llm: LanguageModel; - capabilities: CapabilityRecord; - logger?: Logger; -}): Agent => { - const agentCapabilities = mergeDisjointRecords( - { end }, - capabilities, - ) as CapabilityRecord; - - const taskCounter = makeCounter(); +export const makeAgent = < + State extends Message[], + Action extends Message, + Observation extends Message, +>( + { + languageModel, + capabilities: agentCapabilities, + logger: agentLogger, + }: MakeAgentArgs, + prepareAttempt: PrepareAttempt, +): Agent => { + const taskManager = new TaskManager(); return { - task: async ( - query: string, - { invocationBudget = 10 }: { invocationBudget?: number } = {}, + /** + * Task the agent to fulfill an objective. + * + * @param intent - A string specifying the objective of the task. + * @param judgment - A function that determines if the task is complete. + * @param options - The options for the task. + * @param options.invocationBudget - The maximum number of steps the agent is allowed to take. + * @param options.seed - The seed for the task. + * @param options.logger - The logger for the task. + * @param options.capabilities - The capabilities for the task. + * @param options.nAttempts - The number of attempts the agent is allowed to make. + * @returns The result of the task. + */ + task: async ( + intent: string, + judgment?: (result: unknown) => result is Result, + { + invocationBudget = 10, + seed = Date.now().valueOf(), // XXX: Replace with something more real + logger: printLogger, + capabilities: taskCapabilities = {}, + nAttempts = 1, + }: TaskArgs = {}, ) => { - // XXX Tasks could be integrated deeper in the kernel - const taskId = `t${taskCounter().toString().padStart(3, '0')}`; - const taskLogger = logger?.subLogger({ tags: [taskId] }); - taskLogger?.info('query:', query); - - const { getPromptAndPrefix, pushMessages } = makeChat( + const capabilities = mergeDisjointRecords( agentCapabilities, - query, - ); + taskCapabilities, + ) as CapabilityRecord; - for (let invocation = 0; invocation < invocationBudget; invocation++) { - taskLogger?.info(`begin invocation ${invocation}/${invocationBudget}`); + const thisTask = taskManager.makeTask({ + intent, + capabilities, + ...ifDefined({ judgment }), + }); + const { id: taskId, objective, context } = thisTask; + const taskLogger = agentLogger?.subLogger({ tags: [taskId] }); + taskLogger?.info('intent:', intent); - const { prompt, prefix } = getPromptAndPrefix(); - const parse = makeIncrementalParser({ - prefix, - ...(taskLogger ? { logger: taskLogger } : {}), - }); - taskLogger?.info('prompt:', prompt); + for (let attempt = 0; attempt < nAttempts; attempt++) { + taskLogger?.info(`Attempt ${attempt + 1} of ${nAttempts}`); - const { stream, abort } = await llm.sample(prompt); - let assistantMessage: AssistantMessageJson; + const [prep, state] = prepareAttempt({ + objective, + context, + options: ifDefined({ seed, printLogger, taskLogger }), + }); + const { history } = state; try { - assistantMessage = await gatherStreamingResponse({ - stream, - parse, - }); - } finally { - // Stop the LLM from generating anymore - await abort(); - } - taskLogger?.info('assistantMessage:', assistantMessage); - - // TODO: this should already be validated by the parser - if (!assistantMessage.invoke) { - throw new Error('No invoke in result'); + const result = await doAttempt( + prep, + state, + languageModel, + ifDefined({ maxSteps: invocationBudget, logger: taskLogger }), + ); + thisTask.attempts.push({ history, result }); + return result; + } catch (error) { + if (error instanceof Error) { + thisTask.attempts.push({ history, error }); + } else { + throw new Error(`Unknown error: ${error as string}`, { + cause: error, + }); + } } - const results = await invokeCapabilities( - assistantMessage.invoke, - agentCapabilities, - ); - taskLogger?.info('results:', results); - const didEnd = results.find((capability) => capability.name === 'end'); - if (didEnd) { - taskLogger?.info('exit invocation with result:', didEnd.result); - return didEnd.result; - } - pushMessages( - new AssistantMessage(assistantMessage), - new CapabilityResultMessage(results), - ); } - throw new Error('Invocation budget exceeded'); + const howManyAttempts = `${nAttempts} attempt${nAttempts === 1 ? '' : 's'}`; + throw new Error( + [ + `Failed to complete task in ${howManyAttempts}`, + ...thisTask.attempts.map( + (attempt, index) => + `${index + 1}: ${attempt.error?.message ?? 'Unknown'}`, + ), + ].join('\n'), + ); + }, + /** + * Get the experiences of the agent. Used for learning. + * + * @returns An iterator over the experiences. + */ + get experiences() { + return (async function* () { + for (const task of taskManager.tasks) { + for (const attempt of task.attempts) { + yield { + objective: task.objective, + context: task.context, + ...attempt, + }; + } + } + })(); }, }; }; diff --git a/packages/kernel-agents/src/attempt.test.ts b/packages/kernel-agents/src/attempt.test.ts new file mode 100644 index 000000000..02afb192f --- /dev/null +++ b/packages/kernel-agents/src/attempt.test.ts @@ -0,0 +1,222 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import type { Logger } from '@metamask/logger'; +import type { LanguageModel } from '@ocap/kernel-language-model-service'; +import { vi, describe, it, expect, beforeEach } from 'vitest'; + +import { doAttempt } from './attempt.ts'; +import { makeTestStream } from './test-utils.ts'; +import type { PREP, Progress } from './types/agent.ts'; +import { Message } from './types/messages.ts'; + +class TestMessage extends Message { + constructor(messageType: string, messageBody: Record = {}) { + super(messageType, messageBody); + } +} + +describe('doAttempt', () => { + let mockPrompter: ReturnType; + let mockReader: ReturnType; + let mockEvaluator: ReturnType; + let mockPrinter: ReturnType; + let mockLanguageModel: { sample: ReturnType }; + let prep: PREP; + let logger: Logger; + + beforeEach(() => { + mockPrompter = vi.fn(); + mockReader = vi.fn(); + mockEvaluator = vi.fn(); + mockPrinter = vi.fn(); + mockLanguageModel = { sample: vi.fn() }; + logger = { + info: vi.fn(), + subLogger: vi.fn(() => logger), + } as unknown as Logger; + prep = [ + mockPrompter, + mockReader, + mockEvaluator, + mockPrinter, + ] as unknown as PREP; + }); + + const makeProgress = ( + history: TestMessage[], + isDone: ReturnType, + result?: string, + ): Progress => { + const progress: Progress = { + history, + isDone, + }; + if (result !== undefined) { + progress.result = result; + } + return progress; + }; + + it('returns result when done on first step', async () => { + const history: TestMessage[] = []; + const action = new TestMessage('action'); + const observation = new TestMessage('observation'); + const result = 'test result'; + + mockPrompter.mockReturnValue({ prompt: 'test prompt', readerArgs: {} }); + mockLanguageModel.sample.mockResolvedValue(makeTestStream(['response'])); + mockReader.mockResolvedValue(action); + mockEvaluator.mockResolvedValue(observation); + + const actual = await doAttempt( + prep, + makeProgress( + history, + vi.fn(() => true), + result, + ), + mockLanguageModel as unknown as LanguageModel< + unknown, + { response: string } + >, + { maxSteps: 10, logger }, + ); + + expect(actual).toBe(result); + expect(mockPrompter).toHaveBeenCalledWith(history); + expect(mockLanguageModel.sample).toHaveBeenCalledWith('test prompt'); + expect(mockEvaluator).toHaveBeenCalledWith(history, action); + expect(mockPrinter).not.toHaveBeenCalled(); + expect(logger.info).toHaveBeenCalledWith('Step 1 of 10'); + expect(logger.info).toHaveBeenCalledWith('done:', result); + }); + + it('returns result after multiple steps', async () => { + const history: TestMessage[] = []; + const action1 = new TestMessage('action1'); + const action2 = new TestMessage('action2'); + const observation1 = new TestMessage('observation1'); + const observation2 = new TestMessage('observation2'); + const result = 'final result'; + + let callCount = 0; + mockPrompter.mockImplementation(() => ({ + // The ++ operator is exactly what we want here. + // eslint-disable-next-line no-plusplus + prompt: `prompt ${++callCount}`, + readerArgs: {}, + })); + mockLanguageModel.sample.mockResolvedValue(makeTestStream(['response'])); + mockReader.mockResolvedValueOnce(action1).mockResolvedValueOnce(action2); + mockEvaluator + .mockResolvedValueOnce(observation1) + .mockResolvedValueOnce(observation2); + + const actual = await doAttempt( + prep, + makeProgress( + history, + vi.fn(() => callCount === 2), + result, + ), + mockLanguageModel as unknown as LanguageModel< + unknown, + { response: string } + >, + { maxSteps: 10 }, + ); + + expect(actual).toBe(result); + expect(mockPrompter).toHaveBeenCalledTimes(2); + expect(mockPrinter).toHaveBeenCalledWith(action1, observation1); + }); + + it('passes readerArgs to reader', async () => { + const history: TestMessage[] = []; + const action = new TestMessage('action'); + const readerArgs = { stop: '', prefix: 'test' }; + + mockPrompter.mockReturnValue({ prompt: 'test prompt', readerArgs }); + mockLanguageModel.sample.mockResolvedValue(makeTestStream(['response'])); + mockReader.mockResolvedValue(action); + mockEvaluator.mockResolvedValue(new TestMessage('observation')); + + await doAttempt( + prep, + makeProgress( + history, + vi.fn(() => true), + 'result', + ), + mockLanguageModel as unknown as LanguageModel< + unknown, + { response: string } + >, + { maxSteps: 10 }, + ); + + expect(mockReader).toHaveBeenCalledWith( + expect.objectContaining({ + stream: expect.anything(), + abort: expect.any(Function), + ...readerArgs, + }), + ); + }); + + it('throws error when maxSteps is exceeded', async () => { + const history: TestMessage[] = []; + const action = new TestMessage('action'); + + mockPrompter.mockReturnValue({ prompt: 'test prompt', readerArgs: {} }); + mockLanguageModel.sample.mockResolvedValue(makeTestStream(['response'])); + mockReader.mockResolvedValue(action); + mockEvaluator.mockResolvedValue(new TestMessage('observation')); + + const attempt = doAttempt( + prep, + makeProgress( + history, + vi.fn(() => false), + ), + mockLanguageModel as unknown as LanguageModel< + unknown, + { response: string } + >, + { maxSteps: 3 }, + ); + + await expect(attempt).rejects.toThrow('Invocation budget exceeded'); + expect(mockPrompter).toHaveBeenCalledTimes(3); + expect(mockPrinter).toHaveBeenCalledTimes(3); + }); + + it('handles null observation from evaluator', async () => { + const history: TestMessage[] = []; + const action = new TestMessage('action'); + + mockPrompter.mockReturnValue({ prompt: 'test prompt', readerArgs: {} }); + mockLanguageModel.sample.mockResolvedValue(makeTestStream(['response'])); + mockReader.mockResolvedValue(action); + mockEvaluator.mockResolvedValue(null); + + await doAttempt( + prep, + makeProgress( + history, + vi + .fn(() => false) + .mockReturnValueOnce(false) + .mockReturnValueOnce(true), + 'result', + ), + mockLanguageModel as unknown as LanguageModel< + unknown, + { response: string } + >, + { maxSteps: 10 }, + ); + + expect(mockPrinter).toHaveBeenCalledWith(action, null); + }); +}); diff --git a/packages/kernel-agents/src/attempt.ts b/packages/kernel-agents/src/attempt.ts new file mode 100644 index 000000000..ff6bcafee --- /dev/null +++ b/packages/kernel-agents/src/attempt.ts @@ -0,0 +1,61 @@ +import { SampleGenerationError } from '@metamask/kernel-errors'; +import type { Logger } from '@metamask/logger'; +import type { LanguageModel } from '@ocap/kernel-language-model-service'; + +import type { Message, MessageTypeBase } from './types/messages.ts'; +import type { PREP, Progress } from './types.ts'; +import { withRetries } from './utils.ts'; + +export const doAttempt = async < + Result, + State extends Message[], + Action extends Message, + Observation extends Message, +>( + [prompter, reader, evaluator, printer]: PREP, + progress: Progress, + languageModel: LanguageModel, + { + maxSteps = 10, + maxRetries = 3, + logger, + }: { + maxSteps?: number; + maxRetries?: number; + logger?: Logger; + }, +): Promise => { + const { history } = progress; + + for (let step = 1; step <= maxSteps; step++) { + logger?.info(`Step ${step} of ${maxSteps}`); + + const actionAndOutcome = await withRetries( + async () => { + // Observe + const { prompt, readerArgs } = prompter(history); + + // Act + const { stream, abort } = await languageModel.sample(prompt); + const action = await reader({ stream, abort, ...readerArgs }); + + // Step + const outcome = await evaluator(history, action); + return [action, outcome]; + }, + maxRetries, + (error) => error instanceof SampleGenerationError, + ); + + // If done, exit + if (progress.isDone()) { + const { result } = progress; + logger?.info('done:', result); + return result as Result; + } + + // Render + printer(...actionAndOutcome); + } + throw new Error('Invocation budget exceeded'); +}; diff --git a/packages/kernel-agents/src/capabilities/capability.test.ts b/packages/kernel-agents/src/capabilities/capability.test.ts new file mode 100644 index 000000000..d364c48b9 --- /dev/null +++ b/packages/kernel-agents/src/capabilities/capability.test.ts @@ -0,0 +1,17 @@ +import { describe, it, expect } from 'vitest'; + +import { capability } from './capability.ts'; + +describe('capability', () => { + it('creates a capability with func and schema', () => { + const testCapability = capability(async () => Promise.resolve('test'), { + description: 'a test capability', + args: {}, + }); + expect(testCapability.func).toBeInstanceOf(Function); + expect(testCapability.schema).toStrictEqual({ + description: 'a test capability', + args: {}, + }); + }); +}); diff --git a/packages/kernel-agents/src/capability.ts b/packages/kernel-agents/src/capabilities/capability.ts similarity index 53% rename from packages/kernel-agents/src/capability.ts rename to packages/kernel-agents/src/capabilities/capability.ts index a379daaac..6bce8cb45 100644 --- a/packages/kernel-agents/src/capability.ts +++ b/packages/kernel-agents/src/capabilities/capability.ts @@ -1,12 +1,18 @@ -import type { Invocation } from './messages.ts'; +import type { ExtractRecordKeys } from '../types/capability.ts'; import type { - Capability, CapabilityRecord, - CapabilitySchema, CapabilitySpec, - ExtractRecordKeys, -} from './types.ts'; + CapabilitySchema, + Capability, +} from '../types.ts'; +/** + * Create a capability specification. + * + * @param func - The function to create a capability specification for + * @param schema - The schema for the capability + * @returns A capability specification + */ export const capability = , Return = null>( func: Capability, schema: CapabilitySchema>, @@ -31,20 +37,21 @@ export const extractCapabilitySchemas = ( ), ); -export const invokeCapabilities = async ( - invocations: Invocation[], +type CapabilityEntry = [string, CapabilitySpec]; +/** + * Extract only the functions from the capabilities + * + * @param capabilities - The capabilities to extract the functions from + * @returns A record mapping capability names to their functions + */ +export const extractCapabilities = ( capabilities: CapabilityRecord, -): Promise<(Invocation & { result: unknown })[]> => - await Promise.all( - invocations.map(async ({ name, args }) => ({ - name, - args, - result: await (async () => { - const toInvoke = capabilities[name]; - if (!toInvoke) { - throw new Error(`Invoked capability ${name} not found`); - } - return await toInvoke.func(args as never); - })(), - })), +): Record< + keyof typeof capabilities, + (typeof capabilities)[keyof typeof capabilities]['func'] +> => + Object.fromEntries( + (Object.entries(capabilities) as unknown as CapabilityEntry[]).map( + ([name, { func }]) => [name, func], + ), ); diff --git a/packages/kernel-agents/src/capabilities/end.ts b/packages/kernel-agents/src/capabilities/end.ts new file mode 100644 index 000000000..9dda4ab27 --- /dev/null +++ b/packages/kernel-agents/src/capabilities/end.ts @@ -0,0 +1,49 @@ +import { ifDefined } from '../utils.ts'; +import { capability } from './capability.ts'; + +/** + * A factory function to make a task's `end` capability, which stores the first + * invocation as the final result and ignores all subsequent invocations. + * + * @template Result - The expected type of the final result. + * @returns A tuple containing the end capability, a function to check if the end capability was invoked, and a function to get the final result. + */ +// eslint-disable-next-line @typescript-eslint/explicit-function-return-type +export const makeEnd = () => { + const result: { final?: Result; attachments?: Record } = {}; + const end = capability( + async ({ + final, + attachments, + }: { + final: Result; + attachments?: Record; + }): Promise => { + if (!Object.hasOwn(result, 'final')) { + Object.assign(result, { final, ...ifDefined({ attachments }) }); + } + }, + { + description: 'Return a final response to the user.', + args: { + final: { + required: true, + type: 'string', + description: + 'A concise final response that restates the requested information.', + }, + attachments: { + required: false, + type: 'object', + description: 'Attachments to the final response.', + }, + }, + }, + ); + return [end, () => 'final' in result, () => result.final as Result] as const; +}; + +/** + * A default `end` capability that does nothing. + */ +export const [end] = makeEnd(); diff --git a/packages/kernel-agents/src/capabilities/examples.test.ts b/packages/kernel-agents/src/capabilities/examples.test.ts new file mode 100644 index 000000000..04184e8e8 --- /dev/null +++ b/packages/kernel-agents/src/capabilities/examples.test.ts @@ -0,0 +1,13 @@ +import { describe, it, expect } from 'vitest'; + +import { exampleCapabilities } from './examples.ts'; + +describe('exampleCapabilities', () => { + it('contains the correct capabilities', () => { + expect(exampleCapabilities).toBeDefined(); + expect(Object.keys(exampleCapabilities)).toStrictEqual([ + 'search', + 'getMoonPhase', + ]); + }); +}); diff --git a/packages/kernel-agents/src/capabilities/examples.ts b/packages/kernel-agents/src/capabilities/examples.ts new file mode 100644 index 000000000..02fe23d71 --- /dev/null +++ b/packages/kernel-agents/src/capabilities/examples.ts @@ -0,0 +1,70 @@ +import { capability } from './capability.ts'; + +type SearchResult = { + source: string; + published: string; + snippet: string; +}; +export const search = capability( + async ({ query }: { query: string }): Promise => [ + { + source: 'https://www.google.com', + published: '2025-01-01', + snippet: `No information found for ${query}`, + }, + ], + { + description: 'Search the web for information.', + args: { query: { type: 'string', description: 'The query to search for' } }, + returns: { + type: 'array', + item: { + type: 'object', + properties: { + source: { + type: 'string', + description: 'The source of the information.', + }, + published: { + type: 'string', + description: 'The date the information was published.', + }, + snippet: { + type: 'string', + description: 'The snippet of information.', + }, + }, + }, + }, + }, +); + +const moonPhases = [ + 'new moon', + 'waxing crescent', + 'first quarter', + 'waxing gibbous', + 'full moon', + 'waning gibbous', + 'third quarter', + 'waning crescent', +] as const; +type MoonPhase = (typeof moonPhases)[number]; + +export const getMoonPhase = capability( + async (): Promise => + moonPhases[Math.floor(Math.random() * moonPhases.length)] as MoonPhase, + { + description: 'Get the current phase of the moon.', + args: {}, + returns: { + type: 'string', + // TODO: Add enum support to the capability schema + // @ts-expect-error - enum is not supported by the capability schema + enum: moonPhases, + description: 'The current phase of the moon.', + }, + }, +); + +export const exampleCapabilities = { search, getMoonPhase }; diff --git a/packages/kernel-agents/src/capabilities/math.test.ts b/packages/kernel-agents/src/capabilities/math.test.ts new file mode 100644 index 000000000..65fe010ce --- /dev/null +++ b/packages/kernel-agents/src/capabilities/math.test.ts @@ -0,0 +1,15 @@ +import { it, expect, describe } from 'vitest'; + +import math from './math.ts'; + +describe('math', () => { + it.each([ + ['count', { word: 'abcdefg' }, 7], + ['add', { summands: [1, 2, 3, 4] }, 10], + ['multiply', { factors: [1, 2, 3, 4] }, 24], + ])('%s(%s) = %s', async (name, args, expected) => { + const capability = math[name as keyof typeof math]; + expect(capability).toBeDefined(); + expect(await capability.func(args as never)).toStrictEqual(expected); + }); +}); diff --git a/packages/kernel-agents/src/example-capabilities.ts b/packages/kernel-agents/src/capabilities/math.ts similarity index 94% rename from packages/kernel-agents/src/example-capabilities.ts rename to packages/kernel-agents/src/capabilities/math.ts index 9305e24c1..d521899d7 100644 --- a/packages/kernel-agents/src/example-capabilities.ts +++ b/packages/kernel-agents/src/capabilities/math.ts @@ -40,8 +40,5 @@ export const multiply = capability( }, ); -export const exampleCapabilities = { - count, - add, - multiply, -}; +const capabilities = { count, add, multiply }; +export default capabilities; diff --git a/packages/kernel-agents/src/capability.test.ts b/packages/kernel-agents/src/capability.test.ts deleted file mode 100644 index 709930d54..000000000 --- a/packages/kernel-agents/src/capability.test.ts +++ /dev/null @@ -1,38 +0,0 @@ -import { describe, it, expect } from 'vitest'; - -import { capability, invokeCapabilities } from './capability.ts'; - -describe('capability', () => { - it('creates a capability', () => { - const testCapability = capability(async () => Promise.resolve('test'), { - description: 'a test capability', - args: {}, - }); - expect(testCapability).toStrictEqual({ - func: expect.any(Function), - schema: { description: 'a test capability', args: {} }, - }); - }); -}); - -describe('invokeCapabilities', () => { - it("invokes the assistant's chosen capability", async () => { - const testCapability = capability(async () => Promise.resolve('test'), { - description: 'a test capability', - args: {}, - }); - const result = await invokeCapabilities( - [{ name: 'testCapability', args: {} }], - { testCapability }, - ); - expect(result).toStrictEqual([ - { name: 'testCapability', args: {}, result: 'test' }, - ]); - }); - - it('throws if the capability is not found', async () => { - await expect( - invokeCapabilities([{ name: 'testCapability', args: {} }], {}), - ).rejects.toThrow('Invoked capability testCapability not found'); - }); -}); diff --git a/packages/kernel-agents/src/default-capabilities.ts b/packages/kernel-agents/src/default-capabilities.ts deleted file mode 100644 index 17e1b0039..000000000 --- a/packages/kernel-agents/src/default-capabilities.ts +++ /dev/null @@ -1,12 +0,0 @@ -import { capability } from './capability.ts'; - -export const end = capability(async ({ final }: { final: string }) => final, { - description: 'Return a final response to the user.', - args: { - final: { - type: 'string', - description: - 'A concise final response that restates the requested information', - }, - }, -}); diff --git a/packages/kernel-agents/src/example-capabilities.test.ts b/packages/kernel-agents/src/example-capabilities.test.ts deleted file mode 100644 index b03870ee0..000000000 --- a/packages/kernel-agents/src/example-capabilities.test.ts +++ /dev/null @@ -1,25 +0,0 @@ -import { describe, it, expect } from 'vitest'; - -import { exampleCapabilities } from './example-capabilities.ts'; - -describe('exampleCapabilities', () => { - it('contains the correct capabilities', () => { - expect(exampleCapabilities).toBeDefined(); - expect(Object.keys(exampleCapabilities)).toStrictEqual([ - 'count', - 'add', - 'multiply', - ]); - }); - - it.each([ - ['count', { word: 'abcdefg' }, 7], - ['add', { summands: [1, 2, 3, 4] }, 10], - ['multiply', { factors: [1, 2, 3, 4] }, 24], - ])('%s(%s) = %s', async (name, args, expected) => { - const capability = - exampleCapabilities[name as keyof typeof exampleCapabilities]; - expect(capability).toBeDefined(); - expect(await capability.func(args as never)).toStrictEqual(expected); - }); -}); diff --git a/packages/kernel-agents/src/index.test.ts b/packages/kernel-agents/src/index.test.ts index 3f79c6c12..c78c640e4 100644 --- a/packages/kernel-agents/src/index.test.ts +++ b/packages/kernel-agents/src/index.test.ts @@ -6,7 +6,7 @@ import * as indexModule from './index.ts'; describe('index', () => { it('has the expected exports', () => { expect(Object.keys(indexModule).sort()).toStrictEqual( - expect.arrayContaining(['makeAgent']), + expect.arrayContaining([]), ); }); }); diff --git a/packages/kernel-agents/src/index.ts b/packages/kernel-agents/src/index.ts index 8fb5c9191..d2493229a 100644 --- a/packages/kernel-agents/src/index.ts +++ b/packages/kernel-agents/src/index.ts @@ -1,2 +1 @@ -export { makeAgent } from './agent.ts'; export type { CapabilityRecord } from './types.ts'; diff --git a/packages/kernel-agents/src/parser.test.ts b/packages/kernel-agents/src/parser.test.ts deleted file mode 100644 index 6fe6ea68e..000000000 --- a/packages/kernel-agents/src/parser.test.ts +++ /dev/null @@ -1,88 +0,0 @@ -import type { Logger } from '@metamask/logger'; -import { describe, it, expect, vi, beforeEach } from 'vitest'; - -import { makeIncrementalParser, gatherStreamingResponse } from './parser.ts'; - -describe('makeIncrementalParser', () => { - let mockLogger: Logger; - - beforeEach(() => { - mockLogger = { - info: vi.fn(), - error: vi.fn(), - warn: vi.fn(), - debug: vi.fn(), - } as unknown as Logger; - }); - - it('parses complete JSON in single chunk', () => { - const parser = makeIncrementalParser({}); - expect(parser('{"key": "value"}')).toStrictEqual({ key: 'value' }); - }); - - it('parses JSON across multiple chunks', () => { - const parser = makeIncrementalParser({}); - expect(parser('{"key": "val')).toBeNull(); - expect(parser('ue", "content": 42}')).toStrictEqual({ - key: 'value', - content: 42, - }); - }); - - it('parses JSON with prefix', () => { - const parser = makeIncrementalParser({ prefix: '{"start": true, ' }); - expect(parser('"end": false}')).toStrictEqual({ start: true, end: false }); - }); - - it('logs parsing attempts when logger provided', () => { - const parser = makeIncrementalParser({ logger: mockLogger }); - parser('{"test": "value"}'); - expect(mockLogger.info).toHaveBeenCalledWith( - 'toParse:', - '{"test": "value"}', - ); - }); - - it('throws error for invalid JSON', () => { - const parser = makeIncrementalParser({}); - expect(() => parser('{"invalid": json}')).toThrow('not valid JSON'); - }); - - it('throws error when max chunk count exceeded', () => { - const parser = makeIncrementalParser({ maxChunkCount: 2 }); - parser('chunk1'); - parser('chunk2'); - expect(() => parser('chunk3')).toThrow('Max chunk count reached'); - }); -}); - -describe('gatherStreamingResponse', () => { - it('gathers complete response from single chunk', async () => { - const stream = (async function* () { - yield { response: '{"key": "value"}' }; - })(); - const parser = makeIncrementalParser({}); - const result = await gatherStreamingResponse({ stream, parse: parser }); - expect(result).toStrictEqual({ key: 'value' }); - }); - - it('gathers response from multiple chunks', async () => { - const stream = (async function* () { - yield { response: '{"key": "val' }; - yield { response: 'ue", "content": 42}' }; - })(); - const parser = makeIncrementalParser({}); - const result = await gatherStreamingResponse({ stream, parse: parser }); - expect(result).toStrictEqual({ key: 'value', content: 42 }); - }); - - it('throws error when stream ends without parse event', async () => { - const stream = (async function* () { - yield { response: 'incomplete json' }; - })(); - const parser = makeIncrementalParser({}); - await expect( - gatherStreamingResponse({ stream, parse: parser }), - ).rejects.toThrow('stream ended without a parse event'); - }); -}); diff --git a/packages/kernel-agents/src/parser.ts b/packages/kernel-agents/src/parser.ts deleted file mode 100644 index a87e2465a..000000000 --- a/packages/kernel-agents/src/parser.ts +++ /dev/null @@ -1,81 +0,0 @@ -import type { Logger } from '@metamask/logger'; - -export type MakeIncrementalParserArgs = { - prefix?: string; - maxChunkCount?: number; - logger?: Logger; -}; -export type IncrementalParser = ( - delta: string, -) => Result | null; -/** - * A quick and dirty 'incremental' parser for a streaming response. - * - * @param args - The arguments to make the incremental parser. - * @param args.prefix - The prefix to prepend to the response - * @param args.maxChunkCount - The maximum number of chunks to parse - * @param args.logger - The logger to use for the incremental parser - * @returns An async function that parses a delta of a streaming response, - * returning the result value if parsing is complete or null otherwise. - */ -export const makeIncrementalParser = ({ - prefix = '', - maxChunkCount = 200, - logger, -}: MakeIncrementalParserArgs): IncrementalParser => { - let response = prefix; - let chunkCount = 0; - let leftBracketCount = prefix.split('{').length - 1; - let rightBracketCount = prefix.split('}').length - 1; - return (delta: string) => { - chunkCount += 1; - const subchunks = delta.split('}'); - const lastSubchunk = subchunks.pop() as string; - for (const subchunk of subchunks) { - rightBracketCount += 1; - leftBracketCount += subchunk.split('{').length - 1; - response += `${subchunk}}`; - logger?.info('toParse:', response); - try { - return JSON.parse(response); - } catch (error) { - // XXX There are other ways to detect an irrecoverable state. - // This is the simplest. - if (leftBracketCount === rightBracketCount) { - throw error; - } - } - } - leftBracketCount += lastSubchunk.split('{').length - 1; - response += lastSubchunk; - if (maxChunkCount && chunkCount > maxChunkCount) { - throw new Error(`Max chunk count reached with response:\n${response}`); - } - return null; - }; -}; - -/** - * Gather a streaming response from an stream of chunks. - * - * @param args - The arguments to gather the streaming response. - * @param args.stream - The stream to gather from. - * @param args.parse - The incremental parser to use to parse the response. - * @returns The parsed response. - */ -export const gatherStreamingResponse = async ({ - stream, - parse, -}: { - stream: AsyncIterable<{ response: string }>; - parse: IncrementalParser; -}): Promise => { - for await (const chunk of stream) { - const delta = (chunk as { response: string }).response; - const parsed = parse(delta); - if (parsed !== null) { - return parsed; - } - } - throw new Error('stream ended without a parse event'); -}; diff --git a/packages/kernel-agents/src/prompt.test.ts b/packages/kernel-agents/src/prompt.test.ts deleted file mode 100644 index 15a8ab09e..000000000 --- a/packages/kernel-agents/src/prompt.test.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { describe, it, expect } from 'vitest'; - -import { AssistantMessage } from './messages.ts'; -import type { Transcript } from './messages.ts'; -import { makeChat } from './prompt.ts'; - -describe('makeChat', () => { - it('should make a chat', () => { - const chat = makeChat({}, 'test'); - expect(chat).toBeDefined(); - expect(chat).toHaveProperty('getPromptAndPrefix'); - expect(chat).toHaveProperty('pushMessages'); - }); - - it('should get the prompt and prefix', () => { - const chat = makeChat({}, 'test'); - const { prompt, prefix } = chat.getPromptAndPrefix(); - expect(prompt).toBeDefined(); - expect(prefix).toBeDefined(); - }); - - it('should push a transcript', () => { - const transcript: Transcript = []; - const chat = makeChat({}, 'test', transcript); - const testMessage = new AssistantMessage({ think: ['test'], invoke: [] }); - chat.pushMessages(testMessage); - expect(transcript.pop()).toStrictEqual(testMessage); - }); -}); diff --git a/packages/kernel-agents/src/prompt.ts b/packages/kernel-agents/src/prompt.ts deleted file mode 100644 index 6fd086501..000000000 --- a/packages/kernel-agents/src/prompt.ts +++ /dev/null @@ -1,52 +0,0 @@ -import { extractCapabilitySchemas } from './capability.ts'; -import { exampleTranscripts } from './example-transcripts.ts'; -import { CapabilitySpecMessage, UserMessage } from './messages.ts'; -import type { Transcript } from './messages.ts'; -import type { CapabilityRecord, Chat } from './types.ts'; - -const stringifyTranscript = (transcript: Transcript, index: number): string => - [ - `TRANSCRIPT ${index + 1}: [`, - transcript.map((message) => message.toJSON()).join(', '), - `]`, - ].join(' '); - -export const makeChat = ( - capabilities: CapabilityRecord, - query: string, - transcript: Transcript = [], -): Chat => { - transcript.push( - new CapabilitySpecMessage(extractCapabilitySchemas(capabilities)), - new UserMessage(query), - ); - const transcripts = [...exampleTranscripts, transcript]; - const preamble = [ - `The following are ${transcripts.length} transcripts of conversations between a user and a state-of-the-art capability-augmented assistant.`, - `Each transcript begins with a JSON-formatted list of the assistant's available capabilities, then proceeds to the conversation history, including user messages, assistant capability invocations, and the results of those invocations.`, - `Note that the assistant efficiently invokes capabilities to perform tasks. This reflects that the assistant is intelligent and can reason logically about function composition, and prefers to invoke external capabilities to prove the correctness of its answers.`, - `Also note that, although the assistant does not necessarily use every available capability, it never attempts to use a capability that was not specified prior in the transcript.`, - ].join('\n'); - /** - * The assistant must either immediately invoke a capability, or think and then - * invoke a capability. In either case, the next piece of the transcript must - * begin with this incomplete JSON prefix. - * - * XXX Subtle changes in the prefix can disrupt the tokenized pattern; - * this prompt string is aligned to llama3's implicit tokenizer boundaries. - */ - const responsePrefix = `{"messageType":"assistant","`; - return { - getPromptAndPrefix: () => { - const rawPrompt = [ - preamble, - ...transcripts.map(stringifyTranscript), - ].join('\n\n'); - const prompt = `${rawPrompt.slice(0, rawPrompt.length - 1)}, ${responsePrefix}`; - return { prompt, prefix: responsePrefix }; - }, - pushMessages: (...messages: Transcript) => { - transcript.push(...messages); - }, - }; -}; diff --git a/packages/kernel-agents/src/strategies/README.md b/packages/kernel-agents/src/strategies/README.md new file mode 100644 index 000000000..6511214d3 --- /dev/null +++ b/packages/kernel-agents/src/strategies/README.md @@ -0,0 +1,31 @@ +# Strategies + +A strategy is an abstraction of a multi-agent environment that presents the history of an agent's observations and actions to a language model in a form that elicits useable responses. + +#### Multi-Agent Environment + +``` + + (T) + .- step -. + V | +.------. (Y_U) .---------. (X_A) .-------. +| | --- act --> | | --- obs --> | | +| User | | Env | | Agent | +| | <-- obs --- | | <-- act --- | | +'------' (X_U) '---------' (Y_A) '-------' + + + +``` + +Strategies respect roughly the following mapping between the arrows in the above diagram and the implementations. + +| Arrow | Implementation | +| ----- | -------------- | +| $X_A$ | Prompter | +| $Y_A$ | Reader | +| $T$ | Evaluator | +| $X_U$ | Printer | + +Note that $Y_U$ --- the implementation of the user's actions --- is missing from this table. User actions are theoretically represented via interjections, but the exact mechanism by which the user interleaves interjections into the history is undefined. diff --git a/packages/kernel-agents/src/strategies/json-agent.ts b/packages/kernel-agents/src/strategies/json-agent.ts new file mode 100644 index 000000000..6d78789ab --- /dev/null +++ b/packages/kernel-agents/src/strategies/json-agent.ts @@ -0,0 +1,8 @@ +import { makeAgent } from '../agent.ts'; +import type { Agent } from '../types.ts'; +import type { State, Action, Observation } from './json/messages.ts'; +import { prepareAttempt } from './json/prepare-attempt.ts'; + +export const makeJsonAgent = ( + args: Parameters>[0], +): Agent => makeAgent(args, prepareAttempt); diff --git a/packages/kernel-agents/src/strategies/json/README.md b/packages/kernel-agents/src/strategies/json/README.md new file mode 100644 index 000000000..408a2eed8 --- /dev/null +++ b/packages/kernel-agents/src/strategies/json/README.md @@ -0,0 +1,11 @@ +# Ocap Kernel JSON Agent + +This strategy provides a capability interface that aligns closely with JSON tool calling generations abroad, but it relies on [JSON parsing](./sample-collector.ts) instead of special token use. The JSON agent is capable of capability composition by steps - the agent can choose to feed the result of one invocation as the arguments to another - but abstract capability composition is not supported. + +## Action Space + +Every step for the agent consists of a nonnegative number of thoughts followed by at least one invocation request. The thoughts are made to preceed the invocation in the transcript so that the invocation request is conditioned on the thoughts, not the other way around. + +## Evaluation + +Capabilities are [invoked](./evaluator.ts) by name, parsing JSON-serialized arguments and returning JSON-serialized results. diff --git a/packages/kernel-agents/src/strategies/json/evaluator.test.ts b/packages/kernel-agents/src/strategies/json/evaluator.test.ts new file mode 100644 index 000000000..fa1cf1290 --- /dev/null +++ b/packages/kernel-agents/src/strategies/json/evaluator.test.ts @@ -0,0 +1,36 @@ +import { describe, it, expect } from 'vitest'; + +import { makeEvaluator } from './evaluator.ts'; +import { AssistantMessage, CapabilityResultMessage } from './messages.ts'; +import { capability } from '../../capabilities/capability.ts'; + +describe('invokeCapabilities', () => { + it("invokes the assistant's chosen capability", async () => { + const testCapability = capability(async () => Promise.resolve('test'), { + description: 'a test capability', + args: {}, + }); + const evaluator = makeEvaluator({ capabilities: { testCapability } }); + const result = await evaluator( + [], + new AssistantMessage({ invoke: [{ name: 'testCapability', args: {} }] }), + ); + expect(result).toStrictEqual( + new CapabilityResultMessage([ + { name: 'testCapability', args: {}, result: 'test' }, + ]), + ); + }); + + it('throws if the capability is not found', async () => { + const evaluator = makeEvaluator({ capabilities: {} }); + await expect( + evaluator( + [], + new AssistantMessage({ + invoke: [{ name: 'testCapability', args: {} }], + }), + ), + ).rejects.toThrow('Invoked capability testCapability not found'); + }); +}); diff --git a/packages/kernel-agents/src/strategies/json/evaluator.ts b/packages/kernel-agents/src/strategies/json/evaluator.ts new file mode 100644 index 000000000..bbf811ab2 --- /dev/null +++ b/packages/kernel-agents/src/strategies/json/evaluator.ts @@ -0,0 +1,48 @@ +import type { Logger } from '@metamask/logger'; + +import { CapabilityResultMessage } from './messages.ts'; +import type { AssistantMessage, Transcript } from './messages.ts'; +import type { CapabilityRecord } from '../../types.ts'; + +export const makeEvaluator = + ({ + capabilities = {}, + logger, + }: { + capabilities?: CapabilityRecord; + logger?: Logger; + }) => + async ( + history: Transcript, + message: AssistantMessage, + ): Promise => { + logger?.info('history:', history); + logger?.info('message:', message.toJSON()); + + // Validate the message. + const invocations = message.messageBody.invoke; + if (!invocations) { + throw new Error('No invoke in message'); + } + if (invocations.length === 0) { + throw new Error('Empty invocation list in message'); + } + + const results = await Promise.all( + invocations.map(async ({ name, args }) => ({ + name, + args, + result: await (async () => { + const toInvoke = capabilities[name]; + if (!toInvoke) { + throw new Error(`Invoked capability ${name} not found`); + } + return await toInvoke.func(args as never); + })(), + })), + ); + logger?.info('results:', results); + const resultMessage = new CapabilityResultMessage(results); + history.push(message, resultMessage); + return resultMessage; + }; diff --git a/packages/kernel-agents/src/example-transcripts.ts b/packages/kernel-agents/src/strategies/json/example-transcripts.ts similarity index 80% rename from packages/kernel-agents/src/example-transcripts.ts rename to packages/kernel-agents/src/strategies/json/example-transcripts.ts index 9099a2e23..e9cc549cd 100644 --- a/packages/kernel-agents/src/example-transcripts.ts +++ b/packages/kernel-agents/src/strategies/json/example-transcripts.ts @@ -1,9 +1,6 @@ /** * Example transcripts for the prompt */ -import { extractCapabilitySchemas } from './capability.ts'; -import { end as endCapability } from './default-capabilities.ts'; -import { exampleCapabilities } from './example-capabilities.ts'; import { CapabilitySpecMessage, UserMessage, @@ -11,14 +8,23 @@ import { CapabilityResultMessage, } from './messages.ts'; import type { Transcript } from './messages.ts'; +import { extractCapabilitySchemas } from '../../capabilities/capability.ts'; +import { end as endCapability } from '../../capabilities/end.ts'; +import { search as searchCapability } from '../../capabilities/examples.ts'; +import { + count as countCapability, + multiply as multiplyCapability, +} from '../../capabilities/math.ts'; -const { end, count, add, multiply } = extractCapabilitySchemas({ - ...exampleCapabilities, +const { end, search, count, multiply } = extractCapabilitySchemas({ end: endCapability, + search: searchCapability, + count: countCapability, + multiply: multiplyCapability, }); const simpleSemanticTask: Transcript = [ - new CapabilitySpecMessage({ end, add }), + new CapabilitySpecMessage({ end, search }), new UserMessage('What color is a banana?'), new AssistantMessage({ think: [ diff --git a/packages/kernel-agents/src/message.test.ts b/packages/kernel-agents/src/strategies/json/message.test.ts similarity index 59% rename from packages/kernel-agents/src/message.test.ts rename to packages/kernel-agents/src/strategies/json/message.test.ts index 8a3fe8519..0ccf906bd 100644 --- a/packages/kernel-agents/src/message.test.ts +++ b/packages/kernel-agents/src/strategies/json/message.test.ts @@ -3,30 +3,25 @@ import { describe, it, expect } from 'vitest'; import { AssistantMessage } from './messages.ts'; describe('AssistantMessage', () => { - it('should create an assistant message', () => { + it('creates an assistant message', () => { const message = new AssistantMessage({ think: ['test'], invoke: [] }); expect(message).toBeDefined(); }); - it('serializes think before invoke if present', () => { + it('serializes think before invoke', () => { const message = new AssistantMessage({ invoke: [{ name: 'test', args: {} }], think: ['test'], }); const json = message.toJSON(); - const [left, right] = json.split('think'); - expect(left).toContain('messageType'); - expect(left).not.toContain('invoke'); - expect(right).not.toContain('messageType'); - expect(right).toContain('invoke'); + expect(json.indexOf('think')).toBeLessThan(json.indexOf('invoke')); }); - it('serializes if think is not present', () => { + it('serializes without think when absent', () => { const message = new AssistantMessage({ invoke: [{ name: 'test', args: {} }], }); const json = message.toJSON(); - expect(json).toContain('messageType'); expect(json).not.toContain('think'); expect(json).toContain('invoke'); }); diff --git a/packages/kernel-agents/src/messages.ts b/packages/kernel-agents/src/strategies/json/messages.ts similarity index 92% rename from packages/kernel-agents/src/messages.ts rename to packages/kernel-agents/src/strategies/json/messages.ts index 7dbb4799f..cafa8e83d 100644 --- a/packages/kernel-agents/src/messages.ts +++ b/packages/kernel-agents/src/strategies/json/messages.ts @@ -85,3 +85,12 @@ export type CapabilityResultMessageJson = { messageType: 'capabilityResult'; results: (Invocation & { result: unknown })[]; }; + +export type Observation = + | UserMessage + | CapabilitySpecMessage + | CapabilityResultMessage; + +export type Action = AssistantMessage; + +export type State = (Observation | Action)[]; diff --git a/packages/kernel-agents/src/strategies/json/prepare-attempt.ts b/packages/kernel-agents/src/strategies/json/prepare-attempt.ts new file mode 100644 index 000000000..7cd98730d --- /dev/null +++ b/packages/kernel-agents/src/strategies/json/prepare-attempt.ts @@ -0,0 +1,78 @@ +import { mergeDisjointRecords } from '@metamask/kernel-utils'; +import type { Logger } from '@metamask/logger'; + +import { makeEvaluator } from './evaluator.ts'; +import { CapabilitySpecMessage, UserMessage } from './messages.ts'; +import type { State, Action, Observation } from './messages.ts'; +import { makePrinter } from './printer.ts'; +import { makePrompter } from './prompter.ts'; +import { makeReader } from './reader.ts'; +import { extractCapabilitySchemas } from '../../capabilities/capability.ts'; +import { makeEnd } from '../../capabilities/end.ts'; +import type { + CapabilityRecord, + Context, + Objective, + PrepareAttempt, + PREP, + Progress, +} from '../../types.ts'; +import { ifDefined } from '../../utils.ts'; + +export const prepareAttempt: PrepareAttempt = < + Result, +>({ + objective: { intent, judgment }, + context, + options = {}, +}: { + objective: Objective; + context: Context; + options?: { + taskLogger?: Logger; + printLogger?: Logger; + }; +}): [PREP, Progress] => { + const { taskLogger, printLogger } = options; + + const [end, didEnd, getEnd] = makeEnd(); + + const capabilities = mergeDisjointRecords(context.capabilities, { + end, + }) as CapabilityRecord; + + const history = [ + new CapabilitySpecMessage(extractCapabilitySchemas(capabilities)), + new UserMessage(intent), + ]; + + const progress: Progress = { + history, + isDone: () => { + if (didEnd()) { + const result = getEnd(); + if (!judgment(result)) { + throw new Error(`Invalid result: ${result as string}`); + } + Object.assign(progress, { result }); + return true; + } + return false; + }, + // result: not defined until judgment is satisfied + }; + + const readLogger = taskLogger?.subLogger({ tags: ['read'] }); + const evalLogger = taskLogger?.subLogger({ tags: ['eval'] }); + + return [ + [ + makePrompter(), + makeReader(ifDefined({ logger: readLogger })), + makeEvaluator(ifDefined({ capabilities, logger: evalLogger })), + makePrinter({ history, ...ifDefined({ logger: printLogger }) }), + // TODO: Fix these types + ] as unknown as PREP, + progress, + ]; +}; diff --git a/packages/kernel-agents/src/strategies/json/printer.ts b/packages/kernel-agents/src/strategies/json/printer.ts new file mode 100644 index 000000000..47f1888e0 --- /dev/null +++ b/packages/kernel-agents/src/strategies/json/printer.ts @@ -0,0 +1,28 @@ +import type { Logger } from '@metamask/logger'; + +import type { + AssistantMessage, + CapabilityResultMessage, + Transcript, +} from './messages.ts'; + +export const makePrinter = ({ + history, + logger, +}: { + history: Transcript; + logger?: Logger; +}) => { + for (const message of history) { + logger?.info(message.toJSON()); + } + return ( + message: AssistantMessage, + result: CapabilityResultMessage | null, + ) => { + logger?.info(message.toJSON()); + if (result) { + logger?.info(result.toJSON()); + } + }; +}; diff --git a/packages/kernel-agents/src/strategies/json/prompter.test.ts b/packages/kernel-agents/src/strategies/json/prompter.test.ts new file mode 100644 index 000000000..bc04e7007 --- /dev/null +++ b/packages/kernel-agents/src/strategies/json/prompter.test.ts @@ -0,0 +1,14 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import { describe, it, expect } from 'vitest'; + +import { makePrompter } from './prompter.ts'; + +describe('makePrompter', () => { + it('returns prompt and prefix', () => { + const prompter = makePrompter(); + const { prompt, readerArgs } = prompter([]); + expect(typeof prompt).toBe('string'); + expect(typeof readerArgs.prefix).toBe('string'); + }); +}); diff --git a/packages/kernel-agents/src/strategies/json/prompter.ts b/packages/kernel-agents/src/strategies/json/prompter.ts new file mode 100644 index 000000000..2db4247d5 --- /dev/null +++ b/packages/kernel-agents/src/strategies/json/prompter.ts @@ -0,0 +1,43 @@ +import { exampleTranscripts } from './example-transcripts.ts'; +import type { Transcript } from './messages.ts'; + +const makePreamble = (nTranscripts: number): string => { + const firstLinePrefix = + nTranscripts === 1 + ? 'The following is a transcript of a conversation' + : `The following are ${nTranscripts} transcripts of conversations`; + const secondLinePrefix = nTranscripts === 1 ? 'The' : 'Each'; + return [ + `${firstLinePrefix} between a user and a state-of-the-art capability-augmented assistant.`, + `${secondLinePrefix} transcript begins with a JSON-formatted list of the assistant's available capabilities, then proceeds to the conversation history, including user messages, assistant capability invocations, and the results of those invocations.`, + `Note that the assistant efficiently invokes capabilities to perform tasks. This reflects that the assistant is intelligent and can reason logically about function composition, and prefers to invoke external capabilities to prove the correctness of its answers.`, + `Also note that, although the assistant does not necessarily use every available capability, it never attempts to use a capability that was not specified prior in the transcript.`, + ].join('\n'); +}; + +/** + * The assistant must either immediately invoke a capability, or think and then + * invoke a capability. In either case, the next piece of the transcript must + * begin with this incomplete JSON prefix. + * + * XXX Subtle changes in the prefix can disrupt the tokenized pattern; + * this prompt string is aligned to llama3's implicit tokenizer boundaries. + */ +const prefix = `{"messageType":"assistant","`; + +export const makePrompter = () => (history: Transcript) => { + const transcripts = [...exampleTranscripts, history]; + const preamble = makePreamble(transcripts.length); + const rawPrompt = [ + preamble, + ...transcripts.map((transcript, index) => + [ + `TRANSCRIPT ${index + 1}: [`, + transcript.map((message) => message.toJSON()).join(', '), + `]`, + ].join(' '), + ), + ].join('\n\n'); + const prompt = `${rawPrompt.slice(0, rawPrompt.length - 1)}, ${prefix}`; + return { prompt, readerArgs: { prefix } }; +}; diff --git a/packages/kernel-agents/src/strategies/json/reader.ts b/packages/kernel-agents/src/strategies/json/reader.ts new file mode 100644 index 000000000..09a640488 --- /dev/null +++ b/packages/kernel-agents/src/strategies/json/reader.ts @@ -0,0 +1,33 @@ +import type { Logger } from '@metamask/logger'; + +import { AssistantMessage } from './messages.ts'; +import type { AssistantMessageJson } from './messages.ts'; +import { makeSampleCollector } from './sample-collector.ts'; +import { gatherStreamingResponse, ifDefined, withAbort } from '../../utils.ts'; + +export const makeReader = + ({ logger }: { logger?: Logger }) => + async ({ + stream, + abort, + prefix, + }: { + stream: AsyncIterable<{ response: string }>; + abort: () => Promise; + prefix: string; + }) => { + const sampleLogger = logger?.subLogger({ tags: ['sample'] }); + const gatherLogger = logger?.subLogger({ tags: ['gather'] }); + return await withAbort(abort, async (): Promise => { + const json = await gatherStreamingResponse({ + stream, + parse: makeSampleCollector({ + prefix, + ...ifDefined({ logger: sampleLogger }), + }), + ...ifDefined({ logger: gatherLogger }), + }); + logger?.info('assistant message:', json); + return new AssistantMessage(json); + }); + }; diff --git a/packages/kernel-agents/src/strategies/json/sample-collector.test.ts b/packages/kernel-agents/src/strategies/json/sample-collector.test.ts new file mode 100644 index 000000000..d00980922 --- /dev/null +++ b/packages/kernel-agents/src/strategies/json/sample-collector.test.ts @@ -0,0 +1,77 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import type { Logger } from '@metamask/logger'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; + +import { makeSampleCollector } from './sample-collector.ts'; + +describe('makeSampleCollector', () => { + let mockLogger: Logger; + + beforeEach(() => { + mockLogger = { + info: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + debug: vi.fn(), + } as unknown as Logger; + }); + + it('collects complete JSON in single chunk', () => { + const collector = makeSampleCollector({}); + expect(collector('{"key": "value"}')).toStrictEqual({ key: 'value' }); + }); + + it('collects JSON across multiple chunks', () => { + const collector = makeSampleCollector({}); + expect(collector('{"key": "val')).toBeNull(); + expect(collector('ue", "content": 42}')).toStrictEqual({ + key: 'value', + content: 42, + }); + }); + + it('collects JSON with prefix', () => { + const collector = makeSampleCollector({ prefix: '{"start": true, ' }); + expect(collector('"end": false}')).toStrictEqual({ + start: true, + end: false, + }); + }); + + it('logs collection attempts when logger provided', () => { + const collector = makeSampleCollector({ logger: mockLogger }); + collector('{"test": "value"}'); + expect(mockLogger.info).toHaveBeenCalledWith( + 'toParse:', + '{"test": "value"}', + ); + }); + + it('throws error for invalid JSON', () => { + const invalidJson = '{"invalid": json}'; + const collector = makeSampleCollector({}); + expect(() => collector(invalidJson)).toThrow( + expect.objectContaining({ + message: 'LLM generated invalid response.', + cause: expect.objectContaining({ + message: expect.stringContaining(invalidJson), + }), + }), + ); + }); + + it('throws error when max chunk count exceeded', () => { + const collector = makeSampleCollector({ maxChunkCount: 2 }); + collector('chunk1'); + collector('chunk2'); + expect(() => collector('chunk3')).toThrow( + expect.objectContaining({ + message: 'LLM generated invalid response.', + cause: expect.objectContaining({ + message: expect.stringContaining('Max chunk count reached'), + }), + }), + ); + }); +}); diff --git a/packages/kernel-agents/src/strategies/json/sample-collector.ts b/packages/kernel-agents/src/strategies/json/sample-collector.ts new file mode 100644 index 000000000..8449196e5 --- /dev/null +++ b/packages/kernel-agents/src/strategies/json/sample-collector.ts @@ -0,0 +1,65 @@ +import { SampleGenerationError } from '@metamask/kernel-errors'; +import type { Logger } from '@metamask/logger'; + +import type { SampleCollector } from '../../types.ts'; + +/** + * A quick and dirty sample collector for a streaming response. + * + * @param args - The arguments to make the sample collector. + * @param args.prefix - The prefix to prepend to the response + * @param args.maxChunkCount - The maximum number of chunks to parse + * @param args.logger - The logger to use for the sample collector + * @returns A function that collects a delta of a streaming response, + * returning the result value if collecting is complete or null otherwise. + */ +export const makeSampleCollector = ({ + prefix = '', + maxChunkCount = 200, + logger, +}: { + prefix?: string; + maxChunkCount?: number; + logger?: Logger; +}): SampleCollector => { + let response = prefix; + let chunkCount = 0; + let leftBracketCount = prefix.split('{').length - 1; + let rightBracketCount = prefix.split('}').length - 1; + return (delta: string) => { + chunkCount += 1; + const subchunks = delta.split('}'); + const lastSubchunk = subchunks.pop() as string; + for (const subchunk of subchunks) { + rightBracketCount += 1; + leftBracketCount += subchunk.split('{').length - 1; + response += `${subchunk}}`; + logger?.info('toParse:', response); + try { + const result = JSON.parse(response); + logger?.info('parsed:', result); + return result; + } catch (cause) { + // XXX There are other ways to detect an irrecoverable state. + // This is the simplest. + if (leftBracketCount === rightBracketCount) { + throw new SampleGenerationError( + response, + cause instanceof Error + ? cause + : new Error('Invalid JSON', { cause }), + ); + } + } + } + leftBracketCount += lastSubchunk.split('{').length - 1; + response += lastSubchunk; + if (maxChunkCount && chunkCount > maxChunkCount) { + throw new SampleGenerationError( + response, + new Error('Max chunk count reached'), + ); + } + return null; + }; +}; diff --git a/packages/kernel-agents/src/strategies/repl-agent.ts b/packages/kernel-agents/src/strategies/repl-agent.ts new file mode 100644 index 000000000..c6e5972c7 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl-agent.ts @@ -0,0 +1,8 @@ +import { makeAgent } from '../agent.ts'; +import type { Agent } from '../types.ts'; +import type { State, Action, Observation } from './repl/messages.ts'; +import { prepareAttempt } from './repl/prepare-attempt.ts'; + +export const makeReplAgent = ( + args: Parameters>[0], +): Agent => makeAgent(args, prepareAttempt); diff --git a/packages/kernel-agents/src/strategies/repl/README.md b/packages/kernel-agents/src/strategies/repl/README.md new file mode 100644 index 000000000..1a42119e6 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/README.md @@ -0,0 +1,17 @@ +# Ocap Kernel REPL Agent + +This strategy provides a language model shell that endows an agent with arbitrary javascript definition and invocation capabilities. + +## Action Space + +An action is a single, unambiguously complete, valid javascript statement. For example, `let x = 10` is valid but not unambiguosly complete, because it could be extended to either `let x = 10;` or `let x = 100`, which effect different consequences. On the other hand `let x = 10;` is unambiguously complete, because extending the string results in the creation of a second javascript statement. Likewise, `function x() {}` is unambiguously complete because any extension is either semantically equivalent or invalid, but `const x = () => {}` is not due to caveats like `const x = () => {} && sideEffect()` or `const x = () => {}.prototype`. + +An agent's user conceives of actions as capability invocations like `buyNFT('penguin', ETH(4));`, but other types of statements constitute internal actions the agent can take to compose together its capabilities. Comment statements represent the action of thinking. Elementary mathematics capabilities are available via numeric literals and intrinsic operators. We aim for an agent action space consisting of any single javascript statement. + +### REPL Evaluation + +Javascript statements are broadly separated into _declarations_, _expressions_, and _everything else_. We can tell them apart using a [parser](./parse/javascript.ts). + +A declaration alters the namespace of the evaluated lexical scope, i.e. `let x;` or `function foo() {}`. Declarations, expressions and everything else can all alter the state of the REPL (by executing code), e.g. `console.log('hello');` or `foo('bar');`, but an expression has an implied return value, e.g. `foo(bar);` might return `3` even though the value was not assigned. + +The REPL [evaluator](./evaluator.ts) wraps agent code statements in [helper code](./prepare-evaluation.ts) and evaluates the wrapped code in a [compartment](./compartment.ts). In the case of a declaration, the helper code needs to capture the value of the newly assigned namespace entry or entries. In the case of an expression, the helper code captures the implicit return value of the expression so it can be observed. In any case, the helper code writes the namespace into the scope of the evaluating compartment before agent code evaluation and captures all mutable names of the namespace after agent code evaluation. diff --git a/packages/kernel-agents/src/strategies/repl/compartment.test.ts b/packages/kernel-agents/src/strategies/repl/compartment.test.ts new file mode 100644 index 000000000..18c3e61a0 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/compartment.test.ts @@ -0,0 +1,27 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import { describe, it, expect } from 'vitest'; + +import { makeCompartment } from './compartment.ts'; + +describe('compartment', () => { + it('gets the value from an expression', () => { + const compartment = makeCompartment(); + expect(compartment.evaluate('1 + 1')).toBe(2); + expect(compartment.evaluate('1 + 1; 2 + 2;')).toBe(4); + }); + + it('gets the value from an async expression', async () => { + const compartment = makeCompartment(); + expect( + await compartment.evaluate( + '(async () => await Promise.resolve(1 + 1))()', + ), + ).toBe(2); + expect( + await compartment.evaluate( + '(async () => { await Promise.resolve(1 + 1); return await Promise.resolve(2 + 2); })()', + ), + ).toBe(4); + }); +}); diff --git a/packages/kernel-agents/src/strategies/repl/compartment.ts b/packages/kernel-agents/src/strategies/repl/compartment.ts new file mode 100644 index 000000000..55ad44749 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/compartment.ts @@ -0,0 +1,6 @@ +import 'ses'; +// Assigns Compartment to globalThis. +export type Compartment = { evaluate: (code: string) => unknown }; +export const makeCompartment = ( + endowments: Record = {}, +): Compartment => new Compartment(endowments); diff --git a/packages/kernel-agents/src/strategies/repl/evaluator-error.test.ts b/packages/kernel-agents/src/strategies/repl/evaluator-error.test.ts new file mode 100644 index 000000000..79c8a0bcf --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/evaluator-error.test.ts @@ -0,0 +1,103 @@ +import 'ses'; +import '@ocap/repo-tools/test-utils/mock-endoify'; +import { EvaluatorError, SampleGenerationError } from '@metamask/kernel-errors'; +import { describe, it, expect } from 'vitest'; + +import { processEvaluationError, stripStackTrace } from './evaluator-error.ts'; +import { ERROR } from './symbols.ts'; + +describe('stripStackTrace', () => { + it('strips stack trace from Error', () => { + const error = new Error('test error'); + error.stack = 'Error: test error\n at test.js:1:1'; + const stripped = stripStackTrace(error); + expect(stripped).toBeInstanceOf(Error); + expect((stripped as Error).message).toBe('test error'); + const strippedError = stripped as Error; + expect(strippedError.stack).not.toContain('at test.js'); + }); + + it('preserves error cause chain', () => { + const inner = new Error('inner'); + const outer = new Error('outer', { cause: inner }); + const stripped = stripStackTrace(outer); + expect((stripped as Error).message).toBe('outer'); + expect((stripped as Error).cause).toBeInstanceOf(Error); + expect(((stripped as Error).cause as Error).message).toBe('inner'); + }); + + it('returns non-Error values unchanged', () => { + expect(stripStackTrace('string')).toBe('string'); + expect(stripStackTrace(42)).toBe(42); + expect(stripStackTrace(null)).toBeNull(); + }); +}); + +describe('processEvaluationError', () => { + it('does nothing when result has no error', () => { + const result: { [ERROR]?: unknown } = {}; + expect(() => processEvaluationError(result, 'code')).not.toThrow(); + }); + + it('throws EvaluatorError for internal errors', () => { + const result: { [ERROR]?: unknown } = { + [ERROR]: new EvaluatorError('test', 'code', new Error('cause')), + }; + expect(() => processEvaluationError(result, 'code')).toThrow( + EvaluatorError, + ); + }); + + it('throws SampleGenerationError for SyntaxError', () => { + const result: { [ERROR]?: unknown } = { + [ERROR]: new SyntaxError('syntax error'), + }; + expect(() => processEvaluationError(result, 'bad code')).toThrow( + SampleGenerationError, + ); + }); + + it('throws SampleGenerationError for ReferenceError', () => { + const result: { [ERROR]?: unknown } = { + [ERROR]: new ReferenceError('reference error'), + }; + expect(() => processEvaluationError(result, 'bad code')).toThrow( + SampleGenerationError, + ); + }); + + it('throws SampleGenerationError for Error objects with SyntaxError name', () => { + const error = Object.assign(new Error('error'), { name: 'SyntaxError' }); + const result: { [ERROR]?: unknown } = { [ERROR]: error }; + expect(() => processEvaluationError(result, 'bad code')).toThrow( + SampleGenerationError, + ); + }); + + it('processes and assigns valid-feedback errors', () => { + const result: { [ERROR]?: unknown } = { + [ERROR]: new Error('user error'), + }; + processEvaluationError(result, 'code'); + expect(result[ERROR]).toBeInstanceOf(Error); + const processedError = result[ERROR] as Error; + expect(processedError.message).toBe('user error'); + }); + + it('wraps non-Error values as Error for valid-feedback', () => { + const result: { [ERROR]?: unknown } = { [ERROR]: 'string error' }; + processEvaluationError(result, 'code'); + expect(result[ERROR]).toBeInstanceOf(Error); + expect((result[ERROR] as Error).message).toBe('string error'); + }); + + it('strips stack traces from valid-feedback errors', () => { + const error = new Error('user error'); + error.stack = 'Error: user error\n at test.js:1:1'; + const result: { [ERROR]?: unknown } = { [ERROR]: error }; + processEvaluationError(result, 'code'); + const processedError = result[ERROR] as Error; + expect(processedError.message).toBe('user error'); + expect(processedError.stack).not.toContain('at test.js'); + }); +}); diff --git a/packages/kernel-agents/src/strategies/repl/evaluator-error.ts b/packages/kernel-agents/src/strategies/repl/evaluator-error.ts new file mode 100644 index 000000000..e89a56927 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/evaluator-error.ts @@ -0,0 +1,67 @@ +import { SampleGenerationError, EvaluatorError } from '@metamask/kernel-errors'; + +import type { EvaluationResult } from './prepare-evaluation.ts'; +import { ERROR } from './symbols.ts'; + +/** + * Strips stack traces from an error while preserving the message and cause chain. + * + * @param error - The error to strip stack traces from. + * @returns The error without stack traces. + */ +export const stripStackTrace = (error: unknown): unknown => { + if (!(error instanceof Error)) { + return error; + } + return new Error( + error.message, + ...(error.cause ? [{ cause: stripStackTrace(error.cause) }] : []), + ); +}; + +const asError = (error: unknown): Error => + error instanceof Error ? error : new Error(String(error)); + +const isSyntaxError = (error: unknown): boolean => + error instanceof SyntaxError || + (error instanceof Error && error.name === 'SyntaxError'); + +const isReferenceError = (error: unknown): boolean => + error instanceof ReferenceError || + (error instanceof Error && error.name === 'ReferenceError'); + +/** + * Processes any error in the evaluation result. If an error exists, classifies it + * and either throws (for retry/exit errors) or processes and assigns it back to + * the result (for valid feedback errors). + * + * @param result - The evaluation result object that may contain an error. + * @param code - The code that was being evaluated. + * @throws {SampleGenerationError} For syntax/reference errors that should trigger retry. + * @throws {EvaluatorError} For internal errors that should exit the attempt. + */ +export const processEvaluationError = ( + result: EvaluationResult, + code: string, +): void => { + if (!Object.hasOwn(result, ERROR)) { + return; + } + const error = result[ERROR]; + + // Check if this is already an EvaluatorError (thrown by safe wrappers) + if (error instanceof EvaluatorError) { + throw error; + } + + // Check if this is a sample generation error (syntax/reference errors) + if (isSyntaxError(error) || isReferenceError(error)) { + throw new SampleGenerationError( + code, + stripStackTrace(asError(error)) as Error, + ); + } + + // All other errors are valid feedback (capability errors, NotImplemented, etc.) + result[ERROR] = stripStackTrace(asError(error)); +}; diff --git a/packages/kernel-agents/src/strategies/repl/evaluator.test.ts b/packages/kernel-agents/src/strategies/repl/evaluator.test.ts new file mode 100644 index 000000000..ee8bd8e0b --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/evaluator.test.ts @@ -0,0 +1,180 @@ +import 'ses'; // We need the real Compartment, not the mock. +import '@ocap/repo-tools/test-utils/mock-endoify'; +import { Logger } from '@metamask/logger'; +import { describe, it, expect, beforeEach, vi } from 'vitest'; + +import { makeEvaluator } from './evaluator.ts'; +import { + CommentMessage, + ImportMessage, + ResultMessage, + StatementMessage, +} from './messages.ts'; +import type { ReplTranscript } from './messages.ts'; +import type { EvaluatorState } from './types.ts'; + +describe('evaluator', () => { + let state: EvaluatorState; + let evaluator: ReturnType; + const logger = new Logger('test'); + + beforeEach(() => { + state = { consts: {}, lets: {} }; + evaluator = makeEvaluator({ initState: () => state, logger }); + }); + + const evaluateStatements = async (...statements: string[]) => { + const history: ReplTranscript = []; + for (const statement of statements) { + await evaluator(history, StatementMessage.fromCode(statement)); + } + }; + + describe('evaluates statements', () => { + it('evaluates expressions', async () => { + const history: ReplTranscript = []; + const result = await evaluator(history, StatementMessage.fromCode('42;')); + expect(result).toBeInstanceOf(ResultMessage); + expect(result?.messageBody.return).toBe('42'); + }); + + it('evaluates declarations and updates state', async () => { + await evaluateStatements( + 'const x = 1;', + 'let y = x + 2;', + 'function foo() { return "bar"; }', + 'const z = foo();', + ); + expect(state).toStrictEqual({ + consts: { x: 1, z: 'bar' }, + lets: { y: 3, foo: expect.any(Function) }, + }); + }); + + it('evaluates loops', async () => { + await evaluateStatements( + 'let x = 1;', + 'for (let i = 1; i <= 4; i++) { x *= i; }', + ); + expect(state.lets).toStrictEqual({ x: 24 }); + }); + + it('captures mutated let variables', async () => { + await evaluateStatements('let x = 1;', 'x = 2;', 'x = 3;'); + expect(state.lets).toStrictEqual({ x: 3 }); + }); + }); + + describe('handles statement types', () => { + it('handles comment messages', async () => { + const history: ReplTranscript = []; + const comment = new CommentMessage('// comment'); + const result = await evaluator(history, comment); + expect(result).toBeNull(); + expect(history).toHaveLength(1); + expect(history[0]).toBe(comment); + }); + + it('handles import messages', async () => { + const history: ReplTranscript = []; + const importMsg = new ImportMessage('import { x } from "y";'); + const result = await evaluator(history, importMsg); + expect(result).toBeInstanceOf(ResultMessage); + expect(history).toHaveLength(2); + expect(history[0]).toBe(importMsg); + }); + + it('rejects variable declarations', async () => { + const statement = StatementMessage.fromCode('var x = 1;'); + await expect(evaluator([], statement)).rejects.toThrow( + 'Variable declarations are not allowed', + ); + }); + }); + + describe('creates result messages', () => { + it('creates result with return value', async () => { + const history: ReplTranscript = []; + const result = await evaluator( + history, + StatementMessage.fromCode('"hello";'), + ); + expect(result).toBeInstanceOf(ResultMessage); + expect(result?.messageBody.return).toBe('"hello"'); + }); + + it('creates result with declaration value', async () => { + const history: ReplTranscript = []; + const result = await evaluator( + history, + StatementMessage.fromCode('const x = 42;'), + ); + expect(result).toBeInstanceOf(ResultMessage); + expect(result?.messageBody.value).toBe('x: 42'); + }); + + it('returns null when no result keys are present', async () => { + const history: ReplTranscript = []; + const result = await evaluator( + history, + StatementMessage.fromCode('for (let i = 0; i < 1; i++) {}'), + ); + expect(result).toBeNull(); + }); + }); + + describe('manages state', () => { + it('does not update state when evaluation has error', async () => { + const initialState = { consts: {}, lets: {} }; + const history: ReplTranscript = []; + await evaluator( + history, + StatementMessage.fromCode('const x = undefined.y;'), + ); + expect(state).toStrictEqual(initialState); + }); + }); + + describe('integrates capabilities', () => { + it('evaluates capability calls', async () => { + const mockCap = vi.fn().mockReturnValue('result'); + const evaluatorWithCap = makeEvaluator({ + initState: () => state, + capabilities: { + testCap: { + func: mockCap, + schema: { description: 'Test capability', args: {} }, + }, + }, + }); + const history: ReplTranscript = []; + await evaluatorWithCap(history, StatementMessage.fromCode('testCap();')); + expect(mockCap).toHaveBeenCalled(); + }); + + it('handles multiple capabilities', async () => { + const cap1 = vi.fn().mockReturnValue(1); + const cap2 = vi.fn().mockReturnValue(2); + const evaluatorWithCaps = makeEvaluator({ + initState: () => state, + capabilities: { + cap1: { + func: cap1, + schema: { description: 'Capability 1', args: {} }, + }, + cap2: { + func: cap2, + schema: { description: 'Capability 2', args: {} }, + }, + }, + }); + const history: ReplTranscript = []; + await evaluatorWithCaps( + history, + StatementMessage.fromCode('cap1() + cap2();'), + ); + expect(cap1).toHaveBeenCalled(); + expect(cap2).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/kernel-agents/src/strategies/repl/evaluator.ts b/packages/kernel-agents/src/strategies/repl/evaluator.ts new file mode 100644 index 000000000..78262de09 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/evaluator.ts @@ -0,0 +1,123 @@ +import { EvaluatorError } from '@metamask/kernel-errors'; +import { mergeDisjointRecords } from '@metamask/kernel-utils'; +import type { Logger } from '@metamask/logger'; + +import { makeCompartment } from './compartment.ts'; +import { processEvaluationError } from './evaluator-error.ts'; +import { + CommentMessage, + EvaluationMessage, + ImportMessage, + ResultMessage, +} from './messages.ts'; +import type { ReplTranscript, StatementMessage } from './messages.ts'; +import { prepareEvaluation } from './prepare-evaluation.ts'; +import { ERROR, RETURN } from './symbols.ts'; +import type { EvaluatorState } from './types.ts'; +import { extractCapabilities } from '../../capabilities/capability.ts'; +import type { CapabilityRecord } from '../../types.ts'; +import { ifDefined } from '../../utils.ts'; + +const validateStatement = ( + statement: StatementMessage, +): { earlyResult?: ResultMessage | null } => { + if (statement instanceof CommentMessage) { + // Comments are not evaluated. + return { earlyResult: null }; + } + if (statement instanceof ImportMessage) { + // Imports are not implemented yet. + return { + earlyResult: new ResultMessage({ + [ERROR]: new SyntaxError('Additional imports are not allowed.'), + }), + }; + } + if (!(statement instanceof EvaluationMessage)) { + // This should never happen. + throw new Error( + [ + 'Internal: Unknown statement', + `statement: ${statement.messageBody.node.text}`, + `type: ${statement.messageBody.node.toString()}`, + ].join('\n'), + ); + } + // Otherwise, proceed with the evaluation. + return {}; +}; + +export const makeEvaluator = ({ + capabilities = {}, + logger, + // For testing purposes. + initState = () => ({ consts: {}, lets: {} }), +}: { + capabilities?: CapabilityRecord; + logger?: Logger; + initState?: () => EvaluatorState; +}) => { + const state: EvaluatorState = initState(); + + return async ( + history: ReplTranscript, + statement: StatementMessage, + ): Promise => { + // Validate the statement. + const validation = validateStatement(statement); + if ('earlyResult' in validation) { + const { earlyResult } = validation; + history.push(statement, ...(earlyResult ? [earlyResult] : [])); + return earlyResult; + } + + // Prepare the evaluation. + const { code, endowments, result, commit } = prepareEvaluation( + state, + statement.messageBody.node, + ifDefined({ logger }), + ); + + logger?.info('capabilities:', capabilities); + logger?.info('endowments:', endowments); + logger?.info('evaluating:', code); + + // Prepare the compartment. + const compartmentEndowments = mergeDisjointRecords( + endowments, + extractCapabilities(capabilities), + ); + const compartment = makeCompartment(compartmentEndowments); + + try { + await compartment.evaluate(code); + } catch (cause) { + const asError = (error: unknown): Error => + error instanceof Error ? error : new Error(String(error)); + // Errors that evade $catch are always an EvaluationError + throw new EvaluatorError( + 'REPL evaluation failed', + code, + // If the error is already an EvaluatorError, we rethrow with the code, + cause instanceof EvaluatorError + ? (cause.cause as Error) + : // Otherwise, wrap the error as EvaluatorError + asError(cause), + ); + } + + // Handle errors caught by $catch (user code errors) + processEvaluationError(result, code); + + // Update the state and return the result + const stepResult = [ERROR, RETURN, 'value'].some((key) => + Object.hasOwn(result, key), + ) + ? new ResultMessage(result) + : null; + history.push(statement, ...(stepResult ? [stepResult] : [])); + commit(); + + return stepResult; + }; +}; diff --git a/packages/kernel-agents/src/strategies/repl/example-transcripts.ts b/packages/kernel-agents/src/strategies/repl/example-transcripts.ts new file mode 100644 index 000000000..fb21beaa7 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/example-transcripts.ts @@ -0,0 +1,142 @@ +import { + CommentMessage, + EvaluationMessage, + ImportMessage, + InterjectionMessage, + ResultMessage, +} from './messages.ts'; +import type { ReplTranscript } from './messages.ts'; +import { RETURN } from './symbols.ts'; +import { extractCapabilitySchemas } from '../../capabilities/capability.ts'; +import { end } from '../../capabilities/end.ts'; +import { search } from '../../capabilities/examples.ts'; +import type { CapabilityRecord } from '../../types.ts'; + +// eslint-disable-next-line @typescript-eslint/explicit-function-return-type +const transcribeCapabilities = (capabilities: CapabilityRecord) => [ + ImportMessage.fromNames(Object.keys(capabilities)), + new ResultMessage( + { value: extractCapabilitySchemas(capabilities) }, + { compress: false }, + ), +]; + +/* eslint-disable no-tabs */ + +const treeSearchTask: ReplTranscript = [ + ...transcribeCapabilities({ end, search }), + new InterjectionMessage('What is the oldest tree in South America?'), + new CommentMessage( + '// This information is too specific for me to know on my own.', + ), + new EvaluationMessage( + `await search({ query: 'oldest tree in South America' });`, + ), + new ResultMessage({ + [RETURN]: [ + { + source: + 'https://unofficialnetworks.com/2022/06/02/worlds-oldest-tree-south-america/', + published: '2022-06-02', + snippet: + 'Barichivich turned to statistical modeling to determine the Alerce Milenario’s full age. He used complete cores from other alerce trees and information on how environmental factors and random variation affect tree growth to calibrate a model that simulated a range of possible ages the tree had reached by the beginning of the period covered by the partial core, along with a probability for each age. The method yielded an overall age estimate of 5484 years old, with an 80% chance that the tree has lived for more than 5000 years.', + }, + { + source: + 'https://economictimes.indiatimes.com/news/new-updates/worlds-oldest-tree-great-grandfather-tree-in-chile-to-reveal-the-planets-secrets/articleshow/99690454.cms', + published: '2023-04-22', + snippet: + 'The 5,000-year-old Great Grandfather tree will replace the current oldest tree, the Methuselah, which is 4,850 years old. While Methuselah is located in California, United States of America, the Great Grandfather tree is in Santiago, Chile, South America.\n\nThe Great Grandfather tree is a form of cypress, also known as the Fitzroya Cupressoides or the Patagonian cypress, while the Methuselah is a pine. The Patagonian cypress tree is the largest one found in South America.', + }, + { + source: 'https://forestry.com/guides/top-10-oldest-trees-in-the-world/', + published: '2025-07-14', + snippet: + 'Top 10 Oldest Living Trees in the World\n\nRank Name Species Location Age (Years)\n1 Methuselah Great Basin bristlecone pine White Mountains, California, USA 4,856 (Verified)\n2 Alerce Milenario (Gran Abuelo) Patagonian cypress Alerce Costero National Park, Chile 3,654 (Verified, up to 5,484 debated)\n3 Sarv-e Abarqu Cypress Yazd Province, Iran Approximately 4,000+ (Estimated)\n4 Llangernyw Yew Yew Llangernyw Village, Wales, UK Approximately 4,000 (Estimated)\n5 Olive Tree of Vouves Olive Crete, Greece Approximately 3,000+ (Estimated)\n6 BLK227 Bald cypress Black River, North Carolina, USA 2,650 (Verified)\n7 Jōmon Sugi Cryptomeria Yakushima, Japan 2,000–3,000+ (Estimated)\n8 Chestnut Tree of One Hundred Horses Chestnut Mount Etna, Sicily, Italy 2,000–4,000 (Estimated)\n9 General Sherman Giant sequoia Sequoia National Park, California, USA Approximately 2,500 (Estimated)\n10 Patriarca da Floresta Cariniana legalis Brazil Approximately 2,000+ (Estimated)', + }, + ], + }), + new EvaluationMessage( + 'await end({ final: "According to multiple sources, the oldest tree in South America is Alerce Milenario." });', + ), +]; + +const simpleSemanticTask: ReplTranscript = [ + ...transcribeCapabilities({ end }), + new InterjectionMessage('What color is a banana?'), + ...[ + 'Bananas can be either yellow or green, depending on the variety and ripeness.', + 'Typically, people think of yellow bananas when they think of bananas.', + 'I should give the typical response, but clarify that I am assuming the banana is ripe.', + ].map((comment) => new CommentMessage(`// ${comment}`)), + new EvaluationMessage('const response = "A banana is yellow when ripe.";'), + new ResultMessage({ value: { response: 'A banana is yellow when ripe.' } }), + new EvaluationMessage('await end({ final: response });'), +]; + +const multiStepCalculation: ReplTranscript = [ + ...transcribeCapabilities({ end }), + new InterjectionMessage( + 'What is the size of a matrix with rows indexed by the letters of "piano" and columns by the letters of "guitar"?', + ), + new CommentMessage( + '// The answer will be the product of the length of the word "piano" and the length of the word "guitar".', + ), + new EvaluationMessage( + // eslint-disable-next-line no-template-curly-in-string + 'const response = `Such a matrix would have ${"piano".length * "guitar".length} elements.`;', + ), + new ResultMessage({ + value: { response: 'Such a matrix would have 30 elements.' }, + }), + new EvaluationMessage('await end({ final: response });'), +]; + +const functionDefinition: ReplTranscript = [ + ...transcribeCapabilities({ end }), + new InterjectionMessage( + 'What is the average depth of the following tree? [a, [b, c], d, [e, [f, g]]]', + ), + new CommentMessage( + '// I can solve this problem by recursively finding the depth of each node in the tree.', + ), + new CommentMessage( + '// First, let me define a function to check if a node is a leaf.', + ), + new EvaluationMessage('const isLeaf = (node) => node.length === undefined;'), + new ResultMessage({ value: { isLeaf: '[Function isLeaf]' } }), + new CommentMessage( + '// Next, let me define a tree walking function to calculate the total depth and node count.', + ), + new CommentMessage( + '// I should initialize the total depth and node count before walking the tree.', + ), + new EvaluationMessage('let [totalDepth, nodeCount] = [0, 0];'), + new ResultMessage({ value: { totalDepth: 0, nodeCount: 0 } }), + new EvaluationMessage( + [ + 'function walk(node, depth = 0) {', + ' if (isLeaf(node)) {', + ' totalDepth += depth;', + ' nodeCount += 1;', + ' return;', + ' }', + ' node.forEach(child => { walk(child, depth + 1); });', + '}', + ].join('\n'), + ), + new ResultMessage({ value: { walk: '[Function walk]' } }), + new EvaluationMessage('walk(["a", ["b", "c"], "d", ["e", ["f", "g"]]]);'), + new EvaluationMessage('const averageDepth = totalDepth / nodeCount;'), + new ResultMessage({ value: { averageDepth: 2 } }), + new EvaluationMessage('await end({ final: String(averageDepth) });'), +]; + +/* eslint-enable no-tabs */ + +export const exampleTranscripts = [ + simpleSemanticTask, + multiStepCalculation, + treeSearchTask, + functionDefinition, +] as const; diff --git a/packages/kernel-agents/src/strategies/repl/messages.test.ts b/packages/kernel-agents/src/strategies/repl/messages.test.ts new file mode 100644 index 000000000..accf516b9 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/messages.test.ts @@ -0,0 +1,198 @@ +import 'ses'; +import '@ocap/repo-tools/test-utils/mock-endoify'; +import type { SyntaxNode } from 'tree-sitter'; +import { describe, it, expect } from 'vitest'; + +import { + CommentMessage, + EvaluationMessage, + ImportMessage, + InterjectionMessage, + MAX_LINES, + ResultMessage, + StatementMessage, + observeJson, +} from './messages.ts'; +import { ERROR, RETURN } from './symbols.ts'; + +describe('observeJson', () => { + it.each([ + ['hello', '"hello"'], + [42, '42'], + [true, 'true'], + [null, 'null'], + [undefined, 'undefined'], + ])('stringifies primitive %s', (value, expected) => { + expect(observeJson(value)).toBe(expected); + }); + + it('uses toJsonString for JsonObservable', () => { + expect(observeJson({ toJsonString: () => '"custom"' })).toBe('"custom"'); + }); + + it('stringifies arrays and objects', () => { + expect(observeJson([1, 2, 3])).toBe('[\n 1,\n 2,\n 3\n]'); + expect(observeJson({ a: 1, b: 2 })).toBe('{\n "a": 1,\n "b": 2\n}'); + }); +}); + +describe('StatementMessage', () => { + it.each([ + ['// comment', CommentMessage, 'comment'], + ['import { foo } from "bar";', ImportMessage, 'import'], + ['1 + 1;', EvaluationMessage, 'evaluation'], + ])('creates %s from code', (code, Class, type) => { + const message = StatementMessage.fromCode(code); + expect(message).toBeInstanceOf(Class); + expect(message.messageType).toBe(type); + expect(message.messageBody.code).toBe(code); + expect(message.toReplString()).toBe(`> ${code}`); + }); +}); + +describe.each([ + ['CommentMessage', CommentMessage, '// comment', 'comment'], + ['ImportMessage', ImportMessage, 'import { foo } from "bar";', 'import'], + ['EvaluationMessage', EvaluationMessage, '1 + 1;', 'evaluation'], +])('%s', (_, Class, code, type) => { + it('creates message with optional node', () => { + const node = { type } as SyntaxNode; + const message = new Class(code, node); + expect(message.messageType).toBe(type); + expect(message.messageBody.code).toBe(code); + expect(message.messageBody.node).toBe(node); + expect(message.toReplString()).toBe(`> ${code}`); + }); +}); + +describe('ImportMessage', () => { + it.each([ + [['foo', 'bar'], 'import { foo, bar } from "@ocap/abilities";'], + [['foo'], 'import { foo } from "@ocap/abilities";'], + [[], 'import { } from "@ocap/abilities";'], + ])('creates from names %s', (names, expected) => { + const message = ImportMessage.fromNames(names); + expect(message.messageType).toBe('import'); + expect(message.messageBody.code).toBe(expected); + }); +}); + +describe('InterjectionMessage', () => { + it('creates message and serializes', () => { + const message = new InterjectionMessage('test'); + expect(message.messageType).toBe('interjection'); + expect(message.messageBody.interjection).toBe('test'); + expect(message.toReplString()).toBe('! test'); + expect(message.toJsonString()).toBe('{ "messageType": "interjection", }'); + }); +}); + +describe('ResultMessage', () => { + const longValue = Array.from({ length: 2 * MAX_LINES }, (_, i) => ({ + [`key${i}`]: `value${i}`, + })); + const longString = Array.from( + { length: 2 * MAX_LINES }, + (_, i) => `line ${i}`, + ).join('\n'); + + it('creates message with return value', () => { + const message = new ResultMessage({ [RETURN]: 'hello' }); + expect(message.messageType).toBe('result'); + expect(message.messageBody.return).toBeDefined(); + expect(message.messageBody.error).toBeUndefined(); + expect(message.messageBody.value).toBeUndefined(); + }); + + it('creates message with error', () => { + const message = new ResultMessage({ [ERROR]: new Error('test') }); + expect(message.messageType).toBe('result'); + expect(message.messageBody.error).toBeDefined(); + expect(message.messageBody.return).toBeUndefined(); + expect(message.messageBody.value).toBeUndefined(); + }); + + it('creates message with value', () => { + const message = new ResultMessage({ value: { x: 1 } }); + expect(message.messageType).toBe('result'); + expect(message.messageBody.value).toBeDefined(); + expect(message.messageBody.return).toBeUndefined(); + expect(message.messageBody.error).toBeUndefined(); + }); + + it('formats error correctly', () => { + const message = new ResultMessage({ [ERROR]: new Error('test') }); + expect(message.messageBody.error).toBe('Error: test'); + }); + + it('creates message with all result types', () => { + const message = new ResultMessage({ + [ERROR]: new Error('test'), + [RETURN]: 'returned', + value: { x: 1 }, + }); + expect(message.messageType).toBe('result'); + expect(message.messageBody.error).toBeDefined(); + expect(message.messageBody.return).toBeDefined(); + expect(message.messageBody.value).toBeDefined(); + }); + + it('creates message with empty result', () => { + const message = new ResultMessage({}); + expect(message.messageType).toBe('result'); + expect(message.messageBody.error).toBeUndefined(); + expect(message.messageBody.return).toBeUndefined(); + expect(message.messageBody.value).toBeUndefined(); + }); + + it('compresses long output by default', () => { + const message = new ResultMessage({ value: { output: longValue } }); + const replString = message.toReplString(); + expect(replString.split('\n').length).toBeLessThan(60); + expect(replString).toContain('// ...'); + }); + + it.each([ + ['long error', { [ERROR]: new Error(longString) }], + ['long return', { [RETURN]: longString }], + ])('compresses %s by default', (_, result) => { + const message = new ResultMessage(result); + const replString = message.toReplString(); + expect(replString.split('\n').length).toBeLessThan(30); + }); + + it('does not compress when disabled', () => { + const message = new ResultMessage( + { value: { output: longValue } }, + { compress: false }, + ); + const replString = message.toReplString(); + expect(replString.split('\n').length).toBeGreaterThan(30); + expect(replString).not.toContain('// ...'); + }); + + it('handles multiline values', () => { + const message = new ResultMessage({ + value: { a: 'line1\nline2\nline3', b: 'single' }, + }); + expect(message.messageBody.value).toContain('line1'); + expect(message.messageBody.value).toContain('line3'); + }); + + it('serializes to JSON', () => { + const message = new ResultMessage({ [RETURN]: 'test' }); + expect(message.toJsonString()).toBe('{ "messageType": "result", }'); + }); +}); + +describe('ReplMessage toJsonString', () => { + it('filters non-JsonObservable and includes JsonObservable values', () => { + const message = new InterjectionMessage('test'); + expect(message.toJsonString()).not.toContain('node'); + const observable = { toJsonString: () => '"custom"' }; + ( + message as { messageBody: { test?: typeof observable } } + ).messageBody.test = observable; + expect(message.toJsonString()).toContain('"test": "custom"'); + }); +}); diff --git a/packages/kernel-agents/src/strategies/repl/messages.ts b/packages/kernel-agents/src/strategies/repl/messages.ts new file mode 100644 index 000000000..97ba4e8f6 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/messages.ts @@ -0,0 +1,229 @@ +import { stringify } from '@metamask/kernel-utils'; +import type { SyntaxNode } from 'tree-sitter'; + +import { makeCompartment } from './compartment.ts'; +import { parse } from './parse/javascript.ts'; +import { ERROR, RETURN } from './symbols.ts'; +import { Message } from '../../types/messages.ts'; +import type { Transcript } from '../../types/messages.ts'; + +export type StatementType = 'import' | 'evaluation' | 'comment'; + +export type ReplMessageType = StatementType | 'interjection' | 'result'; + +export type ReplObservable = { + toReplString(): string; +}; + +export type JsonObservable = { + toJsonString(): string; +}; + +type JsonKey = string | number; + +type Primitive = string | number | boolean | null | undefined; + +type JsonObservation = + | Primitive + | JsonObservable + | JsonObservation[] + | { [key: JsonKey]: JsonObservation }; + +const isJsonObservable = (value: unknown): value is JsonObservable => + typeof value === 'object' && value !== null && 'toJsonString' in value; + +export const observeJson = (value: JsonObservation): string => + isJsonObservable(value) ? value.toJsonString() : stringify(value); + +export abstract class ReplMessage< + Type extends ReplMessageType, + Body extends Record, + > + extends Message + implements ReplObservable, JsonObservable +{ + toJsonString(): string { + const messageType = `"messageType": "${this.messageType}"`; + const bodyEntries = Object.entries(this.messageBody) + .filter(([, value]) => isJsonObservable(value)) + .map( + ([key, value]) => `"${key}": ${observeJson(value as JsonObservation)}`, + ); + return `{ ${messageType}, ${bodyEntries.join(', ')} }`; + } + + abstract toReplString(): string; +} + +// Statements comprise the action space of the REPL agent. +export class StatementMessage< + Type extends StatementType = StatementType, +> extends ReplMessage { + toReplString(): string { + return `> ${this.messageBody.code}`; + } + + static fromCode(code: string): StatementMessage { + return statementMessageFromCode(code); + } +} + +const parseStatement = ( + code: string, + name?: string, + bound?: StatementType[], +): SyntaxNode => { + const { rootNode } = parse(code); + const [statement] = rootNode.children as [SyntaxNode]; + if (bound && !bound.includes(statement.type as StatementType)) { + throw new Error(`"${code}" is not a valid ${name}.`); + } + return statement; +}; + +export class CommentMessage extends StatementMessage<'comment'> { + constructor(code: string, statement?: SyntaxNode) { + const node = statement ?? parseStatement(code, 'comment'); + super('comment', { code, node }); + } +} + +export class ImportMessage extends StatementMessage<'import'> { + constructor(code: string, statement?: SyntaxNode) { + const node = statement ?? parseStatement(code, 'import_statement'); + super('import', { code, node }); + } + + static fromNames(names: string[]): ImportMessage { + const code = `import { ${names.join(', ')} } from "@ocap/abilities";`; + return new ImportMessage(code); + } +} + +export class EvaluationMessage extends StatementMessage<'evaluation'> { + constructor(code: string, statement?: SyntaxNode) { + const node = statement ?? parseStatement(code, 'expression_statement'); + super('evaluation', { code, node }); + } +} + +/** + * Make a statement message from code. + * + * @param code - The code to parse. + * @returns A statement message. + */ +function statementMessageFromCode(code: string): StatementMessage { + const { rootNode } = parse(code); + const [statement] = rootNode.children as [SyntaxNode]; + switch (statement.type) { + case 'comment': + return new CommentMessage(code, statement); + case 'import_statement': + return new ImportMessage(code, statement); + default: // XXX Maybe too permissive as the default case. + return new EvaluationMessage(code, statement); + } +} + +export class InterjectionMessage extends ReplMessage< + 'interjection', + { interjection: string } +> { + constructor(interjection: string) { + super('interjection', { interjection }); + } + + toReplString(): string { + return `! ${this.messageBody.interjection}`; + } +} + +const $stringify = harden(stringify); + +export const MAX_LINES = 20; +const HEAD_LENGTH = 14; +const ELLIPSIS = '// ...'; + +const hardenEntry = ([key, value]: [string, unknown]): [string, string] => { + const hardValue = harden(value); + const compartment = makeCompartment({ hardValue, $stringify }); + const stringified = compartment.evaluate('$stringify(hardValue);') as string; + return [key, stringified]; +}; + +type ResultMessageBody = { value?: string; error?: string; return?: string }; + +const compressLines = ( + lines: string[], + { + maxLines = MAX_LINES, + headLength = HEAD_LENGTH, + ellipsis = ELLIPSIS, + }: { maxLines?: number; headLength?: number; ellipsis?: string } = {}, +): string[] => + lines.length > maxLines + ? [ + ...lines.slice(0, headLength), + ellipsis, + ...lines.slice(-(maxLines - headLength - 1)), + ] + : lines; + +type ResultArg = { + value?: Record; + [ERROR]?: unknown; + [RETURN]?: unknown; +}; + +export class ResultMessage extends ReplMessage<'result', ResultMessageBody> { + readonly #compress: boolean; + + constructor( + result: ResultArg, + { compress = true }: { compress?: boolean } = {}, + ) { + const messageBody: ResultMessageBody = {}; + if (Object.hasOwn(result, ERROR)) { + const error = result[ERROR] as Error; + messageBody.error = `${error.name}: ${error.message}`; + } + if (Object.hasOwn(result, RETURN)) { + messageBody.return = hardenEntry(['', result[RETURN]])[1]; + } + if (Object.hasOwn(result, 'value')) { + messageBody.value = Object.entries( + result.value as Record, + ) + .map(hardenEntry) + .map(([key, val]) => `${key}: ${val}`) + .join('\n'); + } + super('result', messageBody); + this.#compress = compress; + } + + toReplString(): string { + const lines = { + error: this.messageBody.error?.split('\n') ?? [], + return: this.messageBody.return?.split('\n') ?? [], + value: this.messageBody.value?.split('\n') ?? [], + }; + const transform = this.#compress + ? compressLines + : (value: string[]) => value; + return [ + ...transform(lines.error), + ...transform(lines.return), + ...transform(lines.value), + ].join('\n'); + } +} + +export type ReplTranscript = Transcript; + +export type Observation = InterjectionMessage | ResultMessage; + +export type Action = StatementMessage; + +export type State = (Observation | Action)[]; diff --git a/packages/kernel-agents/src/strategies/repl/parse/identifiers.test.ts b/packages/kernel-agents/src/strategies/repl/parse/identifiers.test.ts new file mode 100644 index 000000000..aa9883f41 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/parse/identifiers.test.ts @@ -0,0 +1,76 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import type { SyntaxNode } from 'tree-sitter'; +import { describe, it, expect } from 'vitest'; + +import { extractNamesFromDeclaration } from './identifiers.ts'; +import { parse } from './javascript.ts'; + +describe('extractNamesFromDeclaration', () => { + it.each([ + ['function x() { return 1; }', ['x']], + ['function* x() { yield 1; }', ['x']], + ['async function x() { return 1; }', ['x']], + ['async function* x() { yield 1; }', ['x']], + ])('extracts declaration names from %s', (text, expected) => { + const { rootNode } = parse(text); + const [statement] = rootNode.children as [SyntaxNode]; + expect(extractNamesFromDeclaration(statement)).toStrictEqual(expected); + }); + + describe.each(['const', 'let', 'var'])('variable declaration', (keyword) => { + it.each([ + // Variable declaration + [`${keyword} x = 1;`, ['x']], + [`${keyword} x = foo(bar);`, ['x']], + // Array destructuring + [`${keyword} [x] = [foo()];`, ['x']], + [`${keyword} [x, y] = [1, 2];`, ['x', 'y']], + [`${keyword} [x,, y] = [1, 2, 3];`, ['x', 'y']], + [`${keyword} [x, ...rest] = arr;`, ['x', 'rest']], + [`${keyword} [x = 1, y] = arr;`, ['x', 'y']], + // Object destructuring + [`${keyword} { x } = { x: foo() };`, ['x']], + [`${keyword} { x, y } = { x: 1, y: 2 };`, ['x', 'y']], + [`${keyword} { x, ...rest } = obj;`, ['x', 'rest']], + [`${keyword} { x: a, y: b } = obj;`, ['a', 'b']], + [`${keyword} { x: { y, z } } = { x: { y: 1, z: 2 } };`, ['y', 'z']], + [`${keyword} { x: { y, z: w } } = obj;`, ['y', 'w']], + [`${keyword} { a = "b" } = { c: "d" };`, ['a']], + // Arrow function definition + [`${keyword} foo = (x) => x;`, ['foo']], + [`${keyword} foo = ([x]) => x;`, ['foo']], + [`${keyword} foo = ({x}) => x;`, ['foo']], + [`${keyword} x = 1, y = 2, z = 3;`, ['x', 'y', 'z']], + ])('extracts declaration names from %s', (text, expected) => { + const { rootNode } = parse(text); + const [statement] = rootNode.children as [SyntaxNode]; + expect(extractNamesFromDeclaration(statement)).toStrictEqual(expected); + }); + }); + + it('throws for declaration with childless declarator', () => { + expect(() => + // @ts-expect-error Destructive testing + extractNamesFromDeclaration({ + type: 'lexical_declaration', + children: [ + { type: 'const', text: 'const' }, + { type: 'declarator', children: [] }, + { type: ';', text: ';' }, + ] as unknown as [SyntaxNode, SyntaxNode, SyntaxNode], + }), + ).toThrow('Internal: Declarator missing pattern'); + }); + + it.each([ + ['expression_statement', '1 + 1'], + ['for_statement', 'for (let i = 0; i < 10; i++) { console.log(i); }'], + ])('throws for %s', (statementType, code) => { + const { rootNode } = parse(code); + const [statement] = rootNode.children as [SyntaxNode]; + expect(() => extractNamesFromDeclaration(statement)).toThrow( + `Unknown declaration type: ${statementType}`, + ); + }); +}); diff --git a/packages/kernel-agents/src/strategies/repl/parse/identifiers.ts b/packages/kernel-agents/src/strategies/repl/parse/identifiers.ts new file mode 100644 index 000000000..836620529 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/parse/identifiers.ts @@ -0,0 +1,71 @@ +import type { SyntaxNode } from 'tree-sitter'; + +const extractIdentifiers = (pattern?: SyntaxNode): string[] => { + if (!pattern) { + // This would be a tree-sitter error. + throw new Error('Internal: Declarator missing pattern'); + } + + const identifiers: string[] = []; + + // Handle the case where the pattern itself is an identifier (simple cases like 'const x = 1') + if (pattern.type === 'identifier') { + return [pattern.text]; + } + + for (const child of pattern.children) { + switch (child.type) { + case 'identifier': + case 'shorthand_property_identifier_pattern': + identifiers.push(child.text); + break; + default: + // Recursively handle other pattern types + if (child.type.endsWith('_pattern')) { + identifiers.push(...extractIdentifiers(child)); + } + } + } + + return identifiers; +}; + +/** + * Given a declaration, extract the names of the declared identifiers. + * These names cover the keys of the namespace delta resulting from evaluation. + * + * @param declaration - The declaration to extract the names from. + * A declaration is a top level node which is also one of the following: + * - a const statement + * - a let statement + * - a var statement + * - a function declaration + * @returns The names of the identifiers declared in the declaration. + */ +export const extractNamesFromDeclaration = ( + declaration: SyntaxNode, +): string[] => { + const variableIdentifiers = ({ children }: SyntaxNode): string[] => + children + .filter(({ type }) => + ['variable_declarator', 'declarator'].includes(type), + ) + .flatMap(({ children: [pattern] }) => extractIdentifiers(pattern)); + const functionIdentifier = ({ children }: SyntaxNode): string => { + const identifier = children.find((child) => child.type === 'identifier'); + if (!identifier) { + throw new Error('Internal: Function declaration missing identifier'); + } + return identifier.text; + }; + switch (declaration.type) { + case 'lexical_declaration': + case 'variable_declaration': + return variableIdentifiers(declaration); + case 'function_declaration': + case 'generator_function_declaration': + return [functionIdentifier(declaration)]; + default: + throw new Error(`Unknown declaration type: ${declaration.type}`); + } +}; diff --git a/packages/kernel-agents/src/strategies/repl/parse/javascript.test.ts b/packages/kernel-agents/src/strategies/repl/parse/javascript.test.ts new file mode 100644 index 000000000..7e9c01918 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/parse/javascript.test.ts @@ -0,0 +1,37 @@ +/** + * This suite declares expected AST nodes for various JavaScript expressions. + */ +import '@ocap/repo-tools/test-utils/mock-endoify'; +import type { SyntaxNode } from 'tree-sitter'; +import { describe, it, expect } from 'vitest'; + +import { parse } from './javascript.ts'; + +describe('javascript parser', () => { + it.each([ + // An array of expected proposals from the LLM and their AST types. + ['const a = 1;', 'lexical_declaration'], + ['let { b } = { b: 2 };', 'lexical_declaration'], + ['var [ c ] = [ 3 ];', 'variable_declaration'], + ['const x = () => 42;', 'lexical_declaration'], + ['function y() { return 42; }', 'function_declaration'], + ['function* z() { yield 42; }', 'generator_function_declaration'], + ['1 + 1', 'expression_statement'], + ['for (let i = 0; i < 4; i++) { console.log(i); }', 'for_statement'], + ['(function() { return 42; })()', 'expression_statement'], + // Note: the below case becomes a function_declaration once the body closes. + ['function test() {', 'expression_statement'], + ['let length = 11, width = 47, height = 63;', 'lexical_declaration'], + ['// This is a comment', 'comment'], + ['import { foo } from "@ocap/abilities";', 'import_statement'], + ])('parses `%s` as %s', (expression: string, expectedType: string) => { + const tree = parse(expression); + const { rootNode } = tree; + expect(rootNode.text).toStrictEqual(expression); + expect(rootNode.type).toBe('program'); + expect(rootNode.children).toHaveLength(1); + const [child] = rootNode.children as [SyntaxNode]; + expect(child.text).toStrictEqual(expression); + expect(child.type).toBe(expectedType); + }); +}); diff --git a/packages/kernel-agents/src/strategies/repl/parse/javascript.ts b/packages/kernel-agents/src/strategies/repl/parse/javascript.ts new file mode 100644 index 000000000..a1e4e0fb9 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/parse/javascript.ts @@ -0,0 +1,15 @@ +/* eslint-disable @typescript-eslint/naming-convention */ +import Parser from 'tree-sitter'; +import JavaScript from 'tree-sitter-javascript'; +/* eslint-enable @typescript-eslint/naming-convention */ + +const parser = new Parser(); +parser.setLanguage(JavaScript as Parser.Language); + +/** + * Parse a JavaScript statement into a tree-sitter abstract syntax tree. + * + * @param text - The text to parse. + * @returns The parsed tree-sitter abstract syntax tree. + */ +export const parse = (text: string): Parser.Tree => parser.parse(text); diff --git a/packages/kernel-agents/src/strategies/repl/prepare-attempt.ts b/packages/kernel-agents/src/strategies/repl/prepare-attempt.ts new file mode 100644 index 000000000..54e06dfe1 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/prepare-attempt.ts @@ -0,0 +1,87 @@ +import { mergeDisjointRecords } from '@metamask/kernel-utils'; +import type { Logger } from '@metamask/logger'; + +import { makeEvaluator } from './evaluator.ts'; +import type { State, Observation, Action } from './messages.ts'; +import { + InterjectionMessage, + ImportMessage, + ResultMessage, +} from './messages.ts'; +import { makePrinter } from './printer.ts'; +import { makePrompter } from './prompter.ts'; +import { makeReader } from './reader.ts'; +import { extractCapabilitySchemas } from '../../capabilities/capability.ts'; +import { makeEnd } from '../../capabilities/end.ts'; +import type { + PREP, + Objective, + Context, + CapabilityRecord, + Progress, + PrepareAttempt, +} from '../../types.ts'; +import { ifDefined } from '../../utils.ts'; + +export const prepareAttempt: PrepareAttempt = < + Result, +>({ + objective: { intent, judgment }, + context, + options = {}, +}: { + objective: Objective; + context: Context; + options?: { + seed?: number; + tokenLength?: number; + taskLogger?: Logger; + printLogger?: Logger; + }; +}): [PREP, Progress] => { + const { seed, tokenLength, taskLogger, printLogger } = options; + + const [end, didEnd, getEnd] = makeEnd(); + + const capabilities = mergeDisjointRecords(context.capabilities, { + end, + }) as CapabilityRecord; + + const history = [ + new InterjectionMessage(intent), + ImportMessage.fromNames(Object.keys(capabilities)), + new ResultMessage({ value: extractCapabilitySchemas(capabilities) }), + ]; + + const progress: Progress = { + history, + isDone: () => { + if (didEnd()) { + const result = getEnd(); + if (!judgment(result)) { + throw new Error(`Invalid result: ${result as string}`, { + cause: result, + }); + } + Object.assign(progress, { result }); + return true; + } + return false; + }, + // result: not defined until judgment is satisfied + }; + + const readLogger = taskLogger?.subLogger({ tags: ['read'] }); + const evalLogger = taskLogger?.subLogger({ tags: ['eval'] }); + + return [ + [ + makePrompter(ifDefined({ seed, tokenLength })), + makeReader(ifDefined({ logger: readLogger })), + makeEvaluator(ifDefined({ capabilities, logger: evalLogger })), + makePrinter({ history, ...ifDefined({ logger: printLogger }) }), + // TODO: Fix these types + ] as unknown as PREP, + progress, + ]; +}; diff --git a/packages/kernel-agents/src/strategies/repl/prepare-evaluation.test.ts b/packages/kernel-agents/src/strategies/repl/prepare-evaluation.test.ts new file mode 100644 index 000000000..182c32758 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/prepare-evaluation.test.ts @@ -0,0 +1,452 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; +import { Logger } from '@metamask/logger'; +import type { SyntaxNode } from 'tree-sitter'; +import { describe, it, expect, beforeEach, vi } from 'vitest'; + +import * as identifiers from './parse/identifiers.ts'; +import { prepareEvaluation } from './prepare-evaluation.ts'; +import { ERROR, RETURN } from './symbols.ts'; +import type { EvaluatorState } from './types.ts'; + +vi.mock('./parse/identifiers.ts', () => ({ + extractNamesFromDeclaration: vi.fn(), +})); + +describe('prepareEvaluation', () => { + let state: EvaluatorState; + const mockExtractNames = vi.mocked(identifiers.extractNamesFromDeclaration); + + const createMockNode = ( + type: string, + text: string, + children: SyntaxNode[] = [], + ): SyntaxNode => + ({ + type, + text, + children, + toString: () => type, + }) as SyntaxNode; + + beforeEach(() => { + state = { consts: {}, lets: {} }; + vi.clearAllMocks(); + }); + + describe('lexical_declaration', () => { + it('prepares const declaration', () => { + const statement = createMockNode('lexical_declaration', 'const x = 1;', [ + createMockNode('const', 'const'), + ]); + mockExtractNames.mockReturnValue(['x']); + + const evaluable = prepareEvaluation(state, statement); + + expect(evaluable.endowments.consts).toBe(state.consts); + expect(evaluable.endowments.lets).toBe(state.lets); + expect(evaluable.endowments.$frame.$capture).toBeDefined(); + expect(evaluable.endowments.$frame.$catch).toBeDefined(); + expect(evaluable.endowments.$frame.$return).toBeDefined(); + expect(evaluable.code).toContain('const x = 1;'); + expect(evaluable.code).toContain('$return(x);'); + expect(evaluable.result.value).toBeDefined(); + }); + + it('prepares let declaration', () => { + const statement = createMockNode('lexical_declaration', 'let y = 2;', [ + createMockNode('let', 'let'), + ]); + mockExtractNames.mockReturnValue(['y']); + + const evaluable = prepareEvaluation(state, statement); + + expect(evaluable.endowments.$frame.$return).toBeDefined(); + expect(evaluable.code).toContain('let y = 2;'); + expect(evaluable.code).toContain('$return(y);'); + }); + + it('throws for unknown lexical_declaration', () => { + const statement = createMockNode( + 'lexical_declaration', + 'unknown x = 1;', + [createMockNode('unknown', 'unknown')], + ); + + expect(() => prepareEvaluation(state, statement)).toThrow( + 'Unknown lexical_declaration', + ); + }); + + it('commits const declaration to consts', () => { + const statement = createMockNode('lexical_declaration', 'const x = 1;', [ + createMockNode('const', 'const'), + ]); + mockExtractNames.mockReturnValue(['x']); + + const evaluable = prepareEvaluation(state, statement); + evaluable.result.value = { x: 1 }; + evaluable.commit(); + + expect(state.consts).toStrictEqual({ x: 1 }); + expect(state.lets).toStrictEqual({}); + }); + + it('commits let declaration to lets', () => { + const statement = createMockNode('lexical_declaration', 'let y = 2;', [ + createMockNode('let', 'let'), + ]); + mockExtractNames.mockReturnValue(['y']); + + const evaluable = prepareEvaluation(state, statement); + evaluable.result.value = { y: 2 }; + evaluable.commit(); + + expect(state.consts).toStrictEqual({}); + expect(state.lets).toStrictEqual({ y: 2 }); + }); + + it('does not commit when result has error', () => { + const statement = createMockNode('lexical_declaration', 'const x = 1;', [ + createMockNode('const', 'const'), + ]); + mockExtractNames.mockReturnValue(['x']); + + const evaluable = prepareEvaluation(state, statement); + evaluable.result[ERROR] = new Error('test error'); + evaluable.result.value = { x: 1 }; + evaluable.commit(); + + expect(state.consts).toStrictEqual({}); + }); + + it('captures lets namespace on commit', () => { + state.lets = { existing: 'value' }; + const statement = createMockNode('lexical_declaration', 'const x = 1;', [ + createMockNode('const', 'const'), + ]); + mockExtractNames.mockReturnValue(['x']); + + const evaluable = prepareEvaluation(state, statement); + // Simulate captured namespace by directly testing the commit behavior + // The actual capture happens during evaluation, but we test the commit logic + evaluable.result.value = { x: 1 }; + evaluable.commit(); + + expect(state.consts).toStrictEqual({ x: 1 }); + }); + }); + + describe('function_declaration', () => { + it('prepares function declaration', () => { + const statement = createMockNode( + 'function_declaration', + 'function foo() { return 1; }', + ); + mockExtractNames.mockReturnValue(['foo']); + + const evaluable = prepareEvaluation(state, statement); + + expect(evaluable.endowments.$frame.$return).toBeDefined(); + expect(evaluable.code).toContain('function foo() { return 1; }'); + expect(evaluable.code).toContain('$return(foo);'); + }); + + it('prepares generator function declaration', () => { + const statement = createMockNode( + 'generator_function_declaration', + 'function* gen() { yield 1; }', + ); + mockExtractNames.mockReturnValue(['gen']); + + const evaluable = prepareEvaluation(state, statement); + + expect(evaluable.endowments.$frame.$return).toBeDefined(); + expect(evaluable.code).toContain('function* gen() { yield 1; }'); + }); + }); + + describe('expression_statement', () => { + it('prepares expression statement', () => { + const statement = createMockNode('expression_statement', '1 + 1;'); + + const evaluable = prepareEvaluation(state, statement); + + expect(evaluable.endowments.$frame.$return).toBeDefined(); + expect(evaluable.code).toContain('$return(1 + 1);'); + expect(evaluable.result.value).toBeUndefined(); + }); + + it('strips trailing semicolons from expression', () => { + const statement = createMockNode('expression_statement', '42;;;'); + + const evaluable = prepareEvaluation(state, statement); + + expect(evaluable.code).toContain('$return(42);'); + }); + + it('commits expression result', () => { + const statement = createMockNode('expression_statement', '42;'); + + const evaluable = prepareEvaluation(state, statement); + evaluable.result[RETURN] = 42; + evaluable.commit(); + + expect(evaluable.result[RETURN]).toBe(42); + }); + + it('throws when result has no return and no error', () => { + const statement = createMockNode('expression_statement', '42;'); + + const evaluable = prepareEvaluation(state, statement); + + expect(() => evaluable.commit()).toThrow( + 'Internal: Result is undefined but no error was thrown', + ); + }); + + it('does not throw when result has error', () => { + const statement = createMockNode('expression_statement', '42;'); + + const evaluable = prepareEvaluation(state, statement); + evaluable.result[ERROR] = new Error('test error'); + + expect(() => evaluable.commit()).not.toThrow(); + }); + }); + + describe('statement types', () => { + it.each([ + ['if_statement', 'if (true) { }'], + ['for_statement', 'for (let i = 0; i < 10; i++) { }'], + ['for_in_statement', 'for (let x in obj) { }'], + ['for_of_statement', 'for (let x of arr) { }'], + ['for_await_of_statement', 'for await (let x of asyncIter) { }'], + ['for_await_in_statement', 'for await (let x in asyncIter) { }'], + ['for_await_statement', 'for await (let x of asyncIter) { }'], + ['while_statement', 'while (true) { }'], + ['do_while_statement', 'do { } while (true);'], + ['switch_statement', 'switch (x) { case 1: break; }'], + ['try_statement', 'try { } catch (e) { }'], + ['catch_clause', 'catch (e) { }'], + ['finally_clause', 'finally { }'], + ])('prepares %s', (type, code) => { + const statement = createMockNode(type, code); + + const evaluable = prepareEvaluation(state, statement); + + expect(evaluable.endowments.$frame.$catch).toBeDefined(); + expect(evaluable.endowments.$frame.$capture).toBeDefined(); + expect(evaluable.endowments.$frame.$return).toBeUndefined(); + expect(evaluable.code).toContain(code); + }); + + it('commits statement without error', () => { + const statement = createMockNode('if_statement', 'if (true) { }'); + + const evaluable = prepareEvaluation(state, statement); + evaluable.commit(); + + expect(evaluable.result[ERROR]).toBeUndefined(); + }); + + it('does not commit when statement has error', () => { + const statement = createMockNode('if_statement', 'if (true) { }'); + + const evaluable = prepareEvaluation(state, statement); + evaluable.result[ERROR] = new Error('test error'); + evaluable.commit(); + + expect(evaluable.result[ERROR]).toBeDefined(); + }); + }); + + describe('error cases', () => { + it('throws for variable_declaration', () => { + const statement = createMockNode('variable_declaration', 'var x = 1;'); + + expect(() => prepareEvaluation(state, statement)).toThrow( + 'Variable declarations are not allowed', + ); + }); + + it('throws for import_statement', () => { + const statement = createMockNode( + 'import_statement', + 'import { x } from "module";', + ); + + expect(() => prepareEvaluation(state, statement)).toThrow( + 'Imports are not allowed', + ); + expect(() => prepareEvaluation(state, statement)).toThrow(SyntaxError); + }); + + it('throws for unknown statement type', () => { + const statement = createMockNode('unknown_type', 'unknown code'); + + expect(() => prepareEvaluation(state, statement)).toThrow( + 'Unknown statement type', + ); + }); + }); + + describe('with logger', () => { + let logger: Logger; + let infoSpy: ReturnType; + let subLoggerSpy: ReturnType; + + beforeEach(() => { + logger = new Logger('test'); + infoSpy = vi.spyOn(logger, 'info'); + subLoggerSpy = vi.fn().mockReturnValue({ info: infoSpy }); + logger.subLogger = subLoggerSpy; + }); + + it.each([ + ['const', 'const x = 1;', ['x'], { x: 1 }, 'const declaration:'], + ['let', 'let y = 2;', ['y'], { y: 2 }, 'let declaration:'], + ])( + 'logs commit for %s declaration', + (_, code, names, value, logMessage) => { + const statement = createMockNode('lexical_declaration', code, [ + createMockNode( + code.split(' ')[0] as string, + code.split(' ')[0] as string, + ), + ]); + mockExtractNames.mockReturnValue(names); + + const evaluable = prepareEvaluation(state, statement, { logger }); + evaluable.result.value = value; + evaluable.commit(); + + expect(subLoggerSpy).toHaveBeenCalledWith({ tags: ['commit'] }); + expect(infoSpy).toHaveBeenCalledWith(logMessage, value); + }, + ); + + it('logs error on commit', () => { + const statement = createMockNode('lexical_declaration', 'const x = 1;', [ + createMockNode('const', 'const'), + ]); + mockExtractNames.mockReturnValue(['x']); + + const evaluable = prepareEvaluation(state, statement, { logger }); + evaluable.result[ERROR] = new Error('test error'); + evaluable.commit(); + + expect(infoSpy).toHaveBeenCalledWith('result error:', expect.any(Error)); + }); + + it('logs expression return value', () => { + const statement = createMockNode('expression_statement', '42;'); + const evaluable = prepareEvaluation(state, statement, { logger }); + evaluable.result[RETURN] = 42; + evaluable.commit(); + + expect(infoSpy).toHaveBeenCalledWith('result return:', 42); + }); + }); + + describe('wrap function', () => { + it('wraps function to throw EvaluatorError on error', () => { + const statement = createMockNode('expression_statement', '42;'); + + const evaluable = prepareEvaluation(state, statement); + const { $catch } = evaluable.endowments.$frame; + + // $catch doesn't throw, it sets the error in result + $catch(new Error('test error')); + expect(evaluable.result[ERROR]).toBeInstanceOf(Error); + }); + + it('handles non-Error values', () => { + const statement = createMockNode('expression_statement', '42;'); + + const evaluable = prepareEvaluation(state, statement); + const { $catch } = evaluable.endowments.$frame; + + // $catch accepts any value and stores it + $catch('string error'); + expect(evaluable.result[ERROR]).toBe('string error'); + }); + }); + + describe('makeCaptor', () => { + it('creates captor that captures values', () => { + const statement = createMockNode('lexical_declaration', 'let x = 1;', [ + createMockNode('let', 'let'), + ]); + mockExtractNames.mockReturnValue(['x']); + + const evaluable = prepareEvaluation(state, statement); + const { $return } = evaluable.endowments.$frame; + + expect($return).toBeDefined(); + // The captor function is created with specific names, so we test it indirectly + // by verifying the structure is correct + expect(evaluable.result.value).toBeDefined(); + }); + + it('throws for reserved captor name', () => { + const statement = createMockNode('lexical_declaration', 'let x = 1;', [ + createMockNode('let', 'let'), + ]); + mockExtractNames.mockReturnValue(['$UNIQUE']); + + expect(() => prepareEvaluation(state, statement)).toThrow( + 'Captor name "$UNIQUE" is reserved', + ); + }); + }); + + describe('wrapAsyncEvaluation', () => { + it('wraps code with consts and lets', () => { + state.consts = { a: 1 }; + state.lets = { b: 2 }; + const statement = createMockNode('expression_statement', 'a + b;'); + + const evaluable = prepareEvaluation(state, statement); + + expect(evaluable.code).toContain('const { a } = consts;'); + expect(evaluable.code).toContain('let { b } = lets;'); + expect(evaluable.code).toContain('$capture(b);'); + }); + + it('wraps code without consts or lets', () => { + const statement = createMockNode('expression_statement', '42;'); + + const evaluable = prepareEvaluation(state, statement); + + // Check that destructuring for consts/lets is not present + // (the $frame destructuring is always present) + expect(evaluable.code).not.toContain('} = consts;'); + expect(evaluable.code).not.toContain('} = lets;'); + expect(evaluable.code).not.toMatch(/\$capture\([^)]+\);/u); + }); + + it('wraps code in async IIFE', () => { + const statement = createMockNode('expression_statement', '42;'); + + const evaluable = prepareEvaluation(state, statement); + + expect(evaluable.code).toMatch(/^\(async \(\) => \{/u); + expect(evaluable.code).toContain('await null;'); + expect(evaluable.code).toContain( + 'const { $capture, $catch, $return } = $frame;', + ); + expect(evaluable.code).toMatch(/\}\)\(\);$/u); + }); + + it('includes try-catch-finally block', () => { + const statement = createMockNode('expression_statement', '42;'); + + const evaluable = prepareEvaluation(state, statement); + + expect(evaluable.code).toContain('try {'); + expect(evaluable.code).toContain('} catch (e) {'); + expect(evaluable.code).toContain('$catch(e);'); + expect(evaluable.code).toContain('} finally {'); + }); + }); +}); diff --git a/packages/kernel-agents/src/strategies/repl/prepare-evaluation.ts b/packages/kernel-agents/src/strategies/repl/prepare-evaluation.ts new file mode 100644 index 000000000..c6726bb37 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/prepare-evaluation.ts @@ -0,0 +1,387 @@ +import { EvaluatorError } from '@metamask/kernel-errors'; +import type { Logger } from '@metamask/logger'; +import type { SyntaxNode } from 'tree-sitter'; + +import { extractNamesFromDeclaration } from './parse/identifiers.ts'; +import { ERROR, RETURN } from './symbols.ts'; +import type { EvaluatorState, VariableRecord } from './types.ts'; + +/** + * Creates a wrapper for a function that throws EvaluatorError if the wrapped function throws. + * This is used to wrap $return, $catch, and $capture to differentiate between internal and user errors. + * + * @param func - The function to wrap. + * @returns A hardened function that wraps the original function. + */ +const wrap = ( + func: (...args: Args) => void, +): ((...args: Args) => void) => { + return harden((...args: Args) => { + try { + func(...args); + } catch (error) { + throw new EvaluatorError( + 'REPL evaluation failed', + '', + error instanceof Error ? error : new Error(String(error)), + ); + } + }); +}; + +export type Evaluable = { + endowments: { + consts: VariableRecord; + lets: VariableRecord; + $frame: { + $catch: (caught: unknown) => void; + $capture: (...lets: never[]) => void; + $return?: (...values: never[]) => void; + }; + }; + code: string; + result: EvaluationResult; + commit: () => void; +}; + +export type EvaluationResult = { + value?: VariableRecord; + [ERROR]?: unknown; + [RETURN]?: unknown; +}; + +/** + * Wraps an async evaluation in an IIFE to be awaited outside the compartment. + * Assumes a compartment endowed with `{ consts, lets, $catch, $capture }` at + * least. + * + * TODO: Move this functionality to endojs/endo-evaluator + * + * @param args - The arguments to wrap the async evaluation. + * @param args.consts - The consts to destructure. + * @param args.lets - The lets to destructure. + * @param args.code - The code to evaluate. + * @returns Wrapped code ready to evaluate in a compartment endowed with `{ consts, lets, $catch, $capture }`. + */ +const wrapAsyncEvaluation = ({ + consts, + lets, + code, +}: { + consts: VariableRecord; + lets: VariableRecord; + code: string; +}): string => { + const constsKeys = Object.keys(consts); + const letsKeys = Object.keys(lets); + const destructureConsts = + constsKeys.length > 0 ? `const { ${constsKeys.join(',')} } = consts;` : ''; + const destructureLets = + letsKeys.length > 0 ? `let { ${letsKeys.join(',')} } = lets;` : ''; + // The let namespace can be arbitrarily mutated by the statement; the best + // detection is captureion. + const captureLets = + letsKeys.length > 0 ? `$capture(${letsKeys.join(',')});` : ''; + // Async IIFE, to be awaited outside the compartment. We are 'vulnerable' to + // return statements, but we only await whatever is returned; we don't read + // the value. We can also prevent top level return via parsing. + return `(async () => { + await null; + const { $capture, $catch, $return } = $frame; + ${destructureConsts} + ${destructureLets} + try { + ${code} + } catch (e) { + $catch(e); + } finally { + ${captureLets} + } + })();`; +}; + +/** + * Make a captor function that captures names from lexical scope into a record. + * + * Building a function factory from source permits captureion of the + * arguments as individual variables while the record to which they are + * assigned is a reference not endowed to the compartment. + * + * The returned function is wrapped with makeSafe to detect internal errors. + * + * @param names - The names to capture. + * @returns A tuple containing the record and a safe-wrapped function that captures the names into the record. + */ +const makeCaptor = ( + names: string[], +): [VariableRecord, (...names: string[]) => void] => { + const $value = '$UNIQUE'; + if (names.includes($value)) { + throw new Error(`Captor name "${$value}" is reserved`); + } + const value: VariableRecord = {}; + const namespace = names.join(','); + // We use eval safely by constructing the function with care and only + // ever evaluating it in a compartment. + // eslint-disable-next-line @typescript-eslint/no-implied-eval, no-new-func + const captor = Function( + $value, + `return (${namespace}) => void Object.assign(${$value}, { ${namespace} });`, + )(value); + return [value, wrap(captor)]; +}; + +export const prepareEvaluation = ( + state: EvaluatorState, + statement: SyntaxNode, + options: { logger?: Logger } = {}, +): Evaluable => { + switch (statement.type) { + case 'lexical_declaration': + switch (statement.children[0]?.type) { + case 'const': + return prepareImmutableDeclaration(state, statement, options); + case 'let': + return prepareMutableDeclaration(state, statement, options); + case undefined: + default: + throw new Error( + [ + `Unknown lexical_declaration.`, + `statement: ${statement.text}`, + `type: ${statement.toString()}`, + ].join('\n'), + ); + } + case 'function_declaration': + case 'generator_function_declaration': + return prepareMutableDeclaration(state, statement, options); + case 'variable_declaration': + throw new Error( + `Variable declarations are not allowed: "${statement.text}"`, + ); + case 'expression_statement': + return prepareExpression(state, statement, options); + case 'import_statement': + throw new SyntaxError( + 'Imports are not allowed. All accessible capabilities are already imported.', + ); + case 'if_statement': + case 'for_statement': + case 'for_in_statement': + case 'for_of_statement': + case 'for_await_of_statement': + case 'for_await_in_statement': + case 'for_await_statement': + case 'while_statement': + case 'do_while_statement': + case 'switch_statement': + case 'try_statement': + case 'catch_clause': + case 'finally_clause': + // XXX The above case selector is probably long enough to be the default + return prepareStatement(state, statement, options); + default: + throw new Error( + [ + `Unknown statement type.`, + `statement: ${statement.text}`, + `type: ${statement.toString()}`, + ].join('\n'), + ); + } +}; + +/** + * Prepare a declaration for evaluation. + * + * @param state - The evaluator state. + * @param statement - The declaration to prepare. + * @returns The prepared declaration. + */ +function prepareDeclaration( + state: EvaluatorState, + statement: SyntaxNode, +): Omit & { captured: VariableRecord } { + const { consts, lets } = state; + const [captured, $capture] = makeCaptor(Object.keys(lets)); + const names = extractNamesFromDeclaration(statement); + const [value, $return] = makeCaptor(names); + const result: EvaluationResult = { value }; + const $catch = wrap((caught: unknown) => (result[ERROR] = caught)); + return { + endowments: { consts, lets, $frame: { $capture, $catch, $return } }, + code: wrapAsyncEvaluation({ + consts, + lets, + code: `${statement.text};$return(${names.join(',')});`, + }), + result, + captured, + }; +} + +/** + * Prepare a mutable declaration (let or function declaration) for evaluation. + * + * @param state - The evaluator state. + * @param statement - The declaration to prepare. + * @param options - The options. + * @param options.logger - The logger. + * @returns The prepared declaration. + */ +function prepareMutableDeclaration( + state: EvaluatorState, + statement: SyntaxNode, + options: { logger?: Logger } = {}, +): Evaluable { + const { endowments, code, result, captured } = prepareDeclaration( + state, + statement, + ); + const commitLogger = options.logger?.subLogger({ tags: ['commit'] }); + return { + endowments, + code, + result, + commit: () => { + commitLogger?.info('captured namespace:', captured); + Object.assign(state.lets, captured); + if (result[ERROR]) { + commitLogger?.info('result error:', result[ERROR]); + return; + } + commitLogger?.info('let declaration:', result.value); + Object.assign(state.lets, result.value); + }, + }; +} + +/** + * Prepare an immutable declaration (const declaration) for evaluation. + * + * @param state - The evaluator state. + * @param statement - The declaration to prepare. + * @param options - The options. + * @param options.logger - The logger. + * @returns The prepared declaration. + */ +function prepareImmutableDeclaration( + state: EvaluatorState, + statement: SyntaxNode, + options: { logger?: Logger } = {}, +): Evaluable { + const { endowments, code, result, captured } = prepareDeclaration( + state, + statement, + ); + const commitLogger = options.logger?.subLogger({ tags: ['commit'] }); + return { + endowments, + code, + result, + commit: () => { + commitLogger?.info('captured namespace:', captured); + Object.assign(state.lets, captured); + if (result[ERROR]) { + commitLogger?.info('result error:', result[ERROR]); + return; + } + commitLogger?.info('const declaration:', result.value); + Object.assign(state.consts, result.value); + }, + }; +} + +/** + * Strips any trailing semicolons from the code. + * + * @param code - The code to strip the trailing semicolons from. + * @returns The code without the trailing semicolons. + */ +const stripTrailingSemicolons = (code: string): string => + code.trimEnd().endsWith(';') + ? stripTrailingSemicolons(code.trimEnd().slice(0, -1)) + : code.trimEnd(); + +/** + * Prepare an expression for evaluation. + * + * @param state - The evaluator state. + * @param statement - The expression to prepare. + * @param options - The options. + * @param options.logger - The logger. + * @returns The prepared expression. + */ +function prepareExpression( + state: EvaluatorState, + statement: SyntaxNode, + options: { logger?: Logger } = {}, +): Evaluable { + const { consts, lets } = state; + const [captured, $capture] = makeCaptor(Object.keys(lets)); + const result: EvaluationResult = {}; + const $return = wrap((value: unknown) => (result[RETURN] = value)); + const $catch = wrap((caught: unknown) => (result[ERROR] = caught)); + const commitLogger = options.logger?.subLogger({ tags: ['commit'] }); + return { + endowments: { consts, lets, $frame: { $capture, $catch, $return } }, + code: wrapAsyncEvaluation({ + consts, + lets, + code: `$return(${stripTrailingSemicolons(statement.text)});`, + }), + result, + commit: () => { + commitLogger?.info('captured namespace:', captured); + Object.assign(state.lets, captured); + if (result[ERROR]) { + commitLogger?.info('result error:', result[ERROR]); + return; + } + commitLogger?.info('result return:', result[RETURN]); + if (!(RETURN in result)) { + throw new Error( + 'Internal: Result is undefined but no error was thrown', + ); + } + }, + }; +} + +/** + * Prepare an arbitrary statement for evaluation. + * + * @param state - The evaluator state. + * @param statement - The statement to prepare. + * @param options - The options. + * @param options.logger - The logger. + * @returns The prepared statement. + */ +function prepareStatement( + state: EvaluatorState, + statement: SyntaxNode, + options: { logger?: Logger } = {}, +): Evaluable { + const { consts, lets } = state; + const [captured, $capture] = makeCaptor(Object.keys(lets)); + const result: EvaluationResult = {}; + const $catch = wrap((caught: unknown) => (result[ERROR] = caught)); + const commitLogger = options.logger?.subLogger({ tags: ['commit'] }); + return { + endowments: { consts, lets, $frame: { $capture, $catch } }, + code: wrapAsyncEvaluation({ + consts, + lets, + code: statement.text, + }), + result, + commit: () => { + commitLogger?.info('captured namespace:', captured); + Object.assign(state.lets, captured); + if (result[ERROR]) { + commitLogger?.info('result error:', result[ERROR]); + } + }, + }; +} diff --git a/packages/kernel-agents/src/strategies/repl/printer.ts b/packages/kernel-agents/src/strategies/repl/printer.ts new file mode 100644 index 000000000..604d9df80 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/printer.ts @@ -0,0 +1,35 @@ +import type { Logger } from '@metamask/logger'; + +import { ResultMessage } from './messages.ts'; +import type { ReplTranscript, StatementMessage } from './messages.ts'; + +export const makePrinter = ({ + history, + logger, +}: { + history: ReplTranscript; + logger?: Logger; +}) => { + // Render initial state + for (const message of history) { + if (message instanceof ResultMessage) { + const lines = message + .toReplString() + .split('\n') + .filter( + (line) => line.trim() === line || line.startsWith(' "description"'), + ); + if (lines && lines?.length > 0) { + logger?.info(lines?.join('\n')); + } + continue; + } + logger?.info(message.toReplString()); + } + return (statement: StatementMessage, result: ResultMessage | null) => { + logger?.info(statement.toReplString()); + if (result) { + logger?.info(result.toReplString()); + } + }; +}; diff --git a/packages/kernel-agents/src/strategies/repl/prompter.test.ts b/packages/kernel-agents/src/strategies/repl/prompter.test.ts new file mode 100644 index 000000000..b4d524334 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/prompter.test.ts @@ -0,0 +1,29 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import { describe, it, expect } from 'vitest'; + +import { makePrompter } from './prompter.ts'; + +describe('makePrompter', () => { + it('makes the expected prompt', () => { + const prompter = makePrompter({}); + const { prompt, readerArgs } = prompter([]); + expect(typeof readerArgs.stop).toBe('string'); + expect(prompt).toContain( + '> import { end, search } from "@ocap/abilities";', + ); + expect(prompt).toContain( + `end: {\n "description": "Return a final response to the user.",`, + ); + expect(prompt).toContain( + `search: {\n "description": "Search the web for information.",`, + ); + expect(prompt).toContain('! What is the oldest tree in South America?'); + expect(prompt).toContain( + '> // This information is too specific for me to know on my own.', + ); + expect(prompt).toContain( + "> await search({ query: 'oldest tree in South America' });", + ); + }); +}); diff --git a/packages/kernel-agents/src/strategies/repl/prompter.ts b/packages/kernel-agents/src/strategies/repl/prompter.ts new file mode 100644 index 000000000..df0beb62f --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/prompter.ts @@ -0,0 +1,77 @@ +import { exampleTranscripts } from './example-transcripts.ts'; +import type { ReplTranscript } from './messages.ts'; +import { makeRandom } from './random.ts'; +import { ifDefined } from '../../utils.ts'; + +const makePreamble = ( + nTranscripts: number, + wrapWithToken: (text: string) => string, +): string => { + const firstLinePrefix = + nTranscripts === 1 + ? 'The following is a transcript of a javascript REPL session environment' + : `The following are ${nTranscripts} transcripts of javascript REPL session environments`; + return [ + `${firstLinePrefix} controlled by a state-of-the-art capability-augmented computer assistant.`, + `The assistant responds to user interjections by invoking capabilities to perform tasks.`, + `The actions and observations in the transcript environment are wrapped in a line identifier, like ${wrapWithToken('> [action]')}.`, + `Agent actions take the form of javascript statements, like ${wrapWithToken('> let x = 1;')}, ${wrapWithToken('> // I can solve this problem by...')} or ${wrapWithToken('> await search({ query: "eip-1559" });')}.`, + `Observations are either evaluation results like ${wrapWithToken('{ "cost": 508 }')} or user interjections like ${wrapWithToken('! Merge and normalize these datasets.')} or ${wrapWithToken("! Don't schedule anything for Wednesday; I'm busy.")}.`, + 'Each transcript ends with an invocation of the end capability.', + `Note that the assistant efficiently invokes capabilities to perform tasks. This reflects that the assistant is intelligent and can reason logically about function composition, and prefers to invoke external capabilities to prove the correctness of its answers.`, + `Also note that, although the assistant does not necessarily use every available capability, it never attempts to use a capability that was not specified prior in the transcript.`, + ].join('\n'); +}; + +/** + * The Repl Prompter pieces together repl-like representation of message, + * wrapped in a token pair. + * e.g. + * ``` + * > let x = 1; + * x: 1 + * > x += 1; + * x: 2 + * > await end({ final: String(x) }); + * ``` + * + * @param args - The arguments to make the prompter. + * @param args.seed - The seed to use for the random number generator. + * @param args.tokenLength - The length of the token to use for the prompt. + * @returns A prompter function. + */ +export const makePrompter = + ({ seed, tokenLength = 6 }: { seed?: number; tokenLength?: number }) => + ( + history: ReplTranscript, + ): { + prompt: string; + readerArgs: { stop: string }; + } => { + // The random number generator is seeded with a fixed value so that the + // same prompt is generated for the same history. + // Language model otherwise slow. + const random = makeRandom(ifDefined({ seed })); + const makeTokenPair = (): [string, string] => { + const token = random(tokenLength); + return [`〚${token}〛`, `〚/${token}〛`]; + }; + const wrapWithToken = (text: string): string => { + const [open, close] = makeTokenPair(); + return `${open}${text}${close}`; + }; + const transcripts = [...exampleTranscripts, history]; + const rawPrompt = [ + makePreamble(transcripts.length, wrapWithToken), + ...transcripts.map((transcript, index) => + [ + `REPL ${index + 1}:`, + '---', + ...transcript.map((message) => wrapWithToken(message.toReplString())), + ].join('\n'), + ), + ].join('\n\n'); + const [open, stop] = makeTokenPair(); + const prompt = `${rawPrompt}\n${open}>`; + return { prompt, readerArgs: { stop } }; + }; diff --git a/packages/kernel-agents/src/strategies/repl/random.ts b/packages/kernel-agents/src/strategies/repl/random.ts new file mode 100644 index 000000000..8346682c3 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/random.ts @@ -0,0 +1,39 @@ +const MAX_LENGTH = 8; + +/** + * A simple seeded random number generator + * Not cryptographically secure. + * + * @param a - The seed to use for the PRNG. + * @returns A PsuedoRandomNumberGenerator. + */ +/* eslint-disable */ // CTRL+V implementation from https://stackoverflow.com/a/47593316/1123955 +const mulberry32 = (a: number): number => { + let t = (a += 0x6d2b79f5); + t = Math.imul(t ^ (t >>> 15), t | 1); + t ^= t + Math.imul(t ^ (t >>> 7), t | 61); + return (t ^ (t >>> 14)) >>> 0; +}; +/* eslint-enable */ + +/** + * Make a PsuedoRandomNumberGenerator. + * + * @param args - The arguments to make the PRNG. + * @param args.seed - The seed to use for the PRNG. + * @returns A PsuedoRandomNumberGenerator. + */ +export const makeRandom = ({ seed }: { seed?: number }) => { + let _seed = seed ?? 1; + return (length: number = MAX_LENGTH, radix: number = 16): string => { + if (length > MAX_LENGTH) { + throw new Error(`Length must be less than or equal to ${MAX_LENGTH}`); + } + // Get a random 32-bit unsigned integer and update the seed + _seed = mulberry32(_seed); + + // Convert to hex and slice to desired length + // padStart ensures we have leading zeros if needed + return _seed.toString(radix).padStart(length, '0').slice(0, length); + }; +}; diff --git a/packages/kernel-agents/src/strategies/repl/reader.test.ts b/packages/kernel-agents/src/strategies/repl/reader.test.ts new file mode 100644 index 000000000..9678e5b3b --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/reader.test.ts @@ -0,0 +1,54 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import type { Logger } from '@metamask/logger'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; + +import { makeReader } from './reader.ts'; +import { makeTestStream } from '../../test-utils.ts'; + +describe('reader', () => { + let logger: Logger; + + beforeEach(() => { + logger = { + info: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + debug: vi.fn(), + subLogger: vi.fn(() => logger), + } as unknown as Logger; + }); + + const transform = (statement: string) => ({ response: statement }); + + it('reads a statement from a stream', async () => { + const stop = ''; + const { stream, abort } = makeTestStream( + [`console.log("hello");${stop}`], + transform, + ); + const reader = makeReader({ logger }); + const statement = await reader({ stream, abort, stop }); + expect(statement.toReplString()).toBe('> console.log("hello");'); + }); + + it('throws an error if the stream has no stop token', async () => { + const { stream, abort } = makeTestStream( + [`console.log("hello");`], + transform, + ); + const reader = makeReader({ logger }); + await expect(reader({ stream, abort, stop: '' })).rejects.toThrow( + 'Stream ended without a parse event', + ); + }); + + it('throws an error if the stream is empty', async () => { + const stop = ''; + const { stream, abort } = makeTestStream([], transform); + const reader = makeReader({ logger }); + await expect(reader({ stream, abort, stop })).rejects.toThrow( + 'Stream ended without a parse event', + ); + }); +}); diff --git a/packages/kernel-agents/src/strategies/repl/reader.ts b/packages/kernel-agents/src/strategies/repl/reader.ts new file mode 100644 index 000000000..3de7af7ab --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/reader.ts @@ -0,0 +1,32 @@ +import type { Logger } from '@metamask/logger'; + +import type { StatementMessage } from './messages.ts'; +import { makeSampleCollector } from './sample-collector.ts'; +import { ifDefined, withAbort, gatherStreamingResponse } from '../../utils.ts'; + +export const makeReader = + ({ logger }: { logger?: Logger }) => + async ({ + stream, + abort, + stop, + }: { + stream: AsyncIterable<{ response: string }>; + abort: () => Promise; + stop: string; + }) => { + const sampleLogger = logger?.subLogger({ tags: ['sample'] }); + const gatherLogger = logger?.subLogger({ tags: ['gather'] }); + return await withAbort( + abort, + async (): Promise => + await gatherStreamingResponse({ + stream, + parse: makeSampleCollector({ + stop, + ...ifDefined({ logger: sampleLogger }), + }), + ...ifDefined({ logger: gatherLogger }), + }), + ); + }; diff --git a/packages/kernel-agents/src/strategies/repl/sample-collector.ts b/packages/kernel-agents/src/strategies/repl/sample-collector.ts new file mode 100644 index 000000000..4bc025c22 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/sample-collector.ts @@ -0,0 +1,52 @@ +import { SampleGenerationError } from '@metamask/kernel-errors'; +import type { Logger } from '@metamask/logger'; + +import { StatementMessage } from './messages.ts'; +import type { SampleCollector } from '../../types.ts'; + +/** + * A simple sample collector that collects content from a stream until a stop + * string is encountered and returns the content prior, formatted. + * + * XXX This functionality is typically available in the language model service, + * but a reimplementation appears here to remain agnostic to the service API. + * + * @param args - The arguments to make the sample collector. + * @param args.stop - The stop string to stop collection. + * @param args.maxChunkCount - The maximum number of chunks to collect before + * throwing an error. + * @param args.logger - The logger to use for the sample collector. + * @returns A function that collects a delta of a streaming response, returning a + * StatementMessage if the stop string is encountered or null otherwise. + */ +export const makeSampleCollector = ({ + stop, + maxChunkCount = 200, + logger, +}: { + stop: string; + maxChunkCount?: number; + logger?: Logger; +}): SampleCollector => { + let buffer = ''; + let chunkCount = 0; + return (delta: string) => { + chunkCount += 1; + buffer += delta; + if (buffer.includes(stop)) { + const [content] = buffer.split(stop); + if (content === undefined || content.trim() === '') { + throw new SampleGenerationError(buffer, new Error('Empty content')); + } + logger?.info('content:', content); + return StatementMessage.fromCode(content.trim()); + } + if (maxChunkCount && chunkCount > maxChunkCount) { + throw new SampleGenerationError( + buffer, + new Error('Max chunk count reached'), + ); + } + return null; + }; +}; diff --git a/packages/kernel-agents/src/strategies/repl/symbols.ts b/packages/kernel-agents/src/strategies/repl/symbols.ts new file mode 100644 index 000000000..9f711446b --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/symbols.ts @@ -0,0 +1,3 @@ +// Symbols for special keys in the result object. +export const RETURN = Symbol('RETURN'); +export const ERROR = Symbol('ERROR'); diff --git a/packages/kernel-agents/src/strategies/repl/types.ts b/packages/kernel-agents/src/strategies/repl/types.ts new file mode 100644 index 000000000..319802c63 --- /dev/null +++ b/packages/kernel-agents/src/strategies/repl/types.ts @@ -0,0 +1,7 @@ +// Todo: support number & symbol keys +export type VariableRecord = Record; + +export type EvaluatorState = { + consts: VariableRecord; + lets: VariableRecord; +}; diff --git a/packages/kernel-agents/src/task.test.ts b/packages/kernel-agents/src/task.test.ts new file mode 100644 index 000000000..e53c01b29 --- /dev/null +++ b/packages/kernel-agents/src/task.test.ts @@ -0,0 +1,11 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import { describe, expect, it } from 'vitest'; + +import { defaultJudgment } from './task.ts'; + +describe('defaultJudgment', () => { + it('returns true', () => { + expect(defaultJudgment(1)).toBe(true); + }); +}); diff --git a/packages/kernel-agents/src/task.ts b/packages/kernel-agents/src/task.ts new file mode 100644 index 000000000..385305c8c --- /dev/null +++ b/packages/kernel-agents/src/task.ts @@ -0,0 +1,61 @@ +import { makeCounter } from '@metamask/kernel-utils'; + +import type { Task, CapabilityRecord } from './types.ts'; + +/** + * A trivial judgment that always returns true. + * + * @param _result - The result to judge. + * @returns True. + */ +export const defaultJudgment = (_result: unknown): _result is Result => + true; + +const formatTaskId = (count: number): string => + `t${count.toString().padStart(3, '0')}`; + +export class TaskManager { + readonly #tasks: Task[] = []; + + readonly #taskCounter = makeCounter(); + + /** + * Specify a task. + * + * @param args - The arguments to specify the task. + * @param args.intent - A specification of the task to be performed, or a query to be answered. + * @param args.judgment - The function to determine if the task is complete. + * @param args.capabilities - The capabilities available to the task - revocable. + * @param args.knowledge - The knowledge available to the task - irrevocable. + * @returns A task. + */ + makeTask({ + intent, + judgment = defaultJudgment, + capabilities = {}, + knowledge = {}, + }: { + intent: string; + judgment?: (result: unknown) => result is Result; + capabilities?: CapabilityRecord; + knowledge?: Record; + }): Task { + const task: Task = { + id: formatTaskId(this.#taskCounter()), + objective: { intent, judgment }, + context: { knowledge, capabilities }, + attempts: [], + }; + this.#tasks.push(task); + return task; + } + + /** + * Get the tasks managed by the task manager. + * + * @returns The tasks. + */ + get tasks(): Task[] { + return [...this.#tasks]; + } +} diff --git a/packages/kernel-agents/src/test-utils.ts b/packages/kernel-agents/src/test-utils.ts new file mode 100644 index 000000000..f9099ff84 --- /dev/null +++ b/packages/kernel-agents/src/test-utils.ts @@ -0,0 +1,31 @@ +/** + * Make a test stream. + * + * @param statements - The statements to yield. + * @param transform - A function to transform the statements. + * @returns A stream of statements. + * @example + * const stream = makeTestStream(['console.log("hello");', 'console.log("world");']); + * for await (const statement of stream) { + * console.log(statement); + * } + */ +export const makeTestStream = ( + statements: string[], + transform = (statement: string): Yield => statement as Yield, +): { stream: AsyncIterable; abort: () => Promise } => { + let shouldAbort = false; + return { + abort: async () => { + shouldAbort = true; + }, + stream: (async function* () { + for (const statement of statements) { + if (shouldAbort) { + break; + } + yield transform(statement); + } + })(), + }; +}; diff --git a/packages/kernel-agents/src/types.ts b/packages/kernel-agents/src/types.ts index 49e58f8a7..82d53a5d3 100644 --- a/packages/kernel-agents/src/types.ts +++ b/packages/kernel-agents/src/types.ts @@ -1,65 +1,22 @@ -import type { Transcript } from './messages.ts'; - -export type JsonSchema = - | PrimitiveJsonSchema - | ArrayJsonSchema - | ObjectJsonSchemaProperty; - -type PrimitiveJsonSchema = { - type: 'string' | 'number' | 'boolean'; - description?: string; -}; - -type ArrayJsonSchema = { - type: 'array'; - description?: string; - item: JsonSchema; -}; - -type ObjectJsonSchemaProperty = { - type: 'object'; - description?: string; - properties: { - [key: string]: JsonSchema; - }; - required?: string[]; - additionalProperties?: boolean; -}; - -export type Capability, Return = null> = ( - args: Args, -) => Promise; - -export type CapabilitySchema = { - description: string; - args: Record; - returns?: JsonSchema; -}; - -export type ExtractRecordKeys = - Rec extends Record ? Key : never; - -export type CapabilitySpec< - Args extends Record = Record, - Return = void, -> = { - func: Capability; - schema: CapabilitySchema>; -}; - -export type CapabilityRecord = Record< - Keys, - CapabilitySpec ->; - -export type Agent = { - task: ( - prompt: string, - options?: { invocationBudget?: number }, - ) => Promise; -}; - -export type Chat = { - getPromptAndPrefix: () => { prompt: string; prefix: string }; - pushMessages: (...messages: Transcript) => void; -}; +export type { + Capability, + CapabilityRecord, + CapabilitySchema, + CapabilitySpec, +} from './types/capability.ts'; +export type { TaskArgs } from './types/task.ts'; +export type { + Task, + Objective, + Context, + Attempt, + Experience, +} from './types/task.ts'; +export type { Message, Transcript } from './types/messages.ts'; +export type { + Agent, + PREP, + Progress, + SampleCollector, + PrepareAttempt, +} from './types/agent.ts'; diff --git a/packages/kernel-agents/src/types/agent.ts b/packages/kernel-agents/src/types/agent.ts new file mode 100644 index 000000000..8d6e674b5 --- /dev/null +++ b/packages/kernel-agents/src/types/agent.ts @@ -0,0 +1,73 @@ +import type { Logger } from '@metamask/logger'; + +import type { Message, MessageTypeBase } from './messages.ts'; +import type { Context, Experience, Objective } from './task.ts'; + +export type Agent = { + task: ( + intent: string, + judgment?: (result: unknown) => result is Result, + options?: { invocationBudget?: number; logger?: Logger }, + ) => Promise; + get experiences(): AsyncIterable; +}; + +export type SampleCollector = ( + delta: string, +) => Result | null; + +export type Prompter[]> = ( + state: State, +) => { + prompt: string; + readerArgs?: Record; +}; + +export type Reader> = (args: { + // This can be threaded with the stream type from the language model. + stream: AsyncIterable<{ response: string }>; + abort: () => Promise; +}) => Promise; + +export type Evaluator< + State extends Message[], + Action extends Message, + Observation extends Message, +> = (state: State, action: Action) => Promise; + +export type Printer< + Action extends Message, + Observation extends Message, +> = (action: Action, observation: Observation | null) => void; + +export type PREP< + State extends Message[], + Action extends Message, + Observation extends Message, +> = [ + Prompter, + Reader, + Evaluator, + Printer, +]; + +export type Progress[]> = { + history: History; + isDone: () => boolean; + result?: Result; +}; + +export type PrepareAttempt< + // The agent's environment. + State extends Message[], + Action extends Message, + Observation extends Message, + // The user's expectation. +> = (args: { + objective: Objective; + context: Context; + options?: { + taskLogger?: Logger; + printLogger?: Logger; + }; +}) => [PREP, Progress]; diff --git a/packages/kernel-agents/src/types/capability.ts b/packages/kernel-agents/src/types/capability.ts new file mode 100644 index 000000000..c16c41d59 --- /dev/null +++ b/packages/kernel-agents/src/types/capability.ts @@ -0,0 +1,27 @@ +import type { JsonSchema } from './json-schema.ts'; + +export type Capability, Return = null> = ( + args: Args, +) => Promise; + +export type CapabilitySchema = { + description: string; + args: Record; + returns?: JsonSchema; +}; + +export type ExtractRecordKeys = + Rec extends Record ? Key : never; + +export type CapabilitySpec< + Args extends Record = Record, + Return = void, +> = { + func: Capability; + schema: CapabilitySchema>; +}; + +export type CapabilityRecord = Record< + Keys, + CapabilitySpec +>; diff --git a/packages/kernel-agents/src/types/json-schema.ts b/packages/kernel-agents/src/types/json-schema.ts new file mode 100644 index 000000000..5c3574cc3 --- /dev/null +++ b/packages/kernel-agents/src/types/json-schema.ts @@ -0,0 +1,25 @@ +export type JsonSchema = + | PrimitiveJsonSchema + | ArrayJsonSchema + | ObjectJsonSchemaProperty; + +type PrimitiveJsonSchema = { + type: 'string' | 'number' | 'boolean'; + description?: string; +}; + +type ArrayJsonSchema = { + type: 'array'; + description?: string; + item: JsonSchema; +}; + +type ObjectJsonSchemaProperty = { + type: 'object'; + description?: string; + properties: { + [key: string]: JsonSchema; + }; + required?: string[]; + additionalProperties?: boolean; +}; diff --git a/packages/kernel-agents/src/types/messages.ts b/packages/kernel-agents/src/types/messages.ts new file mode 100644 index 000000000..f4ab24231 --- /dev/null +++ b/packages/kernel-agents/src/types/messages.ts @@ -0,0 +1,20 @@ +export type MessageTypeBase = string; + +export abstract class Message< + Type extends MessageTypeBase, + Body extends Record = Record, +> { + messageType: Type; + + messageBody: Body; + + constructor(messageType: Type, messageBody: Body) { + this.messageType = messageType; + this.messageBody = messageBody; + } +} + +export type Transcript< + MessageTypes extends MessageTypeBase, + Interface = unknown, +> = (Message & Interface)[]; diff --git a/packages/kernel-agents/src/types/task.ts b/packages/kernel-agents/src/types/task.ts new file mode 100644 index 000000000..ae99d5867 --- /dev/null +++ b/packages/kernel-agents/src/types/task.ts @@ -0,0 +1,75 @@ +import type { Logger } from '@metamask/logger'; + +import type { CapabilityRecord } from './capability.ts'; +import type { Message, MessageTypeBase, Transcript } from './messages.ts'; + +export type Task = { + id: string; + objective: Objective; + context: Context; + attempts: Attempt[]; +}; + +/** + * A specification of what a user wants from an agent. + */ +export type Objective = { + intent: string; + // For wonky cases, this criterion can be satisfied by assignment. + judgment: (result: unknown) => result is Result; +}; + +/** + * A specification of the context in which an agent is operating. + */ +export type Context = { + capabilities: CapabilityRecord; + knowledge?: Record; +}; + +/** + * An experience of an agent fulfilling an objective in a particular context. + */ +export type Experience = { + objective: Objective; + context: Context; + history: Message[]; +} & ( + | { + result?: unknown; + error?: never; + } + | { + result?: never; + error?: Error; + } +); + +/** + * An attempt by an agent to fulfill an objective in a particular context. + * Organized for the agent's learning process. + */ +export type Attempt< + Result, + Action extends string, + Observation extends string, +> = { + history: Transcript; +} & ( + | { + result?: Result; + error?: never; + } + | { + result?: never; + error?: Error; + } +); + +export type TaskArgs = { + logger?: Logger; + seed?: number; + invocationBudget?: number; + capabilities?: CapabilityRecord; + nAttempts?: number; +}; diff --git a/packages/kernel-agents/src/utils.test.ts b/packages/kernel-agents/src/utils.test.ts new file mode 100644 index 000000000..882a475e3 --- /dev/null +++ b/packages/kernel-agents/src/utils.test.ts @@ -0,0 +1,102 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import { describe, it, expect, vi } from 'vitest'; + +import { makeTestStream } from './test-utils.ts'; +import { + gatherStreamingResponse, + ifDefined, + withAbort, + withRetries, +} from './utils.ts'; + +describe('ifDefined', () => { + it('removes undefined values', () => { + expect(ifDefined({ a: 1, b: undefined, c: 3 })).toStrictEqual({ + a: 1, + c: 3, + }); + }); +}); + +describe('withAbort', () => { + it('calls abort on success', async () => { + const abort = vi.fn(); + const func = vi.fn(); + await withAbort(abort, func); + expect(abort).toHaveBeenCalled(); + }); + + it('calls abort on error', async () => { + const abort = vi.fn(); + const func = vi.fn().mockRejectedValue(new Error('test')); + await expect(withAbort(abort, func)).rejects.toThrow('test'); + expect(abort).toHaveBeenCalled(); + }); +}); + +const makeTestParser = (chunks: string[], finishOn: number) => { + let count = 0; + return (_: string) => { + count += 1; + if (count >= finishOn) { + return JSON.parse(chunks.slice(0, count).join('')); + } + return null; + }; +}; + +describe('gatherStreamingResponse', () => { + const asResponse = (response: string) => ({ response }); + + const prepareStreamAndParse = (chunks: string[], finishOn: number) => { + const { stream } = makeTestStream(chunks, asResponse); + const parse = makeTestParser(chunks, finishOn); + return { stream, parse }; + }; + + it('gathers complete response from single chunk', async () => { + const chunks = ['{"key": "value"}']; + const { stream, parse } = prepareStreamAndParse(chunks, 1); + const result = await gatherStreamingResponse({ stream, parse }); + expect(result).toStrictEqual({ key: 'value' }); + }); + + it('gathers response from multiple chunks', async () => { + const chunks = ['{"key": "val', 'ue", "content": 42}']; + const { stream, parse } = prepareStreamAndParse(chunks, 2); + const result = await gatherStreamingResponse({ stream, parse }); + expect(result).toStrictEqual({ key: 'value', content: 42 }); + }); + + it('throws error when stream ends without parse event', async () => { + const chunks = ['incomplete json']; + const { stream, parse } = prepareStreamAndParse(chunks, 2); + await expect(gatherStreamingResponse({ stream, parse })).rejects.toThrow( + 'Stream ended without a parse event', + ); + }); +}); + +describe('withRetries', () => { + it('retries a function', async () => { + const func = vi + .fn() + .mockRejectedValueOnce(new Error('test')) + .mockResolvedValueOnce('result'); + const result = await withRetries(func, 2); + expect(result).toBe('result'); + }); + + it('throws an error if the function fails after all retries', async () => { + const func = vi.fn().mockRejectedValue(new Error('test')); + await expect(async () => withRetries(func, 2)).rejects.toThrow('test'); + }); + + it('throws an error if the function throws an error that is not retryable', async () => { + const func = vi.fn().mockRejectedValue(new Error('test')); + await expect(async () => withRetries(func, 2, () => false)).rejects.toThrow( + 'test', + ); + }); +}); diff --git a/packages/kernel-agents/src/utils.ts b/packages/kernel-agents/src/utils.ts new file mode 100644 index 000000000..6112bfe2b --- /dev/null +++ b/packages/kernel-agents/src/utils.ts @@ -0,0 +1,104 @@ +import type { Logger } from '@metamask/logger'; + +import type { SampleCollector } from './types.ts'; + +/** + * Return a new object with the undefined values removed. + * + * @param record - The record to filter. + * @returns The new object with the undefined values removed. + */ +// eslint-disable-next-line @typescript-eslint/explicit-function-return-type +export const ifDefined = (record: Record) => + Object.fromEntries( + Object.entries(record).filter(([_, value]) => value !== undefined), + ); + +/** + * Await a promise, and call the abort callback when done or on error. + * + * @param abort - The function to call to abort the operation. + * @param func - The function to call to perform the operation. + * @returns The result of the operation. + */ +export const withAbort = async ( + abort: () => Promise, + func: () => Promise, +): Promise => { + try { + return await func(); + } finally { + await abort(); + } +}; + +/** + * Gather a streaming response from a stream of chunks. + * + * @param args - The arguments to gather the streaming response. + * @param args.stream - The stream to gather from. + * @param args.parse - The incremental parser to use to parse the response. + * @param args.logger - The logger to use for the gather. + * @returns The parsed response. + */ +export const gatherStreamingResponse = async ({ + stream, + parse, + logger, +}: { + stream: AsyncIterable<{ response: string }>; + parse: SampleCollector; + logger?: Logger; +}): Promise => { + for await (const chunk of stream) { + const delta = (chunk as { response: string }).response; + logger?.info('delta:', delta); + const parsed = parse(delta); + if (parsed !== null) { + logger?.info('parsed:', parsed); + return parsed; + } + } + throw new Error('Stream ended without a parse event'); +}; + +/** + * Retry a function up to a given number of times. + * + * @param func - The function to retry. + * @param maxRetries - The maximum number of times to retry. + * @param isRetryable - A function that determines if an error should be retried. Defaults to always retrying. + * @returns The result of the function. + * @throws An error if the function fails after all retries. + * @throws An error if the function throws an error that is not retryable. + */ +export const withRetries = async ( + func: () => Promise<[Action, Observation | null]>, + maxRetries: number = 0, + isRetryable: (error: unknown) => boolean = () => true, +): Promise<[Action, Observation | null]> => { + if (maxRetries < 1) { + return await func(); + } + const errors: unknown[] = []; + for (let retry = 0; retry < maxRetries; retry++) { + try { + return await func(); + } catch (error) { + if (!isRetryable(error)) { + throw error; + } + errors.push(error); + } + } + throw new Error( + [ + `Exceeded retry budget of ${maxRetries}.`, + ...errors.map((error, index) => { + const message = error instanceof Error ? error.message : String(error); + return ` ${index + 1}: ${message}`; + }), + ].join('\n'), + { cause: errors }, + ); +}; diff --git a/packages/kernel-agents/test/constants.ts b/packages/kernel-agents/test/constants.ts index 17d108edf..bfb5ac9d5 100644 --- a/packages/kernel-agents/test/constants.ts +++ b/packages/kernel-agents/test/constants.ts @@ -2,6 +2,7 @@ * Test constants for E2E tests */ export const DEFAULT_MODEL = 'llama3.1:latest'; +export const TEST_MODELS = ['llama3.1:latest', 'gpt-oss:20b']; /** * Ollama API endpoints diff --git a/packages/kernel-agents/test/e2e/agent.test.ts b/packages/kernel-agents/test/e2e/agent.test.ts deleted file mode 100644 index d6f7483ec..000000000 --- a/packages/kernel-agents/test/e2e/agent.test.ts +++ /dev/null @@ -1,115 +0,0 @@ -import '@ocap/repo-tools/test-utils/mock-endoify'; - -import { Logger } from '@metamask/logger'; -import { OllamaNodejsService } from '@ocap/kernel-language-model-service/ollama/nodejs'; -import { fetchMock } from '@ocap/repo-tools/test-utils/fetch-mock'; -import { - afterAll, - beforeAll, - beforeEach, - describe, - expect, - it, - vi, -} from 'vitest'; - -import { makeAgent } from '../../src/agent.ts'; -import { count, add, multiply } from '../../src/example-capabilities.ts'; -import { DEFAULT_MODEL } from '../constants.ts'; - -/** - * Generate a random letter. - * - * @returns a random letter. - */ -function randomLetter(): string { - return String.fromCharCode(Math.floor(Math.random() * 26) + 97); -} - -const logger = new Logger('test'); - -describe('agent', () => { - beforeAll(() => { - fetchMock.disableMocks(); - }); - - afterAll(() => { - fetchMock.enableMocks(); - }); - - let llmService: OllamaNodejsService; - beforeEach(() => { - llmService = new OllamaNodejsService({ endowments: { fetch } }); - }); - - it( - 'should create an agent and process a request', - { - retry: 3, - timeout: 5_000, - }, - async () => { - const llm = await llmService.makeInstance({ model: DEFAULT_MODEL }); - const agent = makeAgent({ llm, capabilities: {}, logger }); - expect(agent).toBeDefined(); - - const letter = randomLetter().toUpperCase(); - const result = (await agent.task( - `Name an animal that starts with the letter "${letter}"`, - )) as string; - expect(result).toBeDefined(); - expect(result.length).toBeGreaterThan(0); - expect(result).toMatch( - new RegExp(`(${letter}|${letter.toLowerCase()})+`, 'u'), - ); - }, - ); - - it( - 'should create an agent that uses tools', - { - retry: 3, - timeout: 5_000, - }, - async () => { - const llm = await llmService.makeInstance({ model: DEFAULT_MODEL }); - const word = 'xf9147qsdhdkj'; - const countSpy = vi.spyOn(count, 'func'); - const agent = makeAgent({ llm, capabilities: { count }, logger }); - expect(agent).toBeDefined(); - const result = await agent.task( - `What is the length of the word "${word}"?`, - ); - expect(result).toBeDefined(); - expect(countSpy).toHaveBeenCalled(); - expect(result).toMatch(word.length.toString()); - }, - ); - - it( - 'performs multi-step calculations', - { - retry: 3, - timeout: 10_000, - }, - async () => { - const llm = await llmService.makeInstance({ model: DEFAULT_MODEL }); - const agent = makeAgent({ - llm, - capabilities: { - count, - add, - multiply, - }, - logger, - }); - expect(agent).toBeDefined(); - const [length, width, height] = [11, 47, 63]; - const result = await agent.task( - `A box with length ${length}, width ${width}, and height ${height} have volume V. How many digits are in the numerical value of V?`, - ); - expect(result).toBeDefined(); - expect(result).toMatch(`${length * width * height}`.length.toString()); - }, - ); -}); diff --git a/packages/kernel-agents/test/e2e/agents.test.ts b/packages/kernel-agents/test/e2e/agents.test.ts new file mode 100644 index 000000000..f0f78d066 --- /dev/null +++ b/packages/kernel-agents/test/e2e/agents.test.ts @@ -0,0 +1,204 @@ +import '@ocap/repo-tools/test-utils/mock-endoify'; + +import { consoleTransport, Logger } from '@metamask/logger'; +import { OllamaNodejsService } from '@ocap/kernel-language-model-service/ollama/nodejs'; +import { fetchMock } from '@ocap/repo-tools/test-utils/fetch-mock'; +import { + afterAll, + afterEach, + beforeAll, + beforeEach, + describe, + expect, + it, + vi, +} from 'vitest'; + +import type { MakeAgentArgs } from '../../src/agent.ts'; +import { getMoonPhase } from '../../src/capabilities/examples.ts'; +import { count, add, multiply } from '../../src/capabilities/math.ts'; +import { makeJsonAgent } from '../../src/strategies/json-agent.ts'; +import { makeReplAgent } from '../../src/strategies/repl-agent.ts'; +import type { Agent } from '../../src/types.ts'; +import { DEFAULT_MODEL } from '../constants.ts'; +import { filterTransports, randomLetter } from '../utils.ts'; + +const logger = new Logger({ + tags: ['test'], + transports: [filterTransports(consoleTransport)], +}); + +const makeJsonAgentWithMathCapabilities = (args: MakeAgentArgs) => + makeJsonAgent({ + ...args, + capabilities: { count, add, multiply, ...args.capabilities }, + }); + +describe.each([ + ['json', makeJsonAgentWithMathCapabilities], + ['repl', makeReplAgent], +])( + '%s agent', + (strategy: string, makeAgent: (args: MakeAgentArgs) => Agent) => { + let result: unknown; + const retry = 2; + const timeout = 60_000; + + const printLogger = new Logger({ + tags: [strategy], + transports: [ + ({ message, data, level }) => console[level](message, ...(data ?? [])), + ], + }); + + const catchErrorAsResult = < + Func extends (...args: unknown[]) => Promise, + >( + func: Func, + ): Func => + (async (...args: unknown[]) => { + try { + return await func(...args); + } catch (error) { + result = error; + throw error; + } + }) as Func; + + beforeAll(() => { + fetchMock.disableMocks(); + }); + + afterAll(() => { + fetchMock.enableMocks(); + }); + + let languageModelService: OllamaNodejsService; + beforeEach(() => { + result = undefined; + languageModelService = new OllamaNodejsService({ endowments: { fetch } }); + printLogger.log(`\n<== New ${strategy.toUpperCase()} ===`); + }); + + afterEach(() => { + printLogger.log('=== ======== ==='); + printLogger.log(`~ ${result as string}`); + printLogger.log(`=== End ${strategy.toUpperCase()} ==>`); + }); + + it( + 'processes a semantic request', + { retry, timeout }, + catchErrorAsResult(async () => { + const languageModel = await languageModelService.makeInstance({ + model: DEFAULT_MODEL, + }); + const agent = makeAgent({ languageModel, capabilities: {}, logger }); + expect(agent).toBeDefined(); + + const categories = ['animal', 'vegetable', 'mineral'] as const; + const category = + categories[Math.floor(Math.random() * categories.length)]; + + const letter = randomLetter().toUpperCase(); + const query = `Name a kind of ${category} that starts with the letter "${letter}"`; + const containsLetter = (content: string): boolean => + content.includes(letter) || content.includes(letter.toLowerCase()); + type CategoryElement = string; + const judgment = (content: unknown): content is CategoryElement => + // In a multi-agent system, we might another LLM to judge the result. + // For now, we'll just check the type and length. + typeof content === 'string' && + content.length > 0 && + containsLetter(content); + result = await agent.task(query, judgment, { logger: printLogger }); + expect(result).toBeDefined(); + expect((result as string).length).toBeGreaterThan(0); + expect(containsLetter(result as string)).toBe(true); + }), + ); + + it( + 'uses tools', + { retry, timeout }, + catchErrorAsResult(async () => { + const languageModel = await languageModelService.makeInstance({ + model: DEFAULT_MODEL, + }); + const getMoonPhaseSpy = vi.spyOn(getMoonPhase, 'func'); + const agent = makeAgent({ + languageModel, + capabilities: { getMoonPhase }, + logger, + }); + expect(agent).toBeDefined(); + const query = `Is it a full moon tonight?`; + result = await agent.task(query, undefined, { logger: printLogger }); + + expect(result).toBeDefined(); + expect(getMoonPhaseSpy).toHaveBeenCalled(); + }), + ); + + it( + 'performs multi-step calculations', + { retry, timeout }, + catchErrorAsResult(async () => { + const languageModel = await languageModelService.makeInstance({ + model: DEFAULT_MODEL, + }); + const capabilities = {}; + const agent = makeAgent({ languageModel, capabilities, logger }); + expect(agent).toBeDefined(); + const [length, width, height] = [11, 47, 63]; + const query = `A box with length ${length}, width ${width}, and height ${height} have volume V. How many digits are in the numerical value of V?`; + result = await agent.task(query, undefined, { logger: printLogger }); + expect(result).toBeDefined(); + expect(result).includes(`${length * width * height}`.length.toString()); + }), + ); + + it( + 'writes complex code to solve a problem', + // Caveat: We don't expect the solution to be correct. + { retry, timeout: 120_000 }, + catchErrorAsResult(async () => { + const languageModel = await languageModelService.makeInstance({ + model: DEFAULT_MODEL, + }); + const capabilities = {}; + const agent = makeAgent({ languageModel, capabilities, logger }); + expect(agent).toBeDefined(); + const query = [ + 'Let S2(42) be the set of all sets of positive two digit numbers that sum to 42.', + 'For example, the sets { 42 }, { 19, 24 }, and { 10, 12, 20 } are elements of S2(42),', + 'but { 10, 12 } is not because 10 + 12 does not equal 42,', + 'and { 2, 40 } is not because 2 is not a two digit number.', + 'What is |S2(42)|?', + ].join('\n'); + result = await agent.task(query, undefined, { + logger: printLogger, + invocationBudget: 42, + }); + expect(result).toBeDefined(); + }), + ); + + it.skipIf(strategy === 'json')( + 'imports capabilities', + { retry, timeout }, + // TODO: This functionality is not yet implemented. + catchErrorAsResult(async () => { + const languageModel = await languageModelService.makeInstance({ + model: DEFAULT_MODEL, + }); + const capabilities = {}; + const agent = makeAgent({ languageModel, capabilities, logger }); + expect(agent).toBeDefined(); + const query = `What is the current moon phase? You may want to import { getMoonPhase } from "@ocap/abilities"`; + result = await agent.task(query, undefined, { logger: printLogger }); + expect(result).toBeDefined(); + }), + ); + }, +); diff --git a/packages/kernel-agents/test/utils.ts b/packages/kernel-agents/test/utils.ts index 8870d4ef9..684077b57 100644 --- a/packages/kernel-agents/test/utils.ts +++ b/packages/kernel-agents/test/utils.ts @@ -1,13 +1,34 @@ +import type { LogEntry } from '@metamask/logger'; + +// extract ignored logger tags from environment variable + +const ignoreTags = + // eslint-disable-next-line n/no-process-env + process?.env?.LOGGER_IGNORE?.split(',')?.map((tag) => tag.trim()) ?? []; + /** - * Construct a bundle path URL from a bundle name. + * Filter a logger transport to ignore command line specified ignore tags. * - * @param bundleName - The name of the bundle. + * @param transports - The transports to filter. + * @returns A transport that filters out the ignore tags. + */ +export const filterTransports = ( + ...transports: ((entry: LogEntry) => void)[] +): ((entry: LogEntry) => void) => + ignoreTags.includes('all') + ? () => undefined + : (entry) => { + if (ignoreTags.some((tag) => entry.tags.includes(tag))) { + return; + } + transports.forEach((transport) => transport(entry)); + }; + +/** + * Generate a random letter. * - * @returns a path string for the named bundle. + * @returns a random letter. */ -export function getBundleSpec(bundleName: string): string { - return new URL( - `../kernel-test/src/vats/${bundleName}.bundle`, - import.meta.url, - ).toString(); +export function randomLetter(): string { + return String.fromCharCode(Math.floor(Math.random() * 26) + 97); } diff --git a/packages/kernel-agents/tsconfig.json b/packages/kernel-agents/tsconfig.json index 08d6aabb6..d36996f86 100644 --- a/packages/kernel-agents/tsconfig.json +++ b/packages/kernel-agents/tsconfig.json @@ -3,7 +3,7 @@ "compilerOptions": { "baseUrl": "./", "lib": ["ES2022"], - "types": ["vitest", "node"] + "types": ["vitest", "node", "ses"] }, "references": [ { "path": "../kernel-language-model-service" }, diff --git a/packages/kernel-errors/src/constants.ts b/packages/kernel-errors/src/constants.ts index a27d5d990..2c98f6b43 100644 --- a/packages/kernel-errors/src/constants.ts +++ b/packages/kernel-errors/src/constants.ts @@ -31,6 +31,8 @@ export const ErrorCode = { VatDeleted: 'VAT_DELETED', VatNotFound: 'VAT_NOT_FOUND', SubclusterNotFound: 'SUBCLUSTER_NOT_FOUND', + SampleGenerationError: 'SAMPLE_GENERATION_ERROR', + InternalError: 'INTERNAL_ERROR', } as const; export type ErrorCode = (typeof ErrorCode)[keyof typeof ErrorCode]; diff --git a/packages/kernel-errors/src/errors/EvaluatorError.test.ts b/packages/kernel-errors/src/errors/EvaluatorError.test.ts new file mode 100644 index 000000000..1ee154696 --- /dev/null +++ b/packages/kernel-errors/src/errors/EvaluatorError.test.ts @@ -0,0 +1,107 @@ +import { describe, it, expect } from 'vitest'; + +import { EvaluatorError } from './EvaluatorError.ts'; +import { ErrorCode, ErrorSentinel } from '../constants.ts'; +import { unmarshalErrorOptions } from '../marshal/unmarshalError.ts'; +import type { MarshaledOcapError } from '../types.ts'; + +describe('EvaluatorError', () => { + const mockMessage = 'REPL evaluation failed'; + const mockCode = 'const x = 1;'; + const mockCause = new Error('Internal: $return threw an error'); + + it('creates an EvaluatorError with the correct properties', () => { + const error = new EvaluatorError(mockMessage, mockCode, mockCause); + expect(error).toBeInstanceOf(EvaluatorError); + expect(error).toBeInstanceOf(Error); + expect(error.code).toBe(ErrorCode.InternalError); + expect(error.message).toBe(mockMessage); + expect(error.data).toStrictEqual({ code: mockCode }); + expect(error.cause).toBe(mockCause); + }); + + it('creates an EvaluatorError with optional error options', () => { + const mockStack = 'custom stack trace'; + const error = new EvaluatorError(mockMessage, mockCode, mockCause, { + stack: mockStack, + }); + expect(error.stack).toBe(mockStack); + expect(error.data).toStrictEqual({ code: mockCode }); + expect(error.cause).toBe(mockCause); + }); + + it('unmarshals a valid marshaled EvaluatorError', () => { + const marshaledError: MarshaledOcapError = { + [ErrorSentinel]: true, + message: mockMessage, + stack: 'customStack', + code: ErrorCode.InternalError, + data: { code: mockCode }, + cause: { + [ErrorSentinel]: true, + message: 'Internal: $return threw an error', + stack: 'causeStack', + }, + }; + + const unmarshaledError = EvaluatorError.unmarshal( + marshaledError, + unmarshalErrorOptions, + ); + expect(unmarshaledError).toBeInstanceOf(EvaluatorError); + expect(unmarshaledError.code).toBe(ErrorCode.InternalError); + expect(unmarshaledError.message).toBe(mockMessage); + expect(unmarshaledError.stack).toBe('customStack'); + expect(unmarshaledError.data).toStrictEqual({ code: mockCode }); + expect(unmarshaledError.cause).toBeInstanceOf(Error); + expect((unmarshaledError.cause as Error).message).toBe( + 'Internal: $return threw an error', + ); + }); + + it('unmarshals an EvaluatorError without a cause', () => { + const marshaledError: MarshaledOcapError = { + [ErrorSentinel]: true, + message: mockMessage, + code: ErrorCode.InternalError, + data: { code: mockCode }, + }; + + const unmarshaledError = EvaluatorError.unmarshal( + marshaledError, + unmarshalErrorOptions, + ); + expect(unmarshaledError).toBeInstanceOf(EvaluatorError); + expect(unmarshaledError.data).toStrictEqual({ code: mockCode }); + expect(unmarshaledError.cause).toBeInstanceOf(Error); + expect((unmarshaledError.cause as Error).message).toBe('Unknown cause'); + }); + + it('throws an error when an invalid data structure is unmarshaled', () => { + const marshaledError: MarshaledOcapError = { + [ErrorSentinel]: true, + message: mockMessage, + code: ErrorCode.InternalError, + data: 'invalid data', + stack: 'stack trace', + }; + + expect(() => + EvaluatorError.unmarshal(marshaledError, unmarshalErrorOptions), + ).toThrow(/At path: data --/u); + }); + + it('throws an error when an invalid code is unmarshaled', () => { + const marshaledError: MarshaledOcapError = { + [ErrorSentinel]: true, + message: mockMessage, + code: ErrorCode.VatNotFound, + data: { code: mockCode }, + stack: 'stack trace', + }; + + expect(() => + EvaluatorError.unmarshal(marshaledError, unmarshalErrorOptions), + ).toThrow(/At path: code --/u); + }); +}); diff --git a/packages/kernel-errors/src/errors/EvaluatorError.ts b/packages/kernel-errors/src/errors/EvaluatorError.ts new file mode 100644 index 000000000..da9e04f6e --- /dev/null +++ b/packages/kernel-errors/src/errors/EvaluatorError.ts @@ -0,0 +1,79 @@ +import { + assert, + lazy, + literal, + object, + optional, + string, + union, +} from '@metamask/superstruct'; + +import { BaseError } from '../BaseError.ts'; +import { + marshaledErrorSchema, + ErrorCode, + MarshaledErrorStruct, +} from '../constants.ts'; +import type { ErrorOptionsWithStack, MarshaledOcapError } from '../types.ts'; + +/** + * An error indicating a violation of evaluator infrastructure expectations. + * These errors indicate internal failures that should exit the attempt, + * such as $return, $catch, or $capture throwing inside the compartment. + * + * Note: This error should be impossible to throw in normal operation, + * even if the compiler cannot detect this. + */ +export class EvaluatorError extends BaseError { + constructor( + message: string, + code: string, + cause: Error, + options?: ErrorOptionsWithStack, + ) { + super(ErrorCode.InternalError, message, { + ...options, + cause, + data: { code }, + }); + harden(this); + } + + /** + * A superstruct struct for validating marshaled {@link EvaluatorError} instances. + */ + public static struct = object({ + ...marshaledErrorSchema, + code: literal(ErrorCode.InternalError), + data: object({ + code: string(), + }), + cause: optional(union([string(), lazy(() => MarshaledErrorStruct)])), + }); + + /** + * Unmarshals a {@link MarshaledError} into an {@link EvaluatorError}. + * + * @param marshaledError - The marshaled error to unmarshal. + * @param unmarshalErrorOptions - The function to unmarshal the error options. + * @returns The unmarshaled error. + */ + public static unmarshal( + marshaledError: MarshaledOcapError, + unmarshalErrorOptions: ( + marshaledError: MarshaledOcapError, + ) => ErrorOptionsWithStack, + ): EvaluatorError { + assert(marshaledError, this.struct); + const cause = marshaledError.cause + ? (unmarshalErrorOptions(marshaledError).cause as Error) + : new Error('Unknown cause'); + return new EvaluatorError( + marshaledError.message, + marshaledError.data.code, + cause, + unmarshalErrorOptions(marshaledError), + ); + } +} +harden(EvaluatorError); diff --git a/packages/kernel-errors/src/errors/SampleGenerationError.test.ts b/packages/kernel-errors/src/errors/SampleGenerationError.test.ts new file mode 100644 index 000000000..5f30e59e4 --- /dev/null +++ b/packages/kernel-errors/src/errors/SampleGenerationError.test.ts @@ -0,0 +1,104 @@ +import { describe, it, expect } from 'vitest'; + +import { SampleGenerationError } from './SampleGenerationError.ts'; +import { ErrorCode, ErrorSentinel } from '../constants.ts'; +import { unmarshalErrorOptions } from '../marshal/unmarshalError.ts'; +import type { MarshaledOcapError } from '../types.ts'; + +describe('SampleGenerationError', () => { + const mockSample = 'const x = invalid syntax'; + const mockCause = new SyntaxError('Unexpected token'); + + it('creates a SampleGenerationError with the correct properties', () => { + const error = new SampleGenerationError(mockSample, mockCause); + expect(error).toBeInstanceOf(SampleGenerationError); + expect(error).toBeInstanceOf(Error); + expect(error.code).toBe(ErrorCode.SampleGenerationError); + expect(error.message).toBe('LLM generated invalid response.'); + expect(error.data).toStrictEqual({ sample: mockSample }); + expect(error.cause).toBe(mockCause); + }); + + it('creates a SampleGenerationError with optional error options', () => { + const mockStack = 'custom stack trace'; + const error = new SampleGenerationError(mockSample, mockCause, { + stack: mockStack, + }); + expect(error.stack).toBe(mockStack); + expect(error.data).toStrictEqual({ sample: mockSample }); + expect(error.cause).toBe(mockCause); + }); + + it('unmarshals a valid marshaled SampleGenerationError', () => { + const marshaledError: MarshaledOcapError = { + [ErrorSentinel]: true, + message: 'LLM generated invalid response.', + stack: 'customStack', + code: ErrorCode.SampleGenerationError, + data: { sample: mockSample }, + cause: { + [ErrorSentinel]: true, + message: 'Unexpected token', + stack: 'syntaxErrorStack', + }, + }; + + const unmarshaledError = SampleGenerationError.unmarshal( + marshaledError, + unmarshalErrorOptions, + ); + expect(unmarshaledError).toBeInstanceOf(SampleGenerationError); + expect(unmarshaledError.code).toBe(ErrorCode.SampleGenerationError); + expect(unmarshaledError.message).toBe('LLM generated invalid response.'); + expect(unmarshaledError.stack).toBe('customStack'); + expect(unmarshaledError.data).toStrictEqual({ sample: mockSample }); + expect(unmarshaledError.cause).toBeInstanceOf(Error); + expect((unmarshaledError.cause as Error).message).toBe('Unexpected token'); + }); + + it('unmarshals a SampleGenerationError without a cause', () => { + const marshaledError: MarshaledOcapError = { + [ErrorSentinel]: true, + message: 'LLM generated invalid response.', + code: ErrorCode.SampleGenerationError, + data: { sample: mockSample }, + }; + + const unmarshaledError = SampleGenerationError.unmarshal( + marshaledError, + unmarshalErrorOptions, + ); + expect(unmarshaledError).toBeInstanceOf(SampleGenerationError); + expect(unmarshaledError.data).toStrictEqual({ sample: mockSample }); + expect(unmarshaledError.cause).toBeInstanceOf(Error); + expect((unmarshaledError.cause as Error).message).toBe('Unknown cause'); + }); + + it('throws an error when an invalid data structure is unmarshaled', () => { + const marshaledError: MarshaledOcapError = { + [ErrorSentinel]: true, + message: 'LLM generated invalid response.', + code: ErrorCode.SampleGenerationError, + data: 'invalid data', + stack: 'stack trace', + }; + + expect(() => + SampleGenerationError.unmarshal(marshaledError, unmarshalErrorOptions), + ).toThrow(/At path: data --/u); + }); + + it('throws an error when an invalid code is unmarshaled', () => { + const marshaledError: MarshaledOcapError = { + [ErrorSentinel]: true, + message: 'LLM generated invalid response.', + code: ErrorCode.VatNotFound, + data: { sample: mockSample }, + stack: 'stack trace', + }; + + expect(() => + SampleGenerationError.unmarshal(marshaledError, unmarshalErrorOptions), + ).toThrow(/At path: code --/u); + }); +}); diff --git a/packages/kernel-errors/src/errors/SampleGenerationError.ts b/packages/kernel-errors/src/errors/SampleGenerationError.ts new file mode 100644 index 000000000..1191bd9d0 --- /dev/null +++ b/packages/kernel-errors/src/errors/SampleGenerationError.ts @@ -0,0 +1,69 @@ +import { + assert, + lazy, + literal, + object, + optional, + string, + union, +} from '@metamask/superstruct'; + +import { BaseError } from '../BaseError.ts'; +import { + marshaledErrorSchema, + ErrorCode, + MarshaledErrorStruct, +} from '../constants.ts'; +import type { ErrorOptionsWithStack, MarshaledOcapError } from '../types.ts'; + +/** + * An error indicating that the LLM generated invalid response. + * This error should trigger resampling from the LLM. + */ +export class SampleGenerationError extends BaseError { + constructor(sample: string, cause: Error, options?: ErrorOptionsWithStack) { + super(ErrorCode.SampleGenerationError, 'LLM generated invalid response.', { + ...options, + cause, + data: { sample }, + }); + harden(this); + } + + /** + * A superstruct struct for validating marshaled {@link SampleGenerationError} instances. + */ + public static struct = object({ + ...marshaledErrorSchema, + code: literal(ErrorCode.SampleGenerationError), + data: object({ + sample: string(), + }), + cause: optional(union([string(), lazy(() => MarshaledErrorStruct)])), + }); + + /** + * Unmarshals a {@link MarshaledError} into a {@link SampleGenerationError}. + * + * @param marshaledError - The marshaled error to unmarshal. + * @param unmarshalErrorOptions - The function to unmarshal the error options. + * @returns The unmarshaled error. + */ + public static unmarshal( + marshaledError: MarshaledOcapError, + unmarshalErrorOptions: ( + marshaledError: MarshaledOcapError, + ) => ErrorOptionsWithStack, + ): SampleGenerationError { + assert(marshaledError, this.struct); + const cause = marshaledError.cause + ? (unmarshalErrorOptions(marshaledError).cause as Error) + : new Error('Unknown cause'); + return new SampleGenerationError( + marshaledError.data.sample, + cause, + unmarshalErrorOptions(marshaledError), + ); + } +} +harden(SampleGenerationError); diff --git a/packages/kernel-errors/src/errors/index.ts b/packages/kernel-errors/src/errors/index.ts index a127bf85a..adc182cb3 100644 --- a/packages/kernel-errors/src/errors/index.ts +++ b/packages/kernel-errors/src/errors/index.ts @@ -1,5 +1,7 @@ import { AbortError } from './AbortError.ts'; import { DuplicateEndowmentError } from './DuplicateEndowmentError.ts'; +import { EvaluatorError } from './EvaluatorError.ts'; +import { SampleGenerationError } from './SampleGenerationError.ts'; import { StreamReadError } from './StreamReadError.ts'; import { VatAlreadyExistsError } from './VatAlreadyExistsError.ts'; import { VatDeletedError } from './VatDeletedError.ts'; @@ -15,4 +17,6 @@ export const errorClasses = { [ErrorCode.VatDeleted]: VatDeletedError, [ErrorCode.VatNotFound]: VatNotFoundError, [ErrorCode.SubclusterNotFound]: SubclusterNotFoundError, + [ErrorCode.SampleGenerationError]: SampleGenerationError, + [ErrorCode.InternalError]: EvaluatorError, } as const; diff --git a/packages/kernel-errors/src/index.test.ts b/packages/kernel-errors/src/index.test.ts index 3facc7183..0e55551ea 100644 --- a/packages/kernel-errors/src/index.test.ts +++ b/packages/kernel-errors/src/index.test.ts @@ -10,8 +10,10 @@ describe('index', () => { 'ErrorCode', 'ErrorSentinel', 'ErrorStruct', + 'EvaluatorError', 'MarshaledErrorStruct', 'MarshaledOcapErrorStruct', + 'SampleGenerationError', 'StreamReadError', 'SubclusterNotFoundError', 'VatAlreadyExistsError', diff --git a/packages/kernel-errors/src/index.ts b/packages/kernel-errors/src/index.ts index 7fe485227..43981ae55 100644 --- a/packages/kernel-errors/src/index.ts +++ b/packages/kernel-errors/src/index.ts @@ -1,5 +1,7 @@ export type { OcapError, MarshaledError } from './types.ts'; export { DuplicateEndowmentError } from './errors/DuplicateEndowmentError.ts'; +export { EvaluatorError } from './errors/EvaluatorError.ts'; +export { SampleGenerationError } from './errors/SampleGenerationError.ts'; export { VatAlreadyExistsError } from './errors/VatAlreadyExistsError.ts'; export { VatDeletedError } from './errors/VatDeletedError.ts'; export { VatNotFoundError } from './errors/VatNotFoundError.ts'; diff --git a/packages/repo-tools/src/test-utils/env/mock-endoify.ts b/packages/repo-tools/src/test-utils/env/mock-endoify.ts index b0953da9c..13a7f44fb 100644 --- a/packages/repo-tools/src/test-utils/env/mock-endoify.ts +++ b/packages/repo-tools/src/test-utils/env/mock-endoify.ts @@ -27,7 +27,7 @@ globalThis.assert = assertFn as unknown as typeof assert; globalThis.HandledPromise = Promise; // @ts-expect-error: Mocks are lies -globalThis.Compartment = vi.fn(); +globalThis.Compartment ??= vi.fn(); vi.mock('@endo/promise-kit', async () => { return makePromiseKitMock(); diff --git a/scripts/rebuild-native.sh b/scripts/rebuild-native.sh index 72fbf9aa5..942dfbb1f 100755 --- a/scripts/rebuild-native.sh +++ b/scripts/rebuild-native.sh @@ -20,6 +20,18 @@ if [ "$FORCE_REBUILD" -eq 1 ]; then echo "🔁 Force rebuild enabled" fi +# Detect Node.js version and set C++ standard for Node.js v24+ +# Node.js v24+ requires C++20 due to V8 API changes +NODE_VERSION=$(node -v | cut -d'v' -f2 | cut -d'.' -f1) +if [ "$NODE_VERSION" -ge 24 ]; then + # Set C++20 flags for node-gyp builds + export CXXFLAGS="${CXXFLAGS} -std=c++20" + export CPPFLAGS="${CPPFLAGS} -std=c++20" + # npm/node-gyp also respects npm_config_* environment variables + export npm_config_cxxflags="${npm_config_cxxflags} -std=c++20" + echo "🔧 Node.js v${NODE_VERSION} detected: Using C++20 for native builds" +fi + # Check and rebuild better-sqlite3 if [ -d node_modules/better-sqlite3 ] && \ { [ "$FORCE_REBUILD" -eq 1 ] || \ @@ -44,6 +56,18 @@ if [ -d node_modules/@ipshipyard/node-datachannel ] && \ fi fi +# Check and rebuild tree-sitter +if [ -d node_modules/tree-sitter ] && \ + { [ "$FORCE_REBUILD" -eq 1 ] || \ + [ ! -f node_modules/tree-sitter/build/Release/tree_sitter.node ]; \ + }; then + echo "🔨 Building tree-sitter..." + if ! npm rebuild tree-sitter; then + echo "❌ Failed to build tree-sitter" >&2 + BUILD_FAILED=1 + fi +fi + # Exit with error if any builds failed if [ $BUILD_FAILED -eq 1 ]; then echo "⚠️ Some native modules failed to build. This may cause runtime errors." >&2 diff --git a/vitest.config.ts b/vitest.config.ts index a2ff91311..4b370a4b4 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -80,10 +80,10 @@ export default defineConfig({ lines: 3.62, }, 'packages/kernel-agents/**': { - statements: 100, - functions: 100, - branches: 100, - lines: 100, + statements: 94.15, + functions: 94.64, + branches: 91.39, + lines: 94.15, }, 'packages/kernel-browser-runtime/**': { statements: 83.71, @@ -130,7 +130,7 @@ export default defineConfig({ 'packages/kernel-ui/**': { statements: 97.57, functions: 97.29, - branches: 93.26, + branches: 93.25, lines: 97.57, }, 'packages/kernel-utils/**': { diff --git a/yarn.lock b/yarn.lock index 548e6d375..aed5e8cc3 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3271,6 +3271,7 @@ __metadata: "@metamask/eslint-config": "npm:^14.0.0" "@metamask/eslint-config-nodejs": "npm:^14.0.0" "@metamask/eslint-config-typescript": "npm:^14.0.0" + "@metamask/kernel-errors": "workspace:^" "@metamask/kernel-utils": "workspace:^" "@metamask/logger": "workspace:^" "@ocap/kernel-language-model-service": "workspace:^" @@ -3293,6 +3294,9 @@ __metadata: eslint-plugin-promise: "npm:^7.2.1" prettier: "npm:^3.5.3" rimraf: "npm:^6.0.1" + ses: "npm:^1.14.0" + tree-sitter: "npm:^0.25.0" + tree-sitter-javascript: "npm:^0.25.0" turbo: "npm:^2.5.6" typedoc: "npm:^0.28.1" typescript: "npm:~5.8.2" @@ -11340,6 +11344,15 @@ __metadata: languageName: node linkType: hard +"node-addon-api@npm:^8.3.0, node-addon-api@npm:^8.3.1": + version: 8.5.0 + resolution: "node-addon-api@npm:8.5.0" + dependencies: + node-gyp: "npm:latest" + checksum: 10/9a893f4f835fbc3908e0070f7bcacf36e37fd06be8008409b104c30df4092a0d9a29927b3a74cdbc1d34338274ba4116d597a41f573e06c29538a1a70d07413f + languageName: node + linkType: hard + "node-emoji@npm:^2.2.0": version: 2.2.0 resolution: "node-emoji@npm:2.2.0" @@ -11352,6 +11365,17 @@ __metadata: languageName: node linkType: hard +"node-gyp-build@npm:^4.8.4": + version: 4.8.4 + resolution: "node-gyp-build@npm:4.8.4" + bin: + node-gyp-build: bin.js + node-gyp-build-optional: optional.js + node-gyp-build-test: build-test.js + checksum: 10/6a7d62289d1afc419fc8fc9bd00aa4e554369e50ca0acbc215cb91446148b75ff7e2a3b53c2c5b2c09a39d416d69f3d3237937860373104b5fe429bf30ad9ac5 + languageName: node + linkType: hard + "node-gyp@npm:^10.0.0": version: 10.3.1 resolution: "node-gyp@npm:10.3.1" @@ -14050,6 +14074,33 @@ __metadata: languageName: node linkType: hard +"tree-sitter-javascript@npm:^0.25.0": + version: 0.25.0 + resolution: "tree-sitter-javascript@npm:0.25.0" + dependencies: + node-addon-api: "npm:^8.3.1" + node-gyp: "npm:latest" + node-gyp-build: "npm:^4.8.4" + peerDependencies: + tree-sitter: ^0.25.0 + peerDependenciesMeta: + tree-sitter: + optional: true + checksum: 10/b2571944b11064ae07cfc06d43a786dda45a10bd774e7f4fa416d89445215d964672a10d16b6cae9fa9544aa7a6e751df2ec84ad099bcad51237e1b541ff840c + languageName: node + linkType: hard + +"tree-sitter@npm:^0.25.0": + version: 0.25.0 + resolution: "tree-sitter@npm:0.25.0" + dependencies: + node-addon-api: "npm:^8.3.0" + node-gyp: "npm:latest" + node-gyp-build: "npm:^4.8.4" + checksum: 10/cba1090bf7eb3fb68ce21f631375d5c7c5f9505db7ed554fcaadbb21dbae1ef5a6ce754e89e024d5fc332706e31bea90e913957726e0783d3a4e253e979b2a95 + languageName: node + linkType: hard + "truncate-utf8-bytes@npm:^1.0.0": version: 1.0.2 resolution: "truncate-utf8-bytes@npm:1.0.2"