Skip to content

Commit 2c58923

Browse files
Copilotaepfli
andcommitted
Improve selector header test with table-driven approach and better coverage
Co-authored-by: aepfli <[email protected]>
1 parent aafc30c commit 2c58923

File tree

1 file changed

+193
-100
lines changed

1 file changed

+193
-100
lines changed

providers/flagd/pkg/service/in_process/service_selector_header_test.go

Lines changed: 193 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"net"
7+
"sync"
78
"testing"
89
"time"
910

@@ -13,114 +14,206 @@ import (
1314
"google.golang.org/grpc/metadata"
1415
)
1516

16-
// Test that the flagd-selector header is sent in gRPC metadata
17-
func TestSelectorHeaderIsSent(t *testing.T) {
18-
// given
19-
host := "localhost"
20-
port := 8091
21-
selector := "source=test,app=selector-test"
22-
headerReceived := make(chan string, 1)
23-
24-
listen, err := net.Listen("tcp", fmt.Sprintf("%s:%d", host, port))
25-
if err != nil {
26-
t.Fatal(err)
27-
}
28-
29-
// Mock server that captures the flagd-selector header
30-
mockServer := &selectorHeaderCapturingServer{
31-
listener: listen,
32-
headerReceived: headerReceived,
33-
mockResponse: &v1.SyncFlagsResponse{
34-
FlagConfiguration: flagRsp,
35-
},
36-
}
37-
38-
inProcessService := NewInProcessService(Configuration{
39-
Host: host,
40-
Port: port,
41-
Selector: selector,
42-
TLSEnabled: false,
43-
})
44-
45-
// when
46-
go func() {
47-
server := grpc.NewServer()
48-
syncv1grpc.RegisterFlagSyncServiceServer(server, mockServer)
49-
if err := server.Serve(mockServer.listener); err != nil {
50-
t.Logf("Server exited with error: %v", err)
51-
}
52-
}()
53-
54-
// Initialize service
55-
err = inProcessService.Init()
56-
if err != nil {
57-
t.Fatal(err)
58-
}
59-
60-
// then - verify that the flagd-selector header was sent
61-
select {
62-
case receivedSelector := <-headerReceived:
63-
if receivedSelector != selector {
64-
t.Fatalf("Expected selector header to be %q, but got %q", selector, receivedSelector)
65-
}
66-
case <-time.After(3 * time.Second):
67-
t.Fatal("Timeout waiting for flagd-selector header to be received")
68-
}
69-
70-
inProcessService.Shutdown()
71-
}
72-
73-
// Mock server that captures the flagd-selector header from incoming requests
17+
// TestSelectorHeader verifies that the flagd-selector header is sent correctly in gRPC metadata
18+
func TestSelectorHeader(t *testing.T) {
19+
tests := []struct {
20+
name string
21+
selector string
22+
expectHeader bool
23+
expectedValue string
24+
}{
25+
{
26+
name: "selector header is sent when configured",
27+
selector: "source=database,app=myapp",
28+
expectHeader: true,
29+
expectedValue: "source=database,app=myapp",
30+
},
31+
{
32+
name: "no selector header when selector is empty",
33+
selector: "",
34+
expectHeader: false,
35+
expectedValue: "",
36+
},
37+
{
38+
name: "selector header with complex value",
39+
selector: "source=test,environment=production,region=us-east",
40+
expectHeader: true,
41+
expectedValue: "source=test,environment=production,region=us-east",
42+
},
43+
}
44+
45+
for _, tt := range tests {
46+
t.Run(tt.name, func(t *testing.T) {
47+
// given
48+
port := findFreePort(t)
49+
listen, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
50+
if err != nil {
51+
t.Fatalf("Failed to create listener: %v", err)
52+
}
53+
defer func() {
54+
// Listener will be closed by GracefulStop, so ignore "use of closed network connection" errors
55+
_ = listen.Close()
56+
}()
57+
58+
headerReceived := make(chan string, 1)
59+
mockServer := &selectorHeaderCapturingServer{
60+
headerReceived: headerReceived,
61+
mockResponse: &v1.SyncFlagsResponse{
62+
FlagConfiguration: flagRsp,
63+
},
64+
}
65+
66+
grpcServer := grpc.NewServer()
67+
syncv1grpc.RegisterFlagSyncServiceServer(grpcServer, mockServer)
68+
69+
serverDone := make(chan struct{})
70+
go func() {
71+
defer close(serverDone)
72+
if err := grpcServer.Serve(listen); err != nil {
73+
t.Logf("Server exited: %v", err)
74+
}
75+
}()
76+
defer func() {
77+
grpcServer.GracefulStop()
78+
<-serverDone
79+
}()
80+
81+
inProcessService := NewInProcessService(Configuration{
82+
Host: "localhost",
83+
Port: port,
84+
Selector: tt.selector,
85+
TLSEnabled: false,
86+
})
87+
88+
// when
89+
err = inProcessService.Init()
90+
if err != nil {
91+
t.Fatalf("Failed to initialize service: %v", err)
92+
}
93+
defer inProcessService.Shutdown()
94+
95+
// Wait for provider to be ready
96+
select {
97+
case <-inProcessService.events:
98+
// Provider ready event
99+
case <-time.After(2 * time.Second):
100+
t.Fatal("Timeout waiting for provider ready event")
101+
}
102+
103+
// then - verify the flagd-selector header
104+
select {
105+
case receivedSelector := <-headerReceived:
106+
if tt.expectHeader {
107+
if receivedSelector != tt.expectedValue {
108+
t.Errorf("Expected selector header to be %q, but got %q", tt.expectedValue, receivedSelector)
109+
}
110+
} else {
111+
if receivedSelector != "" {
112+
t.Errorf("Expected no selector header, but got %q", receivedSelector)
113+
}
114+
}
115+
case <-time.After(3 * time.Second):
116+
if tt.expectHeader {
117+
t.Fatal("Timeout waiting for flagd-selector header")
118+
}
119+
}
120+
})
121+
}
122+
}
123+
124+
// findFreePort finds an available port for testing
125+
func findFreePort(t *testing.T) int {
126+
t.Helper()
127+
listener, err := net.Listen("tcp", "localhost:0")
128+
if err != nil {
129+
t.Fatalf("Failed to find free port: %v", err)
130+
}
131+
port := listener.Addr().(*net.TCPAddr).Port
132+
if err := listener.Close(); err != nil {
133+
t.Fatalf("Failed to close listener: %v", err)
134+
}
135+
return port
136+
}
137+
138+
// selectorHeaderCapturingServer captures the flagd-selector header from incoming requests
74139
type selectorHeaderCapturingServer struct {
75-
listener net.Listener
76-
headerReceived chan string
77-
mockResponse *v1.SyncFlagsResponse
140+
syncv1grpc.UnimplementedFlagSyncServiceServer
141+
headerReceived chan string
142+
mockResponse *v1.SyncFlagsResponse
143+
mu sync.Mutex
78144
}
79145

