diff --git a/args.go b/args.go index f2ea0f0e6..fe31a1671 100644 --- a/args.go +++ b/args.go @@ -173,6 +173,8 @@ const ( ArgKubernetesSSOIssuerURL = "sso-issuer-url" // ArgKubernetesSSOClientID is the OIDC client ID for cluster SSO configuration. ArgKubernetesSSOClientID = "sso-client-id" + // ArgKubernetesSSOLocalServerPort is the port to use for the local server which handles SSO authentication flow. + ArgKubernetesSSOLocalServerPort = "sso-local-server-port" // ArgSurgeUpgrade is a cluster's surge-upgrade argument. ArgSurgeUpgrade = "surge-upgrade" // ArgCommandUpsert is an upsert for a resource to be created or updated argument. @@ -328,7 +330,8 @@ const ( // ArgTriggerDeployment indicates whether to trigger a deployment ArgTriggerDeployment = "trigger-deployment" // ArgVersion is the version of the command to use - ArgVersion = "version" + ArgVersion = "version" + ArgKubernetesSSOURL = "sso-url" // ArgVerbose enables verbose output ArgVerbose = "verbose" diff --git a/commands/doit.go b/commands/doit.go index 927fc2a44..213e7dc83 100644 --- a/commands/doit.go +++ b/commands/doit.go @@ -140,7 +140,7 @@ func initConfig() { } // in case we ever want to change this, or let folks configure it... -func defaultConfigHome() string { +var defaultConfigHome = func() string { cfgDir, err := os.UserConfigDir() checkErr(err) diff --git a/commands/kubernetes.go b/commands/kubernetes.go index fd95268a0..57b84ff33 100644 --- a/commands/kubernetes.go +++ b/commands/kubernetes.go @@ -18,6 +18,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "os" "path/filepath" "sort" @@ -28,12 +29,14 @@ import ( "github.com/blang/semver" "github.com/digitalocean/godo" "github.com/google/uuid" + "github.com/pkg/browser" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/digitalocean/doctl" "github.com/digitalocean/doctl/commands/displayers" "github.com/digitalocean/doctl/do" + "github.com/digitalocean/doctl/internal/kubernetes/sso" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" kubeerrors "k8s.io/apimachinery/pkg/util/errors" @@ -61,6 +64,14 @@ A typical workflow is to use ` + "`" + `doctl kubernetes cluster create` + "`" + The commands under ` + "`" + `doctl kubernetes options` + "`" + ` retrieve values used while creating clusters, such as the list of regions where cluster creation is supported.` ) +func init() { + // Borrowed from https://github.com/int128/kubelogin: + // In credential plugin mode, some browser launcher writes a message to stdout + // and it may break the credential json for client-go. + // This prevents the browser launcher from breaking the credential json. + browser.Stdout = os.Stderr +} + var getCurrentAuthContextFn = defaultGetCurrentAuthContextFn func defaultGetCurrentAuthContextFn() string { @@ -207,6 +218,9 @@ func (p *kubeconfigProvider) ConfigPath() string { // KubernetesCommandService is used to execute Kubernetes commands. type KubernetesCommandService struct { KubeconfigProvider KubeconfigProvider + + // to be used for stubbing in testss + ssoLogin func(ctx context.Context, clientID, issuerURL string, opts ...sso.LocalOIDCLoginOption) (string, time.Time, error) } func kubernetesCommandService() *KubernetesCommandService { @@ -214,6 +228,7 @@ func kubernetesCommandService() *KubernetesCommandService { KubeconfigProvider: &kubeconfigProvider{ pathOptions: clientcmd.NewDefaultPathOptions(), }, + ssoLogin: sso.GetIDToken, } } @@ -473,8 +488,11 @@ Returns the raw YAML for the specified cluster's kubeconfig.`, Writer, aliasOpt( cmdShowConfig.Example = `The following example shows the kubeconfig YAML for a cluster named ` + "`" + `example-cluster` + "`" + `: doctl kubernetes cluster kubeconfig show example-cluster` execCredDesc := "INTERNAL: This hidden command is for printing a cluster's exec credential" - cmdExecCredential := CmdBuilder(cmd, k8sCmdService.RunKubernetesKubeconfigExecCredential, "exec-credential ", execCredDesc, execCredDesc, Writer, hiddenCmd()) + cmdExecCredential := CmdBuilder(cmd, k8sCmdService.RunKubernetesKubeconfigExecCredential, "exec-credential ", execCredDesc, execCredDesc, Writer) //, hiddenCmd()) AddStringFlag(cmdExecCredential, doctl.ArgVersion, "", "", "") + AddStringFlag(cmdExecCredential, doctl.ArgKubernetesSSOIssuerURL, "", "", "") + AddStringFlag(cmdExecCredential, doctl.ArgKubernetesSSOClientID, "", "", "") + AddIntFlag(cmdExecCredential, doctl.ArgKubernetesSSOLocalServerPort, "", 8080, "") cmdSaveConfig := CmdBuilder(cmd, k8sCmdService.RunKubernetesKubeconfigSave, "save ", "Save a cluster's credentials to your local kubeconfig", ` Adds the credentials for the specified cluster to your local kubeconfig. After this, your kubectl installation can directly manage the specified cluster. @@ -1264,10 +1282,13 @@ func cachedExecCredentialPath(id string) string { return filepath.Join(kubeconfigCachePath(), id+".json") } +func cachedSSOExecCredentialPath(id string) string { + return filepath.Join(kubeconfigCachePath(), id+"_sso.json") +} + // loadCachedExecCredential attempts to load the cached exec credential from disk. Never errors // Returns nil if there's no credential, if it failed to load it, or if it's expired. -func loadCachedExecCredential(id string) (*clientauthentication.ExecCredential, error) { - path := cachedExecCredentialPath(id) +func loadCachedExecCredential(path string) (*clientauthentication.ExecCredential, error) { f, err := os.Open(path) if err != nil { if os.IsNotExist(err) { @@ -1300,7 +1321,7 @@ func loadCachedExecCredential(id string) (*clientauthentication.ExecCredential, } // cacheExecCredential caches an ExecCredential to the doctl cache directory -func cacheExecCredential(id string, execCredential *clientauthentication.ExecCredential) error { +func cacheExecCredential(path string, execCredential *clientauthentication.ExecCredential) error { // Don't bother caching if there's no expiration set if execCredential.Status.ExpirationTimestamp.IsZero() { return nil @@ -1311,7 +1332,6 @@ func cacheExecCredential(id string, execCredential *clientauthentication.ExecCre return err } - path := cachedExecCredentialPath(id) f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR|os.O_TRUNC, os.FileMode(0600)) if err != nil { return err @@ -1337,31 +1357,70 @@ func (s *KubernetesCommandService) RunKubernetesKubeconfigExecCredential(c *CmdC return fmt.Errorf("Invalid version %q, expected 'v1beta1'", version) } + var isSSO bool + ssoIssuerURL, err := c.Doit.GetString(c.NS, doctl.ArgKubernetesSSOIssuerURL) + if err != nil { + return fmt.Errorf("Checking %s flag: %v", doctl.ArgKubernetesSSOIssuerURL, err) + } + ssoClientID, err := c.Doit.GetString(c.NS, doctl.ArgKubernetesSSOClientID) + if err != nil { + return fmt.Errorf("Checking %s flag: %v", doctl.ArgKubernetesSSOClientID, err) + } + if (ssoIssuerURL != "" && ssoClientID == "") || (ssoIssuerURL == "" && ssoClientID != "") { + return fmt.Errorf("Invalid SSO configuration: issuer URL and client ID must be provided together") + } + isSSO = (ssoIssuerURL != "" && ssoClientID != "") + kube := c.Kubernetes() + // it's important that we don't print anything to stdout since this command + // is used by kubectl which relies on stdout to contain _only_ the credential + logger := log.New(os.Stderr, "doctl: ", log.LstdFlags) clusterID := c.Args[0] - execCredential, err := loadCachedExecCredential(clusterID) + cachePath := cachedExecCredentialPath(clusterID) + if isSSO { + // store SSO credentials separately so that, if a user switches from SSO to token auth (or vice versa), + // we stop using the cached credentials and instead fetch new ones + cachePath = cachedSSOExecCredentialPath(clusterID) + } + execCredential, err := loadCachedExecCredential(cachePath) if err != nil && Verbose { warn("%v", err) } if execCredential != nil { + logger.Println("Using cached credential") return json.NewEncoder(c.Out).Encode(execCredential) } - credentials, err := kube.GetCredentials(clusterID) - if err != nil { - if errResponse, ok := err.(*godo.ErrorResponse); ok { - return fmt.Errorf("Failed to fetch credentials for cluster %q: %v", clusterID, errResponse.Message) + var token string + var expiry time.Time + if isSSO { + logger.Println("SSO login") + ssoLocalServerPort, err := c.Doit.GetInt(c.NS, doctl.ArgKubernetesSSOLocalServerPort) + if err != nil { + return fmt.Errorf("Checking %s flag: %v", doctl.ArgKubernetesSSOLocalServerPort, err) + } + + if ssoLocalServerPort <= 1024 || ssoLocalServerPort > 65535 { + return fmt.Errorf("Invalid %s flag: %d", doctl.ArgKubernetesSSOLocalServerPort, ssoLocalServerPort) } - return err - } - status := &clientauthentication.ExecCredentialStatus{ - ClientCertificateData: string(credentials.ClientCertificateData), - ClientKeyData: string(credentials.ClientKeyData), - ExpirationTimestamp: &metav1.Time{Time: credentials.ExpiresAt}, - Token: credentials.Token, + token, expiry, err = s.ssoLogin(context.Background(), ssoClientID, ssoIssuerURL, sso.WithLocalServerPort(uint16(ssoLocalServerPort)), sso.WithLogger(logger)) + if err != nil { + return fmt.Errorf("Failed to get ID token: %w", err) + } + } else { + logger.Println("DO PAT login") + credentials, err := kube.GetCredentials(clusterID) + if err != nil { + if errResponse, ok := err.(*godo.ErrorResponse); ok { + return fmt.Errorf("Failed to fetch credentials for cluster %q: %v", clusterID, errResponse.Message) + } + return err + } + expiry = credentials.ExpiresAt + token = credentials.Token } execCredential = &clientauthentication.ExecCredential{ @@ -1369,11 +1428,14 @@ func (s *KubernetesCommandService) RunKubernetesKubeconfigExecCredential(c *CmdC Kind: execCredentialKind, APIVersion: clientauthentication.SchemeGroupVersion.String(), }, - Status: status, + Status: &clientauthentication.ExecCredentialStatus{ + Token: token, + ExpirationTimestamp: &metav1.Time{Time: expiry}, + }, } // Don't error out when caching credentials, just print it if we're being verbose - if err := cacheExecCredential(clusterID, execCredential); err != nil && Verbose { + if err := cacheExecCredential(cachePath, execCredential); err != nil && Verbose { warn("%v", err) } @@ -1408,10 +1470,14 @@ func (s *KubernetesCommandService) RunKubernetesKubeconfigSave(c *CmdConfig) err return err } - path := cachedExecCredentialPath(kubeconfigParams.clusterID) - _, err = os.Stat(path) - if err == nil { - os.Remove(path) + for _, path := range []string{ + cachedExecCredentialPath(kubeconfigParams.clusterID), + cachedSSOExecCredentialPath(kubeconfigParams.clusterID), + } { + _, err = os.Stat(path) + if err == nil { + os.Remove(path) + } } return s.writeOrAddToKubeconfig(kubeconfigParams, remoteKubeconfig, setCurrentContext) diff --git a/commands/kubernetes_test.go b/commands/kubernetes_test.go index a16cc5694..285a50151 100644 --- a/commands/kubernetes_test.go +++ b/commands/kubernetes_test.go @@ -1,19 +1,28 @@ package commands import ( + "bytes" + "context" + "encoding/json" "fmt" + "os" + "path/filepath" "sort" "testing" + "time" "github.com/digitalocean/godo" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + clientauthentication "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" "k8s.io/client-go/tools/clientcmd" clientcmdapi "k8s.io/client-go/tools/clientcmd/api" "github.com/digitalocean/doctl" "github.com/digitalocean/doctl/do" + "github.com/digitalocean/doctl/internal/kubernetes/sso" ) var ( @@ -474,6 +483,157 @@ func TestKubernetesKubeconfigShow(t *testing.T) { }) } +func TestRunKubernetesKubeconfigExecCredential(t *testing.T) { + clusterIDWithCachedSSOCreds := "cluster-id-with-cached-sso-creds" + clusterIDWithCachedTokenCreds := "cluster-id-with-cached-token-creds" + + testRoot := t.TempDir() + origConfigHomeFn := defaultConfigHome + defaultConfigHome = func() string { + return filepath.Join(testRoot, "doctl") + } + t.Cleanup(func() { defaultConfigHome = origConfigHomeFn }) + + execCredCacheDir := filepath.Join(testRoot, "doctl", "cache", "exec-credential") + require.NoError(t, os.MkdirAll(execCredCacheDir, 0o700)) + + // Truncate so JSON cache round-trips match (metav1 time encoding drops sub-second precision). + expiryCached := time.Now().Add(time.Hour).UTC().Truncate(time.Second) + expiryNew := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second) + cachedCred := &clientauthentication.ExecCredential{ + TypeMeta: metav1.TypeMeta{ + Kind: execCredentialKind, + APIVersion: clientauthentication.SchemeGroupVersion.String(), + }, + Status: &clientauthentication.ExecCredentialStatus{ + Token: "cached", + ExpirationTimestamp: &metav1.Time{Time: expiryCached}, + }, + } + fToken, err := os.Create(filepath.Join(execCredCacheDir, clusterIDWithCachedTokenCreds+".json")) + require.NoError(t, err) + require.NoError(t, json.NewEncoder(fToken).Encode(cachedCred)) + require.NoError(t, fToken.Close()) + + fSSO, err := os.Create(filepath.Join(execCredCacheDir, clusterIDWithCachedSSOCreds+"_sso.json")) + require.NoError(t, err) + require.NoError(t, json.NewEncoder(fSSO).Encode(cachedCred)) + require.NoError(t, fSSO.Close()) + + type execCredCase struct { + name string + setArgs func(config *CmdConfig) + prepareMocks func(t *testing.T, tm *tcMocks) + wantToken string + wantExp time.Time + assertCache func(t *testing.T, cacheDir string, got clientauthentication.ExecCredential) + } + + tests := []execCredCase{ + { + name: "new token login", + setArgs: func(config *CmdConfig) { + // double-duty: also ensures that cached SSO cred is not used when token cred is requested + config.Args = []string{clusterIDWithCachedSSOCreds} + config.Doit.Set(config.NS, doctl.ArgVersion, "v1beta1") + }, + prepareMocks: func(t *testing.T, tm *tcMocks) { + tm.kubernetes.EXPECT().GetCredentials(clusterIDWithCachedSSOCreds).Return(&do.KubernetesClusterCredentials{ + KubernetesClusterCredentials: &godo.KubernetesClusterCredentials{ + Token: "do-token", + ExpiresAt: expiryNew, + }, + }, nil) + }, + wantToken: "do-token", + wantExp: expiryNew, + assertCache: func(t *testing.T, cacheDir string, got clientauthentication.ExecCredential) { + p := filepath.Join(cacheDir, clusterIDWithCachedSSOCreds+".json") + onDisk, err := os.ReadFile(p) + require.NoError(t, err) + var cached clientauthentication.ExecCredential + require.NoError(t, json.Unmarshal(onDisk, &cached)) + require.Equal(t, got.Status.Token, cached.Status.Token) + require.Equal(t, got.Status.ExpirationTimestamp.UTC(), cached.Status.ExpirationTimestamp.UTC()) + }, + }, + { + name: "token cache hit", + setArgs: func(config *CmdConfig) { + config.Args = []string{clusterIDWithCachedTokenCreds} + config.Doit.Set(config.NS, doctl.ArgVersion, "v1beta1") + }, + wantToken: "cached", + wantExp: expiryCached, + }, + { + name: "new sso login", + setArgs: func(config *CmdConfig) { + // double-duty: also ensures that cached token cred is not used when SSO cred is requested + config.Args = []string{clusterIDWithCachedTokenCreds} + config.Doit.Set(config.NS, doctl.ArgVersion, "v1beta1") + config.Doit.Set(config.NS, doctl.ArgKubernetesSSOIssuerURL, "https://issuer.example") + config.Doit.Set(config.NS, doctl.ArgKubernetesSSOClientID, "oidc-client-id") + config.Doit.Set(config.NS, doctl.ArgKubernetesSSOLocalServerPort, 8080) + }, + wantToken: "oidc-id-token", + wantExp: expiryNew, + assertCache: func(t *testing.T, cacheDir string, got clientauthentication.ExecCredential) { + p := filepath.Join(cacheDir, clusterIDWithCachedTokenCreds+"_sso.json") + onDisk, err := os.ReadFile(p) + require.NoError(t, err) + var cached clientauthentication.ExecCredential + require.NoError(t, json.Unmarshal(onDisk, &cached)) + require.Equal(t, got.Status.Token, cached.Status.Token) + require.Equal(t, got.Status.ExpirationTimestamp.UTC(), cached.Status.ExpirationTimestamp.UTC()) + }, + }, + { + name: "sso cache hit", + setArgs: func(config *CmdConfig) { + config.Args = []string{clusterIDWithCachedSSOCreds} + config.Doit.Set(config.NS, doctl.ArgVersion, "v1beta1") + config.Doit.Set(config.NS, doctl.ArgKubernetesSSOIssuerURL, "https://issuer.example") + config.Doit.Set(config.NS, doctl.ArgKubernetesSSOClientID, "oidc-client-id") + config.Doit.Set(config.NS, doctl.ArgKubernetesSSOLocalServerPort, 8080) + }, + wantToken: "cached", + wantExp: expiryCached, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + withTestClient(t, func(config *CmdConfig, tm *tcMocks) { + if tt.prepareMocks != nil { + tt.prepareMocks(t, tm) + } + + var buf bytes.Buffer + config.Out = &buf + tt.setArgs(config) + + svc := kubernetesCommandService() + svc.ssoLogin = func(ctx context.Context, clientID, issuerURL string, opts ...sso.LocalOIDCLoginOption) (string, time.Time, error) { + return "oidc-id-token", expiryNew, nil + } + + err := svc.RunKubernetesKubeconfigExecCredential(config) + require.NoError(t, err) + + var got clientauthentication.ExecCredential + require.NoError(t, json.Unmarshal(buf.Bytes(), &got)) + require.Equal(t, tt.wantToken, got.Status.Token) + require.Equal(t, tt.wantExp.UTC(), got.Status.ExpirationTimestamp.UTC()) + + if tt.assertCache != nil { + tt.assertCache(t, execCredCacheDir, got) + } + }) + }) + } +} + func TestKubernetesList(t *testing.T) { withTestClient(t, func(config *CmdConfig, tm *tcMocks) { tm.kubernetes.EXPECT().List().Return(testClusterList, nil) diff --git a/go.mod b/go.mod index 2beb42a39..26a39754b 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/charmbracelet/bubbles v0.13.1-0.20220731172002-8f6516082803 github.com/charmbracelet/bubbletea v0.22.0 github.com/charmbracelet/lipgloss v0.5.0 + github.com/coreos/go-oidc v2.5.0+incompatible github.com/erikgeiser/promptkit v0.7.1-0.20220721185625-1f33bc73d091 github.com/joho/godotenv v1.4.0 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 @@ -108,6 +109,7 @@ require ( github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.0.0-beta.8 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/pquerna/cachecontrol v0.2.0 // indirect github.com/rivo/uniseg v0.4.2 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect @@ -125,6 +127,7 @@ require ( golang.org/x/text v0.23.0 // indirect golang.org/x/time v0.15.0 // indirect golang.org/x/tools v0.26.0 // indirect + gopkg.in/go-jose/go-jose.v2 v2.6.3 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/ini.v1 v1.66.4 // indirect k8s.io/klog/v2 v2.90.1 // indirect diff --git a/go.sum b/go.sum index d92c8e07c..f1a94f481 100644 --- a/go.sum +++ b/go.sum @@ -82,6 +82,8 @@ github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFqu github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= +github.com/coreos/go-oidc v2.5.0+incompatible h1:6W0vGJR3Tu0r0PwfmjOrRZSlfxeEln8dsejt3ZWIvwo= +github.com/coreos/go-oidc v2.5.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -359,6 +361,8 @@ github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qR github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/cachecontrol v0.2.0 h1:vBXSNuE5MYP9IJ5kjsdo8uq+w41jSPgvba2DEnkRx9k= +github.com/pquerna/cachecontrol v0.2.0/go.mod h1:NrUG3Z7Rdu85UNR3vm7SOsl1nFIeSiQnrHV5K9mBcUI= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= @@ -763,6 +767,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/go-jose/go-jose.v2 v2.6.3 h1:nt80fvSDlhKWQgSWyHyy5CfmlQr+asih51R8PTWNKKs= +gopkg.in/go-jose/go-jose.v2 v2.6.3/go.mod h1:zzZDPkNNw/c9IE7Z9jr11mBZQhKQTMzoEEIoEdZlFBI= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.66.4 h1:SsAcf+mM7mRZo2nJNGt8mZCjG8ZRaNGMURJw7BsIST4= diff --git a/internal/kubernetes/sso/auth_error.html b/internal/kubernetes/sso/auth_error.html new file mode 100644 index 000000000..6c8ab2682 --- /dev/null +++ b/internal/kubernetes/sso/auth_error.html @@ -0,0 +1,164 @@ + + + + + + Authentication failed · DigitalOcean + + + +
+ +
+ +

Authentication failed

+

You may close this window.

+
+ DigitalOcean · Kubernetes +
+
+
+ + diff --git a/internal/kubernetes/sso/auth_success.html b/internal/kubernetes/sso/auth_success.html new file mode 100644 index 000000000..c94b0e7ca --- /dev/null +++ b/internal/kubernetes/sso/auth_success.html @@ -0,0 +1,162 @@ + + + + + + Authentication successful · DigitalOcean + + + +
+ +
+ +

Authentication successful

+

You may close this window.

+
+ DigitalOcean · Kubernetes +
+
+
+ + diff --git a/internal/kubernetes/sso/sso.go b/internal/kubernetes/sso/sso.go new file mode 100644 index 000000000..51ed650fc --- /dev/null +++ b/internal/kubernetes/sso/sso.go @@ -0,0 +1,279 @@ +package sso + +import ( + "context" + _ "embed" + "errors" + "fmt" + "log" + "net/http" + "time" + + "golang.org/x/oauth2" + "golang.org/x/sync/errgroup" + + "github.com/coreos/go-oidc" + "github.com/google/uuid" + "github.com/pkg/browser" +) + +var ( + //go:embed auth_success.html + authSuccessHTML []byte + + //go:embed auth_error.html + authErrorHTML []byte + + now = time.Now +) + +const ( + defaultLocalServerPort uint16 = 8080 +) + +// GetIDToken obtains an ID token from an OIDC provider, following the Authorization Code Flow with PKCE: +// https://auth0.com/docs/get-started/authentication-and-authorization-flow/authorization-code-flow-with-pkce +func GetIDToken(ctx context.Context, clientID, issuerURL string, opts ...LocalOIDCLoginOption) (string, time.Time, error) { + ssoTool, err := newLocalOIDCLogin(clientID, issuerURL, opts...) + if err != nil { + return "", time.Time{}, fmt.Errorf("setting up SSO login tool: %w", err) + } + return ssoTool.getIDToken(ctx) +} + +type localOIDCLogin struct { + port uint16 + logger *log.Logger + + // set up on creation + oauth2Config oauth2.Config + state string + codeVerifier string + nonce string + provider *oidc.Provider + + ssoServer *ssoServer + + // only for stubbing in tests + openURL func(url string) error +} + +// LocalOIDCLoginOption is a function that can be used to configure a local OIDC login tool. +type LocalOIDCLoginOption func(*localOIDCLogin) + +// WithLocalServerPort sets the port to use for the local server which handles SSO authentication flow. +func WithLocalServerPort(port uint16) func(*localOIDCLogin) { + return func(l *localOIDCLogin) { + l.port = port + } +} + +// WithLogger sets the logger to use for the local OIDC login tool. +func WithLogger(logger *log.Logger) func(*localOIDCLogin) { + return func(l *localOIDCLogin) { + l.logger = logger + } +} + +// NewLocalOIDCLogin creates a new local OIDC login tool. +func newLocalOIDCLogin(clientID, issuerURL string, opts ...LocalOIDCLoginOption) (*localOIDCLogin, error) { + t := &localOIDCLogin{ + port: defaultLocalServerPort, + openURL: browser.OpenURL, + logger: log.Default(), + } + for _, opt := range opts { + opt(t) + } + + provider, err := oidc.NewProvider(context.Background(), issuerURL) + if err != nil { + return nil, fmt.Errorf("creating OIDC provider: %w", err) + } + + oauth2Config := oauth2.Config{ + ClientID: clientID, + Endpoint: provider.Endpoint(), + RedirectURL: t.redirectURL(), + Scopes: []string{oidc.ScopeOpenID, "email", "team_role"}, + } + + state := uuid.New().String() + codeVerifier := oauth2.GenerateVerifier() + codeChallenge := oauth2.S256ChallengeOption(codeVerifier) + nonce := uuid.New().String() + + ssoServer := &ssoServer{ + authCodeURL: oauth2Config.AuthCodeURL(state, codeChallenge, oidc.Nonce(nonce)), + logger: t.logger, + authCodeCh: make(chan authCodeResponse), + errorCh: make(chan error), + } + + t.oauth2Config = oauth2Config + t.state = state + t.codeVerifier = codeVerifier + t.nonce = nonce + t.ssoServer = ssoServer + t.provider = provider + + return t, nil +} + +func (t *localOIDCLogin) redirectURL() string { + return fmt.Sprintf("http://localhost:%d/callback", t.port) +} + +func (t *localOIDCLogin) loginURL() string { + return fmt.Sprintf("http://localhost:%d/login", t.port) +} + +func (t *localOIDCLogin) getAuthCode(ctx context.Context) (string, error) { + var group errgroup.Group + + server := &http.Server{ + Handler: t.ssoServer, + Addr: fmt.Sprintf("127.0.0.1:%d", t.port), + } + defer server.Close() + + group.Go(func() error { + t.logger.Printf("Starting local server for OIDC authentication on port %d\n", t.port) + if err := server.ListenAndServe(); err != nil { + if err != http.ErrServerClosed { + return fmt.Errorf("web server shutdown unexpectedly: %w", err) + } + } + + return nil + }) + + var authCode authCodeResponse + group.Go(func() error { + defer func() { + ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) + defer cancel() + t.logger.Println("Shutting down local OIDC authentication server") + _ = server.Shutdown(ctx) + }() + + select { + case code := <-t.ssoServer.authCodeCh: + authCode = code + return nil + case err := <-t.ssoServer.errorCh: + return err + case <-ctx.Done(): + return ctx.Err() + } + }) + + t.logger.Printf("Opening login URL in the default browser. If it didn't open automatically, paste this URL into your preferred browser:\n\t%s\n", t.loginURL()) + if err := t.openURL(t.loginURL()); err != nil { + return "", fmt.Errorf("opening OIDC login URL in browser: %v", err) + } + + if err := group.Wait(); err != nil { + return "", err + } + + if authCode.code == "" { + return "", errors.New("no authorization code received") + } + + if authCode.state != t.state { + return "", errors.New("authorzation flow state mismatch") + } + + return authCode.code, nil +} + +func (t *localOIDCLogin) getIDToken(ctx context.Context) (string, time.Time, error) { + code, err := t.getAuthCode(ctx) + if err != nil { + return "", time.Time{}, fmt.Errorf("getting authorization code: %w", err) + } + + t.logger.Println("Received an authorization code, exchanging for ID token") + token, err := t.oauth2Config.Exchange(ctx, code, oauth2.S256ChallengeOption(t.codeVerifier), oauth2.VerifierOption(t.codeVerifier)) + if err != nil { + return "", time.Time{}, fmt.Errorf("exchanging authorization code for ID token: %w", err) + } + + idToken, ok := token.Extra("id_token").(string) + if !ok || idToken == "" { + return "", time.Time{}, errors.New("no ID token found") + } + + verifier := t.provider.Verifier(&oidc.Config{ClientID: t.oauth2Config.ClientID, Now: now}) + verifiedIDToken, err := verifier.Verify(ctx, idToken) + if err != nil { + return "", time.Time{}, fmt.Errorf("could not verify ID token: %w", err) + } + if t.nonce != verifiedIDToken.Nonce { + return "", time.Time{}, fmt.Errorf("nonce did not match (wants %s but got %s)", t.nonce, verifiedIDToken.Nonce) + } + + return idToken, token.Expiry, nil +} + +type authCodeResponse struct { + code string + state string +} + +type ssoServer struct { + authCodeURL string + logger *log.Logger + authCodeCh chan authCodeResponse + errorCh chan error +} + +func (s *ssoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + switch { + case r.Method == "GET" && r.URL.Path == "/callback" && q.Get("error") != "": + s.handleError(w, r) + case r.Method == "GET" && r.URL.Path == "/callback" && q.Get("code") != "": + s.handleAuthorizationCode(w, r) + case r.Method == "GET" && r.URL.Path == "/login": + s.handleLogin(w, r) + case r.Method == "GET" && r.URL.Path == "/error": + s.handleError(w, r) + default: + http.NotFound(w, r) + } +} + +func (s *ssoServer) handleError(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + code := q.Get("error") + if code == "" { + code = "unknown_error" + } + desc := q.Get("error_description") + if desc != "" { + s.errorCh <- fmt.Errorf("%s: %s", code, desc) + } else { + s.errorCh <- errors.New(code) + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write(authErrorHTML) +} + +func (s *ssoServer) handleAuthorizationCode(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + s.authCodeCh <- authCodeResponse{ + code: q.Get("code"), + state: q.Get("state"), + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write(authSuccessHTML) +} + +func (s *ssoServer) handleLogin(w http.ResponseWriter, r *http.Request) { + s.logger.Println("Redirecting to IDP login URL") + http.Redirect(w, r, s.authCodeURL, http.StatusFound) +} diff --git a/internal/kubernetes/sso/sso_test.go b/internal/kubernetes/sso/sso_test.go new file mode 100644 index 000000000..4b4c2566b --- /dev/null +++ b/internal/kubernetes/sso/sso_test.go @@ -0,0 +1,308 @@ +package sso + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" + + jose "gopkg.in/go-jose/go-jose.v2" + + "github.com/stretchr/testify/require" +) + +type fakeIDP struct { + issuerURL string + clientID string + authCode string + + // RSA key used to sign id_token and publish JWKS. + privKey *rsa.PrivateKey + + mu sync.Mutex + lastNonce string + + // inject custom handlers to simulate errors + authorizeHandler func(w http.ResponseWriter, r *http.Request) +} + +func (p *fakeIDP) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/.well-known/openid-configuration" && r.Method == http.MethodGet: + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "issuer": p.issuerURL, + "authorization_endpoint": p.issuerURL + "/oauth/authorize", + "token_endpoint": p.issuerURL + "/oauth/token", + "jwks_uri": p.issuerURL + "/jwks", + "id_token_signing_alg_values_supported": []string{"RS256"}, + }) + case r.URL.Path == "/jwks" && r.Method == http.MethodGet: + pub := jose.JSONWebKey{Key: &p.privKey.PublicKey, KeyID: "test-kid", Algorithm: string(jose.RS256), Use: "sig"} + set := jose.JSONWebKeySet{Keys: []jose.JSONWebKey{pub}} + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(set); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + case r.URL.Path == "/oauth/authorize" && r.Method == http.MethodGet: + if p.authorizeHandler != nil { + p.authorizeHandler(w, r) + return + } + p.mu.Lock() + p.lastNonce = r.URL.Query().Get("nonce") + p.mu.Unlock() + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + u, err := url.Parse(redirectURI) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + q := u.Query() + q.Set("code", p.authCode) + q.Set("state", state) + u.RawQuery = q.Encode() + http.Redirect(w, r, u.String(), http.StatusFound) + case r.URL.Path == "/oauth/token" && r.Method == http.MethodPost: + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "unexpected grant_type", http.StatusBadRequest) + return + } + if r.FormValue("code") != p.authCode { + http.Error(w, "unexpected code", http.StatusBadRequest) + return + } + if r.FormValue("client_id") != p.clientID { + http.Error(w, "unexpected client_id", http.StatusBadRequest) + return + } + p.mu.Lock() + nonce := p.lastNonce + p.mu.Unlock() + claims, err := json.Marshal(map[string]any{ + "iss": p.issuerURL, + "sub": "test-user", + "aud": p.clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + "nonce": nonce, + }) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + signer, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.RS256, Key: p.privKey}, + (&jose.SignerOptions{}).WithHeader(jose.HeaderKey("kid"), "test-kid"), + ) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + object, err := signer.Sign(claims) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + idTokenStr, err := object.CompactSerialize() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprintf(w, `{ + "access_token": "fake-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "id_token": %q + }`, idTokenStr) + default: + http.NotFound(w, r) + } +} + +func authorizeHandlerAccessDenied(w http.ResponseWriter, r *http.Request) { + redirectURI := r.URL.Query().Get("redirect_uri") + u, err := url.Parse(redirectURI) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + q := u.Query() + q.Set("error", "access_denied") + q.Set("error_description", "user declined") + u.RawQuery = q.Encode() + http.Redirect(w, r, u.String(), http.StatusFound) +} + +func authorizeHandlerWrongState(authCode string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + redirectURI := r.URL.Query().Get("redirect_uri") + u, err := url.Parse(redirectURI) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + q := u.Query() + q.Set("code", authCode) + q.Set("state", "wrong-state") + u.RawQuery = q.Encode() + http.Redirect(w, r, u.String(), http.StatusFound) + } +} + +func pickFreePort(t *testing.T) uint16 { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + return uint16(ln.Addr().(*net.TCPAddr).Port) +} + +func waitForLocalServer(ctx context.Context, port uint16) error { + const ( + pollInterval = 25 * time.Millisecond + dialTimeout = 2 * time.Second + ) + addr := fmt.Sprintf("127.0.0.1:%d", port) + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timeout : %w", ctx.Err()) + case <-ticker.C: + conn, err := net.DialTimeout("tcp", addr, dialTimeout) + if err == nil { + conn.Close() + return nil + } + } + } +} + +func TestGetIDToken(t *testing.T) { + const ( + clientID = "test-client-id" + authCode = "test-auth-code-from-idp" + ) + + tests := []struct { + name string + authorizeHandler func(http.ResponseWriter, *http.Request) + wantErr bool + errSubstring string + }{ + { + name: "success", + authorizeHandler: nil, + wantErr: false, + }, + { + name: "callback oauth error", + authorizeHandler: authorizeHandlerAccessDenied, + wantErr: true, + errSubstring: "access_denied", + }, + { + name: "state mismatch", + authorizeHandler: authorizeHandlerWrongState(authCode), + wantErr: true, + errSubstring: "state mismatch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + idp := &fakeIDP{ + clientID: clientID, + authCode: authCode, + privKey: privKey, + authorizeHandler: tt.authorizeHandler, + } + srv := httptest.NewServer(idp) + defer srv.Close() + idp.issuerURL = srv.URL + + port := pickFreePort(t) + browserDone := make(chan error, 1) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + login, err := newLocalOIDCLogin(clientID, idp.issuerURL, + WithLocalServerPort(port), + WithLogger(log.New(io.Discard, "", 0)), + ) + require.NoError(t, err) + + login.openURL = func(loginURL string) error { + go func() { + if err := waitForLocalServer(ctx, port); err != nil { + browserDone <- err + return + } + client := &http.Client{Timeout: 20 * time.Second} + resp, err := client.Get(loginURL) + if err != nil { + browserDone <- err + return + } + resp.Body.Close() + browserDone <- nil + }() + return nil + } + + gotToken, exp, err := login.getIDToken(ctx) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errSubstring) + } else { + require.NoError(t, err) + require.NotEmpty(t, gotToken) + parts := strings.Split(gotToken, ".") + require.Len(t, parts, 3, "compact JWS should have three segments") + rawPayload, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + var claims struct { + Iss string `json:"iss"` + Aud string `json:"aud"` + Nonce string `json:"nonce"` + } + require.NoError(t, json.Unmarshal(rawPayload, &claims)) + require.Equal(t, idp.issuerURL, claims.Iss) + require.Equal(t, clientID, claims.Aud) + require.Equal(t, login.nonce, claims.Nonce) + require.True(t, exp.After(time.Now()), "expected non-zero token expiry from IdP response") + } + + select { + case berr := <-browserDone: + require.NoError(t, berr) + case <-time.After(5 * time.Second): + t.Fatal("browser simulation did not complete") + } + }) + } +} diff --git a/vendor/github.com/coreos/go-oidc/.gitignore b/vendor/github.com/coreos/go-oidc/.gitignore new file mode 100644 index 000000000..c96f2f47b --- /dev/null +++ b/vendor/github.com/coreos/go-oidc/.gitignore @@ -0,0 +1,2 @@ +/bin +/gopath diff --git a/vendor/github.com/coreos/go-oidc/.travis.yml b/vendor/github.com/coreos/go-oidc/.travis.yml new file mode 100644 index 000000000..9f0b06010 --- /dev/null +++ b/vendor/github.com/coreos/go-oidc/.travis.yml @@ -0,0 +1,18 @@ +language: go + +go: + - "1.14" + - "1.15" +arch: + - AMD64 + - ppc64le +install: + - go get -v -t github.com/coreos/go-oidc/... + - go get golang.org/x/tools/cmd/cover + - go get golang.org/x/lint/golint + +script: + - ./test + +notifications: + email: false diff --git a/vendor/github.com/coreos/go-oidc/CONTRIBUTING.md b/vendor/github.com/coreos/go-oidc/CONTRIBUTING.md new file mode 100644 index 000000000..6662073a8 --- /dev/null +++ b/vendor/github.com/coreos/go-oidc/CONTRIBUTING.md @@ -0,0 +1,71 @@ +# How to Contribute + +CoreOS projects are [Apache 2.0 licensed](LICENSE) and accept contributions via +GitHub pull requests. This document outlines some of the conventions on +development workflow, commit message formatting, contact points and other +resources to make it easier to get your contribution accepted. + +# Certificate of Origin + +By contributing to this project you agree to the Developer Certificate of +Origin (DCO). This document was created by the Linux Kernel community and is a +simple statement that you, as a contributor, have the legal right to make the +contribution. See the [DCO](DCO) file for details. + +# Email and Chat + +The project currently uses the general CoreOS email list and IRC channel: +- Email: [coreos-dev](https://groups.google.com/forum/#!forum/coreos-dev) +- IRC: #[coreos](irc://irc.freenode.org:6667/#coreos) IRC channel on freenode.org + +Please avoid emailing maintainers found in the MAINTAINERS file directly. They +are very busy and read the mailing lists. + +## Getting Started + +- Fork the repository on GitHub +- Read the [README](README.md) for build and test instructions +- Play with the project, submit bugs, submit patches! + +## Contribution Flow + +This is a rough outline of what a contributor's workflow looks like: + +- Create a topic branch from where you want to base your work (usually master). +- Make commits of logical units. +- Make sure your commit messages are in the proper format (see below). +- Push your changes to a topic branch in your fork of the repository. +- Make sure the tests pass, and add any new tests as appropriate. +- Submit a pull request to the original repository. + +Thanks for your contributions! + +### Format of the Commit Message + +We follow a rough convention for commit messages that is designed to answer two +questions: what changed and why. The subject line should feature the what and +the body of the commit should describe the why. + +``` +scripts: add the test-cluster command + +this uses tmux to setup a test cluster that you can easily kill and +start for debugging. + +Fixes #38 +``` + +The format can be described more formally as follows: + +``` +: + + + +