diff --git a/app/core/retriever.py b/app/core/retriever.py new file mode 100644 index 0000000..8484474 --- /dev/null +++ b/app/core/retriever.py @@ -0,0 +1,117 @@ +""" +app/core/retriever.py — Semantischer/Edge-Aware/Hybrid Retriever (WP-04) + +Zweck: + Kandidatenfindung via Vektorsuche in *_chunks, optionale Edge-Expansion und + kombiniertes Ranking zur Rückgabe von Top-K Treffern. +Kompatibilität: + Python 3.12+, qdrant-client 1.x +Version: + 0.1.0 (Erstanlage) +Stand: + 2025-10-07 +Bezug: + - app/core/graph_adapter.py (expand) + - app/core/ranking.py (combine_scores) + - app/core/qdrant_points.py (search_chunks_by_vector) +Nutzung: + from app.core.retriever import hybrid_retrieve +Änderungsverlauf: + 0.1.0 (2025-10-07) – Erstanlage. +""" + +from __future__ import annotations +import time +from typing import Dict, List, Optional, Tuple +from qdrant_client import QdrantClient + +from app.models.dto import QueryRequest, QueryResponse, QueryHit +from app.core.ranking import combine_scores +from app.core.graph_adapter import expand +from app.core import qdrant_points as qp +from app.config import get_settings + + +def _require_query_vector(req: QueryRequest) -> List[float]: + """ + Für den Schnelltest ohne eingebundene Embeddings muss query_vector gesetzt sein. + Später kann hier der Embed-Aufruf (Text → 384d) angebunden werden. + """ + if not req.query_vector: + raise ValueError( + "query_vector fehlt. Für den Quick-Test ohne Embeddings bitte einen 384d-Vektor übergeben." + ) + return req.query_vector + + +def semantic_retrieve(req: QueryRequest) -> QueryResponse: + """Nur semantische Kandidaten, keine Edge-Expansion (depth=0).""" + t0 = time.time() + s = get_settings() + client = QdrantClient(url=s.QDRANT_URL, api_key=s.QDRANT_API_KEY) + + q_vec = _require_query_vector(req) + raw_hits = qp.search_chunks_by_vector(client, s.COLLECTION_PREFIX, q_vec, top=req.top_k, filters=req.filters) + id2payload = {pid: payload for (pid, score, payload) in raw_hits} + + results: List[QueryHit] = [] + for pid, s_score, payload in raw_hits: + results.append(QueryHit( + node_id=pid, + note_id=payload.get("note_id"), + semantic_score=float(s_score), + edge_bonus=0.0, + centrality_bonus=0.0, + total_score=float(s_score), # hier un-normalisiert; ok für schnelle Prüfung + paths=None, + source={"path": payload.get("path"), "section": payload.get("section_title")} + )) + + dt = int((time.time() - t0) * 1000) + return QueryResponse(results=results, used_mode="semantic", latency_ms=dt) + + +def hybrid_retrieve(req: QueryRequest) -> QueryResponse: + """Semantik + Edge-Expansion + kombiniertes Ranking.""" + t0 = time.time() + s = get_settings() + client = QdrantClient(url=s.QDRANT_URL, api_key=s.QDRANT_API_KEY) + + q_vec = _require_query_vector(req) + + # 1) Semantische Seeds (top_k * 3 für breitere Basis) + raw_hits = qp.search_chunks_by_vector(client, s.COLLECTION_PREFIX, q_vec, top=req.top_k * 3, filters=req.filters) + id2payload = {pid: payload for (pid, score, payload) in raw_hits} + seeds = [pid for (pid, _, _) in raw_hits] + + # 2) Edge-Expansion + edge_types = req.expand.get("edge_types") if req.expand else None + depth = req.expand.get("depth", 1) if req.expand else 1 + sg = expand(client, s.COLLECTION_PREFIX, seeds, depth=depth, edge_types=edge_types) + + edge_bonus_map = {pid: sg.aggregate_edge_bonus(pid) for pid in seeds} + centrality_map = {pid: sg.centrality_bonus(pid) for pid in seeds} + + # 3) Combined Ranking + scored = combine_scores(raw_hits, edge_bonus_map, centrality_map, + w_sem=s.RETRIEVER_W_SEM, + w_edge=s.RETRIEVER_W_EDGE, + w_cent=s.RETRIEVER_W_CENT) + + # 4) Antwortobjekte (Chunk-Ebene) + results: List[QueryHit] = [] + for pid, total, e, c, s_score in scored[: req.top_k]: + payload = id2payload[pid] + results.append(QueryHit( + node_id=pid, + note_id=payload.get("note_id"), + semantic_score=float(s_score), + edge_bonus=float(e), + centrality_bonus=float(c), + total_score=float(total), + paths=None, + source={"path": payload.get("path"), "section": payload.get("section_title")} + )) + + dt = int((time.time() - t0) * 1000) + return QueryResponse(results=results, used_mode="hybrid", latency_ms=dt)