diff --git a/providers/flagd/README.md b/providers/flagd/README.md index f73412c7c..7ab02db47 100644 --- a/providers/flagd/README.md +++ b/providers/flagd/README.md @@ -189,6 +189,25 @@ The flagd provider currently support following flag evaluation metadata, | `scope` | string | "selector" set for the associated source in flagd | | `providerID` | string | "providerID" set for the associated source in flagd | +## Selector Handling + +When using the in-process resolver with a gRPC sync source, the provider supports filtering flag configurations using a selector. The selector can be configured using the `WithSelector` option or the `FLAGD_SOURCE_SELECTOR` environment variable. + +### Header-based Selector (Recommended) + +The provider now sends the selector as a `flagd-selector` gRPC metadata header when communicating with flagd sync services. This approach is consistent with how selectors are handled across all flagd services (sync, evaluation, and OFREP). + +```go +provider, err := flagd.NewProvider( + flagd.WithInProcessResolver(), + flagd.WithSelector("source=database,app=myapp"), +) +``` + +### Backward Compatibility + +For backward compatibility with older flagd versions, the provider continues to include the selector in the gRPC request fields alongside the header. This dual approach ensures compatibility during the migration period until all flagd instances are updated. + ## Logging If not configured, logging falls back to the standard Go log package at error level only. diff --git a/providers/flagd/pkg/service/in_process/grpc_interceptors.go b/providers/flagd/pkg/service/in_process/grpc_interceptors.go new file mode 100644 index 000000000..3190d911a --- /dev/null +++ b/providers/flagd/pkg/service/in_process/grpc_interceptors.go @@ -0,0 +1,41 @@ +package process + +import ( + "context" + googlegrpc "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +// selectorUnaryInterceptor adds the flagd-selector metadata header to unary gRPC calls +func selectorUnaryInterceptor(selector string) googlegrpc.UnaryClientInterceptor { + return func( + ctx context.Context, + method string, + req, reply interface{}, + cc *googlegrpc.ClientConn, + invoker googlegrpc.UnaryInvoker, + opts ...googlegrpc.CallOption, + ) error { + if selector != "" { + ctx = metadata.AppendToOutgoingContext(ctx, "flagd-selector", selector) + } + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +// selectorStreamInterceptor adds the flagd-selector metadata header to streaming gRPC calls +func selectorStreamInterceptor(selector string) googlegrpc.StreamClientInterceptor { + return func( + ctx context.Context, + desc *googlegrpc.StreamDesc, + cc *googlegrpc.ClientConn, + method string, + streamer googlegrpc.Streamer, + opts ...googlegrpc.CallOption, + ) (googlegrpc.ClientStream, error) { + if selector != "" { + ctx = metadata.AppendToOutgoingContext(ctx, "flagd-selector", selector) + } + return streamer(ctx, desc, cc, method, opts...) + } +} diff --git a/providers/flagd/pkg/service/in_process/grpc_sync.go b/providers/flagd/pkg/service/in_process/grpc_sync.go index 9b6b93caa..337bc49fc 100644 --- a/providers/flagd/pkg/service/in_process/grpc_sync.go +++ b/providers/flagd/pkg/service/in_process/grpc_sync.go @@ -115,13 +115,26 @@ func (g *Sync) Init(ctx context.Context) error { // createConnection creates and configures the gRPC connection func (g *Sync) createConnection() (*grpc.ClientConn, error) { + + var grpcInterceptorDialOptions []grpc.DialOption + if g.Selector != "" { + grpcInterceptorDialOptions = append(grpcInterceptorDialOptions, + grpc.WithChainUnaryInterceptor(selectorUnaryInterceptor(g.Selector)), + grpc.WithChainStreamInterceptor(selectorStreamInterceptor(g.Selector)), + ) + } + if len(g.GrpcDialOptionsOverride) > 0 { g.Logger.Debug("using provided gRPC DialOptions override") - return grpc.NewClient(g.URI, g.GrpcDialOptionsOverride...) + dialOptions := make([]grpc.DialOption, len(g.GrpcDialOptionsOverride)) + dialOptions = append(dialOptions, g.GrpcDialOptionsOverride...) + dialOptions = append(dialOptions, grpcInterceptorDialOptions...) + return grpc.NewClient(g.URI, dialOptions...) } // Build standard dial options dialOptions, err := g.buildDialOptions() + dialOptions = append(dialOptions, grpcInterceptorDialOptions...) if err != nil { return nil, fmt.Errorf("failed to build dial options: %w", err) } diff --git a/providers/flagd/pkg/service/in_process/service_selector_header_test.go b/providers/flagd/pkg/service/in_process/service_selector_header_test.go new file mode 100644 index 000000000..4aa361227 --- /dev/null +++ b/providers/flagd/pkg/service/in_process/service_selector_header_test.go @@ -0,0 +1,219 @@ +package process + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + "buf.build/gen/go/open-feature/flagd/grpc/go/flagd/sync/v1/syncv1grpc" + v1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +// TestSelectorHeader verifies that the flagd-selector header is sent correctly in gRPC metadata +func TestSelectorHeader(t *testing.T) { + tests := []struct { + name string + selector string + expectHeader bool + expectedValue string + }{ + { + name: "selector header is sent when configured", + selector: "source=database,app=myapp", + expectHeader: true, + expectedValue: "source=database,app=myapp", + }, + { + name: "no selector header when selector is empty", + selector: "", + expectHeader: false, + expectedValue: "", + }, + { + name: "selector header with complex value", + selector: "source=test,environment=production,region=us-east", + expectHeader: true, + expectedValue: "source=test,environment=production,region=us-east", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // given + port := findFreePort(t) + listen, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { + // Listener will be closed by GracefulStop, so ignore "use of closed network connection" errors + _ = listen.Close() + }() + + headerReceived := make(chan string, 1) + mockServer := &selectorHeaderCapturingServer{ + headerReceived: headerReceived, + mockResponse: &v1.SyncFlagsResponse{ + FlagConfiguration: flagRsp, + }, + } + + grpcServer := grpc.NewServer() + syncv1grpc.RegisterFlagSyncServiceServer(grpcServer, mockServer) + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + if err := grpcServer.Serve(listen); err != nil { + t.Logf("Server exited: %v", err) + } + }() + defer func() { + grpcServer.GracefulStop() + <-serverDone + }() + + inProcessService := NewInProcessService(Configuration{ + Host: "localhost", + Port: port, + Selector: tt.selector, + TLSEnabled: false, + }) + + // when + err = inProcessService.Init() + if err != nil { + t.Fatalf("Failed to initialize service: %v", err) + } + defer inProcessService.Shutdown() + + // Wait for provider to be ready + select { + case <-inProcessService.events: + // Provider ready event + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for provider ready event") + } + + // then - verify the flagd-selector header + select { + case receivedSelector := <-headerReceived: + if tt.expectHeader { + if receivedSelector != tt.expectedValue { + t.Errorf("Expected selector header to be %q, but got %q", tt.expectedValue, receivedSelector) + } + } else { + if receivedSelector != "" { + t.Errorf("Expected no selector header, but got %q", receivedSelector) + } + } + case <-time.After(3 * time.Second): + if tt.expectHeader { + t.Fatal("Timeout waiting for flagd-selector header") + } + } + }) + } +} + +// findFreePort finds an available port for testing +func findFreePort(t *testing.T) int { + t.Helper() + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to find free port: %v", err) + } + port := listener.Addr().(*net.TCPAddr).Port + if err := listener.Close(); err != nil { + t.Fatalf("Failed to close listener: %v", err) + } + return port +} + +// selectorHeaderCapturingServer captures the flagd-selector header from incoming requests +type selectorHeaderCapturingServer struct { + syncv1grpc.UnimplementedFlagSyncServiceServer + headerReceived chan string + mockResponse *v1.SyncFlagsResponse + mu sync.Mutex +} + +func (s *selectorHeaderCapturingServer) SyncFlags(req *v1.SyncFlagsRequest, stream syncv1grpc.FlagSyncService_SyncFlagsServer) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Extract and capture the flagd-selector header + md, ok := metadata.FromIncomingContext(stream.Context()) + if ok { + if values := md.Get("flagd-selector"); len(values) > 0 { + select { + case s.headerReceived <- values[0]: + default: + // Channel full, skip + } + } else { + select { + case s.headerReceived <- "": + default: + // Channel full, skip + } + } + } else { + select { + case s.headerReceived <- "": + default: + // Channel full, skip + } + } + + // Send mock response + if err := stream.Send(s.mockResponse); err != nil { + return err + } + + // Keep stream open briefly + time.Sleep(500 * time.Millisecond) + return nil +} + +func (s *selectorHeaderCapturingServer) FetchAllFlags(ctx context.Context, req *v1.FetchAllFlagsRequest) (*v1.FetchAllFlagsResponse, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Extract and capture the flagd-selector header + md, ok := metadata.FromIncomingContext(ctx) + if ok { + if values := md.Get("flagd-selector"); len(values) > 0 { + select { + case s.headerReceived <- values[0]: + default: + // Channel full, skip + } + } else { + select { + case s.headerReceived <- "": + default: + // Channel full, skip + } + } + } else { + select { + case s.headerReceived <- "": + default: + // Channel full, skip + } + } + + return &v1.FetchAllFlagsResponse{ + FlagConfiguration: flagRsp, + }, nil +} + +func (s *selectorHeaderCapturingServer) GetMetadata(ctx context.Context, req *v1.GetMetadataRequest) (*v1.GetMetadataResponse, error) { + return &v1.GetMetadataResponse{}, nil +}