diff --git a/auth/oauth/oauth.go b/auth/oauth/oauth.go index 0df9d5c4..80313fa2 100644 --- a/auth/oauth/oauth.go +++ b/auth/oauth/oauth.go @@ -85,6 +85,8 @@ var databricksAWSDomains []string = []string{ } var databricksAzureDomains []string = []string{ + ".staging.azuredatabricks.net", + ".dev.azuredatabricks.net", ".azuredatabricks.net", ".databricks.azure.cn", ".databricks.azure.us", diff --git a/auth/tokenprovider/authenticator.go b/auth/tokenprovider/authenticator.go new file mode 100644 index 00000000..3955a4c9 --- /dev/null +++ b/auth/tokenprovider/authenticator.go @@ -0,0 +1,44 @@ +package tokenprovider + +import ( + "context" + "fmt" + "net/http" + + "github.com/databricks/databricks-sql-go/auth" + "github.com/rs/zerolog/log" +) + +// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider +type TokenProviderAuthenticator struct { + provider TokenProvider +} + +// NewAuthenticator creates an authenticator from a token provider +func NewAuthenticator(provider TokenProvider) auth.Authenticator { + return &TokenProviderAuthenticator{ + provider: provider, + } +} + +// Authenticate implements auth.Authenticator +func (a *TokenProviderAuthenticator) Authenticate(r *http.Request) error { + ctx := r.Context() + if ctx == nil { + ctx = context.Background() + } + + token, err := a.provider.GetToken(ctx) + if err != nil { + return fmt.Errorf("token provider authenticator: failed to get token: %w", err) + } + + if token.AccessToken == "" { + return fmt.Errorf("token provider authenticator: empty access token") + } + + token.SetAuthHeader(r) + log.Debug().Msgf("token provider authenticator: authenticated using provider %s", a.provider.Name()) + + return nil +} diff --git a/auth/tokenprovider/authenticator_test.go b/auth/tokenprovider/authenticator_test.go new file mode 100644 index 00000000..a47dd6bc --- /dev/null +++ b/auth/tokenprovider/authenticator_test.go @@ -0,0 +1,135 @@ +package tokenprovider + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTokenProviderAuthenticator(t *testing.T) { + t.Run("successful_authentication", func(t *testing.T) { + provider := NewStaticTokenProvider("test-token-123") + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "Bearer test-token-123", req.Header.Get("Authorization")) + }) + + t.Run("authentication_with_custom_token_type", func(t *testing.T) { + provider := NewStaticTokenProviderWithType("test-token", "MAC") + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "MAC test-token", req.Header.Get("Authorization")) + }) + + t.Run("authentication_error_propagation", func(t *testing.T) { + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return nil, errors.New("provider failed") + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "provider failed") + assert.Empty(t, req.Header.Get("Authorization")) + }) + + t.Run("empty_token_error", func(t *testing.T) { + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return &Token{ + AccessToken: "", + TokenType: "Bearer", + }, nil + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty access token") + assert.Empty(t, req.Header.Get("Authorization")) + }) + + t.Run("uses_request_context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + provider := &mockProvider{ + tokenFunc: func() (*Token, error) { + // This would normally check context cancellation + return &Token{ + AccessToken: "test-token", + TokenType: "Bearer", + }, nil + }, + } + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + + // Even with cancelled context, this should work as our mock doesn't check it + require.NoError(t, err) + assert.Equal(t, "Bearer test-token", req.Header.Get("Authorization")) + }) + + t.Run("external_token_integration", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "external-token-456", nil + } + provider := NewExternalTokenProvider(tokenFunc) + authenticator := NewAuthenticator(provider) + + req, _ := http.NewRequest("POST", "http://example.com/api", nil) + err := authenticator.Authenticate(req) + + require.NoError(t, err) + assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization")) + }) + + t.Run("cached_provider_integration", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "cached-token", + TokenType: "Bearer", + }, nil + }, + name: "test", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + authenticator := NewAuthenticator(cachedProvider) + + // Multiple authentication attempts + for i := 0; i < 3; i++ { + req, _ := http.NewRequest("GET", "http://example.com", nil) + err := authenticator.Authenticate(req) + require.NoError(t, err) + assert.Equal(t, "Bearer cached-token", req.Header.Get("Authorization")) + } + + // Should only call base provider once due to caching + assert.Equal(t, 1, callCount) + }) +} diff --git a/auth/tokenprovider/cached.go b/auth/tokenprovider/cached.go new file mode 100644 index 00000000..b59e883e --- /dev/null +++ b/auth/tokenprovider/cached.go @@ -0,0 +1,86 @@ +package tokenprovider + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/rs/zerolog/log" +) + +// CachedTokenProvider wraps another provider and caches tokens +type CachedTokenProvider struct { + provider TokenProvider + cache *Token + mutex sync.RWMutex + // RefreshThreshold determines when to refresh (default 5 minutes before expiry) + RefreshThreshold time.Duration +} + +// NewCachedTokenProvider creates a caching wrapper around any token provider +func NewCachedTokenProvider(provider TokenProvider) *CachedTokenProvider { + return &CachedTokenProvider{ + provider: provider, + RefreshThreshold: 5 * time.Minute, + } +} + +// GetToken retrieves a token, using cache if available and valid +func (p *CachedTokenProvider) GetToken(ctx context.Context) (*Token, error) { + // Try to get from cache first + p.mutex.RLock() + cached := p.cache + p.mutex.RUnlock() + + if cached != nil && !p.shouldRefresh(cached) { + log.Debug().Msgf("cached token provider: using cached token for provider %s", p.provider.Name()) + return cached, nil + } + + // Need to refresh + p.mutex.Lock() + defer p.mutex.Unlock() + + // Double-check after acquiring write lock + if p.cache != nil && !p.shouldRefresh(p.cache) { + return p.cache, nil + } + + log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name()) + token, err := p.provider.GetToken(ctx) + if err != nil { + return nil, fmt.Errorf("cached token provider: failed to get token: %w", err) + } + + p.cache = token + return token, nil +} + +// shouldRefresh determines if a token should be refreshed +func (p *CachedTokenProvider) shouldRefresh(token *Token) bool { + if token == nil { + return true + } + + // If no expiry time, assume token doesn't expire + if token.ExpiresAt.IsZero() { + return false + } + + // Refresh if within threshold of expiry + refreshAt := token.ExpiresAt.Add(-p.RefreshThreshold) + return time.Now().After(refreshAt) +} + +// Name returns the provider name +func (p *CachedTokenProvider) Name() string { + return fmt.Sprintf("cached[%s]", p.provider.Name()) +} + +// ClearCache clears the cached token +func (p *CachedTokenProvider) ClearCache() { + p.mutex.Lock() + p.cache = nil + p.mutex.Unlock() +} diff --git a/auth/tokenprovider/exchange.go b/auth/tokenprovider/exchange.go new file mode 100644 index 00000000..8c0bba60 --- /dev/null +++ b/auth/tokenprovider/exchange.go @@ -0,0 +1,204 @@ +package tokenprovider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/rs/zerolog/log" +) + +// FederationProvider wraps another token provider and automatically handles token exchange +type FederationProvider struct { + baseProvider TokenProvider + databricksHost string + clientID string // For SP-wide federation + httpClient *http.Client + // Settings for token exchange + returnOriginalTokenIfAuthenticated bool +} + +// NewFederationProvider creates a federation provider that wraps another provider +// It automatically detects when token exchange is needed and falls back gracefully +func NewFederationProvider(baseProvider TokenProvider, databricksHost string) *FederationProvider { + return &FederationProvider{ + baseProvider: baseProvider, + databricksHost: databricksHost, + httpClient: &http.Client{Timeout: 30 * time.Second}, + returnOriginalTokenIfAuthenticated: true, + } +} + +// NewFederationProviderWithClientID creates a provider for SP-wide federation (M2M) +func NewFederationProviderWithClientID(baseProvider TokenProvider, databricksHost, clientID string) *FederationProvider { + return &FederationProvider{ + baseProvider: baseProvider, + databricksHost: databricksHost, + clientID: clientID, + httpClient: &http.Client{Timeout: 30 * time.Second}, + returnOriginalTokenIfAuthenticated: true, + } +} + +// GetToken gets token from base provider and exchanges if needed +func (p *FederationProvider) GetToken(ctx context.Context) (*Token, error) { + // Get token from base provider + baseToken, err := p.baseProvider.GetToken(ctx) + if err != nil { + return nil, fmt.Errorf("federation provider: failed to get base token: %w", err) + } + + // Check if token is a JWT and needs exchange + if p.needsTokenExchange(baseToken.AccessToken) { + log.Debug().Msgf("federation provider: attempting token exchange for %s", p.baseProvider.Name()) + + // Try token exchange + exchangedToken, err := p.tryTokenExchange(ctx, baseToken.AccessToken) + if err != nil { + log.Warn().Err(err).Msg("federation provider: token exchange failed, using original token") + return baseToken, nil // Fall back to original token + } + + log.Debug().Msg("federation provider: token exchange successful") + return exchangedToken, nil + } + + // Use original token + return baseToken, nil +} + +// needsTokenExchange determines if a token needs exchange by checking if it's from a different issuer +func (p *FederationProvider) needsTokenExchange(tokenString string) bool { + // Try to parse as JWT + token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{}) + if err != nil { + log.Debug().Err(err).Msg("federation provider: not a JWT token, skipping exchange") + return false + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return false + } + + issuer, ok := claims["iss"].(string) + if !ok { + return false + } + + // Check if issuer is different from Databricks host + return !p.isSameHost(issuer, p.databricksHost) +} + +// tryTokenExchange attempts to exchange the token with Databricks +func (p *FederationProvider) tryTokenExchange(ctx context.Context, subjectToken string) (*Token, error) { + // Build exchange URL - add scheme if not present + exchangeURL := p.databricksHost + if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") { + exchangeURL = "https://" + exchangeURL + } + if !strings.HasSuffix(exchangeURL, "/") { + exchangeURL += "/" + } + exchangeURL += "oidc/v1/token" + + // Prepare form data for token exchange + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + data.Set("scope", "sql") + data.Set("subject_token_type", "urn:ietf:params:oauth:token-type:jwt") + data.Set("subject_token", subjectToken) + + if p.returnOriginalTokenIfAuthenticated { + data.Set("return_original_token_if_authenticated", "true") + } + + // Add client_id for SP-wide federation + if p.clientID != "" { + data.Set("client_id", p.clientID) + } + + // Create request + req, err := http.NewRequestWithContext(ctx, "POST", exchangeURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "*/*") + + // Make request + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + } + + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + token := &Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + Scopes: strings.Fields(tokenResp.Scope), + } + + if tokenResp.ExpiresIn > 0 { + token.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + return token, nil +} + +// isSameHost compares two URLs to see if they have the same host +func (p *FederationProvider) isSameHost(url1, url2 string) bool { + // Add scheme to url2 if it doesn't have one (databricksHost may not have scheme) + parsedURL2 := url2 + if !strings.HasPrefix(url2, "http://") && !strings.HasPrefix(url2, "https://") { + parsedURL2 = "https://" + url2 + } + + u1, err1 := url.Parse(url1) + u2, err2 := url.Parse(parsedURL2) + + if err1 != nil || err2 != nil { + return false + } + + // Use Hostname() instead of Host to ignore port differences + // This handles cases like "host.com:443" == "host.com" for HTTPS + return u1.Hostname() == u2.Hostname() +} + +// Name returns the provider name +func (p *FederationProvider) Name() string { + baseName := p.baseProvider.Name() + if p.clientID != "" { + return fmt.Sprintf("federation[%s,sp:%s]", baseName, p.clientID[:8]) // Truncate client ID for readability + } + return fmt.Sprintf("federation[%s]", baseName) +} diff --git a/auth/tokenprovider/external.go b/auth/tokenprovider/external.go new file mode 100644 index 00000000..0e511234 --- /dev/null +++ b/auth/tokenprovider/external.go @@ -0,0 +1,56 @@ +package tokenprovider + +import ( + "context" + "fmt" + "time" +) + +// ExternalTokenProvider provides tokens from an external source (passthrough) +type ExternalTokenProvider struct { + tokenFunc func() (string, error) + tokenType string +} + +// NewExternalTokenProvider creates a provider that gets tokens from an external function +func NewExternalTokenProvider(tokenFunc func() (string, error)) *ExternalTokenProvider { + return &ExternalTokenProvider{ + tokenFunc: tokenFunc, + tokenType: "Bearer", + } +} + +// NewExternalTokenProviderWithType creates a provider with a custom token type +func NewExternalTokenProviderWithType(tokenFunc func() (string, error), tokenType string) *ExternalTokenProvider { + return &ExternalTokenProvider{ + tokenFunc: tokenFunc, + tokenType: tokenType, + } +} + +// GetToken retrieves the token from the external source +func (p *ExternalTokenProvider) GetToken(ctx context.Context) (*Token, error) { + if p.tokenFunc == nil { + return nil, fmt.Errorf("external token provider: token function is nil") + } + + accessToken, err := p.tokenFunc() + if err != nil { + return nil, fmt.Errorf("external token provider: failed to get token: %w", err) + } + + if accessToken == "" { + return nil, fmt.Errorf("external token provider: empty token returned") + } + + return &Token{ + AccessToken: accessToken, + TokenType: p.tokenType, + ExpiresAt: time.Time{}, // External tokens don't provide expiry info + }, nil +} + +// Name returns the provider name +func (p *ExternalTokenProvider) Name() string { + return "external" +} diff --git a/auth/tokenprovider/federation_test.go b/auth/tokenprovider/federation_test.go new file mode 100644 index 00000000..554b7333 --- /dev/null +++ b/auth/tokenprovider/federation_test.go @@ -0,0 +1,348 @@ +package tokenprovider + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper function to create JWT tokens for testing +func createTestJWT(issuer, audience string, expiryHours int) string { + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "exp": time.Now().Add(time.Duration(expiryHours) * time.Hour).Unix(), + "sub": "test-user", + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, _ := token.SignedString([]byte("test-secret")) + return tokenString +} + +func TestFederationProvider_HostComparison(t *testing.T) { + tests := []struct { + name string + issuer string + databricksHost string + shouldExchange bool + }{ + { + name: "same_host_no_port", + issuer: "https://test.databricks.com", + databricksHost: "test.databricks.com", + shouldExchange: false, + }, + { + name: "same_host_with_port_443", + issuer: "https://test.databricks.com:443", + databricksHost: "test.databricks.com", + shouldExchange: false, + }, + { + name: "same_host_both_with_port", + issuer: "https://test.databricks.com:443", + databricksHost: "test.databricks.com:443", + shouldExchange: false, + }, + { + name: "different_host_azure", + issuer: "https://login.microsoftonline.com/tenant-id/", + databricksHost: "test.databricks.com", + shouldExchange: true, + }, + { + name: "different_host_google", + issuer: "https://accounts.google.com", + databricksHost: "test.databricks.com", + shouldExchange: true, + }, + { + name: "different_host_aws", + issuer: "https://cognito-identity.amazonaws.com", + databricksHost: "test.databricks.com", + shouldExchange: true, + }, + { + name: "different_databricks_host", + issuer: "https://test1.databricks.com", + databricksHost: "test2.databricks.com", + shouldExchange: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a JWT token with the specified issuer + jwtToken := createTestJWT(tt.issuer, "databricks", 1) + + // Create a mock base provider + baseProvider := NewStaticTokenProvider(jwtToken) + + // Create federation provider + fedProvider := NewFederationProvider(baseProvider, tt.databricksHost) + + // Check if token needs exchange + needsExchange := fedProvider.needsTokenExchange(jwtToken) + assert.Equal(t, tt.shouldExchange, needsExchange, + "issuer=%s, host=%s, expected shouldExchange=%v, got=%v", + tt.issuer, tt.databricksHost, tt.shouldExchange, needsExchange) + }) + } +} + +func TestFederationProvider_TokenExchangeSuccess(t *testing.T) { + // Create mock token exchange server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and path + assert.Equal(t, "POST", r.Method) + assert.Contains(t, r.URL.Path, "/oidc/v1/token") + + // Verify headers + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + assert.Equal(t, "*/*", r.Header.Get("Accept")) + + // Parse form data + err := r.ParseForm() + require.NoError(t, err) + + // Verify form parameters + assert.Equal(t, "urn:ietf:params:oauth:grant-type:token-exchange", r.FormValue("grant_type")) + assert.Equal(t, "sql", r.FormValue("scope")) + assert.Equal(t, "urn:ietf:params:oauth:token-type:jwt", r.FormValue("subject_token_type")) + assert.NotEmpty(t, r.FormValue("subject_token")) + assert.Equal(t, "true", r.FormValue("return_original_token_if_authenticated")) + + // Return successful token response + response := map[string]interface{}{ + "access_token": "exchanged-databricks-token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "sql", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // Create external token with different issuer + externalToken := createTestJWT("https://login.microsoftonline.com/tenant-id/", "databricks", 1) + baseProvider := NewStaticTokenProvider(externalToken) + + // Create federation provider pointing to mock server + // Use full URL including http:// scheme for test server + fedProvider := NewFederationProvider(baseProvider, server.URL) + + // Get token - should trigger exchange + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + require.NoError(t, err) + assert.Equal(t, "exchanged-databricks-token", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.False(t, token.ExpiresAt.IsZero()) + assert.Contains(t, token.Scopes, "sql") +} + +func TestFederationProvider_TokenExchangeWithClientID(t *testing.T) { + clientID := "test-client-id-12345" + + // Create mock server that checks for client_id + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + require.NoError(t, err) + + // Verify client_id is present + assert.Equal(t, clientID, r.FormValue("client_id")) + + response := map[string]interface{}{ + "access_token": "sp-wide-federation-token", + "token_type": "Bearer", + "expires_in": 3600, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + externalToken := createTestJWT("https://login.microsoftonline.com/tenant-id/", "databricks", 1) + baseProvider := NewStaticTokenProvider(externalToken) + + fedProvider := NewFederationProviderWithClientID(baseProvider, server.URL, clientID) + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + require.NoError(t, err) + assert.Equal(t, "sp-wide-federation-token", token.AccessToken) +} + +func TestFederationProvider_TokenExchangeFailureFallback(t *testing.T) { + // Create mock server that returns error + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error": "invalid_request"}`)) + })) + defer server.Close() + + externalToken := createTestJWT("https://login.microsoftonline.com/tenant-id/", "databricks", 1) + baseProvider := NewStaticTokenProvider(externalToken) + + fedProvider := NewFederationProvider(baseProvider, server.URL) + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + // Should not error - falls back to external token + require.NoError(t, err) + assert.Equal(t, externalToken, token.AccessToken, "Should fall back to original token on exchange failure") + assert.Equal(t, "Bearer", token.TokenType) +} + +func TestFederationProvider_NoExchangeWhenSameIssuer(t *testing.T) { + // Create token with Databricks as issuer + databricksHost := "test.databricks.com" + databricksToken := createTestJWT("https://"+databricksHost, "databricks", 1) + baseProvider := NewStaticTokenProvider(databricksToken) + + fedProvider := NewFederationProvider(baseProvider, databricksHost) + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + // Should not exchange - just return original token + require.NoError(t, err) + assert.Equal(t, databricksToken, token.AccessToken, "Should use original token when issuer matches") +} + +func TestFederationProvider_NonJWTToken(t *testing.T) { + // Use a non-JWT token (e.g., opaque PAT) + opaqueToken := "dapi1234567890abcdef" + baseProvider := NewStaticTokenProvider(opaqueToken) + + fedProvider := NewFederationProvider(baseProvider, "test.databricks.com") + + ctx := context.Background() + token, err := fedProvider.GetToken(ctx) + + // Should not error - just pass through non-JWT token + require.NoError(t, err) + assert.Equal(t, opaqueToken, token.AccessToken, "Should pass through non-JWT tokens") +} + +func TestFederationProvider_ProviderName(t *testing.T) { + baseProvider := NewStaticTokenProvider("test-token") + + t.Run("without_client_id", func(t *testing.T) { + fedProvider := NewFederationProvider(baseProvider, "test.databricks.com") + assert.Equal(t, "federation[static]", fedProvider.Name()) + }) + + t.Run("with_client_id", func(t *testing.T) { + fedProvider := NewFederationProviderWithClientID(baseProvider, "test.databricks.com", "client-12345678-more") + // Should truncate client ID to first 8 chars + assert.Equal(t, "federation[static,sp:client-1]", fedProvider.Name()) + }) +} + +func TestFederationProvider_CachedIntegration(t *testing.T) { + callCount := 0 + exchangeCount := 0 + + // Mock server that counts exchanges + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + exchangeCount++ + response := map[string]interface{}{ + "access_token": "databricks-token", + "token_type": "Bearer", + "expires_in": 3600, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // External provider that counts calls + externalProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + externalToken := createTestJWT("https://login.microsoftonline.com/tenant/", "databricks", 1) + return &Token{ + AccessToken: externalToken, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "external", + } + + fedProvider := NewFederationProvider(externalProvider, server.URL) + cachedProvider := NewCachedTokenProvider(fedProvider) + + ctx := context.Background() + + // First call - should call external provider and exchange + token1, err1 := cachedProvider.GetToken(ctx) + require.NoError(t, err1) + assert.Equal(t, "databricks-token", token1.AccessToken) + assert.Equal(t, 1, callCount, "External provider should be called once") + assert.Equal(t, 1, exchangeCount, "Token should be exchanged once") + + // Second call - should use cache + token2, err2 := cachedProvider.GetToken(ctx) + require.NoError(t, err2) + assert.Equal(t, "databricks-token", token2.AccessToken) + assert.Equal(t, 1, callCount, "External provider should still be called only once (cached)") + assert.Equal(t, 1, exchangeCount, "Token should still be exchanged only once (cached)") +} + +func TestFederationProvider_InvalidJWT(t *testing.T) { + // Test with various invalid JWT formats + testCases := []string{ + "not.a.jwt", + "invalid-token-format", + "", + } + + for _, invalidToken := range testCases { + t.Run("invalid_jwt_"+invalidToken, func(t *testing.T) { + baseProvider := NewStaticTokenProvider(invalidToken) + fedProvider := NewFederationProvider(baseProvider, "test.databricks.com") + + // Should not need exchange for invalid JWT + needsExchange := fedProvider.needsTokenExchange(invalidToken) + assert.False(t, needsExchange, "Invalid JWT should not require exchange") + }) + } +} + +func TestFederationProvider_RealWorldIssuers(t *testing.T) { + // Test with real-world identity provider issuers + issuers := map[string]string{ + "azure_ad": "https://login.microsoftonline.com/72f988bf-86f1-41af-91ab-2d7cd011db47/v2.0", + "google": "https://accounts.google.com", + "aws_cognito": "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_example", + "okta": "https://dev-12345.okta.com/oauth2/default", + "auth0": "https://dev-12345.auth0.com/", + "github": "https://token.actions.githubusercontent.com", + } + + databricksHost := "test.databricks.com" + + for name, issuer := range issuers { + t.Run(name, func(t *testing.T) { + jwtToken := createTestJWT(issuer, "databricks", 1) + baseProvider := NewStaticTokenProvider(jwtToken) + fedProvider := NewFederationProvider(baseProvider, databricksHost) + + needsExchange := fedProvider.needsTokenExchange(jwtToken) + assert.True(t, needsExchange, "Token from %s should require exchange", name) + }) + } +} diff --git a/auth/tokenprovider/provider.go b/auth/tokenprovider/provider.go new file mode 100644 index 00000000..3e94d6ef --- /dev/null +++ b/auth/tokenprovider/provider.go @@ -0,0 +1,43 @@ +package tokenprovider + +import ( + "context" + "net/http" + "time" +) + +// TokenProvider is the interface for providing tokens from various sources +type TokenProvider interface { + // GetToken retrieves a valid access token + GetToken(ctx context.Context) (*Token, error) + + // Name returns the provider name for logging/debugging + Name() string +} + +// Token represents an access token with metadata +type Token struct { + AccessToken string + TokenType string + ExpiresAt time.Time + RefreshToken string + Scopes []string +} + +// IsExpired checks if the token has expired +func (t *Token) IsExpired() bool { + if t.ExpiresAt.IsZero() { + return false // No expiry means token doesn't expire + } + // Consider token expired 5 minutes before actual expiry for safety + return time.Now().Add(5 * time.Minute).After(t.ExpiresAt) +} + +// SetAuthHeader sets the Authorization header on an HTTP request +func (t *Token) SetAuthHeader(r *http.Request) { + tokenType := t.TokenType + if tokenType == "" { + tokenType = "Bearer" + } + r.Header.Set("Authorization", tokenType+" "+t.AccessToken) +} diff --git a/auth/tokenprovider/provider_test.go b/auth/tokenprovider/provider_test.go new file mode 100644 index 00000000..5acb5538 --- /dev/null +++ b/auth/tokenprovider/provider_test.go @@ -0,0 +1,423 @@ +package tokenprovider + +import ( + "context" + "errors" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToken_IsExpired(t *testing.T) { + tests := []struct { + name string + token *Token + expected bool + }{ + { + name: "token_without_expiry", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Time{}, + }, + expected: false, + }, + { + name: "token_expired", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(-10 * time.Minute), + }, + expected: true, + }, + { + name: "token_not_expired", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + expected: false, + }, + { + name: "token_expires_within_5_minutes", + token: &Token{ + AccessToken: "test-token", + ExpiresAt: time.Now().Add(3 * time.Minute), + }, + expected: true, // Should be considered expired due to 5-minute buffer + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.token.IsExpired()) + }) + } +} + +func TestToken_SetAuthHeader(t *testing.T) { + tests := []struct { + name string + token *Token + expectedHeader string + }{ + { + name: "bearer_token", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "Bearer", + }, + expectedHeader: "Bearer test-access-token", + }, + { + name: "default_to_bearer", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "", + }, + expectedHeader: "Bearer test-access-token", + }, + { + name: "custom_token_type", + token: &Token{ + AccessToken: "test-access-token", + TokenType: "CustomAuth", + }, + expectedHeader: "CustomAuth test-access-token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + tt.token.SetAuthHeader(req) + assert.Equal(t, tt.expectedHeader, req.Header.Get("Authorization")) + }) + } +} + +func TestStaticTokenProvider(t *testing.T) { + t.Run("valid_token", func(t *testing.T) { + provider := NewStaticTokenProvider("static-token-123") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "static-token-123", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.True(t, token.ExpiresAt.IsZero()) + assert.Equal(t, "static", provider.Name()) + }) + + t.Run("empty_token_error", func(t *testing.T) { + provider := NewStaticTokenProvider("") + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "token is empty") + }) + + t.Run("custom_token_type", func(t *testing.T) { + provider := NewStaticTokenProviderWithType("static-token", "CustomAuth") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "static-token", token.AccessToken) + assert.Equal(t, "CustomAuth", token.TokenType) + }) + + t.Run("multiple_calls_same_token", func(t *testing.T) { + provider := NewStaticTokenProvider("static-token") + + token1, err1 := provider.GetToken(context.Background()) + token2, err2 := provider.GetToken(context.Background()) + + require.NoError(t, err1) + require.NoError(t, err2) + assert.Equal(t, token1.AccessToken, token2.AccessToken) + }) +} + +func TestExternalTokenProvider(t *testing.T) { + t.Run("successful_token_retrieval", func(t *testing.T) { + callCount := 0 + tokenFunc := func() (string, error) { + callCount++ + return "external-token-" + string(rune(callCount)), nil + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "external-token-\x01", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + assert.Equal(t, "external", provider.Name()) + }) + + t.Run("token_function_error", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "", errors.New("failed to retrieve token") + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "failed to get token") + }) + + t.Run("empty_token_error", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "", nil + } + + provider := NewExternalTokenProvider(tokenFunc) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "empty token returned") + }) + + t.Run("nil_function_error", func(t *testing.T) { + provider := NewExternalTokenProvider(nil) + token, err := provider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "token function is nil") + }) + + t.Run("custom_token_type", func(t *testing.T) { + tokenFunc := func() (string, error) { + return "external-token", nil + } + + provider := NewExternalTokenProviderWithType(tokenFunc, "MAC") + token, err := provider.GetToken(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "external-token", token.AccessToken) + assert.Equal(t, "MAC", token.TokenType) + }) + + t.Run("different_token_each_call", func(t *testing.T) { + counter := 0 + tokenFunc := func() (string, error) { + counter++ + return "token-" + string(rune(counter)), nil + } + + provider := NewExternalTokenProvider(tokenFunc) + + token1, err1 := provider.GetToken(context.Background()) + token2, err2 := provider.GetToken(context.Background()) + + require.NoError(t, err1) + require.NoError(t, err2) + assert.NotEqual(t, token1.AccessToken, token2.AccessToken) + assert.Equal(t, "token-\x01", token1.AccessToken) + assert.Equal(t, "token-\x02", token2.AccessToken) + }) +} + +func TestCachedTokenProvider(t *testing.T) { + t.Run("caches_valid_token", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "cached-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // First call - should fetch from base provider + token1, err1 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err1) + assert.Equal(t, "cached-token", token1.AccessToken) + assert.Equal(t, 1, callCount) + + // Second call - should use cache + token2, err2 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err2) + assert.Equal(t, "cached-token", token2.AccessToken) + assert.Equal(t, 1, callCount) // Should still be 1 + }) + + t.Run("refreshes_expired_token", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + // Return token that expires soon + return &Token{ + AccessToken: "token-" + string(rune(callCount)), + TokenType: "Bearer", + ExpiresAt: time.Now().Add(2 * time.Minute), // Within refresh threshold + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + cachedProvider.RefreshThreshold = 5 * time.Minute + + // First call + token1, err1 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err1) + assert.Equal(t, "token-\x01", token1.AccessToken) + assert.Equal(t, 1, callCount) + + // Second call - should refresh because token expires within threshold + token2, err2 := cachedProvider.GetToken(context.Background()) + require.NoError(t, err2) + assert.Equal(t, "token-\x02", token2.AccessToken) + assert.Equal(t, 2, callCount) + }) + + t.Run("handles_provider_error", func(t *testing.T) { + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + return nil, errors.New("provider error") + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + token, err := cachedProvider.GetToken(context.Background()) + + assert.Error(t, err) + assert.Nil(t, token) + assert.Contains(t, err.Error(), "provider error") + }) + + t.Run("no_expiry_token_not_refreshed", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "permanent-token", + TokenType: "Bearer", + ExpiresAt: time.Time{}, // No expiry + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // Multiple calls should all use cache + for i := 0; i < 5; i++ { + token, err := cachedProvider.GetToken(context.Background()) + require.NoError(t, err) + assert.Equal(t, "permanent-token", token.AccessToken) + } + + assert.Equal(t, 1, callCount) // Should only be called once + }) + + t.Run("clear_cache", func(t *testing.T) { + callCount := 0 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + callCount++ + return &Token{ + AccessToken: "token-" + string(rune(callCount)), + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // First call + token1, _ := cachedProvider.GetToken(context.Background()) + assert.Equal(t, "token-\x01", token1.AccessToken) + assert.Equal(t, 1, callCount) + + // Clear cache + cachedProvider.ClearCache() + + // Next call should fetch new token + token2, _ := cachedProvider.GetToken(context.Background()) + assert.Equal(t, "token-\x02", token2.AccessToken) + assert.Equal(t, 2, callCount) + }) + + t.Run("concurrent_access", func(t *testing.T) { + var callCount atomic.Int32 + baseProvider := &mockProvider{ + tokenFunc: func() (*Token, error) { + // Simulate slow token fetch + time.Sleep(100 * time.Millisecond) + callCount.Add(1) + return &Token{ + AccessToken: "concurrent-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + }, nil + }, + name: "mock", + } + + cachedProvider := NewCachedTokenProvider(baseProvider) + + // Launch multiple goroutines + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + token, err := cachedProvider.GetToken(context.Background()) + assert.NoError(t, err) + assert.Equal(t, "concurrent-token", token.AccessToken) + }() + } + + wg.Wait() + + // Should only fetch token once despite concurrent access + assert.Equal(t, int32(1), callCount.Load()) + }) + + t.Run("provider_name", func(t *testing.T) { + baseProvider := &mockProvider{name: "test-provider"} + cachedProvider := NewCachedTokenProvider(baseProvider) + + assert.Equal(t, "cached[test-provider]", cachedProvider.Name()) + }) +} + +// Mock provider for testing +type mockProvider struct { + tokenFunc func() (*Token, error) + name string +} + +func (m *mockProvider) GetToken(ctx context.Context) (*Token, error) { + if m.tokenFunc != nil { + return m.tokenFunc() + } + return nil, errors.New("not implemented") +} + +func (m *mockProvider) Name() string { + return m.name +} diff --git a/auth/tokenprovider/static.go b/auth/tokenprovider/static.go new file mode 100644 index 00000000..46079ba0 --- /dev/null +++ b/auth/tokenprovider/static.go @@ -0,0 +1,47 @@ +package tokenprovider + +import ( + "context" + "fmt" + "time" +) + +// StaticTokenProvider provides a static token that never changes +type StaticTokenProvider struct { + token string + tokenType string +} + +// NewStaticTokenProvider creates a provider with a static token +func NewStaticTokenProvider(token string) *StaticTokenProvider { + return &StaticTokenProvider{ + token: token, + tokenType: "Bearer", + } +} + +// NewStaticTokenProviderWithType creates a provider with a static token and custom type +func NewStaticTokenProviderWithType(token string, tokenType string) *StaticTokenProvider { + return &StaticTokenProvider{ + token: token, + tokenType: tokenType, + } +} + +// GetToken returns the static token +func (p *StaticTokenProvider) GetToken(ctx context.Context) (*Token, error) { + if p.token == "" { + return nil, fmt.Errorf("static token provider: token is empty") + } + + return &Token{ + AccessToken: p.token, + TokenType: p.tokenType, + ExpiresAt: time.Time{}, // Static tokens don't expire + }, nil +} + +// Name returns the provider name +func (p *StaticTokenProvider) Name() string { + return "static" +} diff --git a/connector.go b/connector.go index 21a5f178..2b5cac60 100644 --- a/connector.go +++ b/connector.go @@ -12,6 +12,7 @@ import ( "github.com/databricks/databricks-sql-go/auth" "github.com/databricks/databricks-sql-go/auth/oauth/m2m" "github.com/databricks/databricks-sql-go/auth/pat" + "github.com/databricks/databricks-sql-go/auth/tokenprovider" "github.com/databricks/databricks-sql-go/driverctx" dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" @@ -274,3 +275,55 @@ func WithClientCredentials(clientID, clientSecret string) ConnOption { } } } + +// WithTokenProvider sets up authentication using a custom token provider +func WithTokenProvider(provider tokenprovider.TokenProvider) ConnOption { + return func(c *config.Config) { + if provider != nil { + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} + +// WithExternalToken sets up authentication using an external token function (passthrough) +func WithExternalToken(tokenFunc func() (string, error)) ConnOption { + return func(c *config.Config) { + if tokenFunc != nil { + provider := tokenprovider.NewExternalTokenProvider(tokenFunc) + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} + +// WithStaticToken sets up authentication using a static token +func WithStaticToken(token string) ConnOption { + return func(c *config.Config) { + if token != "" { + provider := tokenprovider.NewStaticTokenProvider(token) + c.Authenticator = tokenprovider.NewAuthenticator(provider) + } + } +} + +// WithFederatedTokenProvider sets up authentication using token federation +// It wraps the base provider and automatically handles token exchange if needed +func WithFederatedTokenProvider(baseProvider tokenprovider.TokenProvider) ConnOption { + return func(c *config.Config) { + if baseProvider != nil { + // Wrap with federation provider that auto-detects need for token exchange + federationProvider := tokenprovider.NewFederationProvider(baseProvider, c.Host) + c.Authenticator = tokenprovider.NewAuthenticator(federationProvider) + } + } +} + +// WithFederatedTokenProviderAndClientID sets up SP-wide token federation +func WithFederatedTokenProviderAndClientID(baseProvider tokenprovider.TokenProvider, clientID string) ConnOption { + return func(c *config.Config) { + if baseProvider != nil { + // Wrap with federation provider for SP-wide federation + federationProvider := tokenprovider.NewFederationProviderWithClientID(baseProvider, c.Host, clientID) + c.Authenticator = tokenprovider.NewAuthenticator(federationProvider) + } + } +} diff --git a/examples/browser_oauth_federation/main.go b/examples/browser_oauth_federation/main.go new file mode 100644 index 00000000..760ba48c --- /dev/null +++ b/examples/browser_oauth_federation/main.go @@ -0,0 +1,454 @@ +package main + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "strings" + "time" + + dbsql "github.com/databricks/databricks-sql-go" + "github.com/databricks/databricks-sql-go/auth/oauth/u2m" + "github.com/databricks/databricks-sql-go/auth/tokenprovider" + "github.com/joho/godotenv" +) + +func main() { + err := godotenv.Load() + if err != nil { + log.Printf("Warning: .env file not found: %v", err) + } + + fmt.Println("Browser OAuth with Token Federation Test") + fmt.Println("=========================================") + fmt.Println() + + // Get test mode from environment + testMode := os.Getenv("TEST_MODE") + if testMode == "" { + fmt.Println("TEST_MODE not set. Available modes:") + fmt.Println(" passthrough - Account-wide WIF Auth_Flow=0 (Token passthrough)") + fmt.Println(" u2m_federation - Account-wide WIF Auth_Flow=2 (U2M with federation)") + fmt.Println(" u2m_native - Native U2M without federation (baseline)") + fmt.Println(" external_token - Manual token passthrough (for testing exchange)") + os.Exit(1) + } + + switch testMode { + case "passthrough": + testTokenPassthrough() + case "u2m_federation": + testU2MWithFederation() + case "u2m_native": + testU2MNative() + case "external_token": + testExternalTokenWithFederation() + default: + log.Fatalf("Unknown test mode: %s", testMode) + } +} + +// testU2MNative tests native U2M OAuth without token federation (baseline) +func testU2MNative() { + fmt.Println("Test: Native U2M OAuth (Baseline - No Federation)") + fmt.Println("--------------------------------------------------") + fmt.Println("This uses Databricks' built-in OAuth without token exchange") + fmt.Println() + + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + + if host == "" || httpPath == "" { + log.Fatal("DATABRICKS_HOST and DATABRICKS_HTTPPATH must be set") + } + + fmt.Printf("Host: %s\n", host) + fmt.Printf("HTTP Path: %s\n", httpPath) + fmt.Println() + + // Create U2M authenticator + authenticator, err := u2m.NewAuthenticator(host, 2*time.Minute) + if err != nil { + log.Fatal(err) + } + + // Create connector with native OAuth + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithHTTPPath(httpPath), + dbsql.WithAuthenticator(authenticator), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + // Test connection + fmt.Println("Testing connection with browser OAuth...") + + // First try ping to see if we can establish connection + fmt.Println("Attempting to ping database...") + if err := db.Ping(); err != nil { + log.Fatalf("Ping failed: %v", err) + } + fmt.Println("✓ Ping successful") + + if err := testConnection(db); err != nil { + log.Fatalf("Connection test failed: %v", err) + } + + fmt.Println() + fmt.Println("✓ Native U2M OAuth test PASSED") +} + +// testU2MWithFederation tests U2M OAuth with token federation (Account-wide WIF) +func testU2MWithFederation() { + fmt.Println("Test: U2M OAuth with Token Federation (Account-wide WIF)") + fmt.Println("---------------------------------------------------------") + fmt.Println("This tests Auth_Flow=2: Browser OAuth token → Federation exchange → Databricks token") + fmt.Println() + + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + externalIdpHost := os.Getenv("EXTERNAL_IDP_HOST") + + if host == "" || httpPath == "" { + log.Fatal("DATABRICKS_HOST and DATABRICKS_HTTPPATH must be set") + } + + if externalIdpHost == "" { + fmt.Println("WARNING: EXTERNAL_IDP_HOST not set, using Databricks host") + externalIdpHost = host + } + + fmt.Printf("Databricks Host: %s\n", host) + fmt.Printf("HTTP Path: %s\n", httpPath) + fmt.Printf("External IdP Host: %s\n", externalIdpHost) + fmt.Println() + + // Step 1: Get token from external IdP using browser OAuth + fmt.Println("Step 1: Getting token from external IdP via browser OAuth...") + baseAuthenticator, err := u2m.NewAuthenticator(externalIdpHost, 2*time.Minute) + if err != nil { + log.Fatalf("Failed to create U2M authenticator: %v", err) + } + + // Wrap U2M authenticator as a token provider + u2mProvider := &U2MTokenProvider{authenticator: baseAuthenticator} + + // Step 2: Wrap with federation provider for automatic token exchange + fmt.Println("Step 2: Setting up federation provider for automatic token exchange...") + federationProvider := tokenprovider.NewFederationProvider(u2mProvider, host) + cachedProvider := tokenprovider.NewCachedTokenProvider(federationProvider) + + // Create connector with federated authentication + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + // Test connection + fmt.Println() + fmt.Println("Step 3: Testing connection (will trigger browser OAuth and token exchange)...") + if err := testConnection(db); err != nil { + log.Fatalf("Connection test failed: %v", err) + } + + fmt.Println() + fmt.Println("✓ U2M with Token Federation test PASSED") + fmt.Println() + fmt.Println("Token flow: Browser OAuth → External IdP Token → Token Exchange → Databricks Token") +} + +// testTokenPassthrough tests manual token passthrough with federation +func testTokenPassthrough() { + fmt.Println("Test: Token Passthrough with Federation (Auth_Flow=0)") + fmt.Println("------------------------------------------------------") + fmt.Println("This tests passing an external token that gets exchanged automatically") + fmt.Println() + + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + externalToken := os.Getenv("EXTERNAL_TOKEN") + + if host == "" || httpPath == "" { + log.Fatal("DATABRICKS_HOST and DATABRICKS_HTTPPATH must be set") + } + + if externalToken == "" { + log.Fatal("EXTERNAL_TOKEN must be set (get from external IdP)") + } + + fmt.Printf("Host: %s\n", host) + fmt.Printf("Token issuer: %s\n", getTokenIssuer(externalToken)) + fmt.Println() + + // Create static token provider + baseProvider := tokenprovider.NewStaticTokenProvider(externalToken) + + // Wrap with federation provider + federationProvider := tokenprovider.NewFederationProvider(baseProvider, host) + cachedProvider := tokenprovider.NewCachedTokenProvider(federationProvider) + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + fmt.Println("Testing connection with token passthrough...") + if err := testConnection(db); err != nil { + log.Fatalf("Connection test failed: %v", err) + } + + fmt.Println() + fmt.Println("✓ Token Passthrough with Federation test PASSED") +} + +// testExternalTokenWithFederation tests manual token exchange process +func testExternalTokenWithFederation() { + fmt.Println("Test: Manual Token Exchange (for debugging)") + fmt.Println("-------------------------------------------") + fmt.Println() + + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + externalToken := os.Getenv("EXTERNAL_TOKEN") + + if host == "" || httpPath == "" || externalToken == "" { + log.Fatal("DATABRICKS_HOST, DATABRICKS_HTTPPATH, and EXTERNAL_TOKEN must be set") + } + + fmt.Printf("Host: %s\n", host) + fmt.Printf("Token issuer: %s\n", getTokenIssuer(externalToken)) + fmt.Println() + + // Manual token exchange + fmt.Println("Step 1: Manually exchanging token...") + exchangedToken, err := manualTokenExchange(host, externalToken) + if err != nil { + log.Fatalf("Token exchange failed: %v", err) + } + + fmt.Printf("✓ Token exchange successful\n") + fmt.Printf(" Exchanged token length: %d chars\n", len(exchangedToken)) + fmt.Println() + + // Test connection with exchanged token + fmt.Println("Step 2: Testing connection with exchanged token...") + provider := tokenprovider.NewStaticTokenProvider(exchangedToken) + cachedProvider := tokenprovider.NewCachedTokenProvider(provider) + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + if err := testConnection(db); err != nil { + log.Fatalf("Connection test failed: %v", err) + } + + fmt.Println() + fmt.Println("✓ Manual Token Exchange test PASSED") +} + +// Helper: Test database connection with queries +func testConnection(db *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + // Test 1: Simple query + var result int + err := db.QueryRowContext(ctx, "SELECT 1").Scan(&result) + if err != nil { + return fmt.Errorf("simple query failed: %w", err) + } + fmt.Printf("✓ SELECT 1 returned: %d\n", result) + + // Test 2: Range query + rows, err := db.QueryContext(ctx, "SELECT * FROM RANGE(5)") + if err != nil { + return fmt.Errorf("range query failed: %w", err) + } + defer rows.Close() + + count := 0 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return fmt.Errorf("scan failed: %w", err) + } + count++ + } + fmt.Printf("✓ SELECT FROM RANGE(5) returned %d rows\n", count) + + // Test 3: Current user query + var username string + err = db.QueryRowContext(ctx, "SELECT CURRENT_USER()").Scan(&username) + if err != nil { + return fmt.Errorf("current user query failed: %w", err) + } + fmt.Printf("✓ Connected as user: %s\n", username) + + return nil +} + +// Helper: Manual token exchange (for debugging/testing) +func manualTokenExchange(databricksHost, subjectToken string) (string, error) { + exchangeURL := databricksHost + if !strings.HasPrefix(exchangeURL, "http://") && !strings.HasPrefix(exchangeURL, "https://") { + exchangeURL = "https://" + exchangeURL + } + if !strings.HasSuffix(exchangeURL, "/") { + exchangeURL += "/" + } + exchangeURL += "oidc/v1/token" + + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") + data.Set("scope", "sql") + data.Set("subject_token_type", "urn:ietf:params:oauth:token-type:jwt") + data.Set("subject_token", subjectToken) + + req, err := http.NewRequest("POST", exchangeURL, strings.NewReader(data.Encode())) + if err != nil { + return "", err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "*/*") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + } + + if err := json.Unmarshal(body, &tokenResp); err != nil { + return "", err + } + + return tokenResp.AccessToken, nil +} + +// Helper: Get token issuer from JWT (for logging) +func getTokenIssuer(tokenString string) string { + parts := strings.Split(tokenString, ".") + if len(parts) < 2 { + return "not a JWT" + } + + // Decode payload (second part) + payload, err := decodeBase64(parts[1]) + if err != nil { + return "invalid JWT" + } + + var claims map[string]interface{} + if err := json.Unmarshal(payload, &claims); err != nil { + return "invalid JWT" + } + + if iss, ok := claims["iss"].(string); ok { + return iss + } + + return "unknown" +} + +func decodeBase64(s string) ([]byte, error) { + // Add padding if needed + switch len(s) % 4 { + case 2: + s += "==" + case 3: + s += "=" + } + return io.ReadAll(strings.NewReader(s)) +} + +// U2MTokenProvider wraps U2M authenticator as a TokenProvider +type U2MTokenProvider struct { + authenticator interface { + Authenticate(*http.Request) error + } +} + +func (p *U2MTokenProvider) GetToken(ctx context.Context) (*tokenprovider.Token, error) { + // Create a dummy request to trigger authentication + req, err := http.NewRequestWithContext(ctx, "GET", "http://dummy", nil) + if err != nil { + return nil, err + } + + // Authenticate will add Authorization header + if err := p.authenticator.Authenticate(req); err != nil { + return nil, err + } + + // Extract token from Authorization header + authHeader := req.Header.Get("Authorization") + if authHeader == "" { + return nil, fmt.Errorf("no authorization header set") + } + + // Parse "Bearer " + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid authorization header format") + } + + return &tokenprovider.Token{ + AccessToken: parts[1], + TokenType: parts[0], + }, nil +} + +func (p *U2MTokenProvider) Name() string { + return "u2m-browser-oauth" +} diff --git a/examples/token_federation/main.go b/examples/token_federation/main.go new file mode 100644 index 00000000..e2deeaff --- /dev/null +++ b/examples/token_federation/main.go @@ -0,0 +1,344 @@ +package main + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strconv" + "time" + + dbsql "github.com/databricks/databricks-sql-go" + "github.com/databricks/databricks-sql-go/auth/tokenprovider" +) + +func main() { + // Get configuration from environment + host := os.Getenv("DATABRICKS_HOST") + httpPath := os.Getenv("DATABRICKS_HTTPPATH") + port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT")) + if err != nil { + port = 443 + } + + fmt.Println("Token Federation Examples") + fmt.Println("=========================") + + // Choose which example to run based on environment variable + example := os.Getenv("TOKEN_EXAMPLE") + if example == "" { + example = "static" + } + + switch example { + case "static": + runStaticTokenExample(host, httpPath, port) + case "external": + runExternalTokenExample(host, httpPath, port) + case "cached": + runCachedTokenExample(host, httpPath, port) + case "custom": + runCustomProviderExample(host, httpPath, port) + case "oauth": + runOAuthServiceExample(host, httpPath, port) + default: + log.Fatalf("Unknown example: %s", example) + } +} + +// Example 1: Static token (simplest case) +func runStaticTokenExample(host, httpPath string, port int) { + fmt.Println("\nExample 1: Static Token Provider") + fmt.Println("---------------------------------") + + token := os.Getenv("DATABRICKS_ACCESS_TOKEN") + if token == "" { + log.Fatal("DATABRICKS_ACCESS_TOKEN not set") + } + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithPort(port), + dbsql.WithHTTPPath(httpPath), + dbsql.WithStaticToken(token), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + // Test the connection + var result int + err = db.QueryRow("SELECT 1").Scan(&result) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("✓ Connected successfully using static token\n") + fmt.Printf("✓ Test query result: %d\n", result) +} + +// Example 2: External token provider (token passthrough) +func runExternalTokenExample(host, httpPath string, port int) { + fmt.Println("\nExample 2: External Token Provider (Passthrough)") + fmt.Println("------------------------------------------------") + + // Simulate getting token from external source + tokenFunc := func() (string, error) { + // In real scenario, this could: + // - Read from a file + // - Call another service + // - Retrieve from a secret manager + // - Get from environment variable + token := os.Getenv("DATABRICKS_ACCESS_TOKEN") + if token == "" { + return "", fmt.Errorf("no token available") + } + fmt.Println(" → Fetching token from external source...") + return token, nil + } + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithPort(port), + dbsql.WithHTTPPath(httpPath), + dbsql.WithExternalToken(tokenFunc), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + // Test the connection + var result int + err = db.QueryRow("SELECT 2").Scan(&result) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("✓ Connected successfully using external token provider\n") + fmt.Printf("✓ Test query result: %d\n", result) +} + +// Example 3: Cached token provider +func runCachedTokenExample(host, httpPath string, port int) { + fmt.Println("\nExample 3: Cached Token Provider") + fmt.Println("--------------------------------") + + callCount := 0 + // Create a token provider that tracks how many times it's called + baseProvider := tokenprovider.NewExternalTokenProvider(func() (string, error) { + callCount++ + fmt.Printf(" → Token provider called (count: %d)\n", callCount) + token := os.Getenv("DATABRICKS_ACCESS_TOKEN") + if token == "" { + return "", fmt.Errorf("no token available") + } + return token, nil + }) + + // Wrap with caching + cachedProvider := tokenprovider.NewCachedTokenProvider(baseProvider) + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithPort(port), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + // Run multiple queries - token should only be fetched once due to caching + for i := 1; i <= 3; i++ { + var result int + err = db.QueryRow(fmt.Sprintf("SELECT %d", i)).Scan(&result) + if err != nil { + log.Fatal(err) + } + fmt.Printf("✓ Query %d result: %d\n", i, result) + } + + fmt.Printf("✓ Token was fetched %d time(s) (should be 1 due to caching)\n", callCount) +} + +// Example 4: Custom token provider with expiry +func runCustomProviderExample(host, httpPath string, port int) { + fmt.Println("\nExample 4: Custom Token Provider with Expiry") + fmt.Println("--------------------------------------------") + + // Custom provider that simulates token with expiry + provider := &CustomExpiringTokenProvider{ + baseToken: os.Getenv("DATABRICKS_ACCESS_TOKEN"), + expiry: 1 * time.Hour, + } + + // Wrap with caching to handle refresh + cachedProvider := tokenprovider.NewCachedTokenProvider(provider) + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithPort(port), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + var result int + err = db.QueryRow("SELECT 42").Scan(&result) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("✓ Connected with custom provider\n") + fmt.Printf("✓ Token expires at: %s\n", provider.lastToken.ExpiresAt.Format(time.RFC3339)) + fmt.Printf("✓ Test query result: %d\n", result) +} + +// Example 5: OAuth service token provider +func runOAuthServiceExample(host, httpPath string, port int) { + fmt.Println("\nExample 5: OAuth Service Token Provider") + fmt.Println("---------------------------------------") + + oauthEndpoint := os.Getenv("OAUTH_TOKEN_ENDPOINT") + clientID := os.Getenv("OAUTH_CLIENT_ID") + clientSecret := os.Getenv("OAUTH_CLIENT_SECRET") + + if oauthEndpoint == "" || clientID == "" || clientSecret == "" { + fmt.Println("⚠ Skipping OAuth example (OAUTH_TOKEN_ENDPOINT, OAUTH_CLIENT_ID, or OAUTH_CLIENT_SECRET not set)") + return + } + + provider := &OAuthServiceTokenProvider{ + endpoint: oauthEndpoint, + clientID: clientID, + clientSecret: clientSecret, + } + + // Wrap with caching for efficiency + cachedProvider := tokenprovider.NewCachedTokenProvider(provider) + + connector, err := dbsql.NewConnector( + dbsql.WithServerHostname(host), + dbsql.WithPort(port), + dbsql.WithHTTPPath(httpPath), + dbsql.WithTokenProvider(cachedProvider), + ) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + var result string + err = db.QueryRow("SELECT 'OAuth Success'").Scan(&result) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("✓ Connected with OAuth service token\n") + fmt.Printf("✓ Test query result: %s\n", result) +} + +// CustomExpiringTokenProvider simulates a provider with token expiry +type CustomExpiringTokenProvider struct { + baseToken string + expiry time.Duration + lastToken *tokenprovider.Token +} + +func (p *CustomExpiringTokenProvider) GetToken(ctx context.Context) (*tokenprovider.Token, error) { + if p.baseToken == "" { + return nil, fmt.Errorf("no base token configured") + } + + fmt.Println(" → Generating new token with expiry...") + p.lastToken = &tokenprovider.Token{ + AccessToken: p.baseToken, + TokenType: "Bearer", + ExpiresAt: time.Now().Add(p.expiry), + } + + return p.lastToken, nil +} + +func (p *CustomExpiringTokenProvider) Name() string { + return "custom-expiring" +} + +// OAuthServiceTokenProvider gets tokens from an OAuth service +type OAuthServiceTokenProvider struct { + endpoint string + clientID string + clientSecret string +} + +func (p *OAuthServiceTokenProvider) GetToken(ctx context.Context) (*tokenprovider.Token, error) { + fmt.Printf(" → Fetching token from OAuth service: %s\n", p.endpoint) + + // Create OAuth request + req, err := http.NewRequestWithContext(ctx, "POST", p.endpoint, nil) + if err != nil { + return nil, err + } + + req.SetBasicAuth(p.clientID, p.clientSecret) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Make request + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("OAuth service returned %d: %s", resp.StatusCode, body) + } + + // Parse response + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return nil, err + } + + token := &tokenprovider.Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + } + + if tokenResp.ExpiresIn > 0 { + token.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + return token, nil +} + +func (p *OAuthServiceTokenProvider) Name() string { + return "oauth-service" +} diff --git a/go.mod b/go.mod index d9a517c5..1d1fdc78 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/apache/arrow/go/v12 v12.0.1 github.com/apache/thrift v0.17.0 github.com/coreos/go-oidc/v3 v3.5.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/joho/godotenv v1.4.0 github.com/mattn/go-isatty v0.0.20 github.com/pierrec/lz4/v4 v4.1.15 diff --git a/go.sum b/go.sum index edeb89ee..670487a8 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQr github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=