Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion pkg/vmcp/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"log/slog"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -67,6 +68,13 @@ const (
// defaultSessionTTL is the default session time-to-live duration.
// Sessions that are inactive for this duration will be automatically cleaned up.
defaultSessionTTL = 30 * time.Minute

// defaultIdleCheckInterval is how often the idle reaper scans for inactive sessions.
defaultIdleCheckInterval = time.Minute

// defaultRetryAfterSeconds is the Retry-After value returned with HTTP 503
// when the global session limit is reached.
defaultRetryAfterSeconds = 30
)

//go:generate mockgen -destination=mocks/mock_watcher.go -package=mocks -source=server.go Watcher
Expand Down Expand Up @@ -160,6 +168,21 @@ type Config struct {
// SessionFactory creates MultiSessions for Phase 2 session management.
// Required when SessionManagementV2 is true; ignored otherwise.
SessionFactory vmcpsession.MultiSessionFactory

// MaxSessions is the global concurrent session limit when SessionManagementV2 is enabled.
// Requests that would exceed this limit receive HTTP 503 with a Retry-After header.
// 0 uses the default (100). Requires SessionManagementV2 = true.
MaxSessions int

// MaxSessionsPerClient is the per-identity session limit when SessionManagementV2 is enabled.
// Keyed by auth.Identity.Subject; anonymous clients are not limited.
// 0 uses the default (10). Requires SessionManagementV2 = true.
MaxSessionsPerClient int

// IdleSessionTimeout is the duration after which inactive sessions are proactively
// expired when SessionManagementV2 is enabled. Must be ≤ SessionTTL.
// 0 uses the default (5 minutes). Requires SessionManagementV2 = true.
IdleSessionTimeout time.Duration
}

// Server is the Virtual MCP Server that aggregates multiple backends.
Expand Down Expand Up @@ -277,6 +300,24 @@ func New(
if cfg.SessionTTL == 0 {
cfg.SessionTTL = defaultSessionTTL
}
if cfg.MaxSessions == 0 {
cfg.MaxSessions = sessionmanager.DefaultMaxSessions
}
if cfg.MaxSessionsPerClient == 0 {
cfg.MaxSessionsPerClient = sessionmanager.DefaultMaxSessionsPerClient
}
if cfg.IdleSessionTimeout == 0 {
cfg.IdleSessionTimeout = sessionmanager.DefaultIdleSessionTimeout
}
// IdleSessionTimeout must not exceed SessionTTL: if it did, the transport
// TTL reaper could evict sessions before the idle reaper fires, leaving
// per-client counters and idle-tracking maps stale.
if cfg.IdleSessionTimeout > cfg.SessionTTL {
slog.Warn("IdleSessionTimeout exceeds SessionTTL; clamping to SessionTTL",
"idle_session_timeout", cfg.IdleSessionTimeout,
"session_ttl", cfg.SessionTTL)
cfg.IdleSessionTimeout = cfg.SessionTTL
}

// Create hooks for SDK integration
hooks := &server.Hooks{}
Expand Down Expand Up @@ -400,7 +441,12 @@ func New(
if cfg.SessionFactory == nil {
return nil, fmt.Errorf("SessionManagementV2 is enabled but no SessionFactory was provided")
}
vmcpSessMgr = sessionmanager.New(sessionManager, cfg.SessionFactory, backendRegistry)
limits := sessionmanager.Limits{
MaxSessions: cfg.MaxSessions,
MaxSessionsPerClient: cfg.MaxSessionsPerClient,
IdleSessionTimeout: cfg.IdleSessionTimeout,
}
vmcpSessMgr = sessionmanager.New(sessionManager, cfg.SessionFactory, backendRegistry, limits)
slog.Info("session-scoped backend lifecycle enabled")

// Warn about incompatible optimizer configuration and disable it
Expand Down Expand Up @@ -557,6 +603,13 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) {
slog.Info("audit middleware enabled for MCP endpoints")
}

// Apply session limit middleware when V2 session management is active.
// Runs before auth so over-limit requests are rejected early without auth overhead.
if s.vmcpSessionMgr != nil && s.config.MaxSessions > 0 {
mcpHandler = s.sessionLimitMiddleware(mcpHandler)
slog.Info("session limit middleware enabled", "max_sessions", s.config.MaxSessions)
}

// Apply authentication middleware if configured (runs first in chain)
if s.config.AuthMiddleware != nil {
mcpHandler = s.config.AuthMiddleware(mcpHandler)
Expand All @@ -575,6 +628,37 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) {
return mux, nil
}

// sessionLimitMiddleware is a best-effort fast-fail for new session requests
// (no Mcp-Session-Id header): it returns HTTP 503 + Retry-After before the
// request reaches the SDK when the global session cap appears to be reached.
// Existing sessions (with a valid Mcp-Session-Id) are never affected.
//
// This check is intentionally optimistic (non-atomic): it avoids the overhead
// of routing and SDK processing for clearly-over-limit requests, but it does
// not guarantee strict enforcement under concurrent load. Strict enforcement
// is provided atomically by sessionmanager.Manager.Generate(), which uses an
// increment-first reservation to prevent races between concurrent initialize
// requests.
func (s *Server) sessionLimitMiddleware(next http.Handler) http.Handler {
// Resolve the concrete manager once so we can call ActiveSessionCount().
mgr, _ := s.vmcpSessionMgr.(*sessionmanager.Manager)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Mcp-Session-Id") == "" && mgr != nil {
if mgr.ActiveSessionCount() >= s.config.MaxSessions {
w.Header().Set("Retry-After", strconv.Itoa(defaultRetryAfterSeconds))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte(
`{"error":{"code":-32000,"message":"Maximum concurrent sessions exceeded. ` +
`Please try again later or contact administrator."}}`,
))
return
}
}
next.ServeHTTP(w, r)
})
}
Comment on lines +631 to +660
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker: can we remove the rate limiting behavior from this PR?

