mindnet/app/core/retriever.py
Lars 06795feed6
All checks were successful
Deploy mindnet to llm-node / deploy (push) Successful in 4s
Dateien nach "app/core" hochladen
2025-12-03 10:39:11 +01:00

132 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import time
from typing import Any, Dict, List, Tuple
from app.config import get_settings
from app.models.dto import QueryRequest, QueryResponse, QueryHit
import app.core.qdrant as qdr
import app.core.qdrant_points as qp
import app.services.embeddings_client as ec
def _get_client_and_prefix() -> Tuple[Any, str]:
cfg = qdr.QdrantConfig.from_env()
client = qdr.get_client(cfg)
return client, cfg.prefix
def _get_query_vector(req: QueryRequest) -> List[float]:
if req.query_vector is not None:
if not isinstance(req.query_vector, list):
raise ValueError("query_vector muss eine Liste von floats sein")
return req.query_vector
if req.query:
return ec.embed_text(req.query)
raise ValueError("Weder query_vector noch query gesetzt mindestens eines ist erforderlich")
def _semantic_hits(
client: Any,
prefix: str,
vector: List[float],
top_k: int,
filters: Dict | None,
):
flt = filters or None
hits = qp.search_chunks_by_vector(client, prefix, vector, top=top_k, filters=flt)
return hits
def _resolve_top_k(req: QueryRequest) -> int:
if isinstance(req.top_k, int) and req.top_k > 0:
return req.top_k
s = get_settings()
return max(1, int(getattr(s, "RETRIEVER_TOP_K", 10)))
def _compute_total_score(semantic_score: float, payload: Dict[str, Any]) -> Tuple[float, float, float]:
"""Berechnet total_score auf Basis von semantic_score und retriever_weight.
Aktuelle Formel (Step 2):
total_score = semantic_score * max(retriever_weight, 0.0)
retriever_weight stammt aus dem Chunk-Payload und ist bereits aus types.yaml
abgeleitet. Falls nicht gesetzt, wird 1.0 angenommen.
edge_bonus und centrality_bonus bleiben in diesem Schritt 0.0.
"""
raw_weight = payload.get("retriever_weight", 1.0)
try:
weight = float(raw_weight)
except (TypeError, ValueError):
weight = 1.0
if weight < 0.0:
weight = 0.0
edge_bonus = 0.0
cent_bonus = 0.0
total = float(semantic_score) * weight + edge_bonus + cent_bonus
return total, edge_bonus, cent_bonus
def _build_hits_from_semantic(
hits: List[Tuple[str, float, Dict[str, Any]]],
top_k: int,
used_mode: str,
) -> QueryResponse:
"""Formt rohe Treffer in QueryResponse um und wendet das Scoring an."""
t0 = time.time()
enriched: List[Tuple[str, float, Dict[str, Any], float, float, float]] = []
for pid, semantic_score, payload in hits:
total, edge_bonus, cent_bonus = _compute_total_score(semantic_score, payload)
enriched.append((pid, float(semantic_score), payload, total, edge_bonus, cent_bonus))
# Sortierung nach total_score absteigend
enriched_sorted = sorted(enriched, key=lambda h: h[3], reverse=True)
limited = enriched_sorted[: max(1, top_k)]
results: List[QueryHit] = []
for pid, semantic_score, payload, total, edge_bonus, cent_bonus in limited:
note_id = payload.get("note_id")
path = payload.get("path")
section = payload.get("section_title")
results.append(
QueryHit(
node_id=str(pid),
note_id=note_id,
semantic_score=float(semantic_score),
edge_bonus=edge_bonus,
centrality_bonus=cent_bonus,
total_score=total,
paths=None,
source={
"path": path,
"section": section,
},
)
)
dt = int((time.time() - t0) * 1000)
return QueryResponse(results=results, used_mode=used_mode, latency_ms=dt)
def semantic_retrieve(req: QueryRequest) -> QueryResponse:
top_k = _resolve_top_k(req)
vector = _get_query_vector(req)
client, prefix = _get_client_and_prefix()
hits = _semantic_hits(client, prefix, vector, top_k=top_k, filters=req.filters)
return _build_hits_from_semantic(hits, top_k=top_k, used_mode="semantic")
def hybrid_retrieve(req: QueryRequest) -> QueryResponse:
top_k = _resolve_top_k(req)
vector = _get_query_vector(req)
client, prefix = _get_client_and_prefix()
hits = _semantic_hits(client, prefix, vector, top_k=top_k, filters=req.filters)
return _build_hits_from_semantic(hits, top_k=top_k, used_mode="hybrid")