diff --git a/auth/tokenprovider/cached.go b/auth/tokenprovider/cached.go new file mode 100644 index 0000000..b59e883 --- /dev/null +++ b/auth/tokenprovider/cached.go @@ -0,0 +1,86 @@ +package tokenprovider + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/rs/zerolog/log" +) + +// CachedTokenProvider wraps another provider and caches tokens +type CachedTokenProvider struct { + provider TokenProvider + cache *Token + mutex sync.RWMutex + // RefreshThreshold determines when to refresh (default 5 minutes before expiry) + RefreshThreshold time.Duration +} + +// NewCachedTokenProvider creates a caching wrapper around any token provider +func NewCachedTokenProvider(provider TokenProvider) *CachedTokenProvider { + return &CachedTokenProvider{ + provider: provider, + RefreshThreshold: 5 * time.Minute, + } +} + +// GetToken retrieves a token, using cache if available and valid +func (p *CachedTokenProvider) GetToken(ctx context.Context) (*Token, error) { + // Try to get from cache first + p.mutex.RLock() + cached := p.cache + p.mutex.RUnlock() + + if cached != nil && !p.shouldRefresh(cached) { + log.Debug().Msgf("cached token provider: using cached token for provider %s", p.provider.Name()) + return cached, nil + } + + // Need to refresh + p.mutex.Lock() + defer p.mutex.Unlock() + + // Double-check after acquiring write lock + if p.cache != nil && !p.shouldRefresh(p.cache) { + return p.cache, nil + } + + log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name()) + token, err := p.provider.GetToken(ctx) + if err != nil { + return nil, fmt.Errorf("cached token provider: failed to get token: %w", err) + } + + p.cache = token + return token, nil +} + +// shouldRefresh determines if a token should be refreshed +func (p *CachedTokenProvider) shouldRefresh(token *Token) bool { + if token == nil { + return true + } + + // If no expiry time, assume token doesn't expire + if token.ExpiresAt.IsZero() { + return false + } + + // Refresh if within threshold of expiry + refreshAt := token.ExpiresAt.Add(-p.RefreshThreshold) + return time.Now().After(refreshAt) +} + +// Name returns the provider name +func (p *CachedTokenProvider) Name() string { + return fmt.Sprintf("cached[%s]", p.provider.Name()) +} + +// ClearCache clears the cached token +func (p *CachedTokenProvider) ClearCache() { + p.mutex.Lock() + p.cache = nil + p.mutex.Unlock() +} diff --git a/auth/tokenprovider/exchange.go b/auth/tokenprovider/exchange.go new file mode 100644 index 0000000..8c0bba6 --- /dev/null +++ b/auth/tokenprovider/exchange.go @@ -0,0 +1,204 @@ +package tokenprovider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/rs/zerolog/log" +) + +// FederationProvider wraps another token provider and automatically handles token exchange +type FederationProvider struct { + baseProvider TokenProvider + databricksHost string + clientID string // For SP-wide federation + httpClient *http.Client + // Settings for token exchange + returnOriginalTokenIfAuthenticated bool +} + +// NewFederationProvider creates a federation provider that wraps another provider +// It automatically detects when token exchange is needed and falls back gracefully +func NewFederationProvider(baseProvider TokenProvider, databricksHost string) *FederationProvider { + return &FederationProvider{ + baseProvider: baseProvider, + databricksHost: databricksHost, + httpClient: &http.Client{Timeout: 30 * time.Second}, + returnOriginalTokenIfAuthenticated: true, + } +} + +// NewFederationProviderWithClientID creates a provider for SP-wide federation (M2M) +func NewFederationProviderWithClientID(baseProvider TokenProvider, databricksHost, clientID string) *FederationProvider { + return &FederationProvider{ + baseProvider: baseProvider, + databricksHost: databricksHost, + clientID: clientID, + httpClient: &http.Client{Timeout: 30 * time.Second}, + returnOriginalTokenIfAuthenticated: true, + } +} + +// GetToken gets token from base provider and exchanges if needed +func (p *FederationProvider) GetToken(ctx context.Context) (*Token, error) { + // Get token from base provider + baseToken, err := p.baseProvider.GetToken(ctx) + if err != nil { + return nil, fmt.Errorf("federation provider: failed to get base token: %w", err) + } + + // Check if token is a JWT and needs exchange + if p.needsTokenExchange(baseToken.AccessToken) { + log.Debug().Msgf("federation provider: attempting token exchange for %s", p.baseProvider.Name()) + + // Try token exchange + exchangedToken, err := p.tryTokenExchange(ctx, baseToken.AccessToken) + if err != nil { + log.Warn().Err(err).Msg("federation provider: token exchange failed, using original token") + return baseToken, nil // Fall back to original token + } + + log.Debug().Msg("federation provider: token exchange successful") + return exchangedToken, nil + } + + // Use original token + return baseToken, nil +} + +// needsTokenExchange determines if a token needs exchange by checking if it's from a different issuer +func (p *FederationProvider) needsTokenExchange(tokenString string) bool { + // Try to parse as JWT + token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{}) + if err != nil { + log.Debug().Err(err).Msg("federation provider: not a JWT token, skipping exchange") + return false + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return false + } + + issuer, ok := claims["iss"].(string) + if !ok { + return false + } + + // Check if issuer is different from Databricks host + return !p.isSameHost(issuer, p.databricksHost) +} + +// tryTokenExchange attempts to exchange the token with Databricks +func (p *FederationProvider) tryTokenExchange(ctx context.Context, subjectToken string) (*Token, error) { + // Build exchange URL - add scheme if not present + exchangeURL := p.databricksHost + if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") { + exchangeURL = "https://" + exchangeURL + } + if !strings.HasSuffix(exchangeURL, "/") { + exchangeURL += "/" + } + exchangeURL += "oidc/v1/token" + + // Prepare form data for token exchange + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + data.Set("scope", "sql") + data.Set("subject_token_type", "urn:ietf:params:oauth:token-type:jwt") + data.Set("subject_token", subjectToken) + + if p.returnOriginalTokenIfAuthenticated { + data.Set("return_original_token_if_authenticated", "true") + } + + // Add client_id for SP-wide federation + if p.clientID != "" { + data.Set("client_id", p.clientID) + } + + // Create request + req, err := http.NewRequestWithContext(ctx, "POST", exchangeURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "*/*") + + // Make request + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + } + + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + token := &Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + Scopes: strings.Fields(tokenResp.Scope), + } + + if tokenResp.ExpiresIn > 0 { + token.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + return token, nil +} + +// isSameHost compares two URLs to see if they have the same host +func (p *FederationProvider) isSameHost(url1, url2 string) bool { + // Add scheme to url2 if it doesn't have one (databricksHost may not have scheme) + parsedURL2 := url2 + if !strings.HasPrefix(url2, "http://") && !strings.HasPrefix(url2, "https://") { + parsedURL2 = "https://" + url2 + } + + u1, err1 := url.Parse(url1) + u2, err2 := url.Parse(parsedURL2) + + if err1 != nil || err2 != nil { + return false + } + + // Use Hostname() instead of Host to ignore port differences + // This handles cases like "host.com:443" == "host.com" for HTTPS + return u1.Hostname() == u2.Hostname() +} + +// Name returns the provider name +func (p *FederationProvider) Name() string { + baseName := p.baseProvider.Name() + if p.clientID != "" { + return fmt.Sprintf("federation[%s,sp:%s]", baseName, p.clientID[:8]) // Truncate client ID for readability + } + return fmt.Sprintf("federation[%s]", baseName) +} diff --git a/auth/tokenprovider/federation_test.go b/auth/tokenprovider/federation_test.go new file mode 100644 index 0000000..554b733 --- /dev/null +++ b/auth/tokenprovider/federation_test.go @@ -0,0 +1,348 @@ +package tokenprovider + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper function to create JWT tokens for testing +func createTestJWT(issuer, audience string, expiryHours int) string { + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "exp": time.Now().Add(time.Duration(expiryHours) * time.Hour).Unix(), + "sub": "test-user", + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, _ := token.SignedString([]byte("test-secret")) + return tokenString +} + +func TestFederationProvider_HostComparison(t *testing.T) { + tests := []struct { + name string + issuer string + databricksHost string + shouldExchange bool + }{ + { + name: "same_host_no_port", + issuer: "https://test.databricks.com", + databricksHost: "test.databricks.com", + shouldExchange: false, + }, + { + name: "same_host_with_port_443", + issuer: "https://test.databricks.com:443", + databricksHost: "test.databricks.com", + shouldExchange: false, + }, + { + name: "same_host_both_with_port", + issuer: "https://test.databricks.com:443", + databricksHost: "test.databricks.com:443", + shouldExchange: false, + }, + { + name: "different_host_azure", + issuer: "https://login.microsoftonline.com/tenant-id/", + databricksHost: "test.databricks.com", + shouldExchange: true, + }, + { + name: "different_host_google", + issuer: "https://accounts.google.com", + databricksHost: "test.databricks.com", + shouldExchange: true, + }, + { + name: "different_host_aws", + issuer: "https://cognito-identity.amazonaws.com", + databricksHost: "test.databricks.com", + shouldExchange: true, + }, + { + name: "different_databricks_host", + issuer: "https://test1.databricks.com", + databricksHost: "test2.databricks.com", + shouldExchange: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a JWT token with the specified issuer + jwtToken := createTestJWT(tt.issuer, "databricks", 1) + + // Create a mock base provider + baseProvider := NewStaticTokenProvider(jwtToken) + + // Create federation provider + fedProvider := NewFederationProvider(baseProvider, tt.databricksHost) + + // Check if token needs exchange + needsExchange := fedProvider.needsTokenExchange(jwtToken) + assert.Equal(t, tt.shouldExchange, needsExchange, + "issuer=%s, host=%s, expected shouldExchange=%v, got=%v", + tt.issuer, tt.databricksHost, tt.shouldExchange, needsExchange) + }) + } +} + +func TestFederationProvider_TokenExchangeSuccess(t *testing.T) { + // Create mock token exchange server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and path + assert.Equal(t, "POST", r.Method) + assert.Contains(t, r.URL.Path, "/oidc/v1/token") + + // Verify headers + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + assert.Equal(t, "*/*", r.Header.Get("Accept")) + + // Parse form data + err := r.ParseForm() + require.NoError(t, err) + + // Verify form parameters + assert.Equal(t, "urn:ietf:params:oauth:grant-type:token-exchange", r.FormValue("grant_type")) + assert.Equal(t, "sql", r.FormValue("scope")) + assert.Equal(t, "urn:ietf:params:oauth:token-type:jwt", r.FormValue("subject_token_type")) + assert.NotEmpty(t, r.FormValue("subject_token")) + assert.Equal(t, "true", r.FormValue("return_original_token_if_authenticated")) + + // Return successful token response + response := map[string]interface{}{ + "access_token": "exchanged-databricks-token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "sql", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // Create external token with different issuer + externalToken := createTestJWT("https://login.microsoftonline.com/tenant-id/", "databricks", 1) + baseProvider := NewStaticTokenProvider(externalToken) + + // Create federation provider pointing to mock server + // Use full URL including http:// scheme for test server + fedProvider := NewFederationProvider(baseProvider, server.URL) + + // Get token - should trigger exchange + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + require.NoError(t, err) + assert.Equal(t, "exchanged-databricks-token", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.False(t, token.ExpiresAt.IsZero()) + assert.Contains(t, token.Scopes, "sql") +} + +func TestFederationProvider_TokenExchangeWithClientID(t *testing.T) { + clientID := "test-client-id-12345" + + // Create mock server that checks for client_id + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + require.NoError(t, err) + + // Verify client_id is present + assert.Equal(t, clientID, r.FormValue("client_id")) + + response := map[string]interface{}{ + "access_token": "sp-wide-federation-token", + "token_type": "Bearer", + "expires_in": 3600, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + externalToken := createTestJWT("https://login.microsoftonline.com/tenant-id/", "databricks", 1) + baseProvider := NewStaticTokenProvider(externalToken) + + fedProvider := NewFederationProviderWithClientID(baseProvider, server.URL, clientID) + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + require.NoError(t, err) + assert.Equal(t, "sp-wide-federation-token", token.AccessToken) +} + +func TestFederationProvider_TokenExchangeFailureFallback(t *testing.T) { + // Create mock server that returns error + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error": "invalid_request"}`)) + })) + defer server.Close() + + externalToken := createTestJWT("https://login.microsoftonline.com/tenant-id/", "databricks", 1) + baseProvider := NewStaticTokenProvider(externalToken) + + fedProvider := NewFederationProvider(baseProvider, server.URL) + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + // Should not error - falls back to external token + require.NoError(t, err) + assert.Equal(t, externalToken, token.AccessToken, "Should fall back to original token on exchange failure") + assert.Equal(t, "Bearer", token.TokenType) +} + +func TestFederationProvider_NoExchangeWhenSameIssuer(t *testing.T) { + // Create token with Databricks as issuer + databricksHost := "test.databricks.com" + databricksToken := createTestJWT("https://"+databricksHost, "databricks", 1) + baseProvider := NewStaticTokenProvider(databricksToken) + + fedProvider := NewFederationProvider(baseProvider, databricksHost) + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + // Should not exchange - just return original token + require.NoError(t, err) + assert.Equal(t, databricksToken, token.AccessToken, "Should use original token when issuer matches") +} + +func TestFederationProvider_NonJWTToken(t *testing.T) { + // Use a non-JWT token (e.g., opaque PAT) + opaqueToken := "dapi1234567890abcdef" + baseProvider := NewStaticTokenProvider(opaqueToken) + + fedProvider := NewFederationProvider(baseProvider, "test.databricks.com") + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + // Should not error - just pass through non-JWT token + require.NoError(t, err) + assert.Equal(t, opaqueToken, token.AccessToken, "Should pass through non-JWT tokens") +} + +func TestFederationProvider_ProviderName(t *testing.T) { + baseProvider := NewStaticTokenProvider("test-token") + + t.Run("without_client_id", func(t *testing.T) { + fedProvider := NewFederationProvider(baseProvider, "test.databricks.com") + assert.Equal(t, "federation[static]", fedProvider.Name()) + }) + + t.Run("with_client_id", func(t *testing.T) { + fedProvider := NewFederationProviderWithClientID(baseProvider, "test.databricks.com", "client-12345678-more") + // Should truncate client ID to first 8 chars + assert.Equal(t, "federation[static,sp:client-1]", fedProvider.Name()) + }) +} + +func TestFederationProvider_CachedIntegration(t *testing.T) { + callCount := 0 + exchangeCount := 0 + + // Mock server that counts exchanges + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + exchangeCount++ + response := map[string]interface{}{ + "access_token": "databricks-token", + "token_type": "Bearer", + "expires_in": 3600, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // External provider that counts calls + externalProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + externalToken := createTestJWT("https://login.microsoftonline.com/tenant/", "databricks", 1) + return &Token{ + AccessToken: externalToken, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "external", + } + + fedProvider := NewFederationProvider(externalProvider, server.URL) + cachedProvider := NewCachedTokenProvider(fedProvider) + + ctx := context.Background() + + // First call - should call external provider and exchange + token1, err1 := cachedProvider.GetToken(ctx) + require.NoError(t, err1) + assert.Equal(t, "databricks-token", token1.AccessToken) + assert.Equal(t, 1, callCount, "External provider should be called once") + assert.Equal(t, 1, exchangeCount, "Token should be exchanged once") + + // Second call - should use cache + token2, err2 := cachedProvider.GetToken(ctx) + require.NoError(t, err2) + assert.Equal(t, "databricks-token", token2.AccessToken) + assert.Equal(t, 1, callCount, "External provider should still be called only once (cached)") + assert.Equal(t, 1, exchangeCount, "Token should still be exchanged only once (cached)") +} + +func TestFederationProvider_InvalidJWT(t *testing.T) { + // Test with various invalid JWT formats + testCases := []string{ + "not.a.jwt", + "invalid-token-format", + "", + } + + for _, invalidToken := range testCases { + t.Run("invalid_jwt_"+invalidToken, func(t *testing.T) { + baseProvider := NewStaticTokenProvider(invalidToken) + fedProvider := NewFederationProvider(baseProvider, "test.databricks.com") + + // Should not need exchange for invalid JWT + needsExchange := fedProvider.needsTokenExchange(invalidToken) + assert.False(t, needsExchange, "Invalid JWT should not require exchange") + }) + } +} + +func TestFederationProvider_RealWorldIssuers(t *testing.T) { + // Test with real-world identity provider issuers + issuers := map[string]string{ + "azure_ad": "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/v2.0", + "google": "https://accounts.google.com", + "aws_cognito": "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_example", + "okta": "https://dev-12345.okta.com/oauth2/default", + "auth0": "https://dev-12345.auth0.com/", + "github": "https://token.actions.githubusercontent.com", + } + + databricksHost := "test.databricks.com" + + for name, issuer := range issuers { + t.Run(name, func(t *testing.T) { + jwtToken := createTestJWT(issuer, "databricks", 1) + baseProvider := NewStaticTokenProvider(jwtToken) + fedProvider := NewFederationProvider(baseProvider, databricksHost) + + needsExchange := fedProvider.needsTokenExchange(jwtToken) + assert.True(t, needsExchange, "Token from %s should require exchange", name) + }) + } +} diff --git a/connector.go b/connector.go index fce7797..da5f7cb 100644 --- a/connector.go +++ b/connector.go @@ -323,3 +323,26 @@ func WithStaticToken(token string) ConnOption { } } } + +// WithFederatedTokenProvider sets up authentication using token federation +// It wraps the base provider and automatically handles token exchange if needed +func WithFederatedTokenProvider(baseProvider tokenprovider.TokenProvider) ConnOption { + return func(c *config.Config) { + if baseProvider != nil { + // Wrap with federation provider that auto-detects need for token exchange + federationProvider := tokenprovider.NewFederationProvider(baseProvider, c.Host) + c.Authenticator = tokenprovider.NewAuthenticator(federationProvider) + } + } +} + +// WithFederatedTokenProviderAndClientID sets up SP-wide token federation +func WithFederatedTokenProviderAndClientID(baseProvider tokenprovider.TokenProvider, clientID string) ConnOption { + return func(c *config.Config) { + if baseProvider != nil { + // Wrap with federation provider for SP-wide federation + federationProvider := tokenprovider.NewFederationProviderWithClientID(baseProvider, c.Host, clientID) + c.Authenticator = tokenprovider.NewAuthenticator(federationProvider) + } + } +}