from fastapi import FastAPI, Query, HTTPException, Request from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient from qdrant_client.models import VectorParams, Distance, PointStruct, PointIdsList from uuid import uuid4 import requests import os from datetime import datetime, date # Version hochgezählt __version__ = "1.1.6" print(f"[DEBUG] llm_api.py version {__version__} loaded from {__file__}", flush=True) # Ollama-Konfiguration OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/api/generate") OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "mistral:latest") # ----------------------- # MediaWiki-Konfiguration # ----------------------- WIKI_API_URL = os.getenv("WIKI_API_URL", "https://karatetrainer.net/api.php") WIKI_BOT_USER = os.getenv("WIKI_BOT_USER", "") WIKI_BOT_PASSWORD = os.getenv("WIKI_BOT_PASSWORD", "") # FastAPI-Instanz app = FastAPI( title="KI Trainerassistent API", description="Lokale API für Trainingsplanung", version=__version__, docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json" ) # Globaler Fehlerhandler @app.exception_handler(Exception) async def unicorn_exception_handler(request: Request, exc: Exception): return JSONResponse(status_code=500, content={"detail": "Interner Serverfehler. Bitte später erneut versuchen."}) # Globaler Session für MediaWiki-API wiki_session = requests.Session() # Health-Check für MediaWiki @app.get("/import/wiki/health") def wiki_health(): """ Prüft, ob der MediaWiki-Server erreichbar ist. """ params = {"action": "query", "meta": "siteinfo", "siprop": "general", "format": "json"} try: r = wiki_session.get(WIKI_API_URL, params=params, timeout=5) r.raise_for_status() resp = r.json() except Exception as e: raise HTTPException(status_code=502, detail=f"Wiki nicht erreichbar: {e}") # Versuche Servernamen auszulesen, aber gib OK zurück, wenn es fehlt server = resp.get("query", {}).get("general", {}).get("servername") if server: return {"status": "ok", "server": server} return {"status": "ok", "server": None} # ------------------------ # MediaWiki Login Endpoint # ------------------------ class WikiLoginRequest(BaseModel): username: str password: str class WikiLoginResponse(BaseModel): status: str message: Optional[str] = None @app.post("/import/wiki/login", response_model=WikiLoginResponse) def wiki_login(data: WikiLoginRequest): """ Führt Login gegen MediaWiki-API durch und speichert Session-Cookies. """ # Schritt 1: Login-Token holen params_token = {"action": "query", "meta": "tokens", "type": "login", "format": "json"} try: resp1 = wiki_session.get(WIKI_API_URL, params=params_token, timeout=10) resp1.raise_for_status() token = resp1.json().get("query", {}).get("tokens", {}).get("logintoken") if not token: raise ValueError("Kein Login-Token erhalten") except Exception as e: raise HTTPException(status_code=502, detail=f"Fehler Token abrufen: {e}") # Schritt 2: Login mit Token login_data = { "action": "login", "format": "json", "lgname": data.username, "lgpassword": data.password, "lgtoken": token } try: resp2 = wiki_session.post(WIKI_API_URL, data=login_data, timeout=10) resp2.raise_for_status() result = resp2.json().get("login", {}) if result.get("result") != "Success": return WikiLoginResponse(status="failed", message=result.get("reason", "Login fehlgeschlagen")) except Exception as e: raise HTTPException(status_code=502, detail=f"Fehler Login: {e}") return WikiLoginResponse(status="success") # ------------------------ # Fallback: Connectivity ist gegeben, aber kein Servernamen return {"status": "ok", "server": None} # ------------------------"status": "ok", "server": general["servername"]} # ------------------------ # ------------------------ # Modelle für Embed/Search # ------------------------ class ChunkInput(BaseModel): text: str source: str source_type: str = "" title: str = "" version: str = "" related_to: str = "" tags: List[str] = [] owner: str = "" context_tag: Optional[str] = None imported_at: Optional[str] = None chunk_index: Optional[int] = None category: Optional[str] = None class EmbedRequest(BaseModel): chunks: List[ChunkInput] collection: str = "default" class PromptRequest(BaseModel): query: str context_limit: int = 3 collection: str = "default" class EmbedResponse(BaseModel): status: str count: int collection: str class SearchResultItem(BaseModel): score: float = Field(..., ge=0) text: str class PromptResponse(BaseModel): answer: str context: str collection: str class DeleteResponse(BaseModel): status: str count: int collection: str source: Optional[str] = None type: Optional[str] = None # ------------------------------------ # Modelle für Exercises & Plans # ------------------------------------ class Exercise(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) title: str summary: str short_description: str keywords: List[str] = [] link: Optional[str] = None discipline: str group: Optional[str] = None age_group: str target_group: str min_participants: int duration_minutes: int capabilities: Dict[str,int] = {} category: str purpose: str execution: str notes: str preparation: str method: str equipment: List[str] = [] class PhaseExercise(BaseModel): exercise_id: str cond_load: Dict[str, Any] = {} coord_load: Dict[str, Any] = {} instructions: str = "" class PlanPhase(BaseModel): name: str duration_minutes: int method: str method_notes: str = "" exercises: List[PhaseExercise] class TrainingPlan(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) title: str short_description: str collection: str discipline: str group: Optional[str] = None dojo: str date: date plan_duration_weeks: int focus_areas: List[str] = [] predecessor_plan_id: Optional[str] = None age_group: str created_at: datetime = Field(default_factory=datetime.utcnow) phases: List[PlanPhase] # ---------------------------------- # Embedding-Modell und Qdrant-Client # ---------------------------------- model = SentenceTransformer("all-MiniLM-L6-v2") qdrant = QdrantClient( host=os.getenv("QDRANT_HOST", "localhost"), port=int(os.getenv("QDRANT_PORT", 6333)) ) # Ensure Exercise-Collection exists if not qdrant.collection_exists("exercises"): qdrant.recreate_collection( collection_name="exercises", vectors_config=VectorParams( size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE ) ) # Ensure TrainingPlan-Collection exists PLAN_COLL = "training_plans" if not qdrant.collection_exists(PLAN_COLL): qdrant.recreate_collection( collection_name=PLAN_COLL, vectors_config=VectorParams( size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE ) ) # ---------------------- # Endpunkte für Exercises # ---------------------- @app.post("/exercise", response_model=Exercise) def create_exercise(ex: Exercise): # Ensure collection exists if not qdrant.collection_exists("exercises"): qdrant.recreate_collection( collection_name="exercises", vectors_config=VectorParams( size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE ) ) vec = model.encode(f"{ex.title}. {ex.summary}").tolist() point = PointStruct(id=ex.id, vector=vec, payload=ex.dict()) qdrant.upsert(collection_name="exercises", points=[point]) return ex @app.get("/exercise", response_model=List[Exercise]) def list_exercises( discipline: Optional[str] = Query(None), group: Optional[str] = Query(None), tags: Optional[str] = Query(None) ): filters = [] if discipline: filters.append({"key":"discipline","match":{"value":discipline}}) if group: filters.append({"key":"group","match":{"value":group}}) if tags: for t in tags.split(","): filters.append({"key":"keywords","match":{"value":t.strip()}}) if filters: pts, _ = qdrant.scroll( collection_name="exercises", scroll_filter={"must": filters}, limit=10000 ) else: pts, _ = qdrant.scroll(collection_name="exercises", limit=10000) return [Exercise(**pt.payload) for pt in pts] # ----------------- # Bestehende Endpunkte # ----------------- @app.post("/embed") def embed_texts(data: EmbedRequest): collection_name = data.collection if not qdrant.collection_exists(collection_name): qdrant.recreate_collection( collection_name=collection_name, vectors_config=VectorParams(size=384, distance=Distance.COSINE) ) embeddings = model.encode([c.text for c in data.chunks]).tolist() points = [] for i, chunk in enumerate(data.chunks): payload = { "text": chunk.text, "source": chunk.source, "source_type": chunk.source_type, "title": chunk.title, "version": chunk.version, "related_to": chunk.related_to, "tags": chunk.tags, "owner": chunk.owner, "context_tag": chunk.context_tag, "imported_at": chunk.imported_at or datetime.utcnow().isoformat(), "chunk_index": chunk.chunk_index, "category": chunk.category or data.collection } points.append(PointStruct(id=str(uuid4()), vector=embeddings[i], payload=payload)) qdrant.upsert(collection_name=collection_name, points=points) return {"status":"✅ embeddings saved","count":len(points),"collection":collection_name} @app.get("/search", response_model=List[SearchResultItem]) def search_text(query: str = Query(..., min_length=1), limit: int = Query(3, ge=1), collection: str = Query("default")): vec = model.encode(query).tolist() res = qdrant.search(collection_name=collection, query_vector=vec, limit=limit) return [SearchResultItem(score=r.score, text=r.payload['text']) for r in res] @app.post("/prompt", response_model=PromptResponse) def prompt(data: PromptRequest): if not data.query.strip(): raise HTTPException(status_code=400, detail="'query' darf nicht leer sein.") if not (1 <= data.context_limit <= 10): raise HTTPException(status_code=400, detail="'context_limit' muss zwischen 1 und 10 liegen.") hits = qdrant.search( collection_name=data.collection, query_vector=model.encode(data.query).tolist(), limit=data.context_limit ) context = "\n".join(h.payload['text'] for h in hits) payload = {"model":OLLAMA_MODEL,"prompt":f"Context:\n{context}\nQuestion: {data.query}","stream":False} try: r = requests.post(OLLAMA_URL, json=payload, timeout=30) r.raise_for_status() except Exception: raise HTTPException(status_code=502, detail="LLM-Service-Fehler.") return PromptResponse(answer=r.json().get("response",""), context=context, collection=data.collection) @app.delete("/delete-source", response_model=DeleteResponse) def delete_by_source( collection: str = Query(...), source: Optional[str] = Query(None), type: Optional[str] = Query(None), owner: Optional[str] = Query(None), category: Optional[str] = Query(None) ): if not qdrant.collection_exists(collection): raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.") filt = [] if source: filt.append({"key":"source","match":{"value":source}}) if type: filt.append({"key":"type","match":{"value":type}}) if owner: filt.append({"key":"owner","match":{"value":owner}}) if category: filt.append({"key":"category","match":{"value":category}}) if not filt: raise HTTPException(status_code=400, detail="Mindestens ein Filterparameter muss angegeben werden.") pts, _ = qdrant.scroll(collection_name=collection, scroll_filter={"must":filt}, limit=10000) ids = [str(p.id) for p in pts] if not ids: return DeleteResponse(status="🔍 Keine passenden Einträge gefunden.", count=0, collection=collection) qdrant.delete(collection_name=collection, points_selector=PointIdsList(points=ids)) return DeleteResponse(status="🗑️ gelöscht", count=len(ids), collection=collection) @app.delete("/delete-collection", response_model=DeleteResponse) def delete_collection(collection: str = Query(...)): if not qdrant.collection_exists(collection): raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.") qdrant.delete_collection(collection_name=collection) return DeleteResponse(status="🗑️ gelöscht", count=0, collection=collection) # ------------------------ # Endpunkte für TrainingPlans # ------------------------ @app.post("/plan", response_model=TrainingPlan) def create_plan(plan: TrainingPlan): # Ensure plan collection exists if not qdrant.collection_exists(PLAN_COLL): qdrant.recreate_collection( collection_name=PLAN_COLL, vectors_config=VectorParams( size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE ) ) vec = model.encode(f"{plan.title}. {plan.short_description}").tolist() payload = plan.dict() qdrant.upsert(collection_name=PLAN_COLL, points=[PointStruct(id=plan.id, vector=vec, payload=payload)]) return plan @app.get("/plan", response_model=List[TrainingPlan]) def list_plans( collection: str = Query(PLAN_COLL), discipline: Optional[str] = Query(None), group: Optional[str] = Query(None), dojo: Optional[str] = Query(None) ): if not qdrant.collection_exists(collection): return [] pts, _ = qdrant.scroll(collection_name=collection, limit=10000) result: List[TrainingPlan] = [] for pt in pts: plan = TrainingPlan(**pt.payload) if discipline and plan.discipline != discipline: continue if group and plan.group != group: continue if dojo and plan.dojo != dojo: continue result.append(plan) return result