80146
func (s *selectorHeaderCapturingServer) SyncFlags(req *v1.SyncFlagsRequest, stream syncv1grpc.FlagSyncService_SyncFlagsServer) error {
81-
// Extract metadata from context
82-
md, ok := metadata.FromIncomingContext(stream.Context())
83-
if ok {
84-
// Check for flagd-selector header
85-
if values := md.Get("flagd-selector"); len(values) > 0 {
86-
s.headerReceived <- values[0]
87-
} else {
88-
s.headerReceived <- ""
89-
}
90-
} else {
91-
s.headerReceived <- ""
92-
}
93-
94-
// Send mock response
95-
err := stream.Send(s.mockResponse)
96-
if err != nil {
97-
return err
98-
}
99-
100-
// Keep stream open for a bit
101-
time.Sleep(1 * time.Second)
102-
return nil
147+
s.mu.Lock()
148+
defer s.mu.Unlock()
149+
150+
// Extract and capture the flagd-selector header
151+
md, ok := metadata.FromIncomingContext(stream.Context())
152+
if ok {
153+
if values := md.Get("flagd-selector"); len(values) > 0 {
154+
select {
155+
case s.headerReceived <- values[0]:
156+
default:
157+
// Channel full, skip
158+
}
159+
} else {
160+
select {
161+
case s.headerReceived <- "":
162+
default:
163+
// Channel full, skip
164+
}
165+
}
166+
} else {
167+
select {
168+
case s.headerReceived <- "":
169+
default:
170+
// Channel full, skip
171+
}
172+
}
173+
174+
// Send mock response
175+
if err := stream.Send(s.mockResponse); err != nil {
176+
return err
177+
}
178+
179+
// Keep stream open briefly
180+
time.Sleep(500 * time.Millisecond)
181+
return nil
103182
}
104183

105184
func (s *selectorHeaderCapturingServer) FetchAllFlags(ctx context.Context, req *v1.FetchAllFlagsRequest) (*v1.FetchAllFlagsResponse, error) {
106-
// Extract metadata from context
107-
md, ok := metadata.FromIncomingContext(ctx)
108-
if ok {
109-
// Check for flagd-selector header
110-
if values := md.Get("flagd-selector"); len(values) > 0 {
111-
s.headerReceived <- values[0]
112-
} else {
113-
s.headerReceived <- ""
114-
}
115-
} else {
116-
s.headerReceived <- ""
117-
}
118-
119-
return &v1.FetchAllFlagsResponse{
120-
FlagConfiguration: flagRsp,
121-
}, nil
185+
s.mu.Lock()
186+
defer s.mu.Unlock()
187+
188+
// Extract and capture the flagd-selector header
189+
md, ok := metadata.FromIncomingContext(ctx)
190+
if ok {
191+
if values := md.Get("flagd-selector"); len(values) > 0 {
192+
select {
193+
case s.headerReceived <- values[0]:
194+
default:
195+
// Channel full, skip
196+
}
197+
} else {
198+
select {
199+
case s.headerReceived <- "":
200+
default:
201+
// Channel full, skip
202+
}
203+
}
204+
} else {
205+
select {
206+
case s.headerReceived <- "":
207+
default:
208+
// Channel full, skip
209+
}
210+
}
211+
212+
return &v1.FetchAllFlagsResponse{
213+
FlagConfiguration: flagRsp,
214+
}, nil
122215
}
123216

124217
func (s *selectorHeaderCapturingServer) GetMetadata(ctx context.Context, req *v1.GetMetadataRequest) (*v1.GetMetadataResponse, error) {
125-
return &v1.GetMetadataResponse{}, nil
218+
return &v1.GetMetadataResponse{}, nil
126219
}

0 commit comments

Comments
 (0)