mindnet/app/core/retriever.py
Lars 5e86398fcd
All checks were successful
Deploy mindnet to llm-node / deploy (push) Successful in 3s
Dateien nach "app/core" hochladen
2025-12-03 12:03:37 +01:00

234 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import time
from typing import Any, Dict, List, Tuple, Iterable
from app.config import get_settings
from app.models.dto import QueryRequest, QueryResponse, QueryHit
import app.core.qdrant as qdr
import app.core.qdrant_points as qp
import app.services.embeddings_client as ec
import app.core.graph_adapter as ga
def _get_client_and_prefix() -> Tuple[Any, str]:
"""Liefert (QdrantClient, prefix) basierend auf QdrantConfig.from_env()."""
cfg = qdr.QdrantConfig.from_env()
client = qdr.get_client(cfg)
return client, cfg.prefix
def _get_query_vector(req: QueryRequest) -> List[float]:
"""
Liefert den Query-Vektor aus dem Request.
- Falls req.query_vector gesetzt ist, wird dieser unverändert genutzt.
- Falls req.query (Text) gesetzt ist, wird ec.embed_text(req.query) aufgerufen.
- Andernfalls: ValueError.
"""
if req.query_vector is not None:
if not isinstance(req.query_vector, list):
raise ValueError("query_vector muss eine Liste von floats sein")
return req.query_vector
if req.query:
return ec.embed_text(req.query)
raise ValueError("Weder query_vector noch query gesetzt mindestens eines ist erforderlich")
def _semantic_hits(
client: Any,
prefix: str,
vector: List[float],
top_k: int,
filters: Dict | None,
):
"""Kapselt den Aufruf von qp.search_chunks_by_vector."""
flt = filters or None
hits = qp.search_chunks_by_vector(client, prefix, vector, top=top_k, filters=flt)
return hits
def _resolve_top_k(req: QueryRequest) -> int:
"""Ermittelt ein sinnvolles top_k."""
if isinstance(req.top_k, int) and req.top_k > 0:
return req.top_k
s = get_settings()
return max(1, int(getattr(s, "RETRIEVER_TOP_K", 10)))
def _compute_total_score(
semantic_score: float,
payload: Dict[str, Any],
edge_bonus: float = 0.0,
cent_bonus: float = 0.0,
) -> Tuple[float, float, float]:
"""Berechnet total_score auf Basis von semantic_score, retriever_weight und Graph-Boni.
Aktuelle Formel (Step 3):
total_score = semantic_score * max(retriever_weight, 0.0) + edge_bonus + cent_bonus
retriever_weight stammt aus dem Chunk-Payload und ist bereits aus types.yaml
abgeleitet. Falls nicht gesetzt oder nicht interpretierbar, wird 1.0 angenommen.
"""
raw_weight = payload.get("retriever_weight", 1.0)
try:
weight = float(raw_weight)
except (TypeError, ValueError):
weight = 1.0
if weight < 0.0:
weight = 0.0
total = float(semantic_score) * weight + edge_bonus + cent_bonus
return total, float(edge_bonus), float(cent_bonus)
def _extract_expand_options(req: QueryRequest) -> Tuple[int, List[str] | None]:
"""Extrahiert depth und edge_types aus req.expand, falls vorhanden.
- Falls expand nicht gesetzt ist: depth=0, edge_types=None (keine Expansion).
- Unterstützt sowohl Pydantic-Modelle als auch plain dicts.
"""
expand = getattr(req, "expand", None)
if not expand:
return 0, None
depth = 1
edge_types = None
# Pydantic-Modell oder Objekt mit Attributen
if hasattr(expand, "depth") or hasattr(expand, "edge_types"):
try:
depth_val = getattr(expand, "depth", 1) or 1
depth = int(depth_val)
except Exception:
depth = 1
edge_types = getattr(expand, "edge_types", None)
# plain dict aus FastAPI/Pydantic
elif isinstance(expand, dict):
try:
depth_val = expand.get("depth", 1) or 1
depth = int(depth_val)
except Exception:
depth = 1
edge_types = expand.get("edge_types")
if depth < 0:
depth = 0
if edge_types is not None and not isinstance(edge_types, list):
try:
edge_types = list(edge_types)
except Exception:
edge_types = None
return depth, edge_types
def _build_hits_from_semantic(
hits: List[Tuple[str, float, Dict[str, Any]]],
top_k: int,
used_mode: str,
subgraph: Any | None = None,
) -> QueryResponse:
"""Formt rohe Treffer in QueryResponse um und wendet das Scoring an."""
t0 = time.time()
enriched: List[Tuple[str, float, Dict[str, Any], float, float, float]] = []
for pid, semantic_score, payload in hits:
# Graph-Scores, falls Subgraph und stabiler Key vorhanden
edge_bonus = 0.0
cent_bonus = 0.0
if subgraph is not None:
node_key = payload.get("chunk_id") or payload.get("note_id")
if node_key:
try:
edge_bonus = float(subgraph.edge_bonus(node_key))
except Exception:
edge_bonus = 0.0
try:
cent_bonus = float(subgraph.centrality_bonus(node_key))
except Exception:
cent_bonus = 0.0
total, edge_bonus, cent_bonus = _compute_total_score(
semantic_score,
payload,
edge_bonus=edge_bonus,
cent_bonus=cent_bonus,
)
enriched.append((pid, float(semantic_score), payload, total, edge_bonus, cent_bonus))
# Sortierung nach total_score absteigend
enriched_sorted = sorted(enriched, key=lambda h: h[3], reverse=True)
limited = enriched_sorted[: max(1, top_k)]
results: List[QueryHit] = []
for pid, semantic_score, payload, total, edge_bonus, cent_bonus in limited:
note_id = payload.get("note_id")
path = payload.get("path")
section = payload.get("section_title")
results.append(
QueryHit(
node_id=str(pid),
note_id=note_id,
semantic_score=float(semantic_score),
edge_bonus=edge_bonus,
centrality_bonus=cent_bonus,
total_score=total,
paths=None,
source={
"path": path,
"section": section,
},
)
)
dt = int((time.time() - t0) * 1000)
return QueryResponse(results=results, used_mode=used_mode, latency_ms=dt)
def semantic_retrieve(req: QueryRequest) -> QueryResponse:
"""Reiner semantischer Retriever (ohne Edge-Expansion)."""
top_k = _resolve_top_k(req)
vector = _get_query_vector(req)
client, prefix = _get_client_and_prefix()
hits = _semantic_hits(client, prefix, vector, top_k=top_k, filters=req.filters)
# semantic mode: keine Edge-Expansion
return _build_hits_from_semantic(hits, top_k=top_k, used_mode="semantic", subgraph=None)
def hybrid_retrieve(req: QueryRequest) -> QueryResponse:
"""Hybrid-Retriever mit optionaler Edge-Expansion.
Schritt 3:
- Basis sind die semantischen Chunk-Treffer (wie im Semantic-Modus)
- Zusätzlich wird, falls req.expand gesetzt ist und depth > 0,
ein lokaler Subgraph über ga.expand aufgebaut und zur Score-Berechnung verwendet.
"""
top_k = _resolve_top_k(req)
vector = _get_query_vector(req)
client, prefix = _get_client_and_prefix()
hits = _semantic_hits(client, prefix, vector, top_k=top_k, filters=req.filters)
depth, edge_types = _extract_expand_options(req)
subgraph = None
if depth > 0:
# Seeds: stabile IDs aus dem Payload (chunk_id bevorzugt, sonst note_id)
seed_ids: List[str] = []
for _, _score, payload in hits:
key = payload.get("chunk_id") or payload.get("note_id")
if key and key not in seed_ids:
seed_ids.append(key)
if seed_ids:
try:
subgraph = ga.expand(client, prefix, seed_ids, depth=depth, edge_types=edge_types)
except Exception:
# Edge-Expansion ist optional: bei Fehlern weiter ohne Graph-Boni
subgraph = None
return _build_hits_from_semantic(hits, top_k=top_k, used_mode="hybrid", subgraph=subgraph)