It's not really needed today, so I'd prefer to avoid adding it.


// Start starts the Virtual MCP Server and begins serving requests.
//
//nolint:gocyclo // Complexity from health monitoring and startup orchestration is acceptable
Expand Down Expand Up @@ -667,6 +751,19 @@ func (s *Server) Start(ctx context.Context) error {
}
}

// Start idle session reaper if V2 session management is active with an idle timeout.
if mgr, ok := s.vmcpSessionMgr.(*sessionmanager.Manager); ok && s.config.IdleSessionTimeout > 0 {
idleCtx, idleCancel := context.WithCancel(ctx)
mgr.StartIdleReaper(idleCtx, defaultIdleCheckInterval)
slog.Info("idle session reaper started",
"idle_timeout", s.config.IdleSessionTimeout,
"check_interval", defaultIdleCheckInterval)
s.shutdownFuncs = append(s.shutdownFuncs, func(context.Context) error {
idleCancel()
return nil
})
}

// Start status reporter if configured
if s.statusReporter != nil {
shutdown, err := s.statusReporter.Start(ctx)
Expand Down
118 changes: 118 additions & 0 deletions pkg/vmcp/server/session_management_v2_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,55 @@ func buildV2Server(
return ts
}

// buildV2ServerWithLimits is like buildV2Server but accepts an explicit MaxSessions cap.
func buildV2ServerWithLimits(
t *testing.T,
factory vmcpsession.MultiSessionFactory,
maxSessions int,
) *httptest.Server {
t.Helper()

ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)

mockBackendClient := mocks.NewMockBackendClient(ctrl)
mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl)
mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl)

emptyAggCaps := &aggregator.AggregatedCapabilities{}
mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes()
mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(emptyAggCaps, nil).AnyTimes()
mockDiscoveryMgr.EXPECT().Stop().AnyTimes()

rt := router.NewDefaultRouter()

srv, err := server.New(
context.Background(),
&server.Config{
Host: "127.0.0.1",
Port: 0,
SessionTTL: 5 * time.Minute,
SessionManagementV2: true,
SessionFactory: factory,
MaxSessions: maxSessions,
},
rt,
mockBackendClient,
mockDiscoveryMgr,
mockBackendRegistry,
nil,
)
require.NoError(t, err)

handler, err := srv.Handler(context.Background())
require.NoError(t, err)

ts := httptest.NewServer(handler)
t.Cleanup(ts.Close)

return ts
}

// postMCP sends a JSON-RPC POST to /mcp and returns the response.
func postMCP(t *testing.T, baseURL string, body map[string]any, sessionID string) *http.Response {
t.Helper()
Expand Down Expand Up @@ -474,3 +523,72 @@ func TestIntegration_SessionManagementV2_OldPathUnused(t *testing.T) {
"MakeSessionWithID should NOT be called when SessionManagementV2 is false",
)
}

// TestIntegration_SessionManagementV2_SessionLimitMiddleware verifies that the
// global session cap (MaxSessions) is enforced end-to-end: once the cap is
// reached every new initialize request gets HTTP 503 with a Retry-After header
// and a JSON error body, while existing sessions are unaffected.
func TestIntegration_SessionManagementV2_SessionLimitMiddleware(t *testing.T) {
t.Parallel()

const maxSessions = 2

factory := newV2FakeFactory([]vmcp.Tool{{Name: "noop"}})
ts := buildV2ServerWithLimits(t, factory, maxSessions)

initReq := map[string]any{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": map[string]any{
"protocolVersion": "2025-06-18",
"capabilities": map[string]any{},
"clientInfo": map[string]any{"name": "test", "version": "1.0"},
},
}

// Fill the pool to exactly MaxSessions.
sessionIDs := make([]string, maxSessions)
for i := range maxSessions {
resp := postMCP(t, ts.URL, initReq, "")
defer resp.Body.Close() //nolint:gocritic // deferred inside loop is intentional for test cleanup
require.Equal(t, http.StatusOK, resp.StatusCode, "session %d should succeed", i+1)
id := resp.Header.Get("Mcp-Session-Id")
require.NotEmpty(t, id, "session %d should return a session ID", i+1)
sessionIDs[i] = id
}

// The next initialize request must be rejected with 503.
overResp := postMCP(t, ts.URL, initReq, "")
defer overResp.Body.Close()

assert.Equal(t, http.StatusServiceUnavailable, overResp.StatusCode,
"initialize beyond MaxSessions must return 503")
assert.NotEmpty(t, overResp.Header.Get("Retry-After"),
"503 response must include Retry-After header")
assert.Equal(t, "application/json", overResp.Header.Get("Content-Type"))

var errBody struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
require.NoError(t, json.NewDecoder(overResp.Body).Decode(&errBody))
assert.Equal(t, -32000, errBody.Error.Code)
assert.NotEmpty(t, errBody.Error.Message)

// Existing sessions must still be valid (DELETE returns 200, not 404/503).
for _, id := range sessionIDs {
req, err := http.NewRequestWithContext(
context.Background(), http.MethodDelete, ts.URL+"/mcp", http.NoBody,
)
require.NoError(t, err)
req.Header.Set("Mcp-Session-Id", id)
delResp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
delResp.Body.Close()
assert.Equal(t, http.StatusOK, delResp.StatusCode,
"existing session %s should still be terminable after cap is hit", id)
}
}
Loading
Loading