mindnet/app/routers/chat.py

235 lines
7.9 KiB
Python

"""
app/routers/chat.py — RAG Endpunkt (WP-06 Hybrid Router v2)
Update: Robusteres LLM-Parsing für Small Language Models (SLMs).
"""
from fastapi import APIRouter, HTTPException, Depends
from typing import List, Dict, Any
import time
import uuid
import logging
import yaml
from pathlib import Path
from app.config import get_settings
from app.models.dto import ChatRequest, ChatResponse, QueryRequest, QueryHit
from app.services.llm_service import LLMService
from app.core.retriever import Retriever
router = APIRouter()
logger = logging.getLogger(__name__)
# --- Helper: Config Loader ---
_DECISION_CONFIG_CACHE = None
def _load_decision_config() -> Dict[str, Any]:
settings = get_settings()
path = Path(settings.DECISION_CONFIG_PATH)
default_config = {
"strategies": {
"FACT": {"trigger_keywords": []}
}
}
if not path.exists():
logger.warning(f"Decision config not found at {path}, using defaults.")
return default_config
try:
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
except Exception as e:
logger.error(f"Failed to load decision config: {e}")
return default_config
def get_full_config() -> Dict[str, Any]:
global _DECISION_CONFIG_CACHE
if _DECISION_CONFIG_CACHE is None:
_DECISION_CONFIG_CACHE = _load_decision_config()
return _DECISION_CONFIG_CACHE
def get_decision_strategy(intent: str) -> Dict[str, Any]:
config = get_full_config()
strategies = config.get("strategies", {})
return strategies.get(intent, strategies.get("FACT", {}))
# --- Dependencies ---
def get_llm_service():
return LLMService()
def get_retriever():
return Retriever()
# --- Logic ---
def _build_enriched_context(hits: List[QueryHit]) -> str:
context_parts = []
for i, hit in enumerate(hits, 1):
source = hit.source or {}
content = (
source.get("text") or source.get("content") or
source.get("page_content") or source.get("chunk_text") or
"[Kein Text]"
)
title = hit.note_id or "Unbekannt"
note_type = source.get("type", "unknown").upper()
entry = (
f"### QUELLE {i}: {title}\n"
f"TYP: [{note_type}] (Score: {hit.total_score:.2f})\n"
f"INHALT:\n{content}\n"
)
context_parts.append(entry)
return "\n\n".join(context_parts)
async def _classify_intent(query: str, llm: LLMService) -> str:
"""
Hybrid Router v2:
1. Keyword Check (Best/Longest Match) -> FAST
2. LLM Fallback (Robust Parsing) -> SMART
"""
config = get_full_config()
strategies = config.get("strategies", {})
settings = config.get("settings", {})
query_lower = query.lower()
best_intent = None
max_match_length = 0
# 1. FAST PATH: Keywords
for intent_name, strategy in strategies.items():
if intent_name == "FACT": continue
keywords = strategy.get("trigger_keywords", [])
for k in keywords:
if k.lower() in query_lower:
if len(k) > max_match_length:
max_match_length = len(k)
best_intent = intent_name
if best_intent:
logger.info(f"Intent detected via KEYWORD: {best_intent}")
return best_intent
# 2. SLOW PATH: LLM Router
if settings.get("llm_fallback_enabled", False):
router_prompt_template = settings.get("llm_router_prompt", "")
if router_prompt_template:
prompt = router_prompt_template.replace("{query}", query)
logger.info("Keywords failed. Asking LLM for Intent...")
# Kurzer Raw Call
raw_response = await llm.generate_raw_response(prompt)
# --- Robust Parsing für SLMs ---
# Wir suchen nach den bekannten Strategie-Namen im Output
llm_output_upper = raw_response.upper()
logger.info(f"LLM Router Raw Output: '{raw_response}'") # Debugging
found_intents = []
for strat_key in strategies.keys():
# Wir prüfen, ob der Strategie-Name (z.B. "EMPATHY") im Text vorkommt
if strat_key in llm_output_upper:
found_intents.append(strat_key)
# Entscheidung
final_intent = "FACT"
if len(found_intents) == 1:
# Eindeutiger Treffer
final_intent = found_intents[0]
logger.info(f"Intent detected via LLM (Parsed): {final_intent}")
return final_intent
elif len(found_intents) > 1:
# Mehrere Treffer (z.B. "Es ist FACT oder DECISION") -> Nimm den ersten oder Fallback
logger.warning(f"LLM returned multiple intents {found_intents}. Using first match: {found_intents[0]}")
return found_intents[0]
else:
logger.warning(f"LLM did not return a valid strategy name. Falling back to FACT.")
return "FACT"
@router.post("/", response_model=ChatResponse)
async def chat_endpoint(
request: ChatRequest,
llm: LLMService = Depends(get_llm_service),
retriever: Retriever = Depends(get_retriever)
):
start_time = time.time()
query_id = str(uuid.uuid4())
logger.info(f"Chat request [{query_id}]: {request.message[:50]}...")
try:
# 1. Intent Detection
intent = await _classify_intent(request.message, llm)
logger.info(f"[{query_id}] Final Intent: {intent}")
# Strategy Load
strategy = get_decision_strategy(intent)
inject_types = strategy.get("inject_types", [])
prompt_key = strategy.get("prompt_template", "rag_template")
prepend_instr = strategy.get("prepend_instruction", "")
# 2. Primary Retrieval
query_req = QueryRequest(
query=request.message,
mode="hybrid",
top_k=request.top_k,
explain=request.explain
)
retrieve_result = await retriever.search(query_req)
hits = retrieve_result.results
# 3. Strategic Retrieval
if inject_types:
logger.info(f"[{query_id}] Executing Strategic Retrieval for types: {inject_types}...")
strategy_req = QueryRequest(
query=request.message,
mode="hybrid",
top_k=3,
filters={"type": inject_types},
explain=False
)
strategy_result = await retriever.search(strategy_req)
existing_ids = {h.node_id for h in hits}
for strat_hit in strategy_result.results:
if strat_hit.node_id not in existing_ids:
hits.append(strat_hit)
# 4. Context Building
if not hits:
context_str = "Keine relevanten Notizen gefunden."
else:
context_str = _build_enriched_context(hits)
# 5. Generation
template = llm.prompts.get(prompt_key, "{context_str}\n\n{query}")
system_prompt = llm.prompts.get("system_prompt", "")
if prepend_instr:
context_str = f"{prepend_instr}\n\n{context_str}"
final_prompt = template.replace("{context_str}", context_str).replace("{query}", request.message)
logger.info(f"[{query_id}] Sending to LLM (Intent: {intent}, Template: {prompt_key})...")
# System-Prompt separat übergeben
answer_text = await llm.generate_raw_response(prompt=final_prompt, system=system_prompt)
duration_ms = int((time.time() - start_time) * 1000)
return ChatResponse(
query_id=query_id,
answer=answer_text,
sources=hits,
latency_ms=duration_ms,
intent=intent
)
except Exception as e:
logger.error(f"Error in chat endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))