324 lines
12 KiB
Python
324 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# llm_api.py — Version 1.1.11
|
|
|
|
from fastapi import FastAPI, Query, HTTPException, Request
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Dict, Any, Optional
|
|
from sentence_transformers import SentenceTransformer
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import VectorParams, Distance, PointStruct, PointIdsList
|
|
from uuid import uuid4
|
|
import requests
|
|
import os
|
|
from datetime import datetime, date
|
|
|
|
# Version hochgezählt
|
|
__version__ = "1.1.11"
|
|
print(f"[DEBUG] llm_api.py version {__version__} loaded from {__file__}", flush=True)
|
|
|
|
# FastAPI-Anwendung
|
|
app = FastAPI(
|
|
title="KI Trainerassistent API",
|
|
description="Lokale API für Trainingsplanung",
|
|
version=__version__,
|
|
docs_url="/docs",
|
|
redoc_url="/redoc",
|
|
openapi_url="/openapi.json"
|
|
)
|
|
|
|
# Globaler Fehlerhandler
|
|
@app.exception_handler(Exception)
|
|
async def unicorn_exception_handler(request: Request, exc: Exception):
|
|
return JSONResponse(status_code=500, content={"detail": "Interner Serverfehler. Bitte später erneut versuchen."})
|
|
|
|
# --------------------------------
|
|
# Modelle für Embed/Search
|
|
# --------------------------------
|
|
class ChunkInput(BaseModel):
|
|
text: str
|
|
source: str
|
|
source_type: str = ""
|
|
title: str = ""
|
|
version: str = ""
|
|
related_to: str = ""
|
|
tags: List[str] = []
|
|
owner: str = ""
|
|
context_tag: Optional[str] = None
|
|
imported_at: Optional[str] = None
|
|
chunk_index: Optional[int] = None
|
|
category: Optional[str] = None
|
|
|
|
class EmbedRequest(BaseModel):
|
|
chunks: List[ChunkInput]
|
|
collection: str = "default"
|
|
|
|
class PromptRequest(BaseModel):
|
|
query: str
|
|
context_limit: int = 3
|
|
collection: str = "default"
|
|
|
|
class EmbedResponse(BaseModel):
|
|
status: str
|
|
count: int
|
|
collection: str
|
|
|
|
class SearchResultItem(BaseModel):
|
|
score: float = Field(..., ge=0)
|
|
text: str
|
|
|
|
class PromptResponse(BaseModel):
|
|
answer: str
|
|
context: str
|
|
collection: str
|
|
|
|
class DeleteResponse(BaseModel):
|
|
status: str
|
|
count: int
|
|
collection: str
|
|
source: Optional[str] = None
|
|
type: Optional[str] = None
|
|
|
|
# --------------------------------
|
|
# Modelle für Exercises & TrainingPlans
|
|
# --------------------------------
|
|
class Exercise(BaseModel):
|
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
title: str
|
|
summary: str
|
|
short_description: str
|
|
keywords: List[str] = []
|
|
link: Optional[str] = None
|
|
discipline: str
|
|
group: Optional[str] = None
|
|
age_group: str
|
|
target_group: str
|
|
min_participants: int
|
|
duration_minutes: int
|
|
capabilities: Dict[str, int] = {}
|
|
category: str
|
|
purpose: str
|
|
execution: str
|
|
notes: str
|
|
preparation: str
|
|
method: str
|
|
equipment: List[str] = []
|
|
|
|
class PhaseExercise(BaseModel):
|
|
exercise_id: str
|
|
cond_load: Dict[str, Any] = {}
|
|
coord_load: Dict[str, Any] = {}
|
|
instructions: str = ""
|
|
|
|
class PlanPhase(BaseModel):
|
|
name: str
|
|
duration_minutes: int
|
|
method: str
|
|
method_notes: str = ""
|
|
exercises: List[PhaseExercise]
|
|
|
|
class TrainingPlan(BaseModel):
|
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
|
title: str
|
|
short_description: str
|
|
collection: str
|
|
discipline: str
|
|
group: Optional[str] = None
|
|
dojo: str
|
|
date: date
|
|
plan_duration_weeks: int
|
|
focus_areas: List[str] = []
|
|
predecessor_plan_id: Optional[str] = None
|
|
age_group: str
|
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
phases: List[PlanPhase]
|
|
|
|
# ----------------------------------
|
|
# Embedding-Modell und Qdrant-Client
|
|
# ----------------------------------
|
|
model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
qdrant = QdrantClient(
|
|
host=os.getenv("QDRANT_HOST", "localhost"),
|
|
port=int(os.getenv("QDRANT_PORT", 6333))
|
|
)
|
|
|
|
# Collection-Namen
|
|
EXERCISE_COLL = "exercises"
|
|
PLAN_COLL = "training_plans"
|
|
|
|
# Sicherstellen, dass Collections existieren
|
|
if not qdrant.collection_exists(EXERCISE_COLL):
|
|
qdrant.recreate_collection(
|
|
collection_name=EXERCISE_COLL,
|
|
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE)
|
|
)
|
|
if not qdrant.collection_exists(PLAN_COLL):
|
|
qdrant.recreate_collection(
|
|
collection_name=PLAN_COLL,
|
|
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE)
|
|
)
|
|
|
|
# ----------------------------------
|
|
# Endpunkte für Exercises
|
|
# ----------------------------------
|
|
@app.post("/exercise", response_model=Exercise)
|
|
def create_exercise(ex: Exercise):
|
|
vec = model.encode(f"{ex.title}. {ex.summary}").tolist()
|
|
point = PointStruct(id=ex.id, vector=vec, payload=ex.dict())
|
|
qdrant.upsert(collection_name=EXERCISE_COLL, points=[point])
|
|
return ex
|
|
|
|
@app.get("/exercise", response_model=List[Exercise])
|
|
def list_exercises(
|
|
discipline: Optional[str] = Query(None),
|
|
group: Optional[str] = Query(None),
|
|
tags: Optional[str] = Query(None)
|
|
):
|
|
filters = []
|
|
if discipline:
|
|
filters.append({"key": "discipline", "match": {"value": discipline}})
|
|
if group:
|
|
filters.append({"key": "group", "match": {"value": group}})
|
|
if tags:
|
|
for t in tags.split(","):
|
|
filters.append({"key": "keywords", "match": {"value": t.strip()}})
|
|
if filters:
|
|
pts, _ = qdrant.scroll(collection_name=EXERCISE_COLL, scroll_filter={"must": filters}, limit=10000)
|
|
else:
|
|
pts, _ = qdrant.scroll(collection_name=EXERCISE_COLL, limit=10000)
|
|
return [Exercise(**pt.payload) for pt in pts]
|
|
|
|
# ----------------------------------
|
|
# Endpunkte für TrainingPlans
|
|
# ----------------------------------
|
|
@app.post("/plan", response_model=TrainingPlan)
|
|
def create_plan(plan: TrainingPlan):
|
|
vec = model.encode(f"{plan.title}. {plan.short_description}").tolist()
|
|
point = PointStruct(id=plan.id, vector=vec, payload=plan.dict())
|
|
qdrant.upsert(collection_name=PLAN_COLL, points=[point])
|
|
return plan
|
|
|
|
@app.get("/plan", response_model=List[TrainingPlan])
|
|
def list_plans(
|
|
discipline: Optional[str] = Query(None),
|
|
group: Optional[str] = Query(None),
|
|
dojo: Optional[str] = Query(None)
|
|
):
|
|
filters = []
|
|
if discipline:
|
|
filters.append({"key": "discipline", "match": {"value": discipline}})
|
|
if group:
|
|
filters.append({"key": "group", "match": {"value": group}})
|
|
if dojo:
|
|
filters.append({"key": "dojo", "match": {"value": dojo}})
|
|
if filters:
|
|
pts, _ = qdrant.scroll(collection_name=PLAN_COLL, scroll_filter={"must": filters}, limit=10000)
|
|
else:
|
|
pts, _ = qdrant.scroll(collection_name=PLAN_COLL, limit=10000)
|
|
return [TrainingPlan(**pt.payload) for pt in pts]
|
|
|
|
# ----------------------------------
|
|
# Endpunkte Embed/Search und Löschen
|
|
# ----------------------------------
|
|
@app.post("/embed", response_model=EmbedResponse)
|
|
def embed_texts(data: EmbedRequest):
|
|
collection_name = data.collection
|
|
if not qdrant.collection_exists(collection_name):
|
|
qdrant.recreate_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE)
|
|
)
|
|
embeddings = model.encode([c.text for c in data.chunks]).tolist()
|
|
points = []
|
|
for i, chunk in enumerate(data.chunks):
|
|
payload = {**chunk.dict(), "imported_at": chunk.imported_at or datetime.utcnow().isoformat()}
|
|
points.append(PointStruct(id=str(uuid4()), vector=embeddings[i], payload=payload))
|
|
qdrant.upsert(collection_name=collection_name, points=points)
|
|
return EmbedResponse(status="✅ embeddings saved", count=len(points), collection=collection_name)
|
|
|
|
@app.get("/search", response_model=List[SearchResultItem])
|
|
def search_text(query: str = Query(..., min_length=1), limit: int = Query(3, ge=1), collection: str = Query("default")):
|
|
vec = model.encode(query).tolist()
|
|
res = qdrant.search(collection_name=collection, query_vector=vec, limit=limit)
|
|
return [SearchResultItem(score=r.score, text=r.payload['text']) for r in res]
|
|
|
|
@app.post("/prompt", response_model=PromptResponse)
|
|
def prompt(data: PromptRequest):
|
|
if not data.query.strip():
|
|
raise HTTPException(status_code=400, detail="'query' darf nicht leer sein.")
|
|
if not (1 <= data.context_limit <= 10):
|
|
raise HTTPException(status_code=400, detail="'context_limit' muss zwischen 1 und 10 liegen.")
|
|
hits = qdrant.search(collection_name=data.collection, query_vector=model.encrypt(data.query).tolist(), limit=data.context_limit)
|
|
context = "\n".join(h.payload['text'] for h in hits)
|
|
try:
|
|
r = requests.post(
|
|
os.getenv("OLLAMA_URL", "http://localhost:11434/api/generate"),
|
|
json={"model": os.getenv("OLLAMA_MODEL", "mistral:latest"), "prompt": f"Context:\n{context}\nQuestion: {data.query}", "stream": False},
|
|
timeout=30
|
|
)
|
|
r.raise_for_status()
|
|
except Exception:
|
|
raise HTTPException(status_code=502, detail="LLM-Service-Fehler.")
|
|
return PromptResponse(answer=r.json().get("response", ""), context=context, collection=data.collection)
|
|
|
|
@app.delete("/delete-source", response_model=DeleteResponse)
|
|
def delete_by_source(
|
|
collection: str = Query(...), source: Optional[str] = Query(None), type: Optional[str] = Query(None), owner: Optional[str] = Query(None), category: Optional[str] = Query(None)
|
|
):
|
|
if not qdrant.collection_exists(collection):
|
|
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
|
|
filters = []
|
|
if source:
|
|
filters.append({"key": "source", "match": {"value": source}})
|
|
if type:
|
|
filters.append({"key": "type", "match": {"value": type}})
|
|
if owner:
|
|
filters.append({"key": "owner", "match": {"value": owner}})
|
|
if category:
|
|
filters.append({"key": "category", "match": {"value": category}})
|
|
if not filters:
|
|
raise HTTPException(status_code=400, detail="Mindestens ein Filterparameter muss angegeben werden.")
|
|
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter={"must": filters}, limit=10000)
|
|
ids = [str(p.id) for p in pts]
|
|
if not ids:
|
|
return DeleteResponse(status="🔍 Keine passenden Einträge gefunden.", count=0, collection=collection)
|
|
qdrant.delete(collection_name=collection, points_selector=PointIdsList(points=ids))
|
|
return DeleteResponse(status="🗑️ gelöscht", count=len(ids), collection=collection)
|
|
|
|
@app.delete("/delete-collection", response_model=DeleteResponse)
|
|
def delete_collection(collection: str = Query(...)):
|
|
if not qdrant.collection_exists(collection):
|
|
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
|
|
qdrant.delete_collection(collection_name=collection)
|
|
return DeleteResponse(status="🗑️ gelöscht", count=0, collection=collection)
|
|
|
|
# ----------------------------------------------------------------
|
|
# MediaWiki-Login (v1.1.11)
|
|
# ----------------------------------------------------------------
|
|
MEDIAWIKI_API_URL = os.getenv("MEDIAWIKI_API_URL", "https://www.Karatetrainer.de/api.php")
|
|
MEDIAWIKI_USER = os.getenv("MEDIAWIKI_USER", "LarsS@APIBot")
|
|
MEDIAWIKI_PASSWORD= os.getenv("MEDIAWIKI_PASSWORD", "6snci781sh79tbmvb2u9ld4bkd1i7n5t")
|
|
wiki_session = requests.Session()
|
|
|
|
@app.post("/import/wiki/login")
|
|
async def import_wiki_login():
|
|
try:
|
|
params_token = {"action": "query", "meta": "tokens", "type": "login", "format": "json"}
|
|
resp1 = wiki_session.get(MEDIAWIKI_API_URL, params=params_token)
|
|
resp1.raise_for_status()
|
|
token = resp1.json()["query"]["tokens"]["logintoken"]
|
|
|
|
login_params = {"action": "login", "format": "json"}
|
|
login_data = {"lgname": MEDIAWIKI_USER, "lgpassword": MEDIAWIKI_PASSWORD, "lgtoken": token}
|
|
resp2 = wiki_session.post(MEDIAWIKI_API_URL, params=login_params, data=login_data)
|
|
resp2.raise_for_status()
|
|
result = resp2.json().get("login", {})
|
|
|
|
if result.get("result") == "Success":
|
|
return {"status": "✅ MediaWiki login erfolgreich."}
|
|
else:
|
|
raise HTTPException(status_code=401, detail=f"Login fehlgeschlagen: {result.get('reason','unbekannter Fehler')}")
|
|
|
|
except requests.RequestException as e:
|
|
raise HTTPException(status_code=502, detail=f"Fehler bei Wiki-API-Aufruf: {str(e)}")
|