diff --git a/libs/langchain/src/agents/middleware/index.ts b/libs/langchain/src/agents/middleware/index.ts index 989604052514..3f64ef80263f 100644 --- a/libs/langchain/src/agents/middleware/index.ts +++ b/libs/langchain/src/agents/middleware/index.ts @@ -64,6 +64,10 @@ export { toolRetryMiddleware, type ToolRetryMiddlewareConfig, } from "./toolRetry.js"; +export { + modelRetryMiddleware, + type ModelRetryMiddlewareConfig, +} from "./modelRetry.js"; export { toolEmulatorMiddleware, type ToolEmulatorOptions, diff --git a/libs/langchain/src/agents/middleware/modelRetry.ts b/libs/langchain/src/agents/middleware/modelRetry.ts new file mode 100644 index 000000000000..3d985a8f420b --- /dev/null +++ b/libs/langchain/src/agents/middleware/modelRetry.ts @@ -0,0 +1,241 @@ +/** + * Model retry middleware for agents. + */ +import { z } from "zod/v3"; +import { AIMessage } from "@langchain/core/messages"; + +import { createMiddleware } from "../middleware.js"; +import type { AgentMiddleware } from "./types.js"; + +/** + * Configuration options for the Model Retry Middleware. + */ +export const ModelRetryMiddlewareOptionsSchema = z.object({ + /** + * Maximum number of retry attempts after the initial call. + * Default is 2 retries (3 total attempts). Must be >= 0. + */ + maxRetries: z.number().min(0).default(2), + + /** + * Either an array of error constructors to retry on, or a function + * that takes an error and returns `true` if it should be retried. + * Default is to retry on all errors. + */ + retryOn: z + .union([ + z.function().args(z.instanceof(Error)).returns(z.boolean()), + // eslint-disable-next-line @typescript-eslint/no-explicit-any + z.array(z.custom Error>()), + ]) + .default(() => () => true), + + /** + * Behavior when all retries are exhausted. Options: + * - `"raise"` (default): Re-raise the exception, stopping agent execution. + * - `"return_message"`: Return an AIMessage with error details, allowing + * the agent to potentially handle the failure gracefully. + * - Custom function: Function that takes the exception and returns a string + * for the AIMessage content, allowing custom error formatting. + */ + onFailure: z + .union([ + z.literal("raise"), + z.literal("return_message"), + z.function().args(z.instanceof(Error)).returns(z.string()), + ]) + .default("raise"), + + /** + * Multiplier for exponential backoff. Each retry waits + * `initialDelayMs * (backoffFactor ** retryNumber)` milliseconds. + * Set to 0.0 for constant delay. Default is 2.0. + */ + backoffFactor: z.number().min(0).default(2.0), + + /** + * Initial delay in milliseconds before first retry. Default is 1000 (1 second). + */ + initialDelayMs: z.number().min(0).default(1000), + + /** + * Maximum delay in milliseconds between retries. Caps exponential + * backoff growth. Default is 60000 (60 seconds). + */ + maxDelayMs: z.number().min(0).default(60000), + + /** + * Whether to add random jitter (±25%) to delay to avoid thundering herd. + * Default is `true`. + */ + jitter: z.boolean().default(true), +}); + +export type ModelRetryMiddlewareConfig = z.input< + typeof ModelRetryMiddlewareOptionsSchema +>; + +/** + * Middleware that automatically retries failed model calls with configurable backoff. + * + * Supports retrying on specific exceptions and exponential backoff. + * + * @example Basic usage with default settings (2 retries, exponential backoff) + * ```ts + * import { createAgent, modelRetryMiddleware } from "langchain"; + * + * const agent = createAgent({ + * model: "openai:gpt-4o", + * tools: [searchTool], + * middleware: [modelRetryMiddleware()], + * }); + * ``` + * + * @example Retry specific exceptions only + * ```ts + * import { modelRetryMiddleware } from "langchain"; + * + * const retry = modelRetryMiddleware({ + * maxRetries: 4, + * retryOn: [TimeoutError, NetworkError], + * backoffFactor: 1.5, + * }); + * ``` + * + * @example Custom exception filtering + * ```ts + * function shouldRetry(error: Error): boolean { + * // Only retry on rate limit errors + * if (error.name === "RateLimitError") { + * return true; + * } + * // Or check for specific HTTP status codes + * if (error.name === "HTTPError" && "statusCode" in error) { + * const statusCode = (error as any).statusCode; + * return statusCode === 429 || statusCode === 503; + * } + * return false; + * } + * + * const retry = modelRetryMiddleware({ + * maxRetries: 3, + * retryOn: shouldRetry, + * }); + * ``` + * + * @example Return error message instead of raising + * ```ts + * const retry = modelRetryMiddleware({ + * maxRetries: 4, + * onFailure: "return_message", // Return AIMessage with error instead of throwing + * }); + * ``` + * + * @example Custom error message formatting + * ```ts + * const formatError = (error: Error) => + * `Model call failed: ${error.message}. Please try again later.`; + * + * const retry = modelRetryMiddleware({ + * maxRetries: 4, + * onFailure: formatError, + * }); + * ``` + * + * @example Constant backoff (no exponential growth) + * ```ts + * const retry = modelRetryMiddleware({ + * maxRetries: 5, + * backoffFactor: 0.0, // No exponential growth + * initialDelayMs: 2000, // Always wait 2 seconds + * }); + * ``` + * + * @example Raise exception on failure (default) + * ```ts + * const retry = modelRetryMiddleware({ + * maxRetries: 2, + * onFailure: "raise", // Re-raise exception (default behavior) + * }); + * ``` + * + * @param config - Configuration options for the retry middleware + * @returns A middleware instance that handles model failures with retries + */ +export function modelRetryMiddleware( + config: ModelRetryMiddlewareConfig = {} +): AgentMiddleware { + const { + maxRetries, + retryOn, + onFailure, + backoffFactor, + initialDelayMs, + maxDelayMs, + jitter, + } = ModelRetryMiddlewareOptionsSchema.parse(config); + + /** + * Format the failure message when retries are exhausted. + */ + const formatFailureMessage = (error: Error, attemptsMade: number): string => { + const errorType = error.constructor.name; + const attemptWord = attemptsMade === 1 ? "attempt" : "attempts"; + return `Model call failed after ${attemptsMade} ${attemptWord} with ${errorType}: ${error.message}`; + }; + + /** + * Handle failure when all retries are exhausted. + */ + const handleFailure = (error: Error): AIMessage => { + if (onFailure === "raise") { + throw error; + } + + let content: string; + if (typeof onFailure === "function") { + content = onFailure(error); + } else { + // We don't know the exact attempt count when using RunnableRetry, + // but we can estimate it based on maxRetries + const attemptsMade = maxRetries + 1; + content = formatFailureMessage(error, attemptsMade); + } + + return new AIMessage({ + content, + }); + }; + + return createMiddleware({ + name: "modelRetryMiddleware", + contextSchema: ModelRetryMiddlewareOptionsSchema, + wrapModelCall: async (request, handler) => { + // Build retry options for withRetry + const retryOptions = { + stopAfterAttempt: maxRetries + 1, // maxRetries is retries after initial, so total attempts is maxRetries + 1 + retryOn, + backoffFactor, + initialDelayMs, + maxDelayMs, + jitter, + }; + + try { + // Invoke the retry runnable with the request + return await handler({ + ...request, + model: request.model.withRetry(retryOptions), + }); + } catch (error) { + // RunnableRetry throws on failure, but we need to handle onFailure option + const err = + error && typeof error === "object" && "message" in error + ? (error as Error) + : new Error(String(error)); + + return handleFailure(err); + } + }, + }); +} diff --git a/libs/langchain/src/agents/middleware/tests/modelRetry.test.ts b/libs/langchain/src/agents/middleware/tests/modelRetry.test.ts new file mode 100644 index 000000000000..dd9c8fc5cdd9 --- /dev/null +++ b/libs/langchain/src/agents/middleware/tests/modelRetry.test.ts @@ -0,0 +1,548 @@ +/** + * Tests for ModelRetryMiddleware functionality. + */ + +import { describe, it, expect } from "vitest"; +import { HumanMessage, AIMessage } from "@langchain/core/messages"; +import { MemorySaver } from "@langchain/langgraph-checkpoint"; +import { z } from "zod/v3"; + +import { createAgent } from "../../index.js"; +import { modelRetryMiddleware } from "../modelRetry.js"; +import { FakeToolCallingModel } from "../../tests/utils.js"; + +// Custom error types for testing +class TimeoutError extends Error { + constructor(message: string) { + super(message); + this.name = "TimeoutError"; + } +} + +class NetworkError extends Error { + constructor(message: string) { + super(message); + this.name = "NetworkError"; + } +} + +class RateLimitError extends Error { + constructor(message: string) { + super(message); + this.name = "RateLimitError"; + } +} + +/** + * Helper class to create a model that fails a certain number of times before succeeding. + */ +class TemporaryFailureModel extends FakeToolCallingModel { + private attempt = 0; + private failCount: number; + + constructor(failCount: number) { + super({ toolCalls: [[]] }); + this.failCount = failCount; + } + + async _generate(...args: Parameters) { + this.attempt += 1; + if (this.attempt <= this.failCount) { + throw new Error(`Temporary failure ${this.attempt}`); + } + const result = await super._generate(...args); + // Modify the content to indicate success after retries + if (result.generations[0]?.message) { + result.generations[0].message = new AIMessage({ + content: `Success after ${this.attempt} attempts`, + id: result.generations[0].message.id, + }); + } + return result; + } +} + +/** + * Helper class to create a model that always fails with a specific error. + */ +class AlwaysFailingModel extends FakeToolCallingModel { + private error: Error; + + constructor(error: Error) { + super({ toolCalls: [[]] }); + this.error = error; + } + + async _generate() { + throw this.error; + } +} + +describe("modelRetryMiddleware", () => { + describe("Initialization", () => { + it("should initialize with default values", () => { + const retry = modelRetryMiddleware(); + expect(retry).toBeDefined(); + expect(retry.name).toBe("modelRetryMiddleware"); + }); + + it("should initialize with custom values", () => { + const retry = modelRetryMiddleware({ + maxRetries: 5, + retryOn: [TimeoutError, NetworkError], + onFailure: "raise", + backoffFactor: 1.5, + initialDelayMs: 500, + maxDelayMs: 30000, + jitter: false, + }); + expect(retry).toBeDefined(); + expect(retry.name).toBe("modelRetryMiddleware"); + }); + }); + + describe("Validation", () => { + it("should throw ZodError for invalid maxRetries", () => { + try { + modelRetryMiddleware({ maxRetries: -1 }); + expect.fail("Should have thrown an error"); + } catch (error) { + expect(error).toBeInstanceOf(z.ZodError); + const zodError = error as z.ZodError; + expect(zodError.issues[0].path).toEqual(["maxRetries"]); + expect(zodError.issues[0].code).toBe("too_small"); + } + }); + + it("should throw ZodError for invalid initialDelayMs", () => { + try { + modelRetryMiddleware({ initialDelayMs: -1 }); + expect.fail("Should have thrown an error"); + } catch (error) { + expect(error).toBeInstanceOf(z.ZodError); + const zodError = error as z.ZodError; + expect(zodError.issues[0].path).toEqual(["initialDelayMs"]); + expect(zodError.issues[0].code).toBe("too_small"); + } + }); + + it("should throw ZodError for invalid maxDelayMs", () => { + try { + modelRetryMiddleware({ maxDelayMs: -1 }); + expect.fail("Should have thrown an error"); + } catch (error) { + expect(error).toBeInstanceOf(z.ZodError); + const zodError = error as z.ZodError; + expect(zodError.issues[0].path).toEqual(["maxDelayMs"]); + expect(zodError.issues[0].code).toBe("too_small"); + } + }); + + it("should throw ZodError for invalid backoffFactor", () => { + try { + modelRetryMiddleware({ backoffFactor: -1 }); + expect.fail("Should have thrown an error"); + } catch (error) { + expect(error).toBeInstanceOf(z.ZodError); + const zodError = error as z.ZodError; + expect(zodError.issues[0].path).toEqual(["backoffFactor"]); + expect(zodError.issues[0].code).toBe("too_small"); + } + }); + }); + + describe("Basic functionality", () => { + it("should not retry working model (no retry needed)", async () => { + const model = new FakeToolCallingModel({ + toolCalls: [[]], + }); + + const retry = modelRetryMiddleware({ + maxRetries: 2, + initialDelayMs: 10, + jitter: false, + }); + + const agent = createAgent({ + model, + tools: [], + middleware: [retry] as const, + checkpointer: new MemorySaver(), + }); + + const result = await agent.invoke( + { messages: [new HumanMessage("Hello")] }, + { configurable: { thread_id: "test" } } + ); + + const aiMessages = result.messages.filter(AIMessage.isInstance); + expect(aiMessages.length).toBeGreaterThan(0); + }); + + it("should retry failing model and succeed after temporary failures", async () => { + const model = new TemporaryFailureModel(2); + + const retry = modelRetryMiddleware({ + maxRetries: 3, + initialDelayMs: 10, + jitter: false, + }); + + const agent = createAgent({ + model, + tools: [], + middleware: [retry] as const, + checkpointer: new MemorySaver(), + }); + + const result = await agent.invoke( + { messages: [new HumanMessage("Hello")] }, + { configurable: { thread_id: "test" } } + ); + + const aiMessages = result.messages.filter(AIMessage.isInstance); + expect(aiMessages.length).toBeGreaterThan(0); + expect(aiMessages[aiMessages.length - 1].content).toContain( + "Success after 3 attempts" + ); + }); + + it("should retry failing model and raise on failure (default)", async () => { + const model = new AlwaysFailingModel(new Error("Model failed")); + + const retry = modelRetryMiddleware({ + maxRetries: 2, + initialDelayMs: 10, + jitter: false, + onFailure: "raise", + }); + + const agent = createAgent({ + model, + tools: [], + middleware: [retry] as const, + checkpointer: new MemorySaver(), + }); + + // Should raise the Error from the model + await expect( + agent.invoke( + { messages: [new HumanMessage("Hello")] }, + { configurable: { thread_id: "test" } } + ) + ).rejects.toThrow("Model failed"); + }); + + it("should retry failing model and return error message", async () => { + const model = new AlwaysFailingModel(new Error("Model failed")); + + const retry = modelRetryMiddleware({ + maxRetries: 2, + initialDelayMs: 10, + jitter: false, + onFailure: "return_message", + }); + + const agent = createAgent({ + model, + tools: [], + middleware: [retry] as const, + checkpointer: new MemorySaver(), + }); + + const result = await agent.invoke( + { messages: [new HumanMessage("Hello")] }, + { configurable: { thread_id: "test" } } + ); + + const aiMessages = result.messages.filter(AIMessage.isInstance); + expect(aiMessages.length).toBeGreaterThan(0); + // Should contain error message with attempts + expect(aiMessages[aiMessages.length - 1].content).toContain( + "3 attempts" + ); + expect(aiMessages[aiMessages.length - 1].content).toContain("Error"); + }); + + it("should use custom failure formatter", async () => { + const customFormatter = (error: Error): string => { + return `Custom error: ${error.constructor.name}`; + }; + + const model = new AlwaysFailingModel(new Error("Model failed")); + + const retry = modelRetryMiddleware({ + maxRetries: 1, + initialDelayMs: 10, + jitter: false, + onFailure: customFormatter, + }); + + const agent = createAgent({ + model, + tools: [], + middleware: [retry] as const, + checkpointer: new MemorySaver(), + }); + + const result = await agent.invoke( + { messages: [new HumanMessage("Hello")] }, + { configurable: { thread_id: "test" } } + ); + + const aiMessages = result.messages.filter(AIMessage.isInstance); + expect(aiMessages.length).toBeGreaterThan(0); + expect(aiMessages[aiMessages.length - 1].content).toBe( + "Custom error: Error" + ); + }); + }); + + describe("Retry on specific exceptions", () => { + it("should retry on specified error types", async () => { + class TimeoutFailureModel extends FakeToolCallingModel { + private attempt = 0; + + constructor() { + super({ toolCalls: [[]] }); + } + + async _generate(...args: Parameters) { + this.attempt += 1; + if (this.attempt <= 1) { + throw new TimeoutError("Timeout"); + } + return super._generate(...args); + } + } + + const model = new TimeoutFailureModel(); + + const retry = modelRetryMiddleware({ + maxRetries: 2, + initialDelayMs: 10, + jitter: false, + retryOn: [TimeoutError], + }); + + const agent = createAgent({ + model, + tools: [], + middleware: [retry] as const, + checkpointer: new MemorySaver(), + }); + + const result = await agent.invoke( + { messages: [new HumanMessage("Hello")] }, + { configurable: { thread_id: "test" } } + ); + + const aiMessages = result.messages.filter(AIMessage.isInstance); + expect(aiMessages.length).toBeGreaterThan(0); + }); + + it("should not retry on non-specified error types", async () => { + const model = new AlwaysFailingModel(new Error("Generic error")); + + const retry = modelRetryMiddleware({ + maxRetries: 2, + initialDelayMs: 10, + jitter: false, + retryOn: [TimeoutError, RateLimitError], + onFailure: "return_message", + }); + + const agent = createAgent({ + model, + tools: [], + middleware: [retry] as const, + checkpointer: new MemorySaver(), + }); + + const result = await agent.invoke( + { messages: [new HumanMessage("Hello")] }, + { configurable: { thread_id: "test" } } + ); + + const aiMessages = result.messages.filter(AIMessage.isInstance); + expect(aiMessages.length).toBeGreaterThan(0); + // Should fail immediately without retries since Error is not in retryOn list + expect(aiMessages[aiMessages.length - 1].content).toContain("1 attempt"); + }); + + it("should use custom retry function", async () => { + class RateLimitFailureModel extends FakeToolCallingModel { + private attempt = 0; + + constructor() { + super({ toolCalls: [[]] }); + } + + async _generate(...args: Parameters) { + this.attempt += 1; + if (this.attempt <= 1) { + const error = new Error("Rate limit exceeded"); + (error as any).statusCode = 429; + throw error; + } + return super._generate(...args); + } + } + + const model = new RateLimitFailureModel(); + + const shouldRetry = (error: Error): boolean => { + return ( + error.name === "RateLimitError" || + ((error as any).statusCode === 429) + ); + }; + + const retry = modelRetryMiddleware({ + maxRetries: 2, + initialDelayMs: 10, + jitter: false, + retryOn: shouldRetry, + }); + + const agent = createAgent({ + model, + tools: [], + middleware: [retry] as const, + checkpointer: new MemorySaver(), + }); + + const result = await agent.invoke( + { messages: [new HumanMessage("Hello")] }, + { configurable: { thread_id: "test" } } + ); + + const aiMessages = result.messages.filter(AIMessage.isInstance); + expect(aiMessages.length).toBeGreaterThan(0); + }); + }); + + describe("Backoff behavior", () => { + it("should apply exponential backoff", async () => { + class BackoffTestModel extends FakeToolCallingModel { + private attempt = 0; + private delays: number[] = []; + private lastTime = Date.now(); + + constructor() { + super({ toolCalls: [[]] }); + } + + async _generate(...args: Parameters) { + const currentTime = Date.now(); + if (this.attempt > 0) { + this.delays.push(currentTime - this.lastTime); + } + this.lastTime = currentTime; + this.attempt += 1; + if (this.attempt <= 2) { + throw new Error(`Temporary failure ${this.attempt}`); + } + return super._generate(...args); + } + + getDelays() { + return this.delays; + } + } + + const model = new BackoffTestModel(); + + const retry = modelRetryMiddleware({ + maxRetries: 3, + initialDelayMs: 100, + backoffFactor: 2.0, + jitter: false, + }); + + const agent = createAgent({ + model, + tools: [], + middleware: [retry] as const, + checkpointer: new MemorySaver(), + }); + + await agent.invoke( + { messages: [new HumanMessage("Hello")] }, + { configurable: { thread_id: "test" } } + ); + + const delays = model.getDelays(); + // Should have delays between retries + expect(delays.length).toBeGreaterThan(0); + // First delay should be around initialDelayMs (100ms) + expect(delays[0]).toBeGreaterThanOrEqual(90); + expect(delays[0]).toBeLessThan(150); + // Second delay should be around initialDelayMs * backoffFactor (200ms) + if (delays.length > 1) { + expect(delays[1]).toBeGreaterThanOrEqual(180); + expect(delays[1]).toBeLessThan(250); + } + }); + + it("should apply constant backoff when backoffFactor is 0", async () => { + class ConstantBackoffTestModel extends FakeToolCallingModel { + private attempt = 0; + private delays: number[] = []; + private lastTime = Date.now(); + + constructor() { + super({ toolCalls: [[]] }); + } + + async _generate(...args: Parameters) { + const currentTime = Date.now(); + if (this.attempt > 0) { + this.delays.push(currentTime - this.lastTime); + } + this.lastTime = currentTime; + this.attempt += 1; + if (this.attempt <= 2) { + throw new Error(`Temporary failure ${this.attempt}`); + } + return super._generate(...args); + } + + getDelays() { + return this.delays; + } + } + + const model = new ConstantBackoffTestModel(); + + const retry = modelRetryMiddleware({ + maxRetries: 3, + initialDelayMs: 100, + backoffFactor: 0.0, + jitter: false, + }); + + const agent = createAgent({ + model, + tools: [], + middleware: [retry] as const, + checkpointer: new MemorySaver(), + }); + + await agent.invoke( + { messages: [new HumanMessage("Hello")] }, + { configurable: { thread_id: "test" } } + ); + + const delays = model.getDelays(); + // All delays should be approximately the same (around initialDelayMs) + if (delays.length > 1) { + const avgDelay = + delays.reduce((a, b) => a + b, 0) / delays.length; + expect(avgDelay).toBeGreaterThanOrEqual(90); + expect(avgDelay).toBeLessThan(150); + } + }); + }); +}); +