Skip to content

Commit e9700be

Browse files
committed
Add notification debouncing support to Protocol
1 parent 42c9260 commit e9700be

File tree

1 file changed

+95
-16
lines changed
  • kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared

1 file changed

+95
-16
lines changed

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt

Lines changed: 95 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,19 @@ import kotlinx.atomicfu.atomic
2727
import kotlinx.atomicfu.getAndUpdate
2828
import kotlinx.atomicfu.update
2929
import kotlinx.collections.immutable.PersistentMap
30+
import kotlinx.collections.immutable.PersistentSet
3031
import kotlinx.collections.immutable.persistentMapOf
32+
import kotlinx.collections.immutable.persistentSetOf
3133
import kotlinx.coroutines.CompletableDeferred
34+
import kotlinx.coroutines.CoroutineScope
3235
import kotlinx.coroutines.Deferred
36+
import kotlinx.coroutines.Dispatchers
37+
import kotlinx.coroutines.SupervisorJob
3338
import kotlinx.coroutines.TimeoutCancellationException
39+
import kotlinx.coroutines.cancelChildren
40+
import kotlinx.coroutines.launch
3441
import kotlinx.coroutines.withTimeout
42+
import kotlinx.coroutines.yield
3543
import kotlinx.serialization.ExperimentalSerializationApi
3644
import kotlinx.serialization.json.ClassDiscriminatorMode
3745
import kotlinx.serialization.json.Json
@@ -68,23 +76,55 @@ public val McpJson: Json by lazy {
6876

6977
/**
7078
* Additional initialization options.
79+
*
80+
* @property enforceStrictCapabilities whether to restrict emitted requests to only those that the remote side has indicated
81+
* that they can handle, through their advertised capabilities.
82+
*
83+
* Note that this DOES NOT affect checking of _local_ side capabilities, as it is
84+
* considered a logic error to mis-specify those.
85+
*
86+
* Currently, this defaults to false, for backwards compatibility with SDK versions
87+
* that did not advertise capabilities correctly.
88+
* In the future, this will default to true.
89+
*
90+
* @property debouncedNotificationMethods an array of notification method names that should be automatically debounced.
91+
* Any notifications with a method in this list will be coalesced if they occur in the same tick of the event loop.
92+
* e.g., ['notifications/tools/list_changed']
7193
*/
7294
public open class ProtocolOptions(
73-
/**
74-
* Whether to restrict emitted requests to only those that the remote side has indicated
75-
* that they can handle, through their advertised capabilities.
76-
*
77-
* Note that this DOES NOT affect checking of _local_ side capabilities, as it is
78-
* considered a logic error to mis-specify those.
79-
*
80-
* Currently, this defaults to false, for backwards compatibility with SDK versions
81-
* that did not advertise capabilities correctly.
82-
* In the future, this will default to true.
83-
*/
8495
public var enforceStrictCapabilities: Boolean = false,
96+
public val debouncedNotificationMethods: List<Method> = emptyList(),
97+
) {
98+
public operator fun component1(): Boolean = enforceStrictCapabilities
99+
public operator fun component2(): List<Method> = debouncedNotificationMethods
100+
101+
public open fun copy(
102+
enforceStrictCapabilities: Boolean = this.enforceStrictCapabilities,
103+
debouncedNotificationMethods: List<Method> = this.debouncedNotificationMethods,
104+
): ProtocolOptions = ProtocolOptions(enforceStrictCapabilities, debouncedNotificationMethods)
105+
106+
override fun equals(other: Any?): Boolean {
107+
if (this === other) return true
108+
if (other == null || this::class != other::class) return false
109+
110+
other as ProtocolOptions
111+
112+
return when {
113+
enforceStrictCapabilities != other.enforceStrictCapabilities -> false
114+
debouncedNotificationMethods != other.debouncedNotificationMethods -> false
115+
else -> true
116+
}
117+
}
85118

86-
public var timeout: Duration = DEFAULT_REQUEST_TIMEOUT,
87-
)
119+
override fun hashCode(): Int {
120+
var result = enforceStrictCapabilities.hashCode()
121+
result = 31 * result + debouncedNotificationMethods.hashCode()
122+
return result
123+
}
124+
125+
override fun toString(): String =
126+
"ProtocolOptions(enforceStrictCapabilities=$enforceStrictCapabilities, debouncedNotificationMethods=$debouncedNotificationMethods)"
127+
}
88128

89129
/**
90130
* The default request timeout.
@@ -153,6 +193,11 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
153193
public val progressHandlers: Map<ProgressToken, ProgressCallback>
154194
get() = _progressHandlers.value
155195

196+
@Suppress("ktlint:standard:backing-property-naming")
197+
private val _pendingDebouncedNotifications: AtomicRef<PersistentSet<Method>> = atomic(persistentSetOf())
198+
private val notificationScopeJob = SupervisorJob()
199+
private val notificationScope = CoroutineScope(notificationScopeJob + Dispatchers.Default)
200+
156201
/**
157202
* Callback for when the connection is closed for any reason.
158203
*
@@ -224,6 +269,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
224269
val handlersToNotify = _responseHandlers.value.values.toList()
225270
_responseHandlers.getAndSet(persistentMapOf())
226271
_progressHandlers.getAndSet(persistentMapOf())
272+
_pendingDebouncedNotifications.update { it.clear() }
273+
notificationScopeJob.cancelChildren()
227274
transport = null
228275
onClose()
229276

@@ -473,13 +520,45 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
473520
/**
474521
* Emits a notification, which is a one-way message that does not expect a response.
475522
*/
476-
public suspend fun notification(notification: Notification) {
523+
public suspend fun notification(notification: Notification, relatedRequestId: RequestId? = null) {
477524
logger.trace { "Sending notification: ${notification.method}" }
478525
val transport = this.transport ?: error("Not connected")
479526
assertNotificationCapability(notification.method)
527+
val jsonRpcNotification = notification.toJSON()
528+
529+
val isDebounced =
530+
options?.debouncedNotificationMethods?.contains(notification.method) == true &&
531+
notification.params == null &&
532+
relatedRequestId == null
533+
534+
if (isDebounced) {
535+
if (notification.method in _pendingDebouncedNotifications.value) {
536+
logger.trace { "Skipping debounced notification: ${notification.method}" }
537+
return
538+
}
539+
540+
_pendingDebouncedNotifications.update { it.add(notification.method) }
541+
542+
notificationScope.launch {
543+
try {
544+
yield()
545+
} finally {
546+
_pendingDebouncedNotifications.update { it.remove(notification.method) }
547+
}
548+
549+
val activeTransport = this@Protocol.transport ?: return@launch
550+
551+
try {
552+
activeTransport.send(jsonRpcNotification)
553+
} catch (cause: Throwable) {
554+
logger.error(cause) { "Error sending debounced notification: ${notification.method}" }
555+
onError(cause)
556+
}
557+
}
558+
return
559+
}
480560

481-
val message = notification.toJSON()
482-
transport.send(message)
561+
transport.send(jsonRpcNotification)
483562
}
484563

485564
/**

0 commit comments

Comments
 (0)