from __future__ import annotations import time from typing import Any, Dict, List, Tuple from app.config import get_settings from app.models.dto import QueryRequest, QueryResponse, QueryHit import app.core.qdrant as qdr import app.core.qdrant_points as qp import app.services.embeddings_client as ec import app.core.graph_adapter as ga def _get_client_and_prefix() -> Tuple[Any, str]: """Liefert (QdrantClient, prefix) basierend auf QdrantConfig.from_env().""" cfg = qdr.QdrantConfig.from_env() client = qdr.get_client(cfg) return client, cfg.prefix def _get_query_vector(req: QueryRequest) -> List[float]: """ Liefert den Query-Vektor aus dem Request. - Falls req.query_vector gesetzt ist, wird dieser unverändert genutzt. - Falls req.query (Text) gesetzt ist, wird ec.embed_text(req.query) aufgerufen. - Andernfalls: ValueError. """ if req.query_vector is not None: if not isinstance(req.query_vector, list): raise ValueError("query_vector muss eine Liste von floats sein") return req.query_vector if req.query: return ec.embed_text(req.query) raise ValueError("Weder query_vector noch query gesetzt – mindestens eines ist erforderlich") def _semantic_hits( client: Any, prefix: str, vector: List[float], top_k: int, filters: Dict | None, ): """Kapselt den Aufruf von qp.search_chunks_by_vector.""" flt = filters or None hits = qp.search_chunks_by_vector(client, prefix, vector, top=top_k, filters=flt) return hits def _resolve_top_k(req: QueryRequest) -> int: """Ermittelt ein sinnvolles top_k.""" if isinstance(req.top_k, int) and req.top_k > 0: return req.top_k s = get_settings() return max(1, int(getattr(s, "RETRIEVER_TOP_K", 10))) def _compute_total_score( semantic_score: float, payload: Dict[str, Any], edge_bonus: float = 0.0, cent_bonus: float = 0.0, ) -> Tuple[float, float, float]: """Berechnet total_score auf Basis von semantic_score, retriever_weight und Graph-Boni. Aktuelle Formel (Step 3): total_score = semantic_score * max(retriever_weight, 0.0) + edge_bonus + cent_bonus retriever_weight stammt aus dem Chunk-Payload und ist bereits aus types.yaml abgeleitet. Falls nicht gesetzt oder nicht interpretierbar, wird 1.0 angenommen. """ raw_weight = payload.get("retriever_weight", 1.0) try: weight = float(raw_weight) except (TypeError, ValueError): weight = 1.0 if weight < 0.0: weight = 0.0 total = float(semantic_score) * weight + edge_bonus + cent_bonus return total, float(edge_bonus), float(cent_bonus) def _extract_expand_options(req: QueryRequest) -> Tuple[int, List[str] | None]: """Extrahiert depth und edge_types aus req.expand, falls vorhanden. - Falls expand nicht gesetzt ist: depth=0, edge_types=None (keine Expansion). - Unterstützt sowohl Pydantic-Modelle als auch plain dicts. """ expand = getattr(req, "expand", None) if not expand: return 0, None depth = 1 edge_types = None # Pydantic-Modell oder Objekt mit Attributen if hasattr(expand, "depth") or hasattr(expand, "edge_types"): try: depth_val = getattr(expand, "depth", 1) or 1 depth = int(depth_val) except Exception: depth = 1 edge_types = getattr(expand, "edge_types", None) # plain dict aus FastAPI/Pydantic elif isinstance(expand, dict): try: depth_val = expand.get("depth", 1) or 1 depth = int(depth_val) except Exception: depth = 1 edge_types = expand.get("edge_types") if depth < 0: depth = 0 if edge_types is not None and not isinstance(edge_types, list): try: edge_types = list(edge_types) except Exception: edge_types = None return depth, edge_types def _build_hits_from_semantic( hits: List[Tuple[str, float, Dict[str, Any]]], top_k: int, used_mode: str, subgraph: Any | None = None, ) -> QueryResponse: """Formt rohe Treffer in QueryResponse um und wendet das Scoring an.""" t0 = time.time() enriched: List[Tuple[str, float, Dict[str, Any], float, float, float]] = [] for pid, semantic_score, payload in hits: # Graph-Scores, falls Subgraph und stabiler Key vorhanden edge_bonus = 0.0 cent_bonus = 0.0 if subgraph is not None: # Standard-Key wie im ursprünglichen Verhalten (für Fakes in Tests): node_key = payload.get("chunk_id") or payload.get("note_id") # Falls es sich um unseren echten Subgraph-Typ handelt, wissen wir, # dass Knoten als note_id modelliert sind → dann gezielt note_id nutzen. if isinstance(subgraph, ga.Subgraph): node_key = payload.get("note_id") if node_key: try: edge_bonus = float(subgraph.edge_bonus(node_key)) except Exception: edge_bonus = 0.0 try: cent_bonus = float(subgraph.centrality_bonus(node_key)) except Exception: cent_bonus = 0.0 total, edge_bonus, cent_bonus = _compute_total_score( semantic_score, payload, edge_bonus=edge_bonus, cent_bonus=cent_bonus, ) enriched.append((pid, float(semantic_score), payload, total, edge_bonus, cent_bonus)) # Sortierung nach total_score absteigend enriched_sorted = sorted(enriched, key=lambda h: h[3], reverse=True) limited = enriched_sorted[: max(1, top_k)] results: List[QueryHit] = [] for pid, semantic_score, payload, total, edge_bonus, cent_bonus in limited: note_id = payload.get("note_id") path = payload.get("path") section = payload.get("section_title") results.append( QueryHit( node_id=str(pid), note_id=note_id, semantic_score=float(semantic_score), edge_bonus=edge_bonus, centrality_bonus=cent_bonus, total_score=total, paths=None, source={ "path": path, "section": section, }, ) ) dt = int((time.time() - t0) * 1000) return QueryResponse(results=results, used_mode=used_mode, latency_ms=dt) def semantic_retrieve(req: QueryRequest) -> QueryResponse: """Reiner semantischer Retriever (ohne Edge-Expansion).""" top_k = _resolve_top_k(req) vector = _get_query_vector(req) client, prefix = _get_client_and_prefix() hits = _semantic_hits(client, prefix, vector, top_k=top_k, filters=req.filters) # semantic mode: keine Edge-Expansion return _build_hits_from_semantic(hits, top_k=top_k, used_mode="semantic", subgraph=None) def hybrid_retrieve(req: QueryRequest) -> QueryResponse: """Hybrid-Retriever mit optionaler Edge-Expansion. Schritt 3: - Basis sind die semantischen Chunk-Treffer (wie im Semantic-Modus) - Zusätzlich wird, falls req.expand gesetzt ist und depth > 0, ein lokaler Subgraph über ga.expand aufgebaut und zur Score-Berechnung verwendet. """ top_k = _resolve_top_k(req) vector = _get_query_vector(req) client, prefix = _get_client_and_prefix() hits = _semantic_hits(client, prefix, vector, top_k=top_k, filters=req.filters) depth, edge_types = _extract_expand_options(req) subgraph = None if depth > 0: # Seeds: stabile IDs aus dem Payload # WICHTIG: Wir verwenden note_id als Knoten-ID, da Edges zwischen Notes # modelliert sind (source_id/target_id = note_id). seed_ids: List[str] = [] for _, _score, payload in hits: key = payload.get("note_id") if key and key not in seed_ids: seed_ids.append(key) if seed_ids: try: subgraph = ga.expand(client, prefix, seed_ids, depth=depth, edge_types=edge_types) except Exception: # Edge-Expansion ist optional: bei Fehlern weiter ohne Graph-Boni subgraph = None return _build_hits_from_semantic(hits, top_k=top_k, used_mode="hybrid", subgraph=subgraph)