Skip to content

Commit 0f9209a

Browse files
authored
Handle additional error types in client initialization handling (#404)
May throw the following exceptions: `CancellationException`, `McpException`, `StreamableHttpError`, and `SerializationException` ## Motivation and Context This makes it possible to handle errors directly without needing to inspect the cause. For example: ```kotlin try { client.connect(transport) } catch (e: StreamableHttpError) { // ... } ``` ## How Has This Been Tested? locally ## Breaking Changes None ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update ## Checklist - [x] I have read the [MCP Documentation](https://modelcontextprotocol.io) - [x] My code follows the repository's style guidelines - [x] New and existing tests pass locally - [x] I have added appropriate error handling - [ ] I have added or updated documentation as needed
1 parent 42c9260 commit 0f9209a

File tree

2 files changed

+88
-5
lines changed
  • kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client
  • kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client

2 files changed

+88
-5
lines changed

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListRootsResult
3434
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest
3535
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult
3636
import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel
37+
import io.modelcontextprotocol.kotlin.sdk.types.McpException
3738
import io.modelcontextprotocol.kotlin.sdk.types.Method
3839
import io.modelcontextprotocol.kotlin.sdk.types.PingRequest
3940
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest
@@ -54,6 +55,7 @@ import kotlinx.atomicfu.update
5455
import kotlinx.collections.immutable.minus
5556
import kotlinx.collections.immutable.persistentMapOf
5657
import kotlinx.collections.immutable.toPersistentSet
58+
import kotlinx.serialization.SerializationException
5759
import kotlinx.serialization.json.JsonObject
5860
import kotlin.coroutines.cancellation.CancellationException
5961

@@ -196,11 +198,15 @@ public open class Client(private val clientInfo: Implementation, options: Client
196198
logger.error(error) { "Failed to initialize client: ${error.message}" }
197199
close()
198200

199-
if (error !is CancellationException) {
200-
throw IllegalStateException("Error connecting to transport: ${error.message}", error)
201-
}
201+
when (error) {
202+
is CancellationException,
203+
is McpException,
204+
is StreamableHttpError,
205+
is SerializationException,
206+
-> throw error
202207

203-
throw error
208+
else -> throw IllegalStateException("Error connecting to transport: ${error.message}", error)
209+
}
204210
}
205211
}
206212

kotlin-sdk-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult
2626
import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel
2727
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification
2828
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotificationParams
29+
import io.modelcontextprotocol.kotlin.sdk.types.McpException
2930
import io.modelcontextprotocol.kotlin.sdk.types.Method
3031
import io.modelcontextprotocol.kotlin.sdk.types.Role
3132
import io.modelcontextprotocol.kotlin.sdk.types.Root
@@ -232,6 +233,82 @@ class ClientTest {
232233
assertTrue(closed)
233234
}
234235

236+
@Test
237+
fun `should rethrow McpException as is`() = runTest {
238+
var closed = false
239+
val failingTransport = object : AbstractTransport() {
240+
override suspend fun start() {}
241+
242+
override suspend fun send(message: JSONRPCMessage) {
243+
if (message !is JSONRPCRequest) return
244+
check(message.method == Method.Defined.Initialize.value)
245+
throw McpException(
246+
code = -32600,
247+
message = "Invalid Request",
248+
)
249+
}
250+
251+
override suspend fun close() {
252+
closed = true
253+
}
254+
}
255+
256+
val client = Client(
257+
clientInfo = Implementation(
258+
name = "test client",
259+
version = "1.0",
260+
),
261+
options = ClientOptions(),
262+
)
263+
264+
val exception = assertFailsWith<McpException> {
265+
client.connect(failingTransport)
266+
}
267+
268+
assertEquals(-32600, exception.code)
269+
assertEquals("MCP error -32600: Invalid Request", exception.message)
270+
271+
assertTrue(closed)
272+
}
273+
274+
@Test
275+
fun `should rethrow StreamableHttpError as is`() = runTest {
276+
var closed = false
277+
val failingTransport = object : AbstractTransport() {
278+
override suspend fun start() {}
279+
280+
override suspend fun send(message: JSONRPCMessage) {
281+
if (message !is JSONRPCRequest) return
282+
check(message.method == Method.Defined.Initialize.value)
283+
throw StreamableHttpError(
284+
code = 500,
285+
message = "Internal Server Error",
286+
)
287+
}
288+
289+
override suspend fun close() {
290+
closed = true
291+
}
292+
}
293+
294+
val client = Client(
295+
clientInfo = Implementation(
296+
name = "test client",
297+
version = "1.0",
298+
),
299+
options = ClientOptions(),
300+
)
301+
302+
val exception = assertFailsWith<StreamableHttpError> {
303+
client.connect(failingTransport)
304+
}
305+
306+
assertEquals(500, exception.code)
307+
assertEquals("Streamable HTTP error: Internal Server Error", exception.message)
308+
309+
assertTrue(closed)
310+
}
311+
235312
@Test
236313
fun `should respect server capabilities`() = runTest {
237314
val serverOptions = ServerOptions(
@@ -922,7 +999,7 @@ class ClientTest {
922999
println("Client connected")
9231000
},
9241001
launch {
925-
serverSessionResult.complete(server.connect(serverTransport))
1002+
serverSessionResult.complete(server.createSession(serverTransport))
9261003
println("Server connected")
9271004
},
9281005
).joinAll()

0 commit comments

Comments
 (0)