diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index b793534595..1e04032955 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -16,6 +16,7 @@ import ( "log/slog" "net" "net/http" + "strconv" "strings" "sync" "time" @@ -67,6 +68,13 @@ const ( // defaultSessionTTL is the default session time-to-live duration. // Sessions that are inactive for this duration will be automatically cleaned up. defaultSessionTTL = 30 * time.Minute + + // defaultIdleCheckInterval is how often the idle reaper scans for inactive sessions. + defaultIdleCheckInterval = time.Minute + + // defaultRetryAfterSeconds is the Retry-After value returned with HTTP 503 + // when the global session limit is reached. + defaultRetryAfterSeconds = 30 ) //go:generate mockgen -destination=mocks/mock_watcher.go -package=mocks -source=server.go Watcher @@ -160,6 +168,21 @@ type Config struct { // SessionFactory creates MultiSessions for Phase 2 session management. // Required when SessionManagementV2 is true; ignored otherwise. SessionFactory vmcpsession.MultiSessionFactory + + // MaxSessions is the global concurrent session limit when SessionManagementV2 is enabled. + // Requests that would exceed this limit receive HTTP 503 with a Retry-After header. + // 0 uses the default (100). Requires SessionManagementV2 = true. + MaxSessions int + + // MaxSessionsPerClient is the per-identity session limit when SessionManagementV2 is enabled. + // Keyed by auth.Identity.Subject; anonymous clients are not limited. + // 0 uses the default (10). Requires SessionManagementV2 = true. + MaxSessionsPerClient int + + // IdleSessionTimeout is the duration after which inactive sessions are proactively + // expired when SessionManagementV2 is enabled. Must be ≤ SessionTTL. + // 0 uses the default (5 minutes). Requires SessionManagementV2 = true. + IdleSessionTimeout time.Duration } // Server is the Virtual MCP Server that aggregates multiple backends. @@ -277,6 +300,24 @@ func New( if cfg.SessionTTL == 0 { cfg.SessionTTL = defaultSessionTTL } + if cfg.MaxSessions == 0 { + cfg.MaxSessions = sessionmanager.DefaultMaxSessions + } + if cfg.MaxSessionsPerClient == 0 { + cfg.MaxSessionsPerClient = sessionmanager.DefaultMaxSessionsPerClient + } + if cfg.IdleSessionTimeout == 0 { + cfg.IdleSessionTimeout = sessionmanager.DefaultIdleSessionTimeout + } + // IdleSessionTimeout must not exceed SessionTTL: if it did, the transport + // TTL reaper could evict sessions before the idle reaper fires, leaving + // per-client counters and idle-tracking maps stale. + if cfg.IdleSessionTimeout > cfg.SessionTTL { + slog.Warn("IdleSessionTimeout exceeds SessionTTL; clamping to SessionTTL", + "idle_session_timeout", cfg.IdleSessionTimeout, + "session_ttl", cfg.SessionTTL) + cfg.IdleSessionTimeout = cfg.SessionTTL + } // Create hooks for SDK integration hooks := &server.Hooks{} @@ -400,7 +441,12 @@ func New( if cfg.SessionFactory == nil { return nil, fmt.Errorf("SessionManagementV2 is enabled but no SessionFactory was provided") } - vmcpSessMgr = sessionmanager.New(sessionManager, cfg.SessionFactory, backendRegistry) + limits := sessionmanager.Limits{ + MaxSessions: cfg.MaxSessions, + MaxSessionsPerClient: cfg.MaxSessionsPerClient, + IdleSessionTimeout: cfg.IdleSessionTimeout, + } + vmcpSessMgr = sessionmanager.New(sessionManager, cfg.SessionFactory, backendRegistry, limits) slog.Info("session-scoped backend lifecycle enabled") // Warn about incompatible optimizer configuration and disable it @@ -557,6 +603,13 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { slog.Info("audit middleware enabled for MCP endpoints") } + // Apply session limit middleware when V2 session management is active. + // Runs before auth so over-limit requests are rejected early without auth overhead. + if s.vmcpSessionMgr != nil && s.config.MaxSessions > 0 { + mcpHandler = s.sessionLimitMiddleware(mcpHandler) + slog.Info("session limit middleware enabled", "max_sessions", s.config.MaxSessions) + } + // Apply authentication middleware if configured (runs first in chain) if s.config.AuthMiddleware != nil { mcpHandler = s.config.AuthMiddleware(mcpHandler) @@ -575,6 +628,37 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { return mux, nil } +// sessionLimitMiddleware is a best-effort fast-fail for new session requests +// (no Mcp-Session-Id header): it returns HTTP 503 + Retry-After before the +// request reaches the SDK when the global session cap appears to be reached. +// Existing sessions (with a valid Mcp-Session-Id) are never affected. +// +// This check is intentionally optimistic (non-atomic): it avoids the overhead +// of routing and SDK processing for clearly-over-limit requests, but it does +// not guarantee strict enforcement under concurrent load. Strict enforcement +// is provided atomically by sessionmanager.Manager.Generate(), which uses an +// increment-first reservation to prevent races between concurrent initialize +// requests. +func (s *Server) sessionLimitMiddleware(next http.Handler) http.Handler { + // Resolve the concrete manager once so we can call ActiveSessionCount(). + mgr, _ := s.vmcpSessionMgr.(*sessionmanager.Manager) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Mcp-Session-Id") == "" && mgr != nil { + if mgr.ActiveSessionCount() >= s.config.MaxSessions { + w.Header().Set("Retry-After", strconv.Itoa(defaultRetryAfterSeconds)) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte( + `{"error":{"code":-32000,"message":"Maximum concurrent sessions exceeded. ` + + `Please try again later or contact administrator."}}`, + )) + return + } + } + next.ServeHTTP(w, r) + }) +} + // Start starts the Virtual MCP Server and begins serving requests. // //nolint:gocyclo // Complexity from health monitoring and startup orchestration is acceptable @@ -667,6 +751,19 @@ func (s *Server) Start(ctx context.Context) error { } } + // Start idle session reaper if V2 session management is active with an idle timeout. + if mgr, ok := s.vmcpSessionMgr.(*sessionmanager.Manager); ok && s.config.IdleSessionTimeout > 0 { + idleCtx, idleCancel := context.WithCancel(ctx) + mgr.StartIdleReaper(idleCtx, defaultIdleCheckInterval) + slog.Info("idle session reaper started", + "idle_timeout", s.config.IdleSessionTimeout, + "check_interval", defaultIdleCheckInterval) + s.shutdownFuncs = append(s.shutdownFuncs, func(context.Context) error { + idleCancel() + return nil + }) + } + // Start status reporter if configured if s.statusReporter != nil { shutdown, err := s.statusReporter.Start(ctx) diff --git a/pkg/vmcp/server/session_management_v2_integration_test.go b/pkg/vmcp/server/session_management_v2_integration_test.go index 3c487ebafd..bf72cdb210 100644 --- a/pkg/vmcp/server/session_management_v2_integration_test.go +++ b/pkg/vmcp/server/session_management_v2_integration_test.go @@ -218,6 +218,55 @@ func buildV2Server( return ts } +// buildV2ServerWithLimits is like buildV2Server but accepts an explicit MaxSessions cap. +func buildV2ServerWithLimits( + t *testing.T, + factory vmcpsession.MultiSessionFactory, + maxSessions int, +) *httptest.Server { + t.Helper() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) + + emptyAggCaps := &aggregator.AggregatedCapabilities{} + mockBackendRegistry.EXPECT().List(gomock.Any()).Return(nil).AnyTimes() + mockDiscoveryMgr.EXPECT().Discover(gomock.Any(), gomock.Any()).Return(emptyAggCaps, nil).AnyTimes() + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + rt := router.NewDefaultRouter() + + srv, err := server.New( + context.Background(), + &server.Config{ + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + SessionManagementV2: true, + SessionFactory: factory, + MaxSessions: maxSessions, + }, + rt, + mockBackendClient, + mockDiscoveryMgr, + mockBackendRegistry, + nil, + ) + require.NoError(t, err) + + handler, err := srv.Handler(context.Background()) + require.NoError(t, err) + + ts := httptest.NewServer(handler) + t.Cleanup(ts.Close) + + return ts +} + // postMCP sends a JSON-RPC POST to /mcp and returns the response. func postMCP(t *testing.T, baseURL string, body map[string]any, sessionID string) *http.Response { t.Helper() @@ -474,3 +523,72 @@ func TestIntegration_SessionManagementV2_OldPathUnused(t *testing.T) { "MakeSessionWithID should NOT be called when SessionManagementV2 is false", ) } + +// TestIntegration_SessionManagementV2_SessionLimitMiddleware verifies that the +// global session cap (MaxSessions) is enforced end-to-end: once the cap is +// reached every new initialize request gets HTTP 503 with a Retry-After header +// and a JSON error body, while existing sessions are unaffected. +func TestIntegration_SessionManagementV2_SessionLimitMiddleware(t *testing.T) { + t.Parallel() + + const maxSessions = 2 + + factory := newV2FakeFactory([]vmcp.Tool{{Name: "noop"}}) + ts := buildV2ServerWithLimits(t, factory, maxSessions) + + initReq := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{"name": "test", "version": "1.0"}, + }, + } + + // Fill the pool to exactly MaxSessions. + sessionIDs := make([]string, maxSessions) + for i := range maxSessions { + resp := postMCP(t, ts.URL, initReq, "") + defer resp.Body.Close() //nolint:gocritic // deferred inside loop is intentional for test cleanup + require.Equal(t, http.StatusOK, resp.StatusCode, "session %d should succeed", i+1) + id := resp.Header.Get("Mcp-Session-Id") + require.NotEmpty(t, id, "session %d should return a session ID", i+1) + sessionIDs[i] = id + } + + // The next initialize request must be rejected with 503. + overResp := postMCP(t, ts.URL, initReq, "") + defer overResp.Body.Close() + + assert.Equal(t, http.StatusServiceUnavailable, overResp.StatusCode, + "initialize beyond MaxSessions must return 503") + assert.NotEmpty(t, overResp.Header.Get("Retry-After"), + "503 response must include Retry-After header") + assert.Equal(t, "application/json", overResp.Header.Get("Content-Type")) + + var errBody struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + require.NoError(t, json.NewDecoder(overResp.Body).Decode(&errBody)) + assert.Equal(t, -32000, errBody.Error.Code) + assert.NotEmpty(t, errBody.Error.Message) + + // Existing sessions must still be valid (DELETE returns 200, not 404/503). + for _, id := range sessionIDs { + req, err := http.NewRequestWithContext( + context.Background(), http.MethodDelete, ts.URL+"/mcp", http.NoBody, + ) + require.NoError(t, err) + req.Header.Set("Mcp-Session-Id", id) + delResp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + delResp.Body.Close() + assert.Equal(t, http.StatusOK, delResp.StatusCode, + "existing session %s should still be terminable after cap is hit", id) + } +} diff --git a/pkg/vmcp/server/sessionmanager/session_manager.go b/pkg/vmcp/server/sessionmanager/session_manager.go index f41ec991f0..31b66d0f91 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager.go +++ b/pkg/vmcp/server/sessionmanager/session_manager.go @@ -18,6 +18,9 @@ import ( "errors" "fmt" "log/slog" + "sync" + "sync/atomic" + "time" "github.com/google/uuid" "github.com/mark3labs/mcp-go/mcp" @@ -39,8 +42,47 @@ const ( // MetadataValTrue is the string value stored under MetadataKeyTerminated // when a session has been terminated. MetadataValTrue = "true" + + // DefaultMaxSessions is the default global concurrent session limit. + // 0 disables the global limit. + DefaultMaxSessions = 100 + + // DefaultMaxSessionsPerClient is the default per-identity session limit. + // 0 disables the per-client limit. + DefaultMaxSessionsPerClient = 10 + + // DefaultIdleSessionTimeout is the default duration after which inactive + // sessions are proactively expired. Must be ≤ session TTL. + // 0 disables idle expiry. + DefaultIdleSessionTimeout = 5 * time.Minute + + // defaultIdleCheckInterval is how often the idle reaper scans for idle sessions. + defaultIdleCheckInterval = time.Minute ) +// ErrSessionLimitReached is returned when the global session limit is hit. +var ErrSessionLimitReached = errors.New("maximum concurrent sessions exceeded") + +// ErrPerClientSessionLimitReached is returned when the per-client session limit is hit. +var ErrPerClientSessionLimitReached = errors.New("maximum sessions per client exceeded") + +// Limits configures resource-exhaustion protections for the Manager. +type Limits struct { + // MaxSessions is the maximum number of concurrent sessions globally. + // 0 means unlimited. + MaxSessions int + + // MaxSessionsPerClient is the maximum concurrent sessions per client identity, + // keyed by auth.Identity.Subject. Anonymous clients (no Subject) are not limited. + // 0 means unlimited. + MaxSessionsPerClient int + + // IdleSessionTimeout is the maximum duration a session may be inactive + // before it is proactively expired. Must be ≤ the session TTL. + // 0 disables idle expiry. + IdleSessionTimeout time.Duration +} + // Manager bridges the domain session lifecycle (MultiSession / MultiSessionFactory) // to the mark3labs SDK's SessionIdManager interface. // @@ -68,19 +110,40 @@ type Manager struct { storage *transportsession.Manager factory vmcpsession.MultiSessionFactory backendRegistry vmcp.BackendRegistry + limits Limits + + // perClientMu guards perClientCounts and sessionSubject. + perClientMu sync.Mutex + perClientCounts map[string]int // subject → active session count + sessionSubject map[string]string // sessionID → subject (for decrement on Terminate) + + // idleActivityMu guards idleActivity. + idleActivityMu sync.RWMutex + idleActivity map[string]time.Time // sessionID → last active time + + // activeSessionCount tracks sessions that have been generated but not yet + // terminated, excluding terminated placeholders left for TTL cleanup. + // This gives an accurate count for global limit enforcement, unlike + // storage.Count() which includes those terminated placeholders. + activeSessionCount atomic.Int64 } -// New creates a Manager backed by the given transport -// manager, session factory, and backend registry. +// New creates a Manager backed by the given transport manager, session factory, +// backend registry, and resource-exhaustion limits. func New( storage *transportsession.Manager, factory vmcpsession.MultiSessionFactory, backendRegistry vmcp.BackendRegistry, + limits Limits, ) *Manager { return &Manager{ storage: storage, factory: factory, backendRegistry: backendRegistry, + limits: limits, + perClientCounts: make(map[string]int), + sessionSubject: make(map[string]string), + idleActivity: make(map[string]time.Time), } } @@ -93,6 +156,22 @@ func New( // The placeholder is replaced by CreateSession() in Phase 2 once context // is available via the OnRegisterSession hook. func (sm *Manager) Generate() string { + // Atomically claim a slot before allocating storage. Incrementing first + // (rather than Load → check → Add) eliminates the TOCTOU race where + // concurrent initialize requests all observe Count < MaxSessions and all + // proceed past the cap. If the incremented value exceeds the cap, or if + // storage allocation fails, the slot is released immediately. + if sm.limits.MaxSessions > 0 { + if int(sm.activeSessionCount.Add(1)) > sm.limits.MaxSessions { + sm.activeSessionCount.Add(-1) + slog.Warn("Manager: session limit reached, rejecting new session", + "active", sm.activeSessionCount.Load(), + "max", sm.limits.MaxSessions, + "error", ErrSessionLimitReached) + return "" + } + } + sessionID := uuid.New().String() if err := sm.storage.AddWithID(sessionID); err != nil { @@ -101,10 +180,17 @@ func (sm *Manager) Generate() string { sessionID = uuid.New().String() if err := sm.storage.AddWithID(sessionID); err != nil { slog.Error("Manager: failed to store placeholder session on retry", "session_id", sessionID, "error", err) + if sm.limits.MaxSessions > 0 { + sm.activeSessionCount.Add(-1) + } return "" } } + if sm.limits.MaxSessions <= 0 { + // Unlimited: count is not pre-incremented above, so increment here. + sm.activeSessionCount.Add(1) + } slog.Debug("Manager: generated placeholder session", "session_id", sessionID) return sessionID } @@ -151,6 +237,12 @@ func (sm *Manager) CreateSession( // Resolve the caller identity (may be nil for anonymous access). identity, _ := auth.IdentityFromContext(ctx) + // Enforce per-client session limit for identified callers. + perClientIncremented, err := sm.enforcePerClientLimit(sessionID, identity) + if err != nil { + return nil, err + } + // Note: Token hash and salt are computed and stored by the session factory // (MakeSessionWithID below). Token binding enforcement happens at the session // level via validateCaller(), which uses HMAC-SHA256 with a per-session salt. @@ -168,6 +260,9 @@ func (sm *Manager) CreateSession( allowAnonymous := vmcpsession.ShouldAllowAnonymous(identity) sess, err := sm.factory.MakeSessionWithID(ctx, sessionID, identity, allowAnonymous, backends) if err != nil { + if perClientIncremented { + sm.decrementPerClientCount(sessionID) + } return nil, fmt.Errorf("Manager.CreateSession: failed to create multi-session: %w", err) } @@ -180,6 +275,9 @@ func (sm *Manager) CreateSession( placeholder2, exists := sm.storage.Get(sessionID) if !exists { _ = sess.Close() + if perClientIncremented { + sm.decrementPerClientCount(sessionID) + } return nil, fmt.Errorf( "Manager.CreateSession: placeholder for session %q disappeared during backend init (terminated concurrently)", sessionID, @@ -187,6 +285,9 @@ func (sm *Manager) CreateSession( } if placeholder2.GetMetadata()[MetadataKeyTerminated] == MetadataValTrue { _ = sess.Close() + if perClientIncremented { + sm.decrementPerClientCount(sessionID) + } return nil, fmt.Errorf( "Manager.CreateSession: session %q was terminated during backend init (marked after first check)", sessionID, @@ -200,9 +301,15 @@ func (sm *Manager) CreateSession( if err := sm.storage.UpsertSession(sess); err != nil { // Best-effort close of the newly created session to release backend connections. _ = sess.Close() + if perClientIncremented { + sm.decrementPerClientCount(sessionID) + } return nil, fmt.Errorf("Manager.CreateSession: failed to replace placeholder: %w", err) } + // Session is fully established — start the idle clock. + sm.resetIdleActivity(sessionID) + slog.Debug("Manager: created multi-session", "session_id", sessionID, "backend_count", len(backends)) @@ -263,6 +370,11 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { sess, exists := sm.storage.Get(sessionID) if !exists { slog.Debug("Manager.Terminate: session not found (already expired?)", "session_id", sessionID) + // The storage entry may have been removed by TTL cleanup racing with + // Terminate(). Clean up any in-memory map entries that may be left behind + // to prevent per-client counts from sticking and stale idle-reap entries. + sm.decrementPerClientCount(sessionID) + sm.removeIdleActivity(sessionID) return false, nil } @@ -276,6 +388,9 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { if deleteErr := sm.storage.Delete(sessionID); deleteErr != nil { return false, fmt.Errorf("Manager.Terminate: failed to delete session from storage: %w", deleteErr) } + sm.activeSessionCount.Add(-1) + sm.decrementPerClientCount(sessionID) + sm.removeIdleActivity(sessionID) } else { // Placeholder session (not yet upgraded to MultiSession). // @@ -294,6 +409,7 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { // We mark (not delete) so Validate() can return isTerminated=true, which // lets the SDK distinguish "actively terminated" from "never existed". // TTL cleanup will remove the placeholder later. + sm.activeSessionCount.Add(-1) sess.SetMetadata(MetadataKeyTerminated, MetadataValTrue) if replaceErr := sm.storage.UpsertSession(sess); replaceErr != nil { slog.Warn("Manager.Terminate: failed to persist terminated flag for placeholder; attempting delete fallback", @@ -385,6 +501,10 @@ func (sm *Manager) GetAdaptedTools(sessionID string) ([]mcpserver.ServerTool, er return mcp.NewToolResultError(callErr.Error()), nil } + // Reset idle clock after the tool call completes so long-running tools + // are not reaped mid-execution by the idle reaper. + sm.resetIdleActivity(capturedSessionID) + return &mcp.CallToolResult{ Result: mcp.Result{ Meta: conversion.ToMCPMeta(result.Meta), @@ -404,3 +524,134 @@ func (sm *Manager) GetAdaptedTools(sessionID string) ([]mcpserver.ServerTool, er return sdkTools, nil } + +// ActiveSessionCount returns the number of sessions that have been generated +// but not yet terminated. Unlike storage.Count(), this excludes terminated +// placeholders left in storage for TTL cleanup, giving an accurate measure +// for global session limit enforcement. +func (sm *Manager) ActiveSessionCount() int { + return int(sm.activeSessionCount.Load()) +} + +// --------------------------------------------------------------------------- +// Per-client session limit helpers +// --------------------------------------------------------------------------- + +// enforcePerClientLimit checks and increments the per-client session count for the +// given identity. Returns (true, nil) when the count was incremented, (false, nil) +// for anonymous sessions (not subject to limiting), and (false, err) when the limit +// is exceeded. The caller must call decrementPerClientCount on any failure path when +// the returned bool is true. +func (sm *Manager) enforcePerClientLimit(sessionID string, identity *auth.Identity) (bool, error) { + subject := identitySubject(identity) + if sm.limits.MaxSessionsPerClient <= 0 || subject == "" { + return false, nil + } + sm.perClientMu.Lock() + defer sm.perClientMu.Unlock() + if sm.perClientCounts[subject] >= sm.limits.MaxSessionsPerClient { + return false, fmt.Errorf("%w: subject %q", ErrPerClientSessionLimitReached, subject) + } + sm.perClientCounts[subject]++ + sm.sessionSubject[sessionID] = subject + return true, nil +} + +// decrementPerClientCount removes the per-client counter entry for sessionID. +// It is safe to call even if the session was never counted (anonymous sessions). +func (sm *Manager) decrementPerClientCount(sessionID string) { + sm.perClientMu.Lock() + defer sm.perClientMu.Unlock() + subject, ok := sm.sessionSubject[sessionID] + if !ok { + return + } + delete(sm.sessionSubject, sessionID) + if sm.perClientCounts[subject] > 0 { + sm.perClientCounts[subject]-- + } +} + +// identitySubject returns the Subject claim for identity-based rate limiting. +// Returns "" for nil identities or identities without a Subject, which opts +// them out of per-client limiting. +func identitySubject(identity *auth.Identity) string { + if identity == nil { + return "" + } + return identity.Subject +} + +// --------------------------------------------------------------------------- +// Idle session timeout helpers +// --------------------------------------------------------------------------- + +// resetIdleActivity records the current time as the last-active timestamp for +// sessionID. Called on session creation and on every tool call. +// No-op when IdleSessionTimeout is zero (idle tracking disabled). +func (sm *Manager) resetIdleActivity(sessionID string) { + if sm.limits.IdleSessionTimeout <= 0 { + return + } + sm.idleActivityMu.Lock() + sm.idleActivity[sessionID] = time.Now() + sm.idleActivityMu.Unlock() +} + +// removeIdleActivity removes the idle-tracking entry for sessionID. +// Called from Terminate() so the reaper does not attempt to re-terminate. +func (sm *Manager) removeIdleActivity(sessionID string) { + sm.idleActivityMu.Lock() + delete(sm.idleActivity, sessionID) + sm.idleActivityMu.Unlock() +} + +// reapIdleSessions terminates any sessions that have been inactive longer than +// the configured IdleSessionTimeout. +func (sm *Manager) reapIdleSessions() { + cutoff := time.Now().Add(-sm.limits.IdleSessionTimeout) + + sm.idleActivityMu.RLock() + var toTerminate []string + for sessionID, lastActive := range sm.idleActivity { + if lastActive.Before(cutoff) { + toTerminate = append(toTerminate, sessionID) + } + } + sm.idleActivityMu.RUnlock() + + for _, sessionID := range toTerminate { + slog.Info("Manager: terminating idle session", + "session_id", sessionID, + "idle_timeout", sm.limits.IdleSessionTimeout) + if _, err := sm.Terminate(sessionID); err != nil { + slog.Warn("Manager: failed to terminate idle session", + "session_id", sessionID, "error", err) + } + } +} + +// StartIdleReaper starts a background goroutine that periodically calls +// reapIdleSessions. It is a no-op when IdleSessionTimeout is zero (disabled). +// The goroutine is stopped when ctx is cancelled; the caller should add a +// cancel to shutdownFuncs to ensure cleanup on server Stop(). +func (sm *Manager) StartIdleReaper(ctx context.Context, interval time.Duration) { + if sm.limits.IdleSessionTimeout <= 0 { + return + } + if interval <= 0 { + interval = defaultIdleCheckInterval + } + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sm.reapIdleSessions() + } + } + }() +} diff --git a/pkg/vmcp/server/sessionmanager/session_manager_test.go b/pkg/vmcp/server/sessionmanager/session_manager_test.go index 3f891c6293..7ca68a1d95 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager_test.go +++ b/pkg/vmcp/server/sessionmanager/session_manager_test.go @@ -204,7 +204,7 @@ func newTestTransportManager(t *testing.T) *transportsession.Manager { return mgr } -// newTestVMCPSessionManager is a convenience constructor for tests. +// newTestVMCPSessionManager is a convenience constructor for tests using default (zero) Limits. func newTestVMCPSessionManager( t *testing.T, factory vmcpsession.MultiSessionFactory, @@ -212,7 +212,19 @@ func newTestVMCPSessionManager( ) (*Manager, *transportsession.Manager) { t.Helper() storage := newTestTransportManager(t) - return New(storage, factory, registry), storage + return New(storage, factory, registry, Limits{}), storage +} + +// newTestVMCPSessionManagerWithLimits creates a Manager with explicit resource limits. +func newTestVMCPSessionManagerWithLimits( + t *testing.T, + factory vmcpsession.MultiSessionFactory, + registry vmcp.BackendRegistry, + limits Limits, +) (*Manager, *transportsession.Manager) { + t.Helper() + storage := newTestTransportManager(t) + return New(storage, factory, registry, limits), storage } // --------------------------------------------------------------------------- @@ -252,7 +264,7 @@ func TestVMCPSessionManager_Generate(t *testing.T) { t.Cleanup(func() { _ = failingMgr.Stop() }) factory := newFakeFactory(nil) - sm := New(failingMgr, factory, newFakeRegistry()) + sm := New(failingMgr, factory, newFakeRegistry(), Limits{}) id := sm.Generate() assert.Empty(t, id, "Generate() should return '' when storage is unavailable") @@ -634,7 +646,7 @@ func TestVMCPSessionManager_Terminate(t *testing.T) { failingStorage, ) t.Cleanup(func() { _ = storage.Stop() }) - sm := New(storage, factory, registry) + sm := New(storage, factory, registry, Limits{}) // Generate a placeholder (first Store, succeeds). sessionID := sm.Generate() @@ -671,7 +683,7 @@ func TestVMCPSessionManager_Terminate(t *testing.T) { failingStorage, ) t.Cleanup(func() { _ = storage.Stop() }) - sm := New(storage, factory, registry) + sm := New(storage, factory, registry, Limits{}) // Generate a placeholder (first Store, succeeds). sessionID := sm.Generate() @@ -1010,3 +1022,323 @@ func newCallToolRequest(name string, args map[string]any) mcp.CallToolRequest { req.Params.Arguments = args return req } + +// --------------------------------------------------------------------------- +// Tests: per-client session limit +// --------------------------------------------------------------------------- + +func TestVMCPSessionManager_PerClientLimit(t *testing.T) { + t.Parallel() + + t.Run("allows sessions up to MaxSessionsPerClient", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 2}) + + identity := &auth.Identity{Subject: "user-1"} + ctx := auth.WithIdentity(context.Background(), identity) + + id1 := sm.Generate() + _, err := sm.CreateSession(ctx, id1) + require.NoError(t, err) + + id2 := sm.Generate() + _, err = sm.CreateSession(ctx, id2) + require.NoError(t, err) + }) + + t.Run("rejects session when per-client limit reached", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 1}) + + identity := &auth.Identity{Subject: "user-1"} + ctx := auth.WithIdentity(context.Background(), identity) + + id1 := sm.Generate() + _, err := sm.CreateSession(ctx, id1) + require.NoError(t, err) + + id2 := sm.Generate() + _, err = sm.CreateSession(ctx, id2) + require.ErrorIs(t, err, ErrPerClientSessionLimitReached) + }) + + t.Run("count is decremented after Terminate, allowing new session", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 1}) + + identity := &auth.Identity{Subject: "user-1"} + ctx := auth.WithIdentity(context.Background(), identity) + + id1 := sm.Generate() + _, err := sm.CreateSession(ctx, id1) + require.NoError(t, err) + + _, err = sm.Terminate(id1) + require.NoError(t, err) + + id2 := sm.Generate() + _, err = sm.CreateSession(ctx, id2) + require.NoError(t, err, "should allow new session after previous was terminated") + }) + + t.Run("anonymous sessions (no Subject) are not limited", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 1}) + + // No identity in context → anonymous. + ctx := context.Background() + + id1 := sm.Generate() + _, err := sm.CreateSession(ctx, id1) + require.NoError(t, err) + + id2 := sm.Generate() + _, err = sm.CreateSession(ctx, id2) + require.NoError(t, err, "anonymous sessions should not be subject to per-client limit") + }) + + t.Run("different subjects have independent counts", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 1}) + + ctx1 := auth.WithIdentity(context.Background(), &auth.Identity{Subject: "user-a"}) + ctx2 := auth.WithIdentity(context.Background(), &auth.Identity{Subject: "user-b"}) + + idA := sm.Generate() + _, err := sm.CreateSession(ctx1, idA) + require.NoError(t, err) + + idB := sm.Generate() + _, err = sm.CreateSession(ctx2, idB) + require.NoError(t, err, "user-b should have its own independent count") + }) +} + +// --------------------------------------------------------------------------- +// Tests: idle session reaper +// --------------------------------------------------------------------------- + +func TestVMCPSessionManager_IdleReaper(t *testing.T) { + t.Parallel() + + t.Run("terminates session that exceeds idle timeout", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + idleTimeout := 5 * time.Minute + sm, storage := newTestVMCPSessionManagerWithLimits(t, factory, registry, + Limits{IdleSessionTimeout: idleTimeout}) + + ctx := context.Background() + id := sm.Generate() + _, err := sm.CreateSession(ctx, id) + require.NoError(t, err) + + // Back-date the idle timestamp so the session appears past the timeout + // without any real sleep, making the test immune to CI scheduling jitter. + sm.idleActivityMu.Lock() + sm.idleActivity[id] = time.Now().Add(-(idleTimeout + time.Second)) + sm.idleActivityMu.Unlock() + + sm.reapIdleSessions() + + _, exists := storage.Get(id) + assert.False(t, exists, "idle session should have been reaped") + }) + + t.Run("does not terminate session active within timeout", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory([]vmcp.Tool{{Name: "noop"}}) + registry := newFakeRegistry() + idleTimeout := 200 * time.Millisecond + sm, storage := newTestVMCPSessionManagerWithLimits(t, factory, registry, + Limits{IdleSessionTimeout: idleTimeout}) + + ctx := context.Background() + id := sm.Generate() + _, err := sm.CreateSession(ctx, id) + require.NoError(t, err) + + // Simulate activity by resetting the idle clock. + sm.resetIdleActivity(id) + + // Reap immediately — session should survive. + sm.reapIdleSessions() + + _, exists := storage.Get(id) + assert.True(t, exists, "recently active session should not be reaped") + }) + + t.Run("reaper is no-op when IdleSessionTimeout is zero", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, storage := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{}) + + ctx := context.Background() + id := sm.Generate() + _, err := sm.CreateSession(ctx, id) + require.NoError(t, err) + + // Idle map should be empty when timeout is disabled. + sm.idleActivityMu.RLock() + idleCount := len(sm.idleActivity) + sm.idleActivityMu.RUnlock() + assert.Equal(t, 0, idleCount, "idle map should be empty when timeout is disabled") + + sm.reapIdleSessions() // should not panic or touch storage + + _, exists := storage.Get(id) + assert.True(t, exists, "session should still exist when idle reaper is disabled") + }) +} + +// --------------------------------------------------------------------------- +// Tests: ActiveSessionCount / global limit accuracy +// --------------------------------------------------------------------------- + +func TestVMCPSessionManager_ActiveSessionCount(t *testing.T) { + t.Parallel() + + t.Run("increments on Generate and decrements on Terminate for MultiSession", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManager(t, factory, registry) + + assert.Equal(t, 0, sm.ActiveSessionCount()) + + id := sm.Generate() + assert.Equal(t, 1, sm.ActiveSessionCount()) + + _, err := sm.CreateSession(context.Background(), id) + require.NoError(t, err) + assert.Equal(t, 1, sm.ActiveSessionCount(), "CreateSession should not change the count") + + _, err = sm.Terminate(id) + require.NoError(t, err) + assert.Equal(t, 0, sm.ActiveSessionCount()) + }) + + t.Run("decrements on Terminate for placeholder (terminated but not deleted)", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManager(t, factory, registry) + + id := sm.Generate() + assert.Equal(t, 1, sm.ActiveSessionCount()) + + // Terminate the placeholder before CreateSession — it is marked terminated, + // not deleted, but the active count must still drop. + _, err := sm.Terminate(id) + require.NoError(t, err) + assert.Equal(t, 0, sm.ActiveSessionCount(), + "terminated placeholder must not count towards active sessions") + }) + + t.Run("Generate returns empty string and does not increment count when global limit reached", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessions: 1}) + + // First generate succeeds; active count becomes 1 (== MaxSessions). + id := sm.Generate() + require.NotEmpty(t, id) + assert.Equal(t, 1, sm.ActiveSessionCount()) + + // Second generate must be rejected because the limit is reached. + id2 := sm.Generate() + assert.Empty(t, id2, "Generate must return empty string when global limit is reached") + assert.Equal(t, 1, sm.ActiveSessionCount(), "rejected Generate must not increment active count") + }) + + t.Run("rejected CreateSession (per-client limit) does not leak into active count", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, _ := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 1}) + + identity := &auth.Identity{Subject: "user-x"} + ctx := auth.WithIdentity(context.Background(), identity) + + // First session: succeeds. + id1 := sm.Generate() + _, err := sm.CreateSession(ctx, id1) + require.NoError(t, err) + assert.Equal(t, 1, sm.ActiveSessionCount()) + + // Second generate: count = 2. + id2 := sm.Generate() + assert.Equal(t, 2, sm.ActiveSessionCount()) + + // CreateSession fails (per-client limit). The server will call Terminate(id2). + _, err = sm.CreateSession(ctx, id2) + require.ErrorIs(t, err, ErrPerClientSessionLimitReached) + _, _ = sm.Terminate(id2) // server-side cleanup + + // Active count must return to 1 (only the first session remains). + assert.Equal(t, 1, sm.ActiveSessionCount(), + "failed registration must not permanently consume the global session budget") + }) + + t.Run("Terminate cleans up in-memory maps when storage entry already removed by TTL", func(t *testing.T) { + t.Parallel() + + factory := newFakeFactory(nil) + registry := newFakeRegistry() + sm, storage := newTestVMCPSessionManagerWithLimits(t, factory, registry, Limits{MaxSessionsPerClient: 5}) + + identity := &auth.Identity{Subject: "user-ttl"} + ctx := auth.WithIdentity(context.Background(), identity) + + id := sm.Generate() + require.NotEmpty(t, id) + _, err := sm.CreateSession(ctx, id) + require.NoError(t, err) + + // Simulate TTL eviction by deleting directly from the transport storage, + // bypassing sm.Terminate() (so sessionSubject/idleActivity are NOT cleaned up yet). + require.NoError(t, storage.Delete(id)) + + // Now call Terminate() — storage.Get returns !exists. The fix must still + // clean up the in-memory maps so per-client counts and idle entries don't leak. + _, err = sm.Terminate(id) + require.NoError(t, err) + + // Per-client count for this identity must be back to zero. + sm.perClientMu.Lock() + count := sm.perClientCounts[identity.Subject] + sm.perClientMu.Unlock() + assert.Equal(t, 0, count, "per-client count must be cleaned up even when storage entry was already gone") + + // Idle activity must be removed. + sm.idleActivityMu.RLock() + _, hasIdle := sm.idleActivity[id] + sm.idleActivityMu.RUnlock() + assert.False(t, hasIdle, "idle activity entry must be cleaned up even when storage entry was already gone") + }) +}