From e48bdb2401ec17da0240d8a515b17a9cc204c346 Mon Sep 17 00:00:00 2001 From: Lars Date: Tue, 2 Dec 2025 17:32:40 +0100 Subject: [PATCH] =?UTF-8?q?tests/test=5Fretriever=5Fbasic.py=20hinzugef?= =?UTF-8?q?=C3=BCgt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_retriever_basic.py | 64 +++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests/test_retriever_basic.py diff --git a/tests/test_retriever_basic.py b/tests/test_retriever_basic.py new file mode 100644 index 0000000..d250459 --- /dev/null +++ b/tests/test_retriever_basic.py @@ -0,0 +1,64 @@ +from app.core import retriever as r +from app.models.dto import QueryRequest +import app.core.qdrant as qdr +import app.core.qdrant_points as qp +import app.services.embeddings_client as ec + + +class _DummyClient: + """Minimaler Platzhalter für QdrantClient in Unit-Tests.""" + pass + + +def _fake_get_client(cfg): + return _DummyClient() + + +def _fake_embed_text(text: str): + # Liefert stabilen Vektor; Größe ist für diesen Test egal, + # da search_chunks_by_vector ebenfalls gefaked wird. + return [0.0] * 16 + + +def _fake_search_chunks_by_vector(client, prefix, vector, top=10, filters=None): + # einfache Hitliste in absteigender Score-Reihenfolge + return [ + ("chunk:1", 0.9, {"note_id": "note:1", "path": "a.md", "section_title": "S1"}), + ("chunk:2", 0.7, {"note_id": "note:2", "path": "b.md", "section_title": "S2"}), + ("chunk:3", 0.4, {"note_id": "note:3", "path": "c.md", "section_title": "S3"}), + ] + + +def test_semantic_retrieve_basic(monkeypatch): + # Qdrant- und Embedding-Aufrufe faken + monkeypatch.setattr(qdr, "get_client", _fake_get_client) + monkeypatch.setattr(ec, "embed_text", _fake_embed_text) + monkeypatch.setattr(qp, "search_chunks_by_vector", _fake_search_chunks_by_vector) + + req = QueryRequest(mode="semantic", query="karate trainingsplan", top_k=2) + resp = r.semantic_retrieve(req) + + assert resp.used_mode == "semantic" + assert len(resp.results) == 2 + + # Scores absteigend sortiert + scores = [h.total_score for h in resp.results] + assert scores[0] >= scores[1] + # Korrekte Zuordnung von note_id und Quelle + assert resp.results[0].note_id == "note:1" + assert resp.results[0].source["path"] == "a.md" + + +def test_hybrid_retrieve_basic(monkeypatch): + # Qdrant- und Embedding-Aufrufe faken + monkeypatch.setattr(qdr, "get_client", _fake_get_client) + monkeypatch.setattr(qp, "search_chunks_by_vector", _fake_search_chunks_by_vector) + + # Im Hybrid-Modus arbeiten wir im Step-1-Stand nur mit query_vector + req = QueryRequest(mode="hybrid", query_vector=[0.0] * 16, top_k=2) + resp = r.hybrid_retrieve(req) + + assert resp.used_mode == "hybrid" + assert len(resp.results) == 2 + scores = [h.total_score for h in resp.results] + assert scores[0] >= scores[1]