From 4d1e994c1561ae56df198bd3ad6dcf6876a8f461 Mon Sep 17 00:00:00 2001 From: Konstantin Pavlov <1517853+kpavlov@users.noreply.github.com> Date: Thu, 4 Dec 2025 12:57:30 +0200 Subject: [PATCH] Better exception handling in StdioClientTransport Refactor `McpException` for improved handling with convenience constructors and enhanced exception wrapping logic in `StdioClientTransport`. Update tests to use `kotest` matchers and add JUnit parameterized tests for exception handling. --- gradle/libs.versions.toml | 2 + kotlin-sdk-client/build.gradle.kts | 1 + .../kotlin/sdk/client/StdioClientTransport.kt | 23 ++++-- .../StdioClientTransportErrorHandlingTest.kt | 78 +++++++++++++++++++ kotlin-sdk-core/api/kotlin-sdk-core.api | 5 +- .../kotlin/sdk/types/McpException.kt | 17 ++-- .../kotlin/AbstractPromptIntegrationTest.kt | 10 ++- 7 files changed, 121 insertions(+), 15 deletions(-) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 0b31051c..c1d67730 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -22,6 +22,7 @@ mockk = "1.14.6" mokksy = "0.6.2" serialization = "1.9.0" slf4j = "2.0.17" +junit="6.0.1" [libraries] # Plugins @@ -58,6 +59,7 @@ mockk = { module = "io.mockk:mockk", version.ref = "mockk" } mokksy = { group = "dev.mokksy", name = "mokksy", version.ref = "mokksy" } netty-bom = { group = "io.netty", name = "netty-bom", version.ref = "netty" } slf4j-simple = { group = "org.slf4j", name = "slf4j-simple", version.ref = "slf4j" } +junit-jupiter-params = { module = "org.junit.jupiter:junit-jupiter-params", version.ref = "junit" } # Samples ktor-client-cio = { group = "io.ktor", name = "ktor-client-cio", version.ref = "ktor" } diff --git a/kotlin-sdk-client/build.gradle.kts b/kotlin-sdk-client/build.gradle.kts index 4755ac7c..36a3ed8e 100644 --- a/kotlin-sdk-client/build.gradle.kts +++ b/kotlin-sdk-client/build.gradle.kts @@ -54,6 +54,7 @@ kotlin { implementation(libs.awaitility) implementation(libs.ktor.client.apache5) implementation(libs.mockk) + implementation(libs.junit.jupiter.params) implementation(libs.mokksy) implementation(dependencies.platform(libs.netty.bom)) runtimeOnly(libs.slf4j.simple) diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index b6aa8275..66d43f76 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -15,6 +15,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.McpException import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.CONNECTION_CLOSED import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.INTERNAL_ERROR +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers @@ -71,12 +72,14 @@ import kotlin.jvm.JvmOverloads * @param input The input stream where messages are received. * @param output The output stream where messages are sent. * @param error Optional error stream for stderr monitoring. - * @param sendChannel Channel for outbound messages. Default: buffered channel (capacity 64). + * @param sendChannel Channel for outbound messages. Default: buffered channel + * (implementation-default capacity). * @param classifyStderr Callback to classify stderr lines. Return [StderrSeverity.FATAL] to fail transport, * or [StderrSeverity.WARNING] / [StderrSeverity.INFO] / [StderrSeverity.DEBUG] * to log, or [StderrSeverity.IGNORE] to discard. * Default value: [StderrSeverity.DEBUG]. - * @see MCP Specification + * @see MCP Specification + * @see [Channel.BUFFERED] */ @OptIn(ExperimentalAtomicApi::class) public class StdioClientTransport @JvmOverloads public constructor( @@ -232,15 +235,25 @@ public class StdioClientTransport @JvmOverloads public constructor( @Suppress("TooGenericExceptionCaught", "SwallowedException") try { sendChannel.send(message) + } catch (e: CancellationException) { + throw e // MUST rethrow immediately - don't log, don't wrap } catch (e: ClosedSendChannelException) { logger.debug(e) { "Cannot send message: transport is closed" } - throw McpException(CONNECTION_CLOSED, "Transport is closed") + throw McpException( + code = CONNECTION_CLOSED, + message = "Transport is closed", + cause = e, + ) } catch (e: McpException) { logger.debug(e) { "Error while sending message: ${e.message}" } throw e - } catch (e: Exception) { + } catch (e: Throwable) { logger.error(e) { "Error while sending message: ${e.message}" } - throw McpException(INTERNAL_ERROR, "Error while sending message: ${e.message}") + throw McpException( + code = INTERNAL_ERROR, + message = "Error while sending message: ${e.message}", + cause = e, + ) } } diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/stdio/StdioClientTransportErrorHandlingTest.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/stdio/StdioClientTransportErrorHandlingTest.kt index ba3bb2a4..43a8b04f 100644 --- a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/stdio/StdioClientTransportErrorHandlingTest.kt +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/stdio/StdioClientTransportErrorHandlingTest.kt @@ -1,12 +1,28 @@ package io.modelcontextprotocol.kotlin.sdk.client.stdio +import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.booleans.shouldBeFalse import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.shouldBeInstanceOf +import io.kotest.matchers.types.shouldBeSameInstanceAs +import io.mockk.coEvery +import io.mockk.mockk import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ClosedSendChannelException import kotlinx.coroutines.delay import kotlinx.coroutines.test.runTest import kotlinx.io.Buffer import kotlinx.io.writeString +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import java.util.stream.Stream import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.test.Test @@ -95,4 +111,66 @@ class StdioClientTransportErrorHandlingTest { // Empty input should close cleanly without error errorCalled.shouldBeFalse() } + + companion object { + @JvmStatic + fun exceptions(): Stream = Stream.of( + Arguments.of( + CancellationException(), + false, // should not wrap, propagate + null, + ), + Arguments.of( + McpException(-1, "dummy"), + false, // should not wrap, propagate + null, + ), + Arguments.of( + ClosedSendChannelException("dummy"), + true, // should wrap in McpException + ErrorCode.CONNECTION_CLOSED, + ), + Arguments.of( + Exception(), + true, + ErrorCode.INTERNAL_ERROR, + ), + Arguments.of( + OutOfMemoryError(), + true, + ErrorCode.INTERNAL_ERROR, + ), + + ) + } + + @ParameterizedTest + @MethodSource("exceptions") + fun `Send should handle exceptions`(throwable: Throwable, shouldWrap: Boolean, expectedCode: Int?) = runTest { + val sendChannel: Channel = mockk(relaxed = true) + + transport = StdioClientTransport( + input = Buffer(), + output = Buffer(), + sendChannel = sendChannel, + ) + + coEvery { sendChannel.send(any()) } throws throwable + + transport.start() + + // Cancel the coroutine while it's suspended in send() + val exception = shouldThrow { + transport.send(JSONRPCRequest(id = "test-1", method = "test/method")) + } + + if (shouldWrap) { + exception.shouldBeInstanceOf { + it.cause shouldBeSameInstanceAs throwable + it.code shouldBe expectedCode + } + } else { + exception shouldBeSameInstanceAs throwable + } + } } diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index 8694e2d5..49b88cb3 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -2654,11 +2654,12 @@ public abstract interface annotation class io/modelcontextprotocol/kotlin/sdk/ty } public final class io/modelcontextprotocol/kotlin/sdk/types/McpException : java/lang/Exception { + public fun (ILjava/lang/String;)V public fun (ILjava/lang/String;Lkotlinx/serialization/json/JsonElement;)V - public synthetic fun (ILjava/lang/String;Lkotlinx/serialization/json/JsonElement;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (ILjava/lang/String;Lkotlinx/serialization/json/JsonElement;Ljava/lang/Throwable;)V + public synthetic fun (ILjava/lang/String;Lkotlinx/serialization/json/JsonElement;Ljava/lang/Throwable;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getCode ()I public final fun getData ()Lkotlinx/serialization/json/JsonElement; - public fun getMessage ()Ljava/lang/String; } public abstract interface class io/modelcontextprotocol/kotlin/sdk/types/MediaContent : io/modelcontextprotocol/kotlin/sdk/types/ContentBlock { diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/McpException.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/McpException.kt index fb6e1ee5..f63f4a9b 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/McpException.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/McpException.kt @@ -1,14 +1,19 @@ package io.modelcontextprotocol.kotlin.sdk.types import kotlinx.serialization.json.JsonElement +import kotlin.jvm.JvmOverloads /** * Represents an error specific to the MCP protocol. * - * @property code The error code. - * @property message The error message. - * @property data Additional error data as a JSON object. + * @property code The MCP/JSON‑RPC error code. + * @property data Optional additional error payload as a JSON element; `null` when not provided. + * @param message The error message. + * @param cause The original cause. */ -public class McpException(public val code: Int, message: String, public val data: JsonElement? = null) : Exception() { - override val message: String = "MCP error $code: $message" -} +public class McpException @JvmOverloads public constructor( + public val code: Int, + message: String, + public val data: JsonElement? = null, + cause: Throwable? = null, +) : Exception("MCP error $code: $message", cause) diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt index b66ac7bd..fbaca727 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt @@ -1,5 +1,7 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin +import io.kotest.assertions.withClue +import io.kotest.matchers.string.shouldContain import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult @@ -391,7 +393,9 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { } } - assertTrue(exception.message.contains("requiredArg2"), "Exception should mention the missing argument") + withClue("Exception should mention the missing argument") { + exception.message shouldContain "requiredArg2" + } // test with no args val exception2 = assertThrows { @@ -407,7 +411,9 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() { } } - assertTrue(exception2.message.contains("requiredArg"), "Exception should mention a missing required argument") + withClue("Exception should mention a missing required argument") { + exception2.message shouldContain "requiredArg" + } // test with all required args val result = client.getPrompt(