Trainer_LLM/llm-api/old strukture/llm_api1.1.6.py

422 lines
14 KiB
Python

from fastapi import FastAPI, Query, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct, PointIdsList
from uuid import uuid4
import requests
import os
from datetime import datetime, date
# Version hochgezählt
__version__ = "1.1.6"
print(f"[DEBUG] llm_api.py version {__version__} loaded from {__file__}", flush=True)
# Ollama-Konfiguration
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/api/generate")
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "mistral:latest")
# -----------------------
# MediaWiki-Konfiguration
# -----------------------
WIKI_API_URL = os.getenv("WIKI_API_URL", "https://karatetrainer.net/api.php")
WIKI_BOT_USER = os.getenv("WIKI_BOT_USER", "")
WIKI_BOT_PASSWORD = os.getenv("WIKI_BOT_PASSWORD", "")
# FastAPI-Instanz
app = FastAPI(
title="KI Trainerassistent API",
description="Lokale API für Trainingsplanung",
version=__version__,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json"
)
# Globaler Fehlerhandler
@app.exception_handler(Exception)
async def unicorn_exception_handler(request: Request, exc: Exception):
return JSONResponse(status_code=500, content={"detail": "Interner Serverfehler. Bitte später erneut versuchen."})
# Globaler Session für MediaWiki-API
wiki_session = requests.Session()
# Health-Check für MediaWiki
@app.get("/import/wiki/health")
def wiki_health():
"""
Prüft, ob der MediaWiki-Server erreichbar ist.
"""
params = {"action": "query", "meta": "siteinfo", "siprop": "general", "format": "json"}
try:
r = wiki_session.get(WIKI_API_URL, params=params, timeout=5)
r.raise_for_status()
resp = r.json()
except Exception as e:
raise HTTPException(status_code=502, detail=f"Wiki nicht erreichbar: {e}")
# Versuche Servernamen auszulesen, aber gib OK zurück, wenn es fehlt
server = resp.get("query", {}).get("general", {}).get("servername")
if server:
return {"status": "ok", "server": server}
return {"status": "ok", "server": None}
# ------------------------
# MediaWiki Login Endpoint
# ------------------------
class WikiLoginRequest(BaseModel):
username: str
password: str
class WikiLoginResponse(BaseModel):
status: str
message: Optional[str] = None
@app.post("/import/wiki/login", response_model=WikiLoginResponse)
def wiki_login(data: WikiLoginRequest):
"""
Führt Login gegen MediaWiki-API durch und speichert Session-Cookies.
"""
# Schritt 1: Login-Token holen
params_token = {"action": "query", "meta": "tokens", "type": "login", "format": "json"}
try:
resp1 = wiki_session.get(WIKI_API_URL, params=params_token, timeout=10)
resp1.raise_for_status()
token = resp1.json().get("query", {}).get("tokens", {}).get("logintoken")
if not token:
raise ValueError("Kein Login-Token erhalten")
except Exception as e:
raise HTTPException(status_code=502, detail=f"Fehler Token abrufen: {e}")
# Schritt 2: Login mit Token
login_data = {
"action": "login", "format": "json",
"lgname": data.username, "lgpassword": data.password,
"lgtoken": token
}
try:
resp2 = wiki_session.post(WIKI_API_URL, data=login_data, timeout=10)
resp2.raise_for_status()
result = resp2.json().get("login", {})
if result.get("result") != "Success":
return WikiLoginResponse(status="failed", message=result.get("reason", "Login fehlgeschlagen"))
except Exception as e:
raise HTTPException(status_code=502, detail=f"Fehler Login: {e}")
return WikiLoginResponse(status="success")
# ------------------------
# Fallback: Connectivity ist gegeben, aber kein Servernamen
return {"status": "ok", "server": None}
# ------------------------"status": "ok", "server": general["servername"]}
# ------------------------
# ------------------------
# Modelle für Embed/Search
# ------------------------
class ChunkInput(BaseModel):
text: str
source: str
source_type: str = ""
title: str = ""
version: str = ""
related_to: str = ""
tags: List[str] = []
owner: str = ""
context_tag: Optional[str] = None
imported_at: Optional[str] = None
chunk_index: Optional[int] = None
category: Optional[str] = None
class EmbedRequest(BaseModel):
chunks: List[ChunkInput]
collection: str = "default"
class PromptRequest(BaseModel):
query: str
context_limit: int = 3
collection: str = "default"
class EmbedResponse(BaseModel):
status: str
count: int
collection: str
class SearchResultItem(BaseModel):
score: float = Field(..., ge=0)
text: str
class PromptResponse(BaseModel):
answer: str
context: str
collection: str
class DeleteResponse(BaseModel):
status: str
count: int
collection: str
source: Optional[str] = None
type: Optional[str] = None
# ------------------------------------
# Modelle für Exercises & Plans
# ------------------------------------
class Exercise(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
title: str
summary: str
short_description: str
keywords: List[str] = []
link: Optional[str] = None
discipline: str
group: Optional[str] = None
age_group: str
target_group: str
min_participants: int
duration_minutes: int
capabilities: Dict[str,int] = {}
category: str
purpose: str
execution: str
notes: str
preparation: str
method: str
equipment: List[str] = []
class PhaseExercise(BaseModel):
exercise_id: str
cond_load: Dict[str, Any] = {}
coord_load: Dict[str, Any] = {}
instructions: str = ""
class PlanPhase(BaseModel):
name: str
duration_minutes: int
method: str
method_notes: str = ""
exercises: List[PhaseExercise]
class TrainingPlan(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
title: str
short_description: str
collection: str
discipline: str
group: Optional[str] = None
dojo: str
date: date
plan_duration_weeks: int
focus_areas: List[str] = []
predecessor_plan_id: Optional[str] = None
age_group: str
created_at: datetime = Field(default_factory=datetime.utcnow)
phases: List[PlanPhase]
# ----------------------------------
# Embedding-Modell und Qdrant-Client
# ----------------------------------
model = SentenceTransformer("all-MiniLM-L6-v2")
qdrant = QdrantClient(
host=os.getenv("QDRANT_HOST", "localhost"),
port=int(os.getenv("QDRANT_PORT", 6333))
)
# Ensure Exercise-Collection exists
if not qdrant.collection_exists("exercises"):
qdrant.recreate_collection(
collection_name="exercises",
vectors_config=VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=Distance.COSINE
)
)
# Ensure TrainingPlan-Collection exists
PLAN_COLL = "training_plans"
if not qdrant.collection_exists(PLAN_COLL):
qdrant.recreate_collection(
collection_name=PLAN_COLL,
vectors_config=VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=Distance.COSINE
)
)
# ----------------------
# Endpunkte für Exercises
# ----------------------
@app.post("/exercise", response_model=Exercise)
def create_exercise(ex: Exercise):
# Ensure collection exists
if not qdrant.collection_exists("exercises"):
qdrant.recreate_collection(
collection_name="exercises",
vectors_config=VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=Distance.COSINE
)
)
vec = model.encode(f"{ex.title}. {ex.summary}").tolist()
point = PointStruct(id=ex.id, vector=vec, payload=ex.dict())
qdrant.upsert(collection_name="exercises", points=[point])
return ex
@app.get("/exercise", response_model=List[Exercise])
def list_exercises(
discipline: Optional[str] = Query(None),
group: Optional[str] = Query(None),
tags: Optional[str] = Query(None)
):
filters = []
if discipline:
filters.append({"key":"discipline","match":{"value":discipline}})
if group:
filters.append({"key":"group","match":{"value":group}})
if tags:
for t in tags.split(","):
filters.append({"key":"keywords","match":{"value":t.strip()}})
if filters:
pts, _ = qdrant.scroll(
collection_name="exercises",
scroll_filter={"must": filters},
limit=10000
)
else:
pts, _ = qdrant.scroll(collection_name="exercises", limit=10000)
return [Exercise(**pt.payload) for pt in pts]
# -----------------
# Bestehende Endpunkte
# -----------------
@app.post("/embed")
def embed_texts(data: EmbedRequest):
collection_name = data.collection
if not qdrant.collection_exists(collection_name):
qdrant.recreate_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
)
embeddings = model.encode([c.text for c in data.chunks]).tolist()
points = []
for i, chunk in enumerate(data.chunks):
payload = {
"text": chunk.text,
"source": chunk.source,
"source_type": chunk.source_type,
"title": chunk.title,
"version": chunk.version,
"related_to": chunk.related_to,
"tags": chunk.tags,
"owner": chunk.owner,
"context_tag": chunk.context_tag,
"imported_at": chunk.imported_at or datetime.utcnow().isoformat(),
"chunk_index": chunk.chunk_index,
"category": chunk.category or data.collection
}
points.append(PointStruct(id=str(uuid4()), vector=embeddings[i], payload=payload))
qdrant.upsert(collection_name=collection_name, points=points)
return {"status":"✅ embeddings saved","count":len(points),"collection":collection_name}
@app.get("/search", response_model=List[SearchResultItem])
def search_text(query: str = Query(..., min_length=1), limit: int = Query(3, ge=1), collection: str = Query("default")):
vec = model.encode(query).tolist()
res = qdrant.search(collection_name=collection, query_vector=vec, limit=limit)
return [SearchResultItem(score=r.score, text=r.payload['text']) for r in res]
@app.post("/prompt", response_model=PromptResponse)
def prompt(data: PromptRequest):
if not data.query.strip():
raise HTTPException(status_code=400, detail="'query' darf nicht leer sein.")
if not (1 <= data.context_limit <= 10):
raise HTTPException(status_code=400, detail="'context_limit' muss zwischen 1 und 10 liegen.")
hits = qdrant.search(
collection_name=data.collection,
query_vector=model.encode(data.query).tolist(),
limit=data.context_limit
)
context = "\n".join(h.payload['text'] for h in hits)
payload = {"model":OLLAMA_MODEL,"prompt":f"Context:\n{context}\nQuestion: {data.query}","stream":False}
try:
r = requests.post(OLLAMA_URL, json=payload, timeout=30)
r.raise_for_status()
except Exception:
raise HTTPException(status_code=502, detail="LLM-Service-Fehler.")
return PromptResponse(answer=r.json().get("response",""), context=context, collection=data.collection)
@app.delete("/delete-source", response_model=DeleteResponse)
def delete_by_source(
collection: str = Query(...),
source: Optional[str] = Query(None),
type: Optional[str] = Query(None),
owner: Optional[str] = Query(None),
category: Optional[str] = Query(None)
):
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
filt = []
if source: filt.append({"key":"source","match":{"value":source}})
if type: filt.append({"key":"type","match":{"value":type}})
if owner: filt.append({"key":"owner","match":{"value":owner}})
if category: filt.append({"key":"category","match":{"value":category}})
if not filt:
raise HTTPException(status_code=400, detail="Mindestens ein Filterparameter muss angegeben werden.")
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter={"must":filt}, limit=10000)
ids = [str(p.id) for p in pts]
if not ids:
return DeleteResponse(status="🔍 Keine passenden Einträge gefunden.", count=0, collection=collection)
qdrant.delete(collection_name=collection, points_selector=PointIdsList(points=ids))
return DeleteResponse(status="🗑️ gelöscht", count=len(ids), collection=collection)
@app.delete("/delete-collection", response_model=DeleteResponse)
def delete_collection(collection: str = Query(...)):
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
qdrant.delete_collection(collection_name=collection)
return DeleteResponse(status="🗑️ gelöscht", count=0, collection=collection)
# ------------------------
# Endpunkte für TrainingPlans
# ------------------------
@app.post("/plan", response_model=TrainingPlan)
def create_plan(plan: TrainingPlan):
# Ensure plan collection exists
if not qdrant.collection_exists(PLAN_COLL):
qdrant.recreate_collection(
collection_name=PLAN_COLL,
vectors_config=VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=Distance.COSINE
)
)
vec = model.encode(f"{plan.title}. {plan.short_description}").tolist()
payload = plan.dict()
qdrant.upsert(collection_name=PLAN_COLL, points=[PointStruct(id=plan.id, vector=vec, payload=payload)])
return plan
@app.get("/plan", response_model=List[TrainingPlan])
def list_plans(
collection: str = Query(PLAN_COLL),
discipline: Optional[str] = Query(None),
group: Optional[str] = Query(None),
dojo: Optional[str] = Query(None)
):
if not qdrant.collection_exists(collection):
return []
pts, _ = qdrant.scroll(collection_name=collection, limit=10000)
result: List[TrainingPlan] = []
for pt in pts:
plan = TrainingPlan(**pt.payload)
if discipline and plan.discipline != discipline: continue
if group and plan.group != group: continue
if dojo and plan.dojo != dojo: continue
result.append(plan)
return result