diff --git a/internal/api/handlers/search_handler.go b/internal/api/handlers/search_handler.go index 8499f15..8e1e35b 100644 --- a/internal/api/handlers/search_handler.go +++ b/internal/api/handlers/search_handler.go @@ -20,8 +20,8 @@ import ( type SearchService interface { SemanticSearch(ctx context.Context, query, tenantID string, limit int, minScore float64, cursor string) ( service.SearchResult, error) - SimilarFeedback(ctx context.Context, feedbackRecordID uuid.UUID, tenantID string, limit int, - minScore float64, cursor string) (service.SearchResult, error) + SimilarFeedback(ctx context.Context, feedbackRecordID uuid.UUID, limit int, minScore float64, cursor string) ( + service.SearchResult, error) } // SearchHandler handles HTTP requests for semantic search and similar feedback. @@ -157,27 +157,20 @@ func (h *SearchHandler) SimilarFeedback(w http.ResponseWriter, r *http.Request) return } - tenantID := r.URL.Query().Get("tenant_id") - if tenantID == "" { - response.RespondInvalidParams(w, r, response.InvalidParam{Name: "tenant_id", Reason: "is required"}) - - return - } - limit := parseLimit(r.URL.Query().Get("limit"), defaultSearchLimit, maxSearchLimit) cursor := strings.TrimSpace(r.URL.Query().Get("cursor")) minScore := parseMinScore(r.URL.Query().Get("min_score")) - res, err := h.service.SimilarFeedback(r.Context(), id, tenantID, limit, minScore, cursor) + res, err := h.service.SimilarFeedback(r.Context(), id, limit, minScore, cursor) if err != nil { - if errors.Is(err, service.ErrEmbeddingNotFound) { - response.RespondNotFound(w, r, "Feedback record has no embedding for the current model") + if errors.Is(err, service.ErrMissingTenantID) { + response.RespondNotFound(w, r, "Source feedback record not found or has no tenant") return } - if errors.Is(err, service.ErrMissingTenantID) { - response.RespondInvalidParams(w, r, response.InvalidParam{Name: "tenant_id", Reason: "is required"}) + if errors.Is(err, service.ErrEmbeddingNotFound) { + response.RespondNotFound(w, r, "Feedback record has no embedding for the current model") return } diff --git a/internal/api/handlers/search_handler_test.go b/internal/api/handlers/search_handler_test.go index 7d5064c..31efd40 100644 --- a/internal/api/handlers/search_handler_test.go +++ b/internal/api/handlers/search_handler_test.go @@ -19,8 +19,8 @@ import ( type mockSearchService struct { semanticFunc func(ctx context.Context, query, tenantID string, limit int, minScore float64, cursor string) (service.SearchResult, error) - similarFunc func(ctx context.Context, feedbackRecordID uuid.UUID, tenantID string, limit int, - minScore float64, cursor string) (service.SearchResult, error) + similarFunc func(ctx context.Context, feedbackRecordID uuid.UUID, limit int, minScore float64, + cursor string) (service.SearchResult, error) } func (m *mockSearchService) SemanticSearch( @@ -34,10 +34,10 @@ func (m *mockSearchService) SemanticSearch( } func (m *mockSearchService) SimilarFeedback( - ctx context.Context, feedbackRecordID uuid.UUID, tenantID string, limit int, minScore float64, cursor string, + ctx context.Context, feedbackRecordID uuid.UUID, limit int, minScore float64, cursor string, ) (service.SearchResult, error) { if m.similarFunc != nil { - return m.similarFunc(ctx, feedbackRecordID, tenantID, limit, minScore, cursor) + return m.similarFunc(ctx, feedbackRecordID, limit, minScore, cursor) } return service.SearchResult{}, nil @@ -159,26 +159,44 @@ func TestSearchHandler_SemanticSearch(t *testing.T) { const similarURL = "http://test/v1/feedback-records/018e1234-5678-9abc-def0-123456789abc/similar" func TestSearchHandler_SimilarFeedback(t *testing.T) { - t.Run("missing tenant_id returns 400", func(t *testing.T) { - handler := NewSearchHandler(&mockSearchService{}) + t.Run("success derives tenant from source record", func(t *testing.T) { + id := uuid.MustParse("018e1234-5678-9abc-def0-123456789abc") + similarID := uuid.MustParse("018e1234-5678-9abc-def0-aaaaaaaaaaaa") + mock := &mockSearchService{ + similarFunc: func(_ context.Context, fid uuid.UUID, limit int, minScore float64, + cursor string, + ) (service.SearchResult, error) { + assert.Equal(t, id, fid) + assert.Equal(t, 10, limit) + assert.InDelta(t, 0.7, minScore, 1e-9) + assert.Empty(t, cursor) + + return service.SearchResult{ + Results: []models.FeedbackRecordWithScore{ + {FeedbackRecordID: similarID, Score: 0.88, FieldLabel: "Similar field", ValueText: "Similar feedback text."}, + }, + }, nil + }, + } + handler := NewSearchHandler(mock) req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL, nil) rec := httptest.NewRecorder() - req.SetPathValue("id", "018e1234-5678-9abc-def0-123456789abc") + req.SetPathValue("id", id.String()) handler.SimilarFeedback(rec, req) - assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, http.StatusOK, rec.Code) }) t.Run("embedding not found returns 404", func(t *testing.T) { mock := &mockSearchService{ - similarFunc: func(_ context.Context, _ uuid.UUID, _ string, _ int, _ float64, _ string) (service.SearchResult, error) { + similarFunc: func(_ context.Context, _ uuid.UUID, _ int, _ float64, _ string) (service.SearchResult, error) { return service.SearchResult{}, service.ErrEmbeddingNotFound }, } handler := NewSearchHandler(mock) - req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL+"?tenant_id=env-1", nil) + req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL, nil) req.SetPathValue("id", "018e1234-5678-9abc-def0-123456789abc") rec := httptest.NewRecorder() @@ -188,16 +206,35 @@ func TestSearchHandler_SimilarFeedback(t *testing.T) { assert.Equal(t, http.StatusNotFound, rec.Code) }) + t.Run("source record without tenant returns 404", func(t *testing.T) { + id := uuid.MustParse("018e1234-5678-9abc-def0-123456789abc") + mock := &mockSearchService{ + similarFunc: func(_ context.Context, fid uuid.UUID, _ int, _ float64, _ string) (service.SearchResult, error) { + assert.Equal(t, id, fid) + + return service.SearchResult{}, service.ErrMissingTenantID + }, + } + handler := NewSearchHandler(mock) + req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL, nil) + req.SetPathValue("id", id.String()) + + rec := httptest.NewRecorder() + + handler.SimilarFeedback(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + }) + t.Run("success returns 200 with data and value", func(t *testing.T) { id := uuid.MustParse("018e1234-5678-9abc-def0-123456789abc") similarID := uuid.MustParse("018e1234-5678-9abc-def0-aaaaaaaaaaaa") similarVal := "Similar feedback text." mock := &mockSearchService{ - similarFunc: func(_ context.Context, fid uuid.UUID, tenantID string, limit int, minScore float64, + similarFunc: func(_ context.Context, fid uuid.UUID, limit int, minScore float64, cursor string, ) (service.SearchResult, error) { assert.Equal(t, id, fid) - assert.Equal(t, "env-1", tenantID) assert.Equal(t, 10, limit) assert.InDelta(t, 0.7, minScore, 1e-9) assert.Empty(t, cursor) @@ -210,7 +247,7 @@ func TestSearchHandler_SimilarFeedback(t *testing.T) { }, } handler := NewSearchHandler(mock) - req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL+"?tenant_id=env-1&limit=10", nil) + req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, similarURL+"?limit=10", nil) req.SetPathValue("id", id.String()) rec := httptest.NewRecorder() diff --git a/internal/repository/embeddings_repository.go b/internal/repository/embeddings_repository.go index 006c252..cf2263c 100644 --- a/internal/repository/embeddings_repository.go +++ b/internal/repository/embeddings_repository.go @@ -122,21 +122,40 @@ var ErrEmbeddingNotFound = errors.New("embedding not found for feedback record a func (r *EmbeddingsRepository) GetEmbeddingByFeedbackRecordAndModel( ctx context.Context, feedbackRecordID uuid.UUID, model string, ) ([]float32, error) { - var vec pgvector.HalfVector + embedding, _, err := r.GetEmbeddingAndTenantByFeedbackRecordAndModel(ctx, feedbackRecordID, model) + if err != nil { + return nil, err + } + + return embedding, nil +} + +// GetEmbeddingAndTenantByFeedbackRecordAndModel returns the stored embedding and its feedback record tenant. +// Used by record-level similar feedback so the source record determines the tenant boundary for the search. +// Returns ErrEmbeddingNotFound when no embedding exists for the current model. +func (r *EmbeddingsRepository) GetEmbeddingAndTenantByFeedbackRecordAndModel( + ctx context.Context, feedbackRecordID uuid.UUID, model string, +) ([]float32, string, error) { + var ( + vec pgvector.HalfVector + tenantID string + ) err := r.db.QueryRow(ctx, - `SELECT embedding FROM embeddings WHERE feedback_record_id = $1 AND model = $2`, + `SELECT e.embedding, fr.tenant_id FROM embeddings e + INNER JOIN feedback_records fr ON fr.id = e.feedback_record_id + WHERE e.feedback_record_id = $1 AND e.model = $2`, feedbackRecordID, model, - ).Scan(&vec) + ).Scan(&vec, &tenantID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return nil, ErrEmbeddingNotFound + return nil, "", ErrEmbeddingNotFound } - return nil, fmt.Errorf("get embedding: %w", err) + return nil, "", fmt.Errorf("get embedding and tenant: %w", err) } - return vec.Slice(), nil + return vec.Slice(), tenantID, nil } // GetEmbeddingByFeedbackRecordAndModelAndTenant returns the stored embedding only when the feedback record diff --git a/internal/service/search_service.go b/internal/service/search_service.go index 056c878..4da1772 100644 --- a/internal/service/search_service.go +++ b/internal/service/search_service.go @@ -28,9 +28,9 @@ var ( // EmbeddingsRepositoryForSearch provides the embedding read operations needed for semantic search. // HasMore is true when there may be additional results (full page returned or full fetch limit consumed). type EmbeddingsRepositoryForSearch interface { - GetEmbeddingByFeedbackRecordAndModelAndTenant( - ctx context.Context, feedbackRecordID uuid.UUID, model, tenantID string, - ) ([]float32, error) + GetEmbeddingAndTenantByFeedbackRecordAndModel( + ctx context.Context, feedbackRecordID uuid.UUID, model string, + ) ([]float32, string, error) NearestFeedbackRecordsByEmbedding( ctx context.Context, model string, queryEmbedding []float32, tenantID string, limit int, excludeID *uuid.UUID, minScore float64, ) ([]models.FeedbackRecordWithScore, bool, error) @@ -149,24 +149,20 @@ func (s *SearchService) SemanticSearch( return out, nil } -// SimilarFeedback returns feedback record IDs and similarity scores for records similar to the given one, scoped to tenantID. -// Requires non-empty tenantID. Returns ErrEmbeddingNotFound when the record has no embedding for the current model. -// Uses cursor-based pagination. +// SimilarFeedback returns feedback record IDs and similarity scores for records similar to the given one. +// The tenant boundary is derived from the source record before running nearest-neighbor search. +// Returns ErrEmbeddingNotFound when the record has no embedding for the current model. Uses cursor-based pagination. func (s *SearchService) SimilarFeedback( - ctx context.Context, feedbackRecordID uuid.UUID, tenantID string, limit int, minScore float64, cursor string, + ctx context.Context, feedbackRecordID uuid.UUID, limit int, minScore float64, cursor string, ) (SearchResult, error) { out := SearchResult{} - if tenantID == "" { - return out, ErrMissingTenantID - } - // Load source embedding only if the record belongs to this tenant (tenant isolation). - embedding, err := s.embeddingsRepo.GetEmbeddingByFeedbackRecordAndModelAndTenant(ctx, feedbackRecordID, s.model, tenantID) + embedding, tenantID, err := s.getSimilarFeedbackSourceEmbedding(ctx, feedbackRecordID) if err != nil { if errors.Is(err, repository.ErrEmbeddingNotFound) { - s.logger.Debug("similar feedback: no embedding or tenant mismatch", + s.logger.Debug("similar feedback: no embedding", "feedbackRecordId", feedbackRecordID.String(), "model", s.model) - //nolint:wrapcheck // return as-is so handler can map to 404 + return out, err } @@ -214,6 +210,23 @@ func (s *SearchService) SimilarFeedback( return out, nil } +func (s *SearchService) getSimilarFeedbackSourceEmbedding( + ctx context.Context, + feedbackRecordID uuid.UUID, +) ([]float32, string, error) { + embedding, resolvedTenantID, err := s.embeddingsRepo.GetEmbeddingAndTenantByFeedbackRecordAndModel( + ctx, feedbackRecordID, s.model) + if err != nil { + return nil, "", fmt.Errorf("get embedding and tenant: %w", err) + } + + if resolvedTenantID == "" { + return nil, "", ErrMissingTenantID + } + + return embedding, resolvedTenantID, nil +} + func (s *SearchService) getQueryEmbeddingCached(ctx context.Context, query string) ([]float32, error) { if vec, ok := s.queryCache.Get(query); ok { if s.cacheMetrics != nil { diff --git a/internal/service/search_service_test.go b/internal/service/search_service_test.go index f270be7..0371867 100644 --- a/internal/service/search_service_test.go +++ b/internal/service/search_service_test.go @@ -35,8 +35,8 @@ func (m *mockEmbeddingClient) CreateEmbeddingForQuery(ctx context.Context, input } type mockEmbeddingsRepoForSearch struct { - getEmbeddingByTenantFunc func(ctx context.Context, feedbackRecordID uuid.UUID, model, tenantID string) ([]float32, error) - nearestFunc func( + getEmbeddingAndTenantFunc func(ctx context.Context, feedbackRecordID uuid.UUID, model string) ([]float32, string, error) + nearestFunc func( ctx context.Context, model string, queryEmbedding []float32, tenantID string, limit int, excludeID *uuid.UUID, minScore float64, ) ([]models.FeedbackRecordWithScore, bool, error) @@ -46,14 +46,14 @@ type mockEmbeddingsRepoForSearch struct { ) ([]models.FeedbackRecordWithScore, bool, error) } -func (m *mockEmbeddingsRepoForSearch) GetEmbeddingByFeedbackRecordAndModelAndTenant( - ctx context.Context, feedbackRecordID uuid.UUID, model, tenantID string, -) ([]float32, error) { - if m.getEmbeddingByTenantFunc != nil { - return m.getEmbeddingByTenantFunc(ctx, feedbackRecordID, model, tenantID) +func (m *mockEmbeddingsRepoForSearch) GetEmbeddingAndTenantByFeedbackRecordAndModel( + ctx context.Context, feedbackRecordID uuid.UUID, model string, +) ([]float32, string, error) { + if m.getEmbeddingAndTenantFunc != nil { + return m.getEmbeddingAndTenantFunc(ctx, feedbackRecordID, model) } - return nil, repository.ErrEmbeddingNotFound + return nil, "", repository.ErrEmbeddingNotFound } func (m *mockEmbeddingsRepoForSearch) NearestFeedbackRecordsByEmbedding( @@ -151,55 +151,25 @@ func TestSearchService_SemanticSearch(t *testing.T) { } func TestSearchService_SimilarFeedback(t *testing.T) { - t.Run("missing tenantID returns ErrMissingTenantID", func(t *testing.T) { - svc := NewSearchService(SearchServiceParams{ - EmbeddingClient: &mockEmbeddingClient{}, - EmbeddingsRepo: &mockEmbeddingsRepoForSearch{}, - Model: "test-model", - }) - res, err := svc.SimilarFeedback(context.Background(), uuid.MustParse("018e1234-5678-9abc-def0-123456789abc"), "", 10, 0, "") - assert.Empty(t, res.Results) - assert.ErrorIs(t, err, ErrMissingTenantID) - }) - - t.Run("embedding not found returns ErrEmbeddingNotFound", func(t *testing.T) { - rid := uuid.MustParse("018e1234-5678-9abc-def0-123456789abc") - svc := NewSearchService(SearchServiceParams{ - EmbeddingClient: &mockEmbeddingClient{}, - EmbeddingsRepo: &mockEmbeddingsRepoForSearch{ - getEmbeddingByTenantFunc: func(_ context.Context, id uuid.UUID, _, tenantID string) ([]float32, error) { - assert.Equal(t, rid, id) - assert.Equal(t, "env-1", tenantID) - - return nil, repository.ErrEmbeddingNotFound - }, - }, - Model: "test-model", - }) - res, err := svc.SimilarFeedback(context.Background(), rid, "env-1", 10, 0, "") - assert.Empty(t, res.Results) - assert.ErrorIs(t, err, repository.ErrEmbeddingNotFound) - }) - - t.Run("success returns results and excludes source record", func(t *testing.T) { + t.Run("derives tenant from source record", func(t *testing.T) { sourceID := uuid.MustParse("018e1234-5678-9abc-def0-123456789abc") similarID := uuid.MustParse("018e1234-5678-9abc-def0-aaaaaaaaaaaa") + sourceTenantID := " env-1 " svc := NewSearchService(SearchServiceParams{ EmbeddingClient: &mockEmbeddingClient{}, EmbeddingsRepo: &mockEmbeddingsRepoForSearch{ - getEmbeddingByTenantFunc: func(_ context.Context, id uuid.UUID, model, tenantID string) ([]float32, error) { + getEmbeddingAndTenantFunc: func(_ context.Context, id uuid.UUID, model string) ([]float32, string, error) { assert.Equal(t, sourceID, id) assert.Equal(t, "test-model", model) - assert.Equal(t, "env-1", tenantID) - return []float32{0.1, 0.2}, nil + return []float32{0.1, 0.2}, sourceTenantID, nil }, nearestFunc: func( _ context.Context, model string, _ []float32, tenantID string, limit int, excludeID *uuid.UUID, minScore float64, ) ([]models.FeedbackRecordWithScore, bool, error) { assert.Equal(t, "test-model", model) - assert.Equal(t, "env-1", tenantID) + assert.Equal(t, sourceTenantID, tenantID) assert.Equal(t, 10, limit) require.NotNil(t, excludeID) assert.Equal(t, sourceID, *excludeID) @@ -212,11 +182,29 @@ func TestSearchService_SimilarFeedback(t *testing.T) { }, Model: "test-model", }) - res, err := svc.SimilarFeedback(context.Background(), sourceID, "env-1", 10, 0.5, "") + res, err := svc.SimilarFeedback(context.Background(), sourceID, 10, 0.5, "") require.NoError(t, err) require.Len(t, res.Results, 1) assert.Equal(t, similarID, res.Results[0].FeedbackRecordID) - assert.InDelta(t, 0.88, res.Results[0].Score, 1e-9) + }) + + t.Run("embedding not found returns ErrEmbeddingNotFound", func(t *testing.T) { + rid := uuid.MustParse("018e1234-5678-9abc-def0-123456789abc") + svc := NewSearchService(SearchServiceParams{ + EmbeddingClient: &mockEmbeddingClient{}, + EmbeddingsRepo: &mockEmbeddingsRepoForSearch{ + getEmbeddingAndTenantFunc: func(_ context.Context, id uuid.UUID, model string) ([]float32, string, error) { + assert.Equal(t, rid, id) + assert.Equal(t, "test-model", model) + + return nil, "", repository.ErrEmbeddingNotFound + }, + }, + Model: "test-model", + }) + res, err := svc.SimilarFeedback(context.Background(), rid, 10, 0, "") + assert.Empty(t, res.Results) + assert.ErrorIs(t, err, repository.ErrEmbeddingNotFound) }) } diff --git a/openapi.yaml b/openapi.yaml index 830f537..4b12fbd 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -750,7 +750,7 @@ paths: **Only available when embeddings are configured** (EMBEDDING_PROVIDER and EMBEDDING_MODEL set). Supported providers: openai, google (Gemini Developer API / Google AI Studio), google-gemini (Gemini Enterprise Agent Platform API). When embeddings are disabled, this endpoint returns 503 Service Unavailable. - The source feedback record must belong to the given tenant_id (enforced). + Hub derives the tenant from the source feedback record and scopes the nearest-neighbor search to that tenant. operationId: similar-feedback-records parameters: - name: id @@ -761,14 +761,6 @@ paths: type: string format: uuid example: "018e1234-5678-9abc-def0-123456789abc" - - name: tenant_id - in: query - description: Tenant ID (required for isolation; must match feedback record tenant_id) - required: true - schema: - type: string - minLength: 1 - example: "org-123" - name: limit in: query description: Number of results to return (default 10, max 100). Consistent with list endpoints. @@ -802,13 +794,13 @@ paths: schema: $ref: '#/components/schemas/SemanticSearchResponse' "400": - description: Bad Request (e.g. missing or empty tenant_id, invalid cursor) + description: Bad Request (e.g. invalid cursor) content: application/problem+json: schema: $ref: '#/components/schemas/ErrorModel' "404": - description: Not Found (feedback record has no embedding for the current model) + description: Not Found (feedback record has no embedding for the current model, or source record has no tenant) content: application/problem+json: schema: