Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ jobs:
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 18
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-test-oauth: "true"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: cockroachdb
pgx-test-database: "postgresql://[email protected]:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"

Expand Down Expand Up @@ -115,6 +126,7 @@ jobs:
PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
PGX_TEST_OAUTH: ${{ matrix.pgx-test-oauth }}
# TestConnectTLS fails. However, it succeeds if I connect to the CI server with upterm and run it. Give up on that test for now.
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
Expand Down
11 changes: 11 additions & 0 deletions ci/setup_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ then
sudo sh -c "echo \"listen_addresses = '127.0.0.1'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
sudo sh -c "cat testsetup/postgresql_ssl.conf >> /etc/postgresql/$PGVERSION/main/postgresql.conf"

if [ "$PGVERSION" -ge 18 ]; then
# Configure and Install OAuth validator for PostgreSQL 18+
sudo sh -c "cat testsetup/oauth_validator_module/postgresql.conf >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
sudo sh -c "cat testsetup/oauth_validator_module/pg_hba.conf >> /etc/postgresql/$PGVERSION/main/pg_hba.conf"
(
cd testsetup/oauth_validator_module
sudo apt-get install -y gcc make libkrb5-dev
make && sudo make install
)
fi

cd testsetup

# Generate CA, server, and encrypted client certificates.
Expand Down
67 changes: 67 additions & 0 deletions pgconn/auth_oauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package pgconn

import (
"context"
"encoding/json"
"errors"
"fmt"

"github.com/jackc/pgx/v5/pgproto3"
)

func (c *PgConn) oauthAuth(ctx context.Context) error {
if c.config.OAuthTokenProvider == nil {
return errors.New("OAuth authentication required but no token provider configured")
}

token, err := c.config.OAuthTokenProvider(ctx)
if err != nil {
return fmt.Errorf("failed to obtain OAuth token: %w", err)
}

// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.1
initialResponse := []byte("n,,\x01auth=Bearer " + token + "\x01\x01")

saslInitialResponse := &pgproto3.SASLInitialResponse{
AuthMechanism: "OAUTHBEARER",
Data: initialResponse,
}
c.frontend.Send(saslInitialResponse)
err = c.flushWithPotentialWriteReadDeadlock()
if err != nil {
return err
}

msg, err := c.receiveMessage()
if err != nil {
return err
}

switch m := msg.(type) {
case *pgproto3.AuthenticationOk:
return nil
case *pgproto3.AuthenticationSASLContinue:
// Server sent error response in SASL continue
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.2
// https://www.rfc-editor.org/rfc/rfc7628.html#section-3.2.3
errResponse := struct {
Status string `json:"status"`
Scope string `json:"scope"`
OpenIDConfiguration string `json:"openid-configuration"`
}{}
err := json.Unmarshal(m.Data, &errResponse)
if err != nil {
return fmt.Errorf("invalid OAuth error response from server: %w", err)
}

// Per RFC 7628 section 3.2.3, we should send a SASLResponse which only contains \x01.
// However, since the connection will be closed anyway, we can skip this
return fmt.Errorf("OAuth authentication failed: %s", errResponse.Status)

case *pgproto3.ErrorResponse:
return ErrorResponseToPgError(m)

default:
return fmt.Errorf("unexpected message type during OAuth auth: %T", msg)
}
}
4 changes: 4 additions & 0 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ type Config struct {
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler

// OAuthTokenProvider is a function that returns an OAuth token for authentication. If set, it will be used for
// OAUTHBEARER SASL authentication when the server requests it.
OAuthTokenProvider func(context.Context) (string, error)

createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}

Expand Down
15 changes: 14 additions & 1 deletion pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,20 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
return nil, newPerDialConnectError("failed to write password message", err)
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
// Check if OAUTHBEARER is supported
serverSupportsOAuthBearer := false
for _, mech := range msg.AuthMechanisms {
if mech == "OAUTHBEARER" {
serverSupportsOAuthBearer = true
break
}
}

if serverSupportsOAuthBearer && pgConn.config.OAuthTokenProvider != nil {
err = pgConn.oauthAuth(ctx)
} else {
err = pgConn.scramAuth(msg.AuthMechanisms)
}
if err != nil {
pgConn.conn.Close()
return nil, newPerDialConnectError("failed SASL auth", err)
Expand Down
54 changes: 53 additions & 1 deletion pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)

const pgbouncerConnStringEnvVar = "PGX_TEST_PGBOUNCER_CONN_STRING"
const (
pgbouncerConnStringEnvVar = "PGX_TEST_PGBOUNCER_CONN_STRING"
// runOAuthTestEnvVar has to be set to "true" to run OAuth tests
runOAuthTestEnvVar = "PGX_TEST_OAUTH"
)

