Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions internal/api/handlers/search_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import (
type SearchService interface {
SemanticSearch(ctx context.Context, query, tenantID string, limit int, minScore float64, cursor string) (
service.SearchResult, error)
SimilarFeedback(ctx context.Context, feedbackRecordID uuid.UUID, tenantID string, limit int,
minScore float64, cursor string) (service.SearchResult, error)
SimilarFeedback(ctx context.Context, feedbackRecordID uuid.UUID, limit int, minScore float64, cursor string) (
service.SearchResult, error)
}

// SearchHandler handles HTTP requests for semantic search and similar feedback.
Expand Down Expand Up @@ -157,27 +157,20 @@ func (h *SearchHandler) SimilarFeedback(w http.ResponseWriter, r *http.Request)
return
}

tenantID := r.URL.Query().Get("tenant_id")
if tenantID == "" {
response.RespondInvalidParams(w, r, response.InvalidParam{Name: "tenant_id", Reason: "is required"})

return
}

limit := parseLimit(r.URL.Query().Get("limit"), defaultSearchLimit, maxSearchLimit)
cursor := strings.TrimSpace(r.URL.Query().Get("cursor"))
minScore := parseMinScore(r.URL.Query().Get("min_score"))

res, err := h.service.SimilarFeedback(r.Context(), id, tenantID, limit, minScore, cursor)
res, err := h.service.SimilarFeedback(r.Context(), id, limit, minScore, cursor)
if err != nil {
if errors.Is(err, service.ErrEmbeddingNotFound) {
response.RespondNotFound(w, r, "Feedback record has no embedding for the current model")
if errors.Is(err, service.ErrMissingTenantID) {
response.RespondNotFound(w, r, "Source feedback record not found or has no tenant")

return
}

if errors.Is(err, service.ErrMissingTenantID) {
response.RespondInvalidParams(w, r, response.InvalidParam{Name: "tenant_id", Reason: "is required"})
if errors.Is(err, service.ErrEmbeddingNotFound) {
response.RespondNotFound(w, r, "Feedback record has no embedding for the current model")

return
}
Expand Down
63 changes: 50 additions & 13 deletions internal/api/handlers/search_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (
type mockSearchService struct {
semanticFunc func(ctx context.Context, query, tenantID string, limit int, minScore float64,
cursor string) (service.SearchResult, error)
similarFunc func(ctx context.Context, feedbackRecordID uuid.UUID, tenantID string, limit int,
minScore float64, cursor string) (service.SearchResult, error)
similarFunc func(ctx context.Context, feedbackRecordID uuid.UUID, limit int, minScore float64,
cursor string) (service.SearchResult, error)
}

func (m *mockSearchService) SemanticSearch(
Expand All @@ -34,10 +34,10 @@ func (m *mockSearchService) SemanticSearch(
}

func (m *mockSearchService) SimilarFeedback(
ctx context.Context, feedbackRecordID uuid.UUID, tenantID string, limit int, minScore float64, cursor string,
ctx context.Context, feedbackRecordID uuid.UUID, limit int, minScore float64, cursor string,
) (service.SearchResult, error) {
if m.similarFunc != nil {
return m.similarFunc(ctx, feedbackRecordID, tenantID, limit, minScore, cursor)
return m.similarFunc(ctx, feedbackRecordID, limit, minScore, cursor)
}

return service.SearchResult{}, nil
Expand Down Expand Up @@ -159,26 +159,44 @@ func TestSearchHandler_SemanticSearch(t *testing.T) {
const similarURL = "http://test/v1/feedback-records/018e1234-5678-9abc-def0-123456789abc/similar"

func TestSearchHandler_SimilarFeedback(t *testing.T) {
t.Run("missing tenant_id returns 400", func(t *testing.T) {
handler := NewSearchHandler(&mockSearchService{})
t.Run("success derives tenant from source record", func(t *testing.T) {
id := uuid.MustParse("018e1234-5678-9abc-def0-123456789abc")
similarID := uuid.MustParse("018e1234-5678-9abc-def0-aaaaaaaaaaaa")
mock := &mockSearchService{
similarFunc: func(_ context.Context, fid uuid.UUID, limit int, minScore float64,
cursor string,
) (service.SearchResult, error) {
assert.Equal(t, id, fid)
assert.Equal(t, 10, limit)
assert.InDelta(t, 0.7, minScore, 1e-9)
assert.Empty(t, cursor)

return service.SearchResult{
Results: []models.FeedbackRecordWithScore{
{FeedbackRecordID: similarID, Score: 0.88, FieldLabel: "Similar field", ValueText: "Similar feedback text."},
},
}, nil
},
}
handler := NewSearchHandler(mock)
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL, nil)
rec := httptest.NewRecorder()

req.SetPathValue("id", "018e1234-5678-9abc-def0-123456789abc")
req.SetPathValue("id", id.String())

handler.SimilarFeedback(rec, req)

assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Equal(t, http.StatusOK, rec.Code)
})

t.Run("embedding not found returns 404", func(t *testing.T) {
mock := &mockSearchService{
similarFunc: func(_ context.Context, _ uuid.UUID, _ string, _ int, _ float64, _ string) (service.SearchResult, error) {
similarFunc: func(_ context.Context, _ uuid.UUID, _ int, _ float64, _ string) (service.SearchResult, error) {
return service.SearchResult{}, service.ErrEmbeddingNotFound
},
}
handler := NewSearchHandler(mock)
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL+"?tenant_id=env-1", nil)
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL, nil)
req.SetPathValue("id", "018e1234-5678-9abc-def0-123456789abc")

rec := httptest.NewRecorder()
Expand All @@ -188,16 +206,35 @@ func TestSearchHandler_SimilarFeedback(t *testing.T) {
assert.Equal(t, http.StatusNotFound, rec.Code)
})

t.Run("source record without tenant returns 404", func(t *testing.T) {
id := uuid.MustParse("018e1234-5678-9abc-def0-123456789abc")
mock := &mockSearchService{
similarFunc: func(_ context.Context, fid uuid.UUID, _ int, _ float64, _ string) (service.SearchResult, error) {
assert.Equal(t, id, fid)

return service.SearchResult{}, service.ErrMissingTenantID
},
}
handler := NewSearchHandler(mock)
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL, nil)
req.SetPathValue("id", id.String())

rec := httptest.NewRecorder()

handler.SimilarFeedback(rec, req)

assert.Equal(t, http.StatusNotFound, rec.Code)
})

t.Run("success returns 200 with data and value", func(t *testing.T) {
id := uuid.MustParse("018e1234-5678-9abc-def0-123456789abc")
similarID := uuid.MustParse("018e1234-5678-9abc-def0-aaaaaaaaaaaa")
similarVal := "Similar feedback text."
mock := &mockSearchService{
similarFunc: func(_ context.Context, fid uuid.UUID, tenantID string, limit int, minScore float64,
similarFunc: func(_ context.Context, fid uuid.UUID, limit int, minScore float64,
cursor string,
) (service.SearchResult, error) {
assert.Equal(t, id, fid)
assert.Equal(t, "env-1", tenantID)
assert.Equal(t, 10, limit)
assert.InDelta(t, 0.7, minScore, 1e-9)
assert.Empty(t, cursor)
Expand All @@ -210,7 +247,7 @@ func TestSearchHandler_SimilarFeedback(t *testing.T) {
},
}
handler := NewSearchHandler(mock)
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL+"?tenant_id=env-1&limit=10", nil)
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL+"?limit=10", nil)
req.SetPathValue("id", id.String())

rec := httptest.NewRecorder()
Expand Down
31 changes: 25 additions & 6 deletions internal/repository/embeddings_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,40 @@ var ErrEmbeddingNotFound = errors.New("embedding not found for feedback record a
func (r *EmbeddingsRepository) GetEmbeddingByFeedbackRecordAndModel(
ctx context.Context, feedbackRecordID uuid.UUID, model string,
) ([]float32, error) {
var vec pgvector.HalfVector
embedding, _, err := r.GetEmbeddingAndTenantByFeedbackRecordAndModel(ctx, feedbackRecordID, model)
if err != nil {
return nil, err
}

return embedding, nil
}

// GetEmbeddingAndTenantByFeedbackRecordAndModel returns the stored embedding and its feedback record tenant.
// Used by record-level similar feedback so the source record determines the tenant boundary for the search.
// Returns ErrEmbeddingNotFound when no embedding exists for the current model.
func (r *EmbeddingsRepository) GetEmbeddingAndTenantByFeedbackRecordAndModel(
ctx context.Context, feedbackRecordID uuid.UUID, model string,
) ([]float32, string, error) {
var (
vec pgvector.HalfVector
tenantID string
)

err := r.db.QueryRow(ctx,
`SELECT embedding FROM embeddings WHERE feedback_record_id = $1 AND model = $2`,
`SELECT e.embedding, fr.tenant_id FROM embeddings e
INNER JOIN feedback_records fr ON fr.id = e.feedback_record_id
WHERE e.feedback_record_id = $1 AND e.model = $2`,
feedbackRecordID, model,
).Scan(&vec)
).Scan(&vec, &tenantID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrEmbeddingNotFound
return nil, "", ErrEmbeddingNotFound
}

return nil, fmt.Errorf("get embedding: %w", err)
return nil, "", fmt.Errorf("get embedding and tenant: %w", err)
}

return vec.Slice(), nil
return vec.Slice(), tenantID, nil
}

// GetEmbeddingByFeedbackRecordAndModelAndTenant returns the stored embedding only when the feedback record
Expand Down
41 changes: 27 additions & 14 deletions internal/service/search_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ var (
// EmbeddingsRepositoryForSearch provides the embedding read operations needed for semantic search.
// HasMore is true when there may be additional results (full page returned or full fetch limit consumed).
type EmbeddingsRepositoryForSearch interface {
GetEmbeddingByFeedbackRecordAndModelAndTenant(
ctx context.Context, feedbackRecordID uuid.UUID, model, tenantID string,
) ([]float32, error)
GetEmbeddingAndTenantByFeedbackRecordAndModel(
ctx context.Context, feedbackRecordID uuid.UUID, model string,
) ([]float32, string, error)
NearestFeedbackRecordsByEmbedding(
ctx context.Context, model string, queryEmbedding []float32, tenantID string, limit int, excludeID *uuid.UUID, minScore float64,
) ([]models.FeedbackRecordWithScore, bool, error)
Expand Down Expand Up @@ -149,24 +149,20 @@ func (s *SearchService) SemanticSearch(
return out, nil
}

// SimilarFeedback returns feedback record IDs and similarity scores for records similar to the given one, scoped to tenantID.
// Requires non-empty tenantID. Returns ErrEmbeddingNotFound when the record has no embedding for the current model.
// Uses cursor-based pagination.
// SimilarFeedback returns feedback record IDs and similarity scores for records similar to the given one.
// The tenant boundary is derived from the source record before running nearest-neighbor search.
// Returns ErrEmbeddingNotFound when the record has no embedding for the current model. Uses cursor-based pagination.
func (s *SearchService) SimilarFeedback(
ctx context.Context, feedbackRecordID uuid.UUID, tenantID string, limit int, minScore float64, cursor string,
ctx context.Context, feedbackRecordID uuid.UUID, limit int, minScore float64, cursor string,
) (SearchResult, error) {
out := SearchResult{}
if tenantID == "" {
return out, ErrMissingTenantID
}

// Load source embedding only if the record belongs to this tenant (tenant isolation).
embedding, err := s.embeddingsRepo.GetEmbeddingByFeedbackRecordAndModelAndTenant(ctx, feedbackRecordID, s.model, tenantID)
embedding, tenantID, err := s.getSimilarFeedbackSourceEmbedding(ctx, feedbackRecordID)
if err != nil {
if errors.Is(err, repository.ErrEmbeddingNotFound) {
s.logger.Debug("similar feedback: no embedding or tenant mismatch",
s.logger.Debug("similar feedback: no embedding",
"feedbackRecordId", feedbackRecordID.String(), "model", s.model)
//nolint:wrapcheck // return as-is so handler can map to 404

return out, err
}

Expand Down Expand Up @@ -214,6 +210,23 @@ func (s *SearchService) SimilarFeedback(
return out, nil
}

func (s *SearchService) getSimilarFeedbackSourceEmbedding(
ctx context.Context,
feedbackRecordID uuid.UUID,
) ([]float32, string, error) {
embedding, resolvedTenantID, err := s.embeddingsRepo.GetEmbeddingAndTenantByFeedbackRecordAndModel(
ctx, feedbackRecordID, s.model)
if err != nil {
return nil, "", fmt.Errorf("get embedding and tenant: %w", err)
}

if resolvedTenantID == "" {
return nil, "", ErrMissingTenantID
}

return embedding, resolvedTenantID, nil
}

func (s *SearchService) getQueryEmbeddingCached(ctx context.Context, query string) ([]float32, error) {
if vec, ok := s.queryCache.Get(query); ok {
if s.cacheMetrics != nil {
Expand Down
Loading
Loading