mindnet/app/core/retriever.py
Lars 4b5c33bb90
All checks were successful
Deploy mindnet to llm-node / deploy (push) Successful in 2s
app/core/retriever.py hinzugefügt
2025-10-07 11:30:53 +02:00

118 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
app/core/retriever.py — Semantischer/Edge-Aware/Hybrid Retriever (WP-04)
Zweck:
Kandidatenfindung via Vektorsuche in *_chunks, optionale Edge-Expansion und
kombiniertes Ranking zur Rückgabe von Top-K Treffern.
Kompatibilität:
Python 3.12+, qdrant-client 1.x
Version:
0.1.0 (Erstanlage)
Stand:
2025-10-07
Bezug:
- app/core/graph_adapter.py (expand)
- app/core/ranking.py (combine_scores)
- app/core/qdrant_points.py (search_chunks_by_vector)
Nutzung:
from app.core.retriever import hybrid_retrieve
Änderungsverlauf:
0.1.0 (2025-10-07) Erstanlage.
"""
from __future__ import annotations
import time
from typing import Dict, List, Optional, Tuple
from qdrant_client import QdrantClient
from app.models.dto import QueryRequest, QueryResponse, QueryHit
from app.core.ranking import combine_scores
from app.core.graph_adapter import expand
from app.core import qdrant_points as qp
from app.config import get_settings
def _require_query_vector(req: QueryRequest) -> List[float]:
"""
Für den Schnelltest ohne eingebundene Embeddings muss query_vector gesetzt sein.
Später kann hier der Embed-Aufruf (Text → 384d) angebunden werden.
"""
if not req.query_vector:
raise ValueError(
"query_vector fehlt. Für den Quick-Test ohne Embeddings bitte einen 384d-Vektor übergeben."
)
return req.query_vector
def semantic_retrieve(req: QueryRequest) -> QueryResponse:
"""Nur semantische Kandidaten, keine Edge-Expansion (depth=0)."""
t0 = time.time()
s = get_settings()
client = QdrantClient(url=s.QDRANT_URL, api_key=s.QDRANT_API_KEY)
q_vec = _require_query_vector(req)
raw_hits = qp.search_chunks_by_vector(client, s.COLLECTION_PREFIX, q_vec, top=req.top_k, filters=req.filters)
id2payload = {pid: payload for (pid, score, payload) in raw_hits}
results: List[QueryHit] = []
for pid, s_score, payload in raw_hits:
results.append(QueryHit(
node_id=pid,
note_id=payload.get("note_id"),
semantic_score=float(s_score),
edge_bonus=0.0,
centrality_bonus=0.0,
total_score=float(s_score), # hier un-normalisiert; ok für schnelle Prüfung
paths=None,
source={"path": payload.get("path"), "section": payload.get("section_title")}
))
dt = int((time.time() - t0) * 1000)
return QueryResponse(results=results, used_mode="semantic", latency_ms=dt)
def hybrid_retrieve(req: QueryRequest) -> QueryResponse:
"""Semantik + Edge-Expansion + kombiniertes Ranking."""
t0 = time.time()
s = get_settings()
client = QdrantClient(url=s.QDRANT_URL, api_key=s.QDRANT_API_KEY)
q_vec = _require_query_vector(req)
# 1) Semantische Seeds (top_k * 3 für breitere Basis)
raw_hits = qp.search_chunks_by_vector(client, s.COLLECTION_PREFIX, q_vec, top=req.top_k * 3, filters=req.filters)
id2payload = {pid: payload for (pid, score, payload) in raw_hits}
seeds = [pid for (pid, _, _) in raw_hits]
# 2) Edge-Expansion
edge_types = req.expand.get("edge_types") if req.expand else None
depth = req.expand.get("depth", 1) if req.expand else 1
sg = expand(client, s.COLLECTION_PREFIX, seeds, depth=depth, edge_types=edge_types)
edge_bonus_map = {pid: sg.aggregate_edge_bonus(pid) for pid in seeds}
centrality_map = {pid: sg.centrality_bonus(pid) for pid in seeds}
# 3) Combined Ranking
scored = combine_scores(raw_hits, edge_bonus_map, centrality_map,
w_sem=s.RETRIEVER_W_SEM,
w_edge=s.RETRIEVER_W_EDGE,
w_cent=s.RETRIEVER_W_CENT)
# 4) Antwortobjekte (Chunk-Ebene)
results: List[QueryHit] = []
for pid, total, e, c, s_score in scored[: req.top_k]:
payload = id2payload[pid]
results.append(QueryHit(
node_id=pid,
note_id=payload.get("note_id"),
semantic_score=float(s_score),
edge_bonus=float(e),
centrality_bonus=float(c),
total_score=float(total),
paths=None,
source={"path": payload.get("path"), "section": payload.get("section_title")}
))
dt = int((time.time() - t0) * 1000)
return QueryResponse(results=results, used_mode="hybrid", latency_ms=dt)