Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions auth/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ var databricksAWSDomains []string = []string{
}

var databricksAzureDomains []string = []string{
".staging.azuredatabricks.net",
".dev.azuredatabricks.net",
".azuredatabricks.net",
".databricks.azure.cn",
".databricks.azure.us",
Expand Down
44 changes: 44 additions & 0 deletions auth/tokenprovider/authenticator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package tokenprovider

import (
"context"
"fmt"
"net/http"

"github.com/databricks/databricks-sql-go/auth"
"github.com/rs/zerolog/log"
)

// TokenProviderAuthenticator implements auth.Authenticator using a TokenProvider
type TokenProviderAuthenticator struct {
provider TokenProvider
}

// NewAuthenticator creates an authenticator from a token provider
func NewAuthenticator(provider TokenProvider) auth.Authenticator {
return &TokenProviderAuthenticator{
provider: provider,
}
}

// Authenticate implements auth.Authenticator
func (a *TokenProviderAuthenticator) Authenticate(r *http.Request) error {
ctx := r.Context()
if ctx == nil {
ctx = context.Background()
}

token, err := a.provider.GetToken(ctx)
if err != nil {
return fmt.Errorf("token provider authenticator: failed to get token: %w", err)
}

if token.AccessToken == "" {
return fmt.Errorf("token provider authenticator: empty access token")
}

token.SetAuthHeader(r)
log.Debug().Msgf("token provider authenticator: authenticated using provider %s", a.provider.Name())

return nil
}
135 changes: 135 additions & 0 deletions auth/tokenprovider/authenticator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package tokenprovider

import (
"context"
"errors"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestTokenProviderAuthenticator(t *testing.T) {
t.Run("successful_authentication", func(t *testing.T) {
provider := NewStaticTokenProvider("test-token-123")
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

require.NoError(t, err)
assert.Equal(t, "Bearer test-token-123", req.Header.Get("Authorization"))
})

t.Run("authentication_with_custom_token_type", func(t *testing.T) {
provider := NewStaticTokenProviderWithType("test-token", "MAC")
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

require.NoError(t, err)
assert.Equal(t, "MAC test-token", req.Header.Get("Authorization"))
})

t.Run("authentication_error_propagation", func(t *testing.T) {
provider := &mockProvider{
tokenFunc: func() (*Token, error) {
return nil, errors.New("provider failed")
},
}
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

assert.Error(t, err)
assert.Contains(t, err.Error(), "provider failed")
assert.Empty(t, req.Header.Get("Authorization"))
})

t.Run("empty_token_error", func(t *testing.T) {
provider := &mockProvider{
tokenFunc: func() (*Token, error) {
return &Token{
AccessToken: "",
TokenType: "Bearer",
}, nil
},
}
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

assert.Error(t, err)
assert.Contains(t, err.Error(), "empty access token")
assert.Empty(t, req.Header.Get("Authorization"))
})

t.Run("uses_request_context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately

provider := &mockProvider{
tokenFunc: func() (*Token, error) {
// This would normally check context cancellation
return &Token{
AccessToken: "test-token",
TokenType: "Bearer",
}, nil
},
}
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil)
err := authenticator.Authenticate(req)

// Even with cancelled context, this should work as our mock doesn't check it
require.NoError(t, err)
assert.Equal(t, "Bearer test-token", req.Header.Get("Authorization"))
})

t.Run("external_token_integration", func(t *testing.T) {
tokenFunc := func() (string, error) {
return "external-token-456", nil
}
provider := NewExternalTokenProvider(tokenFunc)
authenticator := NewAuthenticator(provider)

req, _ := http.NewRequest("POST", "http://example.com/api", nil)
err := authenticator.Authenticate(req)

require.NoError(t, err)
assert.Equal(t, "Bearer external-token-456", req.Header.Get("Authorization"))
})

t.Run("cached_provider_integration", func(t *testing.T) {
callCount := 0
baseProvider := &mockProvider{
tokenFunc: func() (*Token, error) {
callCount++
return &Token{
AccessToken: "cached-token",
TokenType: "Bearer",
}, nil
},
name: "test",
}

cachedProvider := NewCachedTokenProvider(baseProvider)
authenticator := NewAuthenticator(cachedProvider)

// Multiple authentication attempts
for i := 0; i < 3; i++ {
req, _ := http.NewRequest("GET", "http://example.com", nil)
err := authenticator.Authenticate(req)
require.NoError(t, err)
assert.Equal(t, "Bearer cached-token", req.Header.Get("Authorization"))
}

// Should only call base provider once due to caching
assert.Equal(t, 1, callCount)
})
}
86 changes: 86 additions & 0 deletions auth/tokenprovider/cached.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package tokenprovider

import (
"context"
"fmt"
"sync"
"time"

"github.com/rs/zerolog/log"
)

// CachedTokenProvider wraps another provider and caches tokens
type CachedTokenProvider struct {
provider TokenProvider
cache *Token
mutex sync.RWMutex
// RefreshThreshold determines when to refresh (default 5 minutes before expiry)
RefreshThreshold time.Duration
}

// NewCachedTokenProvider creates a caching wrapper around any token provider
func NewCachedTokenProvider(provider TokenProvider) *CachedTokenProvider {
return &CachedTokenProvider{
provider: provider,
RefreshThreshold: 5 * time.Minute,
}
}

// GetToken retrieves a token, using cache if available and valid
func (p *CachedTokenProvider) GetToken(ctx context.Context) (*Token, error) {
// Try to get from cache first
p.mutex.RLock()
cached := p.cache
p.mutex.RUnlock()

if cached != nil && !p.shouldRefresh(cached) {
log.Debug().Msgf("cached token provider: using cached token for provider %s", p.provider.Name())
return cached, nil
}

// Need to refresh
p.mutex.Lock()
defer p.mutex.Unlock()

// Double-check after acquiring write lock
if p.cache != nil && !p.shouldRefresh(p.cache) {
return p.cache, nil
}

log.Debug().Msgf("cached token provider: fetching new token from provider %s", p.provider.Name())
token, err := p.provider.GetToken(ctx)
if err != nil {
return nil, fmt.Errorf("cached token provider: failed to get token: %w", err)
}

p.cache = token
return token, nil
}

// shouldRefresh determines if a token should be refreshed
func (p *CachedTokenProvider) shouldRefresh(token *Token) bool {
if token == nil {
return true
}

// If no expiry time, assume token doesn't expire
if token.ExpiresAt.IsZero() {
return false
}

// Refresh if within threshold of expiry
refreshAt := token.ExpiresAt.Add(-p.RefreshThreshold)
return time.Now().After(refreshAt)
}

// Name returns the provider name
func (p *CachedTokenProvider) Name() string {
return fmt.Sprintf("cached[%s]", p.provider.Name())
}

// ClearCache clears the cached token
func (p *CachedTokenProvider) ClearCache() {
p.mutex.Lock()
p.cache = nil
p.mutex.Unlock()
}
Loading
Loading