-
Notifications
You must be signed in to change notification settings - Fork 66
feat(flagd): Add flagd-selector gRPC metadata header to in-process service. #790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
37de940
2033002
a3f2081
3e5eba6
d8b52b9
b59e9db
4cbc41c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| package process | ||
|
|
||
| import ( | ||
| "context" | ||
| googlegrpc "google.golang.org/grpc" | ||
| "google.golang.org/grpc/metadata" | ||
| ) | ||
|
|
||
| // selectorUnaryInterceptor adds the flagd-selector metadata header to unary gRPC calls | ||
| func selectorUnaryInterceptor(selector string) googlegrpc.UnaryClientInterceptor { | ||
| return func( | ||
| ctx context.Context, | ||
| method string, | ||
| req, reply interface{}, | ||
| cc *googlegrpc.ClientConn, | ||
| invoker googlegrpc.UnaryInvoker, | ||
| opts ...googlegrpc.CallOption, | ||
| ) error { | ||
| if selector != "" { | ||
| ctx = metadata.AppendToOutgoingContext(ctx, "flagd-selector", selector) | ||
| } | ||
| return invoker(ctx, method, req, reply, cc, opts...) | ||
| } | ||
| } | ||
|
|
||
| // selectorStreamInterceptor adds the flagd-selector metadata header to streaming gRPC calls | ||
| func selectorStreamInterceptor(selector string) googlegrpc.StreamClientInterceptor { | ||
| return func( | ||
| ctx context.Context, | ||
| desc *googlegrpc.StreamDesc, | ||
| cc *googlegrpc.ClientConn, | ||
| method string, | ||
| streamer googlegrpc.Streamer, | ||
| opts ...googlegrpc.CallOption, | ||
| ) (googlegrpc.ClientStream, error) { | ||
| if selector != "" { | ||
| ctx = metadata.AppendToOutgoingContext(ctx, "flagd-selector", selector) | ||
| } | ||
| return streamer(ctx, desc, cc, method, opts...) | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -57,7 +57,7 @@ | |||||||||||||||||||||
| }` | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Type aliases for interfaces required by this component - needed for mock generation with gomock | ||||||||||||||||||||||
| type FlagSyncServiceClient interface { | ||||||||||||||||||||||
| syncv1grpc.FlagSyncServiceClient | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
@@ -115,13 +115,26 @@ | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| // createConnection creates and configures the gRPC connection | ||||||||||||||||||||||
| func (g *Sync) createConnection() (*grpc.ClientConn, error) { | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| var grpcInterceptorDialOptions []grpc.DialOption | ||||||||||||||||||||||
| if g.Selector != "" { | ||||||||||||||||||||||
| grpcInterceptorDialOptions = append(grpcInterceptorDialOptions, | ||||||||||||||||||||||
| grpc.WithChainUnaryInterceptor(selectorUnaryInterceptor(g.Selector)), | ||||||||||||||||||||||
| grpc.WithChainStreamInterceptor(selectorStreamInterceptor(g.Selector)), | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if len(g.GrpcDialOptionsOverride) > 0 { | ||||||||||||||||||||||
| g.Logger.Debug("using provided gRPC DialOptions override") | ||||||||||||||||||||||
| return grpc.NewClient(g.URI, g.GrpcDialOptionsOverride...) | ||||||||||||||||||||||
| dialOptions := make([]grpc.DialOption, len(g.GrpcDialOptionsOverride)) | ||||||||||||||||||||||
| dialOptions = append(dialOptions, g.GrpcDialOptionsOverride...) | ||||||||||||||||||||||
| dialOptions = append(dialOptions, grpcInterceptorDialOptions...) | ||||||||||||||||||||||
| return grpc.NewClient(g.URI, dialOptions...) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Build standard dial options | ||||||||||||||||||||||
| dialOptions, err := g.buildDialOptions() | ||||||||||||||||||||||
| dialOptions = append(dialOptions, grpcInterceptorDialOptions...) | ||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||
| return nil, fmt.Errorf("failed to build dial options: %w", err) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
||||||||||||||||||||||
| dialOptions, err := g.buildDialOptions() | |
| dialOptions = append(dialOptions, grpcInterceptorDialOptions...) | |
| if err != nil { | |
| return nil, fmt.Errorf("failed to build dial options: %w", err) | |
| } | |
| dialOptions, err := g.buildDialOptions() | |
| if err != nil { | |
| return nil, fmt.Errorf("failed to build dial options: %w", err) | |
| } | |
| dialOptions = append(dialOptions, grpcInterceptorDialOptions...) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,219 @@ | ||
| package process | ||
aepfli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import ( | ||
| "context" | ||
| "fmt" | ||
| "net" | ||
| "sync" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "buf.build/gen/go/open-feature/flagd/grpc/go/flagd/sync/v1/syncv1grpc" | ||
| v1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1" | ||
| "google.golang.org/grpc" | ||
| "google.golang.org/grpc/metadata" | ||
| ) | ||
|
|
||
| // TestSelectorHeader verifies that the flagd-selector header is sent correctly in gRPC metadata | ||
| func TestSelectorHeader(t *testing.T) { | ||
| tests := []struct { | ||
| name string | ||
| selector string | ||
| expectHeader bool | ||
| expectedValue string | ||
| }{ | ||
| { | ||
| name: "selector header is sent when configured", | ||
| selector: "source=database,app=myapp", | ||
| expectHeader: true, | ||
| expectedValue: "source=database,app=myapp", | ||
| }, | ||
| { | ||
| name: "no selector header when selector is empty", | ||
| selector: "", | ||
| expectHeader: false, | ||
| expectedValue: "", | ||
| }, | ||
| { | ||
| name: "selector header with complex value", | ||
| selector: "source=test,environment=production,region=us-east", | ||
| expectHeader: true, | ||
| expectedValue: "source=test,environment=production,region=us-east", | ||
| }, | ||
| } | ||
|
|
||
| for _, tt := range tests { | ||
| t.Run(tt.name, func(t *testing.T) { | ||
| // given | ||
| port := findFreePort(t) | ||
| listen, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) | ||
| if err != nil { | ||
| t.Fatalf("Failed to create listener: %v", err) | ||
| } | ||
| defer func() { | ||
| // Listener will be closed by GracefulStop, so ignore "use of closed network connection" errors | ||
| _ = listen.Close() | ||
| }() | ||
|
|
||
| headerReceived := make(chan string, 1) | ||
| mockServer := &selectorHeaderCapturingServer{ | ||
| headerReceived: headerReceived, | ||
| mockResponse: &v1.SyncFlagsResponse{ | ||
| FlagConfiguration: flagRsp, | ||
| }, | ||
| } | ||
|
|
||
| grpcServer := grpc.NewServer() | ||
| syncv1grpc.RegisterFlagSyncServiceServer(grpcServer, mockServer) | ||
|
|
||
| serverDone := make(chan struct{}) | ||
| go func() { | ||
| defer close(serverDone) | ||
| if err := grpcServer.Serve(listen); err != nil { | ||
| t.Logf("Server exited: %v", err) | ||
| } | ||
| }() | ||
| defer func() { | ||
| grpcServer.GracefulStop() | ||
| <-serverDone | ||
| }() | ||
|
|
||
| inProcessService := NewInProcessService(Configuration{ | ||
| Host: "localhost", | ||
| Port: port, | ||
| Selector: tt.selector, | ||
| TLSEnabled: false, | ||
| }) | ||
|
|
||
| // when | ||
| err = inProcessService.Init() | ||
| if err != nil { | ||
| t.Fatalf("Failed to initialize service: %v", err) | ||
| } | ||
| defer inProcessService.Shutdown() | ||
|
|
||
| // Wait for provider to be ready | ||
| select { | ||
| case <-inProcessService.events: | ||
| // Provider ready event | ||
| case <-time.After(2 * time.Second): | ||
| t.Fatal("Timeout waiting for provider ready event") | ||
| } | ||
|
|
||
| // then - verify the flagd-selector header | ||
| select { | ||
| case receivedSelector := <-headerReceived: | ||
| if tt.expectHeader { | ||
| if receivedSelector != tt.expectedValue { | ||
| t.Errorf("Expected selector header to be %q, but got %q", tt.expectedValue, receivedSelector) | ||
| } | ||
| } else { | ||
| if receivedSelector != "" { | ||
| t.Errorf("Expected no selector header, but got %q", receivedSelector) | ||
| } | ||
| } | ||
| case <-time.After(3 * time.Second): | ||
| if tt.expectHeader { | ||
| t.Fatal("Timeout waiting for flagd-selector header") | ||
| } | ||
| } | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| // findFreePort finds an available port for testing | ||
| func findFreePort(t *testing.T) int { | ||
| t.Helper() | ||
| listener, err := net.Listen("tcp", "localhost:0") | ||
| if err != nil { | ||
| t.Fatalf("Failed to find free port: %v", err) | ||
| } | ||
| port := listener.Addr().(*net.TCPAddr).Port | ||
| if err := listener.Close(); err != nil { | ||
| t.Fatalf("Failed to close listener: %v", err) | ||
| } | ||
| return port | ||
| } | ||
|
|
||
| // selectorHeaderCapturingServer captures the flagd-selector header from incoming requests | ||
| type selectorHeaderCapturingServer struct { | ||
| syncv1grpc.UnimplementedFlagSyncServiceServer | ||
| headerReceived chan string | ||
| mockResponse *v1.SyncFlagsResponse | ||
| mu sync.Mutex | ||
| } | ||
|
|
||
| func (s *selectorHeaderCapturingServer) SyncFlags(req *v1.SyncFlagsRequest, stream syncv1grpc.FlagSyncService_SyncFlagsServer) error { | ||
| s.mu.Lock() | ||
| defer s.mu.Unlock() | ||
|
|
||
| // Extract and capture the flagd-selector header | ||
| md, ok := metadata.FromIncomingContext(stream.Context()) | ||
| if ok { | ||
| if values := md.Get("flagd-selector"); len(values) > 0 { | ||
| select { | ||
| case s.headerReceived <- values[0]: | ||
| default: | ||
| // Channel full, skip | ||
| } | ||
| } else { | ||
| select { | ||
| case s.headerReceived <- "": | ||
| default: | ||
| // Channel full, skip | ||
| } | ||
| } | ||
| } else { | ||
| select { | ||
| case s.headerReceived <- "": | ||
| default: | ||
| // Channel full, skip | ||
| } | ||
| } | ||
|
Comment on lines
+150
to
+172
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic for extracting the For example, you could create a method like this: func (s *selectorHeaderCapturingServer) captureHeader(ctx context.Context) {
md, _ := metadata.FromIncomingContext(ctx)
headerValue := ""
if values := md.Get("flagd-selector"); len(values) > 0 {
headerValue = values[0]
}
select {
case s.headerReceived <- headerValue:
default:
// Channel is full, which is acceptable in this test.
}
}You could then call |
||
|
|
||
| // Send mock response | ||
| if err := stream.Send(s.mockResponse); err != nil { | ||
| return err | ||
| } | ||
|
|
||
| // Keep stream open briefly | ||
| time.Sleep(500 * time.Millisecond) | ||
| return nil | ||
| } | ||
|
|
||
| func (s *selectorHeaderCapturingServer) FetchAllFlags(ctx context.Context, req *v1.FetchAllFlagsRequest) (*v1.FetchAllFlagsResponse, error) { | ||
| s.mu.Lock() | ||
| defer s.mu.Unlock() | ||
|
|
||
| // Extract and capture the flagd-selector header | ||
| md, ok := metadata.FromIncomingContext(ctx) | ||
| if ok { | ||
| if values := md.Get("flagd-selector"); len(values) > 0 { | ||
| select { | ||
| case s.headerReceived <- values[0]: | ||
| default: | ||
| // Channel full, skip | ||
| } | ||
| } else { | ||
| select { | ||
| case s.headerReceived <- "": | ||
| default: | ||
| // Channel full, skip | ||
| } | ||
| } | ||
| } else { | ||
| select { | ||
| case s.headerReceived <- "": | ||
| default: | ||
| // Channel full, skip | ||
| } | ||
| } | ||
|
|
||
| return &v1.FetchAllFlagsResponse{ | ||
| FlagConfiguration: flagRsp, | ||
| }, nil | ||
| } | ||
|
|
||
| func (s *selectorHeaderCapturingServer) GetMetadata(ctx context.Context, req *v1.GetMetadataRequest) (*v1.GetMetadataResponse, error) { | ||
| return &v1.GetMetadataResponse{}, nil | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a bug in how the
dialOptionsslice is constructed.make([]grpc.DialOption, len(g.GrpcDialOptionsOverride))creates a slice withlen(g.GrpcDialOptionsOverride)nilelements. Appending to this slice results in a slice that starts withnils, which will be passed togrpc.NewClient. This will likely cause connection issues or panics. To fix this, you should initialize an empty slice with appropriate capacity and then append the options.