Trainer_LLM/llm-api/archiv/llm_apiV20.py

168 lines
6.6 KiB
Python

from fastapi import FastAPI, Query, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.openapi.utils import get_openapi
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct, PointIdsList
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
from uuid import uuid4
import requests
import os
from datetime import datetime
# Version hochgezählt
__version__ = "1.0.20"
print(f"[DEBUG] llm_api.py version {__version__} loaded from {__file__}", flush=True)
# Ollama-Konfiguration
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/api/generate")
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "mistral:latest")
# FastAPI-Instanz
app = FastAPI(
title="KI Trainerassistent API",
description="Lokale API für Karate- & Gewaltschutz-Trainingsplanung",
version=__version__,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json"
)
# Globaler Fehlerhandler
@app.exception_handler(Exception)
async def unicorn_exception_handler(request: Request, exc: Exception):
return JSONResponse(status_code=500, content={"detail": "Interner Serverfehler. Bitte später erneut versuchen."})
# Datenmodelle
class ChunkInput(BaseModel):
text: str
source: str
source_type: str = ""
title: str = ""
version: str = ""
related_to: str = ""
tags: List[str] = []
owner: str = ""
context_tag: Optional[str] = None
imported_at: Optional[str] = None
chunk_index: Optional[int] = None
category: Optional[str] = None
class EmbedRequest(BaseModel):
chunks: List[ChunkInput]
collection: str = "default"
class PromptRequest(BaseModel):
query: str
context_limit: int = 3
collection: str = "default"
class EmbedResponse(BaseModel):
status: str
count: int
collection: str
class SearchResultItem(BaseModel):
score: float = Field(..., ge=0)
text: str
class PromptResponse(BaseModel):
answer: str
context: str
collection: str
class DeleteResponse(BaseModel):
status: str
count: int
collection: str
source: Optional[str] = None
type: Optional[str] = None
# Embedding-Modell und Qdrant-Client
model = SentenceTransformer("all-MiniLM-L6-v2")
qdrant = QdrantClient(host=os.getenv("QDRANT_HOST", "localhost"), port=int(os.getenv("QDRANT_PORT", 6333)))
# /embed
@app.post("/embed", response_model=EmbedResponse)
def embed_texts(data: EmbedRequest):
if not data.chunks:
raise HTTPException(status_code=400, detail="'chunks' darf nicht leer sein.")
coll = data.collection
if not qdrant.collection_exists(coll):
qdrant.recreate_collection(collection_name=coll,
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE)
)
embeddings = model.encode([c.text for c in data.chunks]).tolist()
points = [PointStruct(id=str(uuid4()), vector=embeddings[i], payload={'text': c.text, 'source': c.source})
for i, c in enumerate(data.chunks)]
qdrant.upsert(collection_name=coll, points=points)
return EmbedResponse(status="✅ Saved", count=len(points), collection=coll)
# /search
@app.get("/search", response_model=List[SearchResultItem])
def search_text(query: str = Query(..., min_length=1), limit: int = Query(3, ge=1), collection: str = Query("default")):
vec = model.encode(query).tolist()
res = qdrant.search(collection_name=collection, query_vector=vec, limit=limit)
return [SearchResultItem(score=r.score, text=r.payload['text']) for r in res]
# /prompt
@app.post("/prompt", response_model=PromptResponse)
def prompt(data: PromptRequest):
if not data.query.strip(): raise HTTPException(status_code=400, detail="'query' darf nicht leer sein.")
if not (1 <= data.context_limit <= 10): raise HTTPException(status_code=400, detail="'context_limit' muss zwischen 1 und 10 liegen.")
hits = qdrant.search(collection_name=data.collection, query_vector=model.encode(data.query).tolist(), limit=data.context_limit)
context = '\n'.join(h.payload['text'] for h in hits)
payload = {'model': OLLAMA_MODEL, 'prompt': f"Context:\n{context}\nQuestion: {data.query}", 'stream': False}
try:
r = requests.post(OLLAMA_URL, json=payload, timeout=30); r.raise_for_status()
except Exception:
raise HTTPException(status_code=502, detail="LLM-Service-Fehler.")
return PromptResponse(answer=r.json().get('response', ''), context=context, collection=data.collection)
# /delete-source (neue Routine gemäß ursprünglicher funktionierender Logik)
@app.delete("/delete-source", response_model=DeleteResponse)
def delete_by_source(
collection: str = Query(...),
source: str = Query(...),
type: Optional[str] = Query(None)
):
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
# Filter-Bedingungen
must = [{"key": "source", "match": {"value": source}}]
if type:
must.append({"key": "type", "match": {"value": type}})
# IDs sammeln via scroll_filter
try:
points, _ = qdrant.scroll(
collection_name=collection,
scroll_filter={"must": must},
limit=10000
)
except Exception as exc:
print(f"[ERROR] Scroll failed: {exc}", flush=True)
raise HTTPException(status_code=500, detail="Fehler beim Abrufen der Punkte vor dem Löschen.")
point_ids = [str(pt.id) for pt in points]
if not point_ids:
return DeleteResponse(status="🔍 Keine passenden Einträge gefunden.", count=0, collection=collection, source=source, type=type)
# Lösche mittels PointIdsList(points=...)
try:
qdrant.delete(
collection_name=collection,
points_selector=PointIdsList(points=point_ids)
)
except Exception as exc:
print(f"[ERROR] Delete failed: {exc}", flush=True)
raise HTTPException(status_code=500, detail="Fehler beim Löschen nach Source.")
return DeleteResponse(status="🗑️ gelöscht", count=len(point_ids), collection=collection, source=source, type=type)
# /delete-collection
@app.delete("/delete-collection", response_model=DeleteResponse)
def delete_collection(collection: str = Query(...)):
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
qdrant.delete_collection(collection_name=collection)
return DeleteResponse(status="🗑️ gelöscht", count=0, collection=collection)