diff --git a/tests/test_retriever_edges.py b/tests/test_retriever_edges.py new file mode 100644 index 0000000..12deaba --- /dev/null +++ b/tests/test_retriever_edges.py @@ -0,0 +1,74 @@ +from app.core import retriever as r +from app.models.dto import QueryRequest +import app.core.qdrant as qdr +import app.core.qdrant_points as qp +import app.core.graph_adapter as ga +import app.services.embeddings_client as ec + + +class _DummyClient: + pass + + +def _fake_get_client(cfg): + return _DummyClient() + + +def _fake_embed_text(text: str): + # Dimension ist egal, da wir search_chunks_by_vector faken + return [0.0] * 384 + + +def _fake_search_chunks_by_vector(client, prefix, vector, top=10, filters=None): + # Drei Treffer mit identischem semantic_score und retriever_weight, + # aber unterschiedlichen Graph-Boni + return [ + ("chunk:1", 0.8, {"note_id": "note:1", "chunk_id": "c1", "path": "a.md", "section_title": "S1"}), + ("chunk:2", 0.8, {"note_id": "note:2", "chunk_id": "c2", "path": "b.md", "section_title": "S2"}), + ("chunk:3", 0.8, {"note_id": "note:3", "chunk_id": "c3", "path": "c.md", "section_title": "S3"}), + ] + + +class _DummySubgraph: + def __init__(self, edge_scores, cent_scores): + self._edge = edge_scores + self._cent = cent_scores + + def edge_bonus(self, node_id: str) -> float: + return self._edge.get(node_id, 0.0) + + def centrality_bonus(self, node_id: str) -> float: + return self._cent.get(node_id, 0.0) + + +def _fake_expand(client, prefix, seeds, depth=1, edge_types=None): + # Wir gewichten c2 > c1 > c3 + edge_scores = {"c1": 0.1, "c2": 0.5, "c3": 0.0} + cent_scores = {"c1": 0.05, "c2": 0.2, "c3": 0.0} + return _DummySubgraph(edge_scores, cent_scores) + + +def test_hybrid_retrieve_uses_edge_scores(monkeypatch): + # Qdrant-Client, Embeddings, Chunk-Suche und Graph-Expansion faken + monkeypatch.setattr(qdr, "get_client", _fake_get_client) + monkeypatch.setattr(ec, "embed_text", _fake_embed_text) + monkeypatch.setattr(qp, "search_chunks_by_vector", _fake_search_chunks_by_vector) + monkeypatch.setattr(ga, "expand", _fake_expand) + + # expand.depth > 0 triggern, edge_types sind hier egal + req = QueryRequest(mode="hybrid", query="karate", top_k=3, expand={"depth": 1, "edge_types": ["references"]}) + resp = r.hybrid_retrieve(req) + + assert len(resp.results) == 3 + + # Da alle semantic_scores und retriever_weight gleich sind, + # bestimmt die Summe aus edge_bonus + centrality_bonus das Ranking. + first = resp.results[0] + second = resp.results[1] + third = resp.results[2] + + assert first.note_id == "note:2" # c2 hat höchste Graph-Boni + assert second.note_id == "note:1" + assert third.note_id == "note:3" + + assert first.total_score > second.total_score > third.total_score