diff --git a/pkg/authserver/integration_test.go b/pkg/authserver/integration_test.go index 684569997e..90bfee9bd1 100644 --- a/pkg/authserver/integration_test.go +++ b/pkg/authserver/integration_test.go @@ -23,9 +23,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" servercrypto "github.com/stacklok/toolhive/pkg/authserver/server/crypto" "github.com/stacklok/toolhive/pkg/authserver/server/keys" "github.com/stacklok/toolhive/pkg/authserver/server/registration" + "github.com/stacklok/toolhive/pkg/authserver/server/session" "github.com/stacklok/toolhive/pkg/authserver/storage" "github.com/stacklok/toolhive/pkg/authserver/upstream" ) @@ -44,6 +46,7 @@ const ( type testServer struct { Server *httptest.Server PrivateKey *rsa.PrivateKey + authServer Server } // testServerOptions configures the test server setup. @@ -188,6 +191,7 @@ func setupTestServer(t *testing.T, opts ...testServerOption) *testServer { return &testServer{ Server: httpServer, PrivateKey: privateKey, + authServer: srv, } } @@ -862,6 +866,7 @@ func setupTestServerWithOIDCProvider(t *testing.T, m *mockoidc.MockOIDC) *testSe testServer: &testServer{ Server: httpServer, PrivateKey: privateKey, + authServer: srv, }, mockOIDC: m, } @@ -1162,3 +1167,184 @@ func TestIntegration_RefreshToken_ShortLivedAccessToken(t *testing.T) { require.True(t, ok) assert.Greater(t, int64(exp), time.Now().Unix(), "refreshed token exp must be in the future") } + +// TestIntegration_UpstreamTokenService_GetValidTokens tests the UpstreamTokenService +// end-to-end: a real auth server stores upstream tokens during the OAuth callback, +// and the service retrieves them by session ID extracted from the JWT. +func TestIntegration_UpstreamTokenService_GetValidTokens(t *testing.T) { + t.Parallel() + + m := startMockOIDC(t) + ts := setupTestServerWithMockOIDC(t, m) + + verifier := servercrypto.GeneratePKCEVerifier() + challenge := servercrypto.ComputePKCEChallenge(verifier) + + // Complete the full OAuth flow — this stores upstream tokens in the auth server's storage. + authCode, _ := completeAuthorizationFlow(t, ts.Server.URL, authorizationParams{ + ClientID: testClientID, + RedirectURI: testRedirectURI, + State: "upstream-svc-test", + Challenge: challenge, + Scope: "openid profile offline_access", + ResponseType: "code", + }) + + tokenData := exchangeCodeForTokens(t, ts.Server.URL, authCode, verifier, testAudience) + + // Extract tsid from the access token JWT — this is the session ID used by storage. + accessToken, ok := tokenData["access_token"].(string) + require.True(t, ok) + tsid := extractTSID(t, accessToken, ts.PrivateKey.Public()) + + // Create the UpstreamTokenService using the auth server's storage and refresher. + // This mirrors how vMCP would compose these in production. + svc := upstreamtoken.NewInProcessService( + ts.authServer.IDPTokenStorage(), + ts.authServer.UpstreamTokenRefresher(), + ) + + // The service should return the upstream access token stored during callback. + cred, err := svc.GetValidTokens(context.Background(), tsid) + require.NoError(t, err) + require.NotNil(t, cred) + assert.NotEmpty(t, cred.AccessToken, "upstream access token should be present") +} + +// TestIntegration_UpstreamTokenService_RefreshExpiredTokens verifies the transparent +// refresh path: upstream tokens are expired in storage, and the service uses the +// refresher (backed by mockoidc) to get fresh tokens without re-authentication. +func TestIntegration_UpstreamTokenService_RefreshExpiredTokens(t *testing.T) { + t.Parallel() + + m := startMockOIDC(t) + ts := setupTestServerWithMockOIDC(t, m) + + verifier := servercrypto.GeneratePKCEVerifier() + challenge := servercrypto.ComputePKCEChallenge(verifier) + + authCode, _ := completeAuthorizationFlow(t, ts.Server.URL, authorizationParams{ + ClientID: testClientID, + RedirectURI: testRedirectURI, + State: "upstream-refresh-test", + Challenge: challenge, + Scope: "openid profile offline_access", + ResponseType: "code", + }) + + tokenData := exchangeCodeForTokens(t, ts.Server.URL, authCode, verifier, testAudience) + + accessToken, ok := tokenData["access_token"].(string) + require.True(t, ok) + tsid := extractTSID(t, accessToken, ts.PrivateKey.Public()) + + stor := ts.authServer.IDPTokenStorage() + + // Read the stored tokens, then overwrite them with an expired ExpiresAt. + original, err := stor.GetUpstreamTokens(context.Background(), tsid) + require.NoError(t, err) + require.NotNil(t, original) + originalAccessToken := original.AccessToken + + // Queue a new user for mockoidc's refresh token endpoint response. + m.QueueUser(&mockoidc.MockUser{ + Subject: "mock-user-sub-123", + Email: "testuser@example.com", + }) + + // Store tokens back with ExpiresAt in the past to simulate expiry. + expired := &storage.UpstreamTokens{ + ProviderID: original.ProviderID, + AccessToken: original.AccessToken, + RefreshToken: original.RefreshToken, + IDToken: original.IDToken, + ExpiresAt: time.Now().Add(-1 * time.Hour), + UserID: original.UserID, + UpstreamSubject: original.UpstreamSubject, + ClientID: original.ClientID, + } + require.NoError(t, stor.StoreUpstreamTokens(context.Background(), tsid, expired)) + + // The service should transparently refresh the expired tokens. + svc := upstreamtoken.NewInProcessService(stor, ts.authServer.UpstreamTokenRefresher()) + + cred, err := svc.GetValidTokens(context.Background(), tsid) + require.NoError(t, err) + require.NotNil(t, cred) + assert.NotEmpty(t, cred.AccessToken, "refreshed upstream access token should be present") + + // Verify storage was updated with non-expired tokens after refresh. + refreshed, err := stor.GetUpstreamTokens(context.Background(), tsid) + require.NoError(t, err, "refreshed tokens should be retrievable without ErrExpired") + assert.True(t, refreshed.ExpiresAt.After(time.Now()), + "refreshed tokens should have a future expiry, got %v", refreshed.ExpiresAt) + _ = originalAccessToken // used only to confirm the flow completed +} + +// TestIntegration_UpstreamTokenService_SessionNotFound verifies that the service +// returns ErrSessionNotFound for a non-existent session. +func TestIntegration_UpstreamTokenService_SessionNotFound(t *testing.T) { + t.Parallel() + + m := startMockOIDC(t) + ts := setupTestServerWithMockOIDC(t, m) + + svc := upstreamtoken.NewInProcessService( + ts.authServer.IDPTokenStorage(), + ts.authServer.UpstreamTokenRefresher(), + ) + + cred, err := svc.GetValidTokens(context.Background(), "non-existent-session-id") + require.Error(t, err) + assert.ErrorIs(t, err, upstreamtoken.ErrSessionNotFound) + assert.Nil(t, cred) +} + +// TestIntegration_UpstreamTokenService_NoRefreshToken verifies that the service +// returns ErrNoRefreshToken when the upstream access token is expired but no +// refresh token is available. +func TestIntegration_UpstreamTokenService_NoRefreshToken(t *testing.T) { + t.Parallel() + + m := startMockOIDC(t) + ts := setupTestServerWithMockOIDC(t, m) + + stor := ts.authServer.IDPTokenStorage() + + // Store expired tokens without a refresh token. + sessionID := "no-refresh-session" + require.NoError(t, stor.StoreUpstreamTokens(context.Background(), sessionID, &storage.UpstreamTokens{ + ProviderID: "test", + AccessToken: "expired-access", + RefreshToken: "", // no refresh token + ExpiresAt: time.Now().Add(-1 * time.Hour), + UserID: "user-1", + UpstreamSubject: "sub-1", + ClientID: "client-1", + })) + + svc := upstreamtoken.NewInProcessService(stor, ts.authServer.UpstreamTokenRefresher()) + + cred, err := svc.GetValidTokens(context.Background(), sessionID) + require.Error(t, err) + assert.ErrorIs(t, err, upstreamtoken.ErrNoRefreshToken) + assert.Nil(t, cred) +} + +// extractTSID parses a JWT access token and extracts the tsid claim. +func extractTSID(t *testing.T, accessToken string, publicKey any) string { + t.Helper() + + parsed, err := jwt.ParseSigned(accessToken, []jose.SignatureAlgorithm{jose.RS256}) + require.NoError(t, err) + + var claims map[string]interface{} + err = parsed.Claims(publicKey, &claims) + require.NoError(t, err) + + tsid, ok := claims[session.TokenSessionIDClaimKey].(string) + require.True(t, ok, "tsid claim should be present in access token") + require.NotEmpty(t, tsid) + + return tsid +}