func TestConnect(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -120,6 +124,54 @@ func TestConnectTLS(t *testing.T) {
closeConn(t, conn)
}

// TestConnectOAuth is separate from other connect tests because it specifically
// needs a configured OAuthTokenProvider. Further it's only available in Postgres
// 18+ and requires the dummy OAuth validator module installed.
func TestConnectOAuth(t *testing.T) {
if os.Getenv(runOAuthTestEnvVar) != "true" {
t.Skipf("Skipping as '%s=true' is not set", runOAuthTestEnvVar)
}

config, err := pgconn.ParseConfig("host=127.0.0.1 user=pgx_oauth dbname=pgx_test")
require.NoError(t, err)

// Configure OAuthTokenProvider for dummy validator.
// The dummy validator accepts any token and maps it to the user equal to the
// token string.
config.OAuthTokenProvider = func(ctx context.Context) (string, error) {
return "pgx_oauth", nil
}

conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, conn)

result := conn.ExecParams(context.Background(), "SELECT CURRENT_USER", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
require.Len(t, result.Rows, 1)
require.Len(t, result.Rows[0], 1)
require.Equalf(t, "pgx_oauth", string(result.Rows[0][0]), "not logged in as expected user.")
}

func TestConnectOAuthError(t *testing.T) {
if os.Getenv(runOAuthTestEnvVar) != "true" {
t.Skipf("Skipping as '%s=true' is not set", runOAuthTestEnvVar)
}

config, err := pgconn.ParseConfig("host=127.0.0.1 user=pgx_oauth dbname=pgx_test")
require.NoError(t, err)

// Configure OAuthTokenProvider for dummy validator.
// The dummy validator accepts any token and maps it to the user equal to the
// token string. In this case that token will be accepted but as there is no
// user 'INVALID_TOKEN' the connection should fail.
config.OAuthTokenProvider = func(ctx context.Context) (string, error) {
return "INVALID_TOKEN", nil
}

_, err = pgconn.ConnectConfig(context.Background(), config)
require.Error(t, err, "connect should return error for invalid token")
}
func TestConnectTLSPasswordProtectedClientCertWithSSLPassword(t *testing.T) {
t.Parallel()

Expand Down
17 changes: 17 additions & 0 deletions testsetup/oauth_validator_module/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.PHONY = install clean

PG_CONFIG = pg_config
PKGLIBDIR = $(shell $(PG_CONFIG) --pkglibdir)
CPPFLAGS += -I$(shell $(PG_CONFIG) --includedir-server)
CFLAGS += -fPIC

dummy_validator.so: dummy_validator.o
$(CC) -shared $(CFLAGS) $(LDFLAGS) -o $@ $^

dummy_validator.o: dummy_validator.c

install:
install -D -m 755 dummy_validator.so $(PKGLIBDIR)/

clean:
rm -f dummy_validator.o dummy_validator.so
26 changes: 26 additions & 0 deletions testsetup/oauth_validator_module/dummy_validator.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "postgres.h"
#include "fmgr.h"
#include "libpq/oauth.h"

PG_MODULE_MAGIC;

bool validate(const ValidatorModuleState *state, const char *token,
const char *role, ValidatorModuleResult *result) {

elog(LOG, "accept token '%s' for role '%s'", token, role);
char *authn_id = pstrdup(token);
result->authn_id = authn_id;
result->authorized = true;
return true;
}

const OAuthValidatorCallbacks callbacks = {
.magic = PG_OAUTH_VALIDATOR_MAGIC,
.startup_cb = NULL,
.shutdown_cb = NULL,
.validate_cb = validate,
};

const OAuthValidatorCallbacks *_PG_oauth_validator_module_init() {
return &callbacks;
}
1 change: 1 addition & 0 deletions testsetup/oauth_validator_module/pg_hba.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
host all pgx_oauth 127.0.0.1/32 oauth validator=dummy_validator issuer=https://example.com scope=
1 change: 1 addition & 0 deletions testsetup/oauth_validator_module/postgresql.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
oauth_validator_libraries = 'dummy_validator'
1 change: 1 addition & 0 deletions testsetup/postgresql_setup.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ create user pgx_md5 with superuser PASSWORD 'secret';
set password_encryption = 'scram-sha-256';
create user pgx_pw with superuser PASSWORD 'secret';
create user pgx_scram with superuser PASSWORD 'secret';
create user pgx_oauth with superuser;
\set whoami `whoami`
create user :whoami with superuser; -- unix domain socket user

Expand Down