All checks were successful
Deploy mindnet to llm-node / deploy (push) Successful in 2s
118 lines
4.3 KiB
Python
118 lines
4.3 KiB
Python
"""
|
||
app/core/retriever.py — Semantischer/Edge-Aware/Hybrid Retriever (WP-04)
|
||
|
||
Zweck:
|
||
Kandidatenfindung via Vektorsuche in *_chunks, optionale Edge-Expansion und
|
||
kombiniertes Ranking zur Rückgabe von Top-K Treffern.
|
||
Kompatibilität:
|
||
Python 3.12+, qdrant-client 1.x
|
||
Version:
|
||
0.1.0 (Erstanlage)
|
||
Stand:
|
||
2025-10-07
|
||
Bezug:
|
||
- app/core/graph_adapter.py (expand)
|
||
- app/core/ranking.py (combine_scores)
|
||
- app/core/qdrant_points.py (search_chunks_by_vector)
|
||
Nutzung:
|
||
from app.core.retriever import hybrid_retrieve
|
||
Änderungsverlauf:
|
||
0.1.0 (2025-10-07) – Erstanlage.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
import time
|
||
from typing import Dict, List, Optional, Tuple
|
||
from qdrant_client import QdrantClient
|
||
|
||
from app.models.dto import QueryRequest, QueryResponse, QueryHit
|
||
from app.core.ranking import combine_scores
|
||
from app.core.graph_adapter import expand
|
||
from app.core import qdrant_points as qp
|
||
from app.config import get_settings
|
||
|
||
|
||
def _require_query_vector(req: QueryRequest) -> List[float]:
|
||
"""
|
||
Für den Schnelltest ohne eingebundene Embeddings muss query_vector gesetzt sein.
|
||
Später kann hier der Embed-Aufruf (Text → 384d) angebunden werden.
|
||
"""
|
||
if not req.query_vector:
|
||
raise ValueError(
|
||
"query_vector fehlt. Für den Quick-Test ohne Embeddings bitte einen 384d-Vektor übergeben."
|
||
)
|
||
return req.query_vector
|
||
|
||
|
||
def semantic_retrieve(req: QueryRequest) -> QueryResponse:
|
||
"""Nur semantische Kandidaten, keine Edge-Expansion (depth=0)."""
|
||
t0 = time.time()
|
||
s = get_settings()
|
||
client = QdrantClient(url=s.QDRANT_URL, api_key=s.QDRANT_API_KEY)
|
||
|
||
q_vec = _require_query_vector(req)
|
||
raw_hits = qp.search_chunks_by_vector(client, s.COLLECTION_PREFIX, q_vec, top=req.top_k, filters=req.filters)
|
||
id2payload = {pid: payload for (pid, score, payload) in raw_hits}
|
||
|
||
results: List[QueryHit] = []
|
||
for pid, s_score, payload in raw_hits:
|
||
results.append(QueryHit(
|
||
node_id=pid,
|
||
note_id=payload.get("note_id"),
|
||
semantic_score=float(s_score),
|
||
edge_bonus=0.0,
|
||
centrality_bonus=0.0,
|
||
total_score=float(s_score), # hier un-normalisiert; ok für schnelle Prüfung
|
||
paths=None,
|
||
source={"path": payload.get("path"), "section": payload.get("section_title")}
|
||
))
|
||
|
||
dt = int((time.time() - t0) * 1000)
|
||
return QueryResponse(results=results, used_mode="semantic", latency_ms=dt)
|
||
|
||
|
||
def hybrid_retrieve(req: QueryRequest) -> QueryResponse:
|
||
"""Semantik + Edge-Expansion + kombiniertes Ranking."""
|
||
t0 = time.time()
|
||
s = get_settings()
|
||
client = QdrantClient(url=s.QDRANT_URL, api_key=s.QDRANT_API_KEY)
|
||
|
||
q_vec = _require_query_vector(req)
|
||
|
||
# 1) Semantische Seeds (top_k * 3 für breitere Basis)
|
||
raw_hits = qp.search_chunks_by_vector(client, s.COLLECTION_PREFIX, q_vec, top=req.top_k * 3, filters=req.filters)
|
||
id2payload = {pid: payload for (pid, score, payload) in raw_hits}
|
||
seeds = [pid for (pid, _, _) in raw_hits]
|
||
|
||
# 2) Edge-Expansion
|
||
edge_types = req.expand.get("edge_types") if req.expand else None
|
||
depth = req.expand.get("depth", 1) if req.expand else 1
|
||
sg = expand(client, s.COLLECTION_PREFIX, seeds, depth=depth, edge_types=edge_types)
|
||
|
||
edge_bonus_map = {pid: sg.aggregate_edge_bonus(pid) for pid in seeds}
|
||
centrality_map = {pid: sg.centrality_bonus(pid) for pid in seeds}
|
||
|
||
# 3) Combined Ranking
|
||
scored = combine_scores(raw_hits, edge_bonus_map, centrality_map,
|
||
w_sem=s.RETRIEVER_W_SEM,
|
||
w_edge=s.RETRIEVER_W_EDGE,
|
||
w_cent=s.RETRIEVER_W_CENT)
|
||
|
||
# 4) Antwortobjekte (Chunk-Ebene)
|
||
results: List[QueryHit] = []
|
||
for pid, total, e, c, s_score in scored[: req.top_k]:
|
||
payload = id2payload[pid]
|
||
results.append(QueryHit(
|
||
node_id=pid,
|
||
note_id=payload.get("note_id"),
|
||
semantic_score=float(s_score),
|
||
edge_bonus=float(e),
|
||
centrality_bonus=float(c),
|
||
total_score=float(total),
|
||
paths=None,
|
||
source={"path": payload.get("path"), "section": payload.get("section_title")}
|
||
))
|
||
|
||
dt = int((time.time() - t0) * 1000)
|
||
return QueryResponse(results=results, used_mode="hybrid", latency_ms=dt)
|