|
6 | 6 | import com.fasterxml.jackson.databind.node.ArrayNode; |
7 | 7 | import com.microsoft.semantickernel.data.vectorsearch.VectorOperations; |
8 | 8 | import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResult; |
9 | | -import com.microsoft.semantickernel.data.vectorsearch.queries.VectorSearchQuery; |
10 | | -import com.microsoft.semantickernel.data.vectorsearch.queries.VectorizedSearchQuery; |
11 | 9 | import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordMapper; |
12 | 10 | import com.microsoft.semantickernel.data.vectorstorage.definition.DistanceFunction; |
13 | 11 | import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition; |
@@ -471,101 +469,94 @@ protected <Record> List<Record> getRecordsWithFilter(String collectionName, |
471 | 469 | * |
472 | 470 | * @param <Record> the record type |
473 | 471 | * @param collectionName the collection name |
474 | | - * @param query the vectorized search query, containing the vector and search options |
| 472 | + * @param vector the vector to search with |
| 473 | + * @param options the search options |
475 | 474 | * @param recordDefinition the record definition |
476 | 475 | * @param mapper the mapper, responsible for mapping the result set to the record type. |
477 | 476 | * @return the search results |
478 | 477 | */ |
479 | 478 | @Override |
480 | 479 | public <Record> List<VectorSearchResult<Record>> search(String collectionName, |
481 | | - VectorSearchQuery query, VectorStoreRecordDefinition recordDefinition, |
| 480 | + List<Float> vector, VectorSearchOptions options, |
| 481 | + VectorStoreRecordDefinition recordDefinition, |
482 | 482 | VectorStoreRecordMapper<Record, ResultSet> mapper) { |
483 | 483 | if (recordDefinition.getVectorFields().isEmpty()) { |
484 | 484 | throw new SKException("No vector fields defined. Cannot perform vector search"); |
485 | 485 | } |
486 | 486 |
|
487 | | - if (query instanceof VectorizedSearchQuery) { |
488 | | - VectorizedSearchQuery vectorizedSearchQuery = (VectorizedSearchQuery) query; |
489 | | - VectorSearchOptions options = query.getSearchOptions(); |
| 487 | + VectorStoreRecordVectorField firstVectorField = recordDefinition.getVectorFields() |
| 488 | + .get(0); |
| 489 | + if (options == null) { |
| 490 | + options = VectorSearchOptions.createDefault(firstVectorField.getName()); |
| 491 | + } |
490 | 492 |
|
491 | | - VectorStoreRecordVectorField firstVectorField = recordDefinition.getVectorFields() |
492 | | - .get(0); |
493 | | - if (options == null) { |
494 | | - options = VectorSearchOptions.createDefault(firstVectorField.getName()); |
| 493 | + VectorStoreRecordVectorField vectorField = options.getVectorFieldName() == null |
| 494 | + ? firstVectorField |
| 495 | + : (VectorStoreRecordVectorField) recordDefinition |
| 496 | + .getField(options.getVectorFieldName()); |
| 497 | + |
| 498 | + String filter = SQLVectorStoreRecordCollectionSearchMapping |
| 499 | + .buildFilter(options.getVectorSearchFilter(), recordDefinition); |
| 500 | + List<Object> parameters = SQLVectorStoreRecordCollectionSearchMapping |
| 501 | + .getFilterParameters(options.getVectorSearchFilter()); |
| 502 | + |
| 503 | + List<Record> records = getRecordsWithFilter(collectionName, recordDefinition, mapper, |
| 504 | + new GetRecordOptions(true), filter, parameters); |
| 505 | + List<VectorSearchResult<Record>> results = new ArrayList<>(); |
| 506 | + |
| 507 | + DistanceFunction distanceFunction = vectorField.getDistanceFunction() == null |
| 508 | + ? DistanceFunction.EUCLIDEAN_DISTANCE |
| 509 | + : vectorField.getDistanceFunction(); |
| 510 | + |
| 511 | + for (Record record : records) { |
| 512 | + List<Float> recordVector; |
| 513 | + try { |
| 514 | + String json = new ObjectMapper().writeValueAsString(record); |
| 515 | + ArrayNode arrayNode = (ArrayNode) new ObjectMapper().readTree(json) |
| 516 | + .get(vectorField.getEffectiveStorageName()); |
| 517 | + |
| 518 | + recordVector = Stream.iterate(0, i -> i + 1) |
| 519 | + .limit(arrayNode.size()) |
| 520 | + .map(i -> arrayNode.get(i).floatValue()) |
| 521 | + .collect(Collectors.toList()); |
| 522 | + } catch (JsonProcessingException e) { |
| 523 | + throw new RuntimeException(e); |
495 | 524 | } |
496 | 525 |
|
497 | | - VectorStoreRecordVectorField vectorField = options.getVectorFieldName() == null |
498 | | - ? firstVectorField |
499 | | - : (VectorStoreRecordVectorField) recordDefinition |
500 | | - .getField(options.getVectorFieldName()); |
501 | | - |
502 | | - String filter = SQLVectorStoreRecordCollectionSearchMapping |
503 | | - .buildFilter(options.getVectorSearchFilter(), recordDefinition); |
504 | | - List<Object> parameters = SQLVectorStoreRecordCollectionSearchMapping |
505 | | - .getFilterParameters(options.getVectorSearchFilter()); |
506 | | - |
507 | | - List<Record> records = getRecordsWithFilter(collectionName, recordDefinition, mapper, |
508 | | - new GetRecordOptions(true), filter, parameters); |
509 | | - List<VectorSearchResult<Record>> results = new ArrayList<>(); |
510 | | - |
511 | | - DistanceFunction distanceFunction = vectorField.getDistanceFunction() == null |
512 | | - ? DistanceFunction.EUCLIDEAN_DISTANCE |
513 | | - : vectorField.getDistanceFunction(); |
514 | | - |
515 | | - for (Record record : records) { |
516 | | - List<Float> vector; |
517 | | - try { |
518 | | - String json = new ObjectMapper().writeValueAsString(record); |
519 | | - ArrayNode arrayNode = (ArrayNode) new ObjectMapper().readTree(json) |
520 | | - .get(vectorField.getEffectiveStorageName()); |
521 | | - |
522 | | - vector = Stream.iterate(0, i -> i + 1) |
523 | | - .limit(arrayNode.size()) |
524 | | - .map(i -> arrayNode.get(i).floatValue()) |
525 | | - .collect(Collectors.toList()); |
526 | | - } catch (JsonProcessingException e) { |
527 | | - throw new RuntimeException(e); |
528 | | - } |
529 | | - |
530 | | - double score; |
531 | | - switch (distanceFunction) { |
532 | | - case COSINE_SIMILARITY: |
533 | | - score = VectorOperations.cosineSimilarity(vectorizedSearchQuery.getVector(), |
534 | | - vector); |
535 | | - break; |
536 | | - case COSINE_DISTANCE: |
537 | | - score = VectorOperations.cosineDistance(vectorizedSearchQuery.getVector(), |
538 | | - vector); |
539 | | - break; |
540 | | - case EUCLIDEAN_DISTANCE: |
541 | | - score = VectorOperations |
542 | | - .euclideanDistance(vectorizedSearchQuery.getVector(), vector); |
543 | | - break; |
544 | | - case DOT_PRODUCT: |
545 | | - score = VectorOperations.dot(vectorizedSearchQuery.getVector(), vector); |
546 | | - break; |
547 | | - default: |
548 | | - throw new SKException("Unsupported distance function"); |
549 | | - } |
550 | | - |
551 | | - results.add(new VectorSearchResult<>(record, score)); |
| 526 | + double score; |
| 527 | + switch (distanceFunction) { |
| 528 | + case COSINE_SIMILARITY: |
| 529 | + score = VectorOperations.cosineSimilarity(vector, recordVector); |
| 530 | + break; |
| 531 | + case COSINE_DISTANCE: |
| 532 | + score = VectorOperations.cosineDistance(vector, recordVector); |
| 533 | + break; |
| 534 | + case EUCLIDEAN_DISTANCE: |
| 535 | + score = VectorOperations |
| 536 | + .euclideanDistance(vector, recordVector); |
| 537 | + break; |
| 538 | + case DOT_PRODUCT: |
| 539 | + score = VectorOperations.dot(vector, recordVector); |
| 540 | + break; |
| 541 | + default: |
| 542 | + throw new SKException("Unsupported distance function"); |
552 | 543 | } |
553 | 544 |
|
554 | | - Comparator<VectorSearchResult<Record>> comparator = Comparator |
555 | | - .comparingDouble(VectorSearchResult::getScore); |
556 | | - // Higher scores are better |
557 | | - if (distanceFunction == DistanceFunction.COSINE_SIMILARITY |
558 | | - || distanceFunction == DistanceFunction.DOT_PRODUCT) { |
559 | | - comparator = comparator.reversed(); |
560 | | - } |
561 | | - return results.stream() |
562 | | - .sorted(comparator) |
563 | | - .skip(options.getOffset()) |
564 | | - .limit(options.getLimit()) |
565 | | - .collect(Collectors.toList()); |
| 545 | + results.add(new VectorSearchResult<>(record, score)); |
566 | 546 | } |
567 | 547 |
|
568 | | - throw new SKException("Unsupported query type"); |
| 548 | + Comparator<VectorSearchResult<Record>> comparator = Comparator |
| 549 | + .comparingDouble(VectorSearchResult::getScore); |
| 550 | + // Higher scores are better |
| 551 | + if (distanceFunction == DistanceFunction.COSINE_SIMILARITY |
| 552 | + || distanceFunction == DistanceFunction.DOT_PRODUCT) { |
| 553 | + comparator = comparator.reversed(); |
| 554 | + } |
| 555 | + return results.stream() |
| 556 | + .sorted(comparator) |
| 557 | + .skip(options.getOffset()) |
| 558 | + .limit(options.getLimit()) |
| 559 | + .collect(Collectors.toList()); |
569 | 560 | } |
570 | 561 |
|
571 | 562 | /** |
|
0 commit comments