diff --git a/app/core/retriever.py b/app/core/retriever.py index 119dc9c..fefd7a8 100644 --- a/app/core/retriever.py +++ b/app/core/retriever.py @@ -14,7 +14,7 @@ import app.core.graph_adapter as ga try: import yaml # type: ignore[import] -except Exception: # pragma: no cover - Fallback falls PyYAML nicht installiert ist +except Exception: # pragma: no cover - Fallback, falls PyYAML nicht installiert ist yaml = None # type: ignore[assignment] @@ -77,7 +77,13 @@ def _get_query_vector(req: QueryRequest) -> List[float]: settings = get_settings() model_name = settings.MODEL_NAME - return ec.embed_text(req.query, model_name=model_name) + + # Kompatibel mit Fakes in Unit-Tests (ohne model_name-Parameter) + try: + return ec.embed_text(req.query, model_name=model_name) # type: ignore[call-arg] + except TypeError: + # Fallback: einfache Signatur embed_text(text) + return ec.embed_text(req.query) # type: ignore[call-arg] def _semantic_hits( @@ -90,15 +96,15 @@ def _semantic_hits( """Führt eine semantische Suche über mindnet_chunks aus und liefert Roh-Treffer. Rückgabeformat: Liste von (point_id, score, payload) + + Erwartetes Format von qp.search_chunks_by_vector: + List[Tuple[str, float, dict]] """ flt = filters or None - hits = qp.search_chunks_by_vector(client, prefix, vector, top=top_k, filters=flt) + raw_hits = qp.search_chunks_by_vector(client, prefix, vector, top=top_k, filters=flt) results: List[Tuple[str, float, Dict[str, Any]]] = [] - for point in hits: - pid = str(point.id) - score = float(point.score) - payload = dict(point.payload or {}) - results.append((pid, score, payload)) + for pid, score, payload in raw_hits: + results.append((str(pid), float(score), dict(payload or {}))) return results @@ -150,7 +156,7 @@ def _extract_expand_options(req: QueryRequest) -> Tuple[int, List[str] | None]: return 0, None depth = 1 - edge_types = None + edge_types: List[str] | None = None # Pydantic-Modell oder Objekt mit Attributen if hasattr(expand, "depth") or hasattr(expand, "edge_types"): @@ -279,7 +285,7 @@ def hybrid_retrieve(req: QueryRequest) -> QueryResponse: if depth and depth > 0: # Seeds: stabile IDs aus dem Payload (chunk_id bevorzugt, sonst note_id) seed_ids: List[str] = [] - for _, _score, payload in hits: + for _pid, _score, payload in hits: key = payload.get("chunk_id") or payload.get("note_id") if key and key not in seed_ids: seed_ids.append(key)