Skip to content

Commit de02a44

Browse files
committed
Add unit tests for Protocol to validate progress token handling and meta behavior
1 parent e43695f commit de02a44

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed

kotlin-sdk-core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ kotlin {
122122
commonTest {
123123
dependencies {
124124
implementation(kotlin("test"))
125+
implementation(libs.kotlinx.coroutines.test)
125126
implementation(libs.kotest.assertions.core)
126127
implementation(libs.kotest.assertions.json)
127128
}
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
package io.modelcontextprotocol.kotlin.sdk.shared
2+
3+
import io.modelcontextprotocol.kotlin.sdk.types.CustomRequest
4+
import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult
5+
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
6+
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest
7+
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse
8+
import io.modelcontextprotocol.kotlin.sdk.types.McpJson
9+
import io.modelcontextprotocol.kotlin.sdk.types.Method
10+
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest
11+
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams
12+
import io.modelcontextprotocol.kotlin.sdk.types.RequestMeta
13+
import kotlinx.coroutines.async
14+
import kotlinx.coroutines.channels.Channel
15+
import kotlinx.coroutines.test.runTest
16+
import kotlinx.serialization.json.JsonObject
17+
import kotlinx.serialization.json.JsonObjectBuilder
18+
import kotlinx.serialization.json.JsonPrimitive
19+
import kotlinx.serialization.json.buildJsonObject
20+
import kotlinx.serialization.json.encodeToJsonElement
21+
import kotlinx.serialization.json.int
22+
import kotlinx.serialization.json.jsonObject
23+
import kotlinx.serialization.json.jsonPrimitive
24+
import kotlin.test.BeforeTest
25+
import kotlin.test.Test
26+
import kotlin.test.assertEquals
27+
28+
class ProtocolTest {
29+
private lateinit var protocol: TestProtocol
30+
private lateinit var transport: RecordingTransport
31+
32+
@BeforeTest
33+
fun setUp() {
34+
protocol = TestProtocol()
35+
transport = RecordingTransport()
36+
}
37+
38+
@Test
39+
fun `should preserve existing meta when adding progress token`() = runTest {
40+
protocol.connect(transport)
41+
val request = ReadResourceRequest(
42+
ReadResourceRequestParams(
43+
uri = "test://resource",
44+
meta = metaOf {
45+
put("customField", JsonPrimitive("customValue"))
46+
put("anotherField", JsonPrimitive(123))
47+
},
48+
),
49+
)
50+
51+
val inFlight = async {
52+
protocol.request<EmptyResult>(
53+
request = request,
54+
options = RequestOptions(onProgress = {}),
55+
)
56+
}
57+
58+
val sent = transport.awaitRequest()
59+
val params = requireNotNull(sent.params).jsonObject
60+
val meta = params["_meta"]!!.jsonObject
61+
62+
assertEquals("test://resource", params["uri"]!!.jsonPrimitive.content)
63+
assertEquals("customValue", meta["customField"]!!.jsonPrimitive.content)
64+
assertEquals(123, meta["anotherField"]!!.jsonPrimitive.int)
65+
assertEquals(McpJson.encodeToJsonElement(sent.id), meta["progressToken"])
66+
67+
transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
68+
inFlight.await()
69+
}
70+
71+
@Test
72+
fun `should create meta with progress token when none exists`() = runTest {
73+
protocol.connect(transport)
74+
val request = ReadResourceRequest(
75+
ReadResourceRequestParams(
76+
uri = "test://resource",
77+
meta = null,
78+
),
79+
)
80+
81+
val inFlight = async {
82+
protocol.request<EmptyResult>(
83+
request = request,
84+
options = RequestOptions(onProgress = {}),
85+
)
86+
}
87+
88+
val sent = transport.awaitRequest()
89+
val params = requireNotNull(sent.params).jsonObject
90+
val meta = params["_meta"]!!.jsonObject
91+
92+
assertEquals("test://resource", params["uri"]!!.jsonPrimitive.content)
93+
assertEquals(McpJson.encodeToJsonElement(sent.id), meta["progressToken"])
94+
95+
transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
96+
inFlight.await()
97+
}
98+
99+
@Test
100+
fun `should not modify meta when onProgress is absent`() = runTest {
101+
protocol.connect(transport)
102+
val originalMeta = metaJson {
103+
put("customField", JsonPrimitive("customValue"))
104+
}
105+
val request = ReadResourceRequest(
106+
ReadResourceRequestParams(
107+
uri = "test://resource",
108+
meta = RequestMeta(originalMeta),
109+
),
110+
)
111+
112+
val inFlight = async {
113+
protocol.request<EmptyResult>(request)
114+
}
115+
116+
val sent = transport.awaitRequest()
117+
val params = requireNotNull(sent.params).jsonObject
118+
val meta = params["_meta"]!!.jsonObject
119+
120+
assertEquals(originalMeta, meta)
121+
assertEquals("test://resource", params["uri"]!!.jsonPrimitive.content)
122+
123+
transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
124+
inFlight.await()
125+
}
126+
127+
@Test
128+
fun `should create params object when request params are null`() = runTest {
129+
protocol.connect(transport)
130+
val request = CustomRequest(
131+
method = Method.Custom("example"),
132+
params = null,
133+
)
134+
135+
val inFlight = async {
136+
protocol.request<EmptyResult>(
137+
request = request,
138+
options = RequestOptions(onProgress = {}),
139+
)
140+
}
141+
142+
val sent = transport.awaitRequest()
143+
val params = requireNotNull(sent.params).jsonObject
144+
val meta = params["_meta"]!!.jsonObject
145+
146+
assertEquals(setOf("_meta"), params.keys)
147+
assertEquals(McpJson.encodeToJsonElement(sent.id), meta["progressToken"])
148+
149+
transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
150+
inFlight.await()
151+
}
152+
}
153+
154+
private class TestProtocol : Protocol(null) {
155+
override fun assertCapabilityForMethod(method: Method) {}
156+
override fun assertNotificationCapability(method: Method) {}
157+
override fun assertRequestHandlerCapability(method: Method) {}
158+
}
159+
160+
private class RecordingTransport : Transport {
161+
private val sentMessages = Channel<JSONRPCMessage>(Channel.UNLIMITED)
162+
private var onMessageCallback: (suspend (JSONRPCMessage) -> Unit)? = null
163+
private var onCloseCallback: (() -> Unit)? = null
164+
165+
override suspend fun start() {}
166+
167+
override suspend fun send(message: JSONRPCMessage) {
168+
sentMessages.send(message)
169+
}
170+
171+
override suspend fun close() {
172+
onCloseCallback?.invoke()
173+
}
174+
175+
override fun onClose(block: () -> Unit) {
176+
onCloseCallback = block
177+
}
178+
179+
override fun onError(block: (Throwable) -> Unit) {}
180+
181+
override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) {
182+
onMessageCallback = block
183+
}
184+
185+
suspend fun awaitRequest(): JSONRPCRequest {
186+
val message = sentMessages.receive()
187+
return message as? JSONRPCRequest
188+
?: error("Expected JSONRPCRequest but received ${message::class.simpleName}")
189+
}
190+
191+
suspend fun deliver(message: JSONRPCMessage) {
192+
val callback = onMessageCallback ?: error("onMessage callback not registered")
193+
callback(message)
194+
}
195+
}
196+
197+
private fun metaOf(builderAction: JsonObjectBuilder.() -> Unit): RequestMeta = RequestMeta(metaJson(builderAction))
198+
199+
private fun metaJson(builderAction: JsonObjectBuilder.() -> Unit): JsonObject = buildJsonObject(builderAction)

0 commit comments

Comments
 (0)