diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index 62782bf36..1306870c0 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -273,26 +273,32 @@ def _rerank_scores( if not self._rerank_client or not documents: return fallback_scores + rerank_documents = [ + (index, document) for index, document in enumerate(documents) if document.strip() + ] + if not rerank_documents: + return fallback_scores + try: - scores = self._rerank_client.rerank_batch(query, documents) + scores = self._rerank_client.rerank_batch( + query, [document for _, document in rerank_documents] + ) except Exception as e: logger.warning( "[HierarchicalRetriever] Rerank failed, fallback to vector scores: %s", e ) return fallback_scores - if not scores or len(scores) != len(documents): + if not scores or len(scores) != len(rerank_documents): logger.warning( "[HierarchicalRetriever] Invalid rerank result, fallback to vector scores" ) return fallback_scores - normalized_scores: List[float] = [] - for score, fallback in zip(scores, fallback_scores, strict=True): + normalized_scores = list(fallback_scores) + for score, (index, _) in zip(scores, rerank_documents, strict=True): if isinstance(score, (int, float)): - normalized_scores.append(float(score)) - else: - normalized_scores.append(fallback) + normalized_scores[index] = float(score) return normalized_scores def _merge_starting_points( diff --git a/tests/retrieve/test_hierarchical_retriever_rerank.py b/tests/retrieve/test_hierarchical_retriever_rerank.py index 1c5ab89de..2b1e4c5d8 100644 --- a/tests/retrieve/test_hierarchical_retriever_rerank.py +++ b/tests/retrieve/test_hierarchical_retriever_rerank.py @@ -300,6 +300,29 @@ def test_merge_starting_points_prefers_rerank_scores_in_thinking_mode(monkeypatc assert fake_client.calls == [("hello", ["root A", "root B"])] +def test_rerank_scores_preserves_fallbacks_for_empty_documents(monkeypatch): + fake_client = FakeRerankClient([0.95, 0.05]) + monkeypatch.setattr( + "openviking.retrieve.hierarchical_retriever.RerankClient.from_config", + lambda config: fake_client, + ) + + retriever = HierarchicalRetriever( + storage=DummyStorage(), + embedder=DummyEmbedder(), + rerank_config=_config(), + ) + + scores = retriever._rerank_scores( + "hello", + ["root A", "", " ", "root D"], + [0.2, 0.8, 0.7, 0.4], + ) + + assert scores == [0.95, 0.8, 0.7, 0.05] + assert fake_client.calls == [("hello", ["root A", "root D"])] + + @pytest.mark.asyncio async def test_retrieve_uses_rerank_scores_in_thinking_mode(monkeypatch): fake_client = FakeRerankClient([0.95, 0.05, 0.11, 0.95])