Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions providers/flagd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,25 @@ The flagd provider currently support following flag evaluation metadata,
| `scope` | string | "selector" set for the associated source in flagd |
| `providerID` | string | "providerID" set for the associated source in flagd |

## Selector Handling

When using the in-process resolver with a gRPC sync source, the provider supports filtering flag configurations using a selector. The selector can be configured using the `WithSelector` option or the `FLAGD_SOURCE_SELECTOR` environment variable.

### Header-based Selector (Recommended)

The provider now sends the selector as a `flagd-selector` gRPC metadata header when communicating with flagd sync services. This approach is consistent with how selectors are handled across all flagd services (sync, evaluation, and OFREP).

```go
provider, err := flagd.NewProvider(
flagd.WithInProcessResolver(),
flagd.WithSelector("source=database,app=myapp"),
)
```

### Backward Compatibility

For backward compatibility with older flagd versions, the provider continues to include the selector in the gRPC request fields alongside the header. This dual approach ensures compatibility during the migration period until all flagd instances are updated.

## Logging

If not configured, logging falls back to the standard Go log package at error level only.
Expand Down
41 changes: 41 additions & 0 deletions providers/flagd/pkg/service/in_process/grpc_interceptors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package process

import (
"context"
googlegrpc "google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

// selectorUnaryInterceptor adds the flagd-selector metadata header to unary gRPC calls
func selectorUnaryInterceptor(selector string) googlegrpc.UnaryClientInterceptor {
return func(
ctx context.Context,
method string,
req, reply interface{},
cc *googlegrpc.ClientConn,
invoker googlegrpc.UnaryInvoker,
opts ...googlegrpc.CallOption,
) error {
if selector != "" {
ctx = metadata.AppendToOutgoingContext(ctx, "flagd-selector", selector)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}

// selectorStreamInterceptor adds the flagd-selector metadata header to streaming gRPC calls
func selectorStreamInterceptor(selector string) googlegrpc.StreamClientInterceptor {
return func(
ctx context.Context,
desc *googlegrpc.StreamDesc,
cc *googlegrpc.ClientConn,
method string,
streamer googlegrpc.Streamer,
opts ...googlegrpc.CallOption,
) (googlegrpc.ClientStream, error) {
if selector != "" {
ctx = metadata.AppendToOutgoingContext(ctx, "flagd-selector", selector)
}
return streamer(ctx, desc, cc, method, opts...)
}
}
15 changes: 14 additions & 1 deletion providers/flagd/pkg/service/in_process/grpc_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
}`
)

// Type aliases for interfaces required by this component - needed for mock generation with gomock

Check failure on line 60 in providers/flagd/pkg/service/in_process/grpc_sync.go

View workflow job for this annotation

GitHub Actions / lint

ST1021: comment on exported type FlagSyncServiceClient should be of the form "FlagSyncServiceClient ..." (with optional leading article) (staticcheck)
type FlagSyncServiceClient interface {
syncv1grpc.FlagSyncServiceClient
}
Expand Down Expand Up @@ -115,13 +115,26 @@

// createConnection creates and configures the gRPC connection
func (g *Sync) createConnection() (*grpc.ClientConn, error) {

var grpcInterceptorDialOptions []grpc.DialOption
if g.Selector != "" {
grpcInterceptorDialOptions = append(grpcInterceptorDialOptions,
grpc.WithChainUnaryInterceptor(selectorUnaryInterceptor(g.Selector)),
grpc.WithChainStreamInterceptor(selectorStreamInterceptor(g.Selector)),
)
}

if len(g.GrpcDialOptionsOverride) > 0 {
g.Logger.Debug("using provided gRPC DialOptions override")
return grpc.NewClient(g.URI, g.GrpcDialOptionsOverride...)
dialOptions := make([]grpc.DialOption, len(g.GrpcDialOptionsOverride))
dialOptions = append(dialOptions, g.GrpcDialOptionsOverride...)
dialOptions = append(dialOptions, grpcInterceptorDialOptions...)
Comment on lines +129 to +131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a bug in how the dialOptions slice is constructed. make([]grpc.DialOption, len(g.GrpcDialOptionsOverride)) creates a slice with len(g.GrpcDialOptionsOverride) nil elements. Appending to this slice results in a slice that starts with nils, which will be passed to grpc.NewClient. This will likely cause connection issues or panics. To fix this, you should initialize an empty slice with appropriate capacity and then append the options.

Suggested change
dialOptions := make([]grpc.DialOption, len(g.GrpcDialOptionsOverride))
dialOptions = append(dialOptions, g.GrpcDialOptionsOverride...)
dialOptions = append(dialOptions, grpcInterceptorDialOptions...)
dialOptions := make([]grpc.DialOption, 0, len(g.GrpcDialOptionsOverride)+len(grpcInterceptorDialOptions))
dialOptions = append(dialOptions, g.GrpcDialOptionsOverride...)
dialOptions = append(dialOptions, grpcInterceptorDialOptions...)

return grpc.NewClient(g.URI, dialOptions...)
}

// Build standard dial options
dialOptions, err := g.buildDialOptions()
dialOptions = append(dialOptions, grpcInterceptorDialOptions...)
if err != nil {
return nil, fmt.Errorf("failed to build dial options: %w", err)
}
Comment on lines 136 to 140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's standard practice in Go to check for an error immediately after the function call that returns it. This improves readability and ensures that you don't operate on potentially invalid data. Please move the error check to be right after the call to g.buildDialOptions().

Suggested change
dialOptions, err := g.buildDialOptions()
dialOptions = append(dialOptions, grpcInterceptorDialOptions...)
if err != nil {
return nil, fmt.Errorf("failed to build dial options: %w", err)
}
dialOptions, err := g.buildDialOptions()
if err != nil {
return nil, fmt.Errorf("failed to build dial options: %w", err)
}
dialOptions = append(dialOptions, grpcInterceptorDialOptions...)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package process

import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"

"buf.build/gen/go/open-feature/flagd/grpc/go/flagd/sync/v1/syncv1grpc"
v1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

// TestSelectorHeader verifies that the flagd-selector header is sent correctly in gRPC metadata
func TestSelectorHeader(t *testing.T) {
tests := []struct {
name string
selector string
expectHeader bool
expectedValue string
}{
{
name: "selector header is sent when configured",
selector: "source=database,app=myapp",
expectHeader: true,
expectedValue: "source=database,app=myapp",
},
{
name: "no selector header when selector is empty",
selector: "",
expectHeader: false,
expectedValue: "",
},
{
name: "selector header with complex value",
selector: "source=test,environment=production,region=us-east",
expectHeader: true,
expectedValue: "source=test,environment=production,region=us-east",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// given
port := findFreePort(t)
listen, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
// Listener will be closed by GracefulStop, so ignore "use of closed network connection" errors
_ = listen.Close()
}()

headerReceived := make(chan string, 1)
mockServer := &selectorHeaderCapturingServer{
headerReceived: headerReceived,
mockResponse: &v1.SyncFlagsResponse{
FlagConfiguration: flagRsp,
},
}

grpcServer := grpc.NewServer()
syncv1grpc.RegisterFlagSyncServiceServer(grpcServer, mockServer)

serverDone := make(chan struct{})
go func() {
defer close(serverDone)
if err := grpcServer.Serve(listen); err != nil {
t.Logf("Server exited: %v", err)
}
}()
defer func() {
grpcServer.GracefulStop()
<-serverDone
}()

inProcessService := NewInProcessService(Configuration{
Host: "localhost",
Port: port,
Selector: tt.selector,
TLSEnabled: false,
})

// when
err = inProcessService.Init()
if err != nil {
t.Fatalf("Failed to initialize service: %v", err)
}
defer inProcessService.Shutdown()

// Wait for provider to be ready
select {
case <-inProcessService.events:
// Provider ready event
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for provider ready event")
}

// then - verify the flagd-selector header
select {
case receivedSelector := <-headerReceived:
if tt.expectHeader {
if receivedSelector != tt.expectedValue {
t.Errorf("Expected selector header to be %q, but got %q", tt.expectedValue, receivedSelector)
}
} else {
if receivedSelector != "" {
t.Errorf("Expected no selector header, but got %q", receivedSelector)
}
}
case <-time.After(3 * time.Second):
if tt.expectHeader {
t.Fatal("Timeout waiting for flagd-selector header")
}
}
})
}
}

// findFreePort finds an available port for testing
func findFreePort(t *testing.T) int {
t.Helper()
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to find free port: %v", err)
}
port := listener.Addr().(*net.TCPAddr).Port
if err := listener.Close(); err != nil {
t.Fatalf("Failed to close listener: %v", err)
}
return port
}

// selectorHeaderCapturingServer captures the flagd-selector header from incoming requests
type selectorHeaderCapturingServer struct {
syncv1grpc.UnimplementedFlagSyncServiceServer
headerReceived chan string
mockResponse *v1.SyncFlagsResponse
mu sync.Mutex
}

func (s *selectorHeaderCapturingServer) SyncFlags(req *v1.SyncFlagsRequest, stream syncv1grpc.FlagSyncService_SyncFlagsServer) error {
s.mu.Lock()
defer s.mu.Unlock()

// Extract and capture the flagd-selector header
md, ok := metadata.FromIncomingContext(stream.Context())
if ok {
if values := md.Get("flagd-selector"); len(values) > 0 {
select {
case s.headerReceived <- values[0]:
default:
// Channel full, skip
}
} else {
select {
case s.headerReceived <- "":
default:
// Channel full, skip
}
}
} else {
select {
case s.headerReceived <- "":
default:
// Channel full, skip
}
}
Comment on lines +150 to +172
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for extracting the flagd-selector header is duplicated in FetchAllFlags on lines 188-210. To improve maintainability and reduce code duplication, consider extracting this logic into a helper method on selectorHeaderCapturingServer.

For example, you could create a method like this:

func (s *selectorHeaderCapturingServer) captureHeader(ctx context.Context) {
    md, _ := metadata.FromIncomingContext(ctx)
    headerValue := ""
    if values := md.Get("flagd-selector"); len(values) > 0 {
        headerValue = values[0]
    }
    select {
    case s.headerReceived <- headerValue:
    default:
        // Channel is full, which is acceptable in this test.
    }
}

You could then call s.captureHeader(stream.Context()) here and s.captureHeader(ctx) in FetchAllFlags.


// Send mock response
if err := stream.Send(s.mockResponse); err != nil {
return err
}

// Keep stream open briefly
time.Sleep(500 * time.Millisecond)
return nil
}

func (s *selectorHeaderCapturingServer) FetchAllFlags(ctx context.Context, req *v1.FetchAllFlagsRequest) (*v1.FetchAllFlagsResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()

// Extract and capture the flagd-selector header
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if values := md.Get("flagd-selector"); len(values) > 0 {
select {
case s.headerReceived <- values[0]:
default:
// Channel full, skip
}
} else {
select {
case s.headerReceived <- "":
default:
// Channel full, skip
}
}
} else {
select {
case s.headerReceived <- "":
default:
// Channel full, skip
}
}

return &v1.FetchAllFlagsResponse{
FlagConfiguration: flagRsp,
}, nil
}

func (s *selectorHeaderCapturingServer) GetMetadata(ctx context.Context, req *v1.GetMetadataRequest) (*v1.GetMetadataResponse, error) {
return &v1.GetMetadataResponse{}, nil
}
Loading