diff --git a/app/core/retriever.py b/app/core/retriever.py index 38b86c0..b12750d 100644 --- a/app/core/retriever.py +++ b/app/core/retriever.py @@ -1,3 +1,9 @@ +""" +app/core/retriever.py — Hybrider Such-Algorithmus + +Version: + 0.5.1 (WP-05 Fix: Wrapper-Class added) +""" from __future__ import annotations import os @@ -22,24 +28,18 @@ import app.core.graph_adapter as ga try: import yaml # type: ignore[import] -except Exception: # pragma: no cover - Fallback, falls PyYAML nicht installiert ist +except Exception: # pragma: no cover yaml = None # type: ignore[assignment] @lru_cache def _get_scoring_weights() -> Tuple[float, float, float]: - """Liefert (semantic_weight, edge_weight, centrality_weight) für den Retriever. - - Priorität: - 1. Werte aus config/retriever.yaml (falls vorhanden und gültig). - 2. Fallback auf Settings.RETRIEVER_W_* (ENV-basiert). - """ + """Liefert (semantic_weight, edge_weight, centrality_weight) für den Retriever.""" settings = get_settings() sem = float(getattr(settings, "RETRIEVER_W_SEM", 1.0)) edge = float(getattr(settings, "RETRIEVER_W_EDGE", 0.0)) cent = float(getattr(settings, "RETRIEVER_W_CENT", 0.0)) - # YAML-Override, falls konfiguriert config_path = os.getenv("MINDNET_RETRIEVER_CONFIG", "config/retriever.yaml") if yaml is None: return sem, edge, cent @@ -52,22 +52,19 @@ def _get_scoring_weights() -> Tuple[float, float, float]: edge = float(scoring.get("edge_weight", edge)) cent = float(scoring.get("centrality_weight", cent)) except Exception: - # Bei Fehlern in der YAML-Konfiguration defensiv auf Defaults zurückfallen return sem, edge, cent return sem, edge, cent def _get_client_and_prefix() -> Tuple[Any, str]: - """Liefert (QdrantClient, prefix) basierend auf QdrantConfig.from_env().""" + """Liefert (QdrantClient, prefix).""" cfg = qdr.QdrantConfig.from_env() client = qdr.get_client(cfg) return client, cfg.prefix def _get_query_vector(req: QueryRequest) -> List[float]: - """ - Liefert den Query-Vektor aus dem Request. - """ + """Liefert den Query-Vektor aus dem Request.""" if req.query_vector: return list(req.query_vector) @@ -78,9 +75,9 @@ def _get_query_vector(req: QueryRequest) -> List[float]: model_name = settings.MODEL_NAME try: - return ec.embed_text(req.query, model_name=model_name) # type: ignore[call-arg] + return ec.embed_text(req.query, model_name=model_name) except TypeError: - return ec.embed_text(req.query) # type: ignore[call-arg] + return ec.embed_text(req.query) def _semantic_hits( @@ -90,7 +87,7 @@ def _semantic_hits( top_k: int, filters: Dict[str, Any] | None = None, ) -> List[Tuple[str, float, Dict[str, Any]]]: - """Führt eine semantische Suche über mindnet_chunks aus und liefert Roh-Treffer.""" + """Führt eine semantische Suche aus.""" flt = filters or None raw_hits = qp.search_chunks_by_vector(client, prefix, vector, top=top_k, filters=flt) results: List[Tuple[str, float, Dict[str, Any]]] = [] @@ -105,7 +102,7 @@ def _compute_total_score( edge_bonus: float = 0.0, cent_bonus: float = 0.0, ) -> Tuple[float, float, float]: - """Berechnet total_score aus semantic_score, retriever_weight und Graph-Boni.""" + """Berechnet total_score.""" raw_weight = payload.get("retriever_weight", 1.0) try: weight = float(raw_weight) @@ -129,10 +126,7 @@ def _build_explanation( subgraph: Optional[ga.Subgraph], node_key: Optional[str] ) -> Explanation: - """ - Erstellt ein detailliertes Explanation-Objekt für einen Treffer. - Analysiert Scores, Typen und Kanten (Incoming & Outgoing). - """ + """Erstellt ein Explanation-Objekt.""" sem_w, edge_w, cent_w = _get_scoring_weights() try: @@ -155,100 +149,49 @@ def _build_explanation( reasons: List[Reason] = [] edges_dto: List[EdgeDTO] = [] - # 1. Semantische Gründe if semantic_score > 0.85: - reasons.append(Reason( - kind="semantic", - message="Sehr hohe textuelle Übereinstimmung.", - score_impact=breakdown.semantic_contribution - )) + reasons.append(Reason(kind="semantic", message="Sehr hohe textuelle Übereinstimmung.", score_impact=breakdown.semantic_contribution)) elif semantic_score > 0.70: - reasons.append(Reason( - kind="semantic", - message="Gute textuelle Übereinstimmung.", - score_impact=breakdown.semantic_contribution - )) + reasons.append(Reason(kind="semantic", message="Gute textuelle Übereinstimmung.", score_impact=breakdown.semantic_contribution)) - # 2. Typ-Gründe if type_weight != 1.0: msg = "Bevorzugt" if type_weight > 1.0 else "Leicht abgewertet" - reasons.append(Reason( - kind="type", - message=f"{msg} aufgrund des Typs '{note_type}' (Gewicht: {type_weight}).", - score_impact=(sem_w * semantic_score * (type_weight - 1.0)) - )) + reasons.append(Reason(kind="type", message=f"{msg} aufgrund des Typs '{note_type}'.", score_impact=(sem_w * semantic_score * (type_weight - 1.0)))) - # 3. Graph-Gründe (Edges) if subgraph and node_key and edge_bonus > 0: - # Wir sammeln die stärksten Kanten (egal ob rein oder raus), - # die zum Score beitragen. - - # A) Outgoing (Ich verweise auf...) - Das ist oft der Hub-Score if hasattr(subgraph, "get_outgoing_edges"): outgoing = subgraph.get_outgoing_edges(node_key) for edge in outgoing: target = edge.get("target", "Unknown") kind = edge.get("kind", "edge") weight = edge.get("weight", 0.0) - - # Nur relevante Kanten aufnehmen if weight > 0.05: - edges_dto.append(EdgeDTO( - id=f"{node_key}->{target}:{kind}", - kind=kind, source=node_key, target=target, weight=weight, direction="out" - )) + edges_dto.append(EdgeDTO(id=f"{node_key}->{target}:{kind}", kind=kind, source=node_key, target=target, weight=weight, direction="out")) - # B) Incoming (Ich werde verwiesen von...) if hasattr(subgraph, "get_incoming_edges"): incoming = subgraph.get_incoming_edges(node_key) for edge in incoming: src = edge.get("source", "Unknown") kind = edge.get("kind", "edge") weight = edge.get("weight", 0.0) - if weight > 0.05: - edges_dto.append(EdgeDTO( - id=f"{src}->{node_key}:{kind}", - kind=kind, source=src, target=node_key, weight=weight, direction="in" - )) + edges_dto.append(EdgeDTO(id=f"{src}->{node_key}:{kind}", kind=kind, source=src, target=node_key, weight=weight, direction="in")) - # Sortieren nach Gewicht und Top-3 als Reasons generieren all_edges = sorted(edges_dto, key=lambda e: e.weight, reverse=True) - for top_edge in all_edges[:3]: - # Impact schätzen (grob, da Edge-Bonus eine Summe ist) impact = edge_w * top_edge.weight - - if top_edge.direction == "out": - msg = f"Verweist auf '{top_edge.target}' via '{top_edge.kind}'" - else: - msg = f"Referenziert von '{top_edge.source}' via '{top_edge.kind}'" - - reasons.append(Reason( - kind="edge", - message=msg, - score_impact=impact, - details={"kind": top_edge.kind, "weight": top_edge.weight} - )) + dir_txt = "Verweist auf" if top_edge.direction == "out" else "Referenziert von" + tgt_txt = top_edge.target if top_edge.direction == "out" else top_edge.source + reasons.append(Reason(kind="edge", message=f"{dir_txt} '{tgt_txt}' via '{top_edge.kind}'", score_impact=impact, details={"kind": top_edge.kind})) - # 4. Centrality if cent_bonus > 0.01: - reasons.append(Reason( - kind="centrality", - message="Knoten liegt zentral im Kontext.", - score_impact=breakdown.centrality_contribution - )) + reasons.append(Reason(kind="centrality", message="Knoten liegt zentral im Kontext.", score_impact=breakdown.centrality_contribution)) - return Explanation( - breakdown=breakdown, - reasons=reasons, - related_edges=edges_dto if edges_dto else None - ) -# --- End Explanation Logic --- + return Explanation(breakdown=breakdown, reasons=reasons, related_edges=edges_dto if edges_dto else None) def _extract_expand_options(req: QueryRequest) -> Tuple[int, List[str] | None]: - """Extrahiert depth und edge_types aus req.expand.""" + """Extrahiert depth und edge_types.""" expand = getattr(req, "expand", None) if not expand: return 0, None @@ -278,14 +221,10 @@ def _build_hits_from_semantic( top_k: int, used_mode: str, subgraph: ga.Subgraph | None = None, - explain: bool = False, # WP-04b + explain: bool = False, ) -> QueryResponse: - """Baut aus Raw-Hits und optionalem Subgraph strukturierte QueryHits. - - WP-04b: Wenn explain=True, wird _build_explanation aufgerufen. - """ + """Baut strukturierte QueryHits.""" t0 = time.time() - enriched: List[Tuple[str, float, Dict[str, Any], float, float, float]] = [] for pid, semantic_score, payload in hits: @@ -303,26 +242,14 @@ def _build_hits_from_semantic( except Exception: cent_bonus = 0.0 - total, edge_bonus, cent_bonus = _compute_total_score( - semantic_score, - payload, - edge_bonus=edge_bonus, - cent_bonus=cent_bonus, - ) + total, edge_bonus, cent_bonus = _compute_total_score(semantic_score, payload, edge_bonus=edge_bonus, cent_bonus=cent_bonus) enriched.append((pid, float(semantic_score), payload, total, edge_bonus, cent_bonus)) - # Sortierung nach total_score absteigend enriched_sorted = sorted(enriched, key=lambda h: h[3], reverse=True) limited = enriched_sorted[: max(1, top_k)] results: List[QueryHit] = [] for pid, semantic_score, payload, total, edge_bonus, cent_bonus in limited: - note_id = payload.get("note_id") - path = payload.get("path") - section = payload.get("section") or payload.get("section_title") - node_key = payload.get("chunk_id") or payload.get("note_id") - - # WP-04b: Explanation bauen? explanation_obj = None if explain: explanation_obj = _build_explanation( @@ -331,59 +258,44 @@ def _build_hits_from_semantic( edge_bonus=edge_bonus, cent_bonus=cent_bonus, subgraph=subgraph, - node_key=node_key + node_key=payload.get("chunk_id") or payload.get("note_id") ) - results.append( - QueryHit( - node_id=str(pid), - note_id=note_id, - semantic_score=float(semantic_score), - edge_bonus=edge_bonus, - centrality_bonus=cent_bonus, - total_score=total, - paths=None, - source={ - "path": path, - "section": section, - }, - explanation=explanation_obj - ) - ) + results.append(QueryHit( + node_id=str(pid), + note_id=payload.get("note_id"), + semantic_score=float(semantic_score), + edge_bonus=edge_bonus, + centrality_bonus=cent_bonus, + total_score=total, + paths=None, + source={"path": payload.get("path"), "section": payload.get("section") or payload.get("section_title")}, + explanation=explanation_obj + )) dt = int((time.time() - t0) * 1000) return QueryResponse(results=results, used_mode=used_mode, latency_ms=dt) def semantic_retrieve(req: QueryRequest) -> QueryResponse: - """Reiner semantischer Retriever (ohne Edge-Expansion).""" + """Reiner semantischer Retriever.""" client, prefix = _get_client_and_prefix() vector = _get_query_vector(req) top_k = req.top_k or get_settings().RETRIEVER_TOP_K hits = _semantic_hits(client, prefix, vector, top_k=top_k, filters=req.filters) - - # explain Flag durchreichen - return _build_hits_from_semantic( - hits, - top_k=top_k, - used_mode="semantic", - subgraph=None, - explain=req.explain - ) + return _build_hits_from_semantic(hits, top_k=top_k, used_mode="semantic", subgraph=None, explain=req.explain) def hybrid_retrieve(req: QueryRequest) -> QueryResponse: """Hybrid-Retriever: semantische Suche + optionale Edge-Expansion.""" client, prefix = _get_client_and_prefix() - if req.query_vector: vector = list(req.query_vector) else: vector = _get_query_vector(req) top_k = req.top_k or get_settings().RETRIEVER_TOP_K - hits = _semantic_hits(client, prefix, vector, top_k=top_k, filters=req.filters) depth, edge_types = _extract_expand_options(req) @@ -394,18 +306,31 @@ def hybrid_retrieve(req: QueryRequest) -> QueryResponse: key = payload.get("chunk_id") or payload.get("note_id") if key and key not in seed_ids: seed_ids.append(key) - if seed_ids: try: subgraph = ga.expand(client, prefix, seed_ids, depth=depth, edge_types=edge_types) except Exception: subgraph = None - # explain Flag durchreichen - return _build_hits_from_semantic( - hits, - top_k=top_k, - used_mode="hybrid", - subgraph=subgraph, - explain=req.explain - ) \ No newline at end of file + return _build_hits_from_semantic(hits, top_k=top_k, used_mode="hybrid", subgraph=subgraph, explain=req.explain) + + +# --- WP-05 ADDITION: Wrapper Class for Chat Service --- +class Retriever: + """ + Wrapper-Klasse für WP-05 (Chat), die die existierende funktionale Logik nutzt. + Stellt sicher, dass WP-04 (/query) und WP-05 (/chat) dieselbe Basis verwenden. + """ + def __init__(self): + # Settings werden in den Funktionen via get_settings() geholt, + # daher ist hier kein State nötig. + pass + + async def search(self, request: QueryRequest) -> QueryResponse: + """ + Führt die Suche aus. + Mappt auf 'hybrid_retrieve' (synchron), daher trivialer Wrapper. + """ + # Da hybrid_retrieve synchron ist, blockiert es hier kurz den EventLoop. + # Für den aktuellen Scale ist das okay. + return hybrid_retrieve(request) \ No newline at end of file