Trainer_LLM/llm-api/embed_router.py

128 lines
4.9 KiB
Python

# Kommentarzeile 2
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional
from uuid import uuid4
from clients import model, qdrant
from qdrant_client.models import PointStruct, VectorParams, Distance, PointIdsList
import requests, os
router = APIRouter()
# Models
class ChunkInput(BaseModel):
text: str
source: str
source_type: str = ""
title: str = ""
version: str = ""
related_to: str = ""
tags: List[str] = []
owner: str = ""
context_tag: Optional[str] = None
imported_at: Optional[str] = None
chunk_index: Optional[int] = None
category: Optional[str] = None
class EmbedRequest(BaseModel):
chunks: List[ChunkInput]
collection: str = "default"
class PromptRequest(BaseModel):
query: str = Field(..., description="Suchanfrage")
context_limit: int = Field(default=3, ge=1, le=10, description="Anzahl Kontext-Dokumente")
collection: str = Field(default="default", description="Qdrant-Collection")
class PromptResponse(BaseModel):
answer: str
context: str
collection: str
class DeleteResponse(BaseModel):
status: str
count: int
collection: str
source: Optional[str] = None
type: Optional[str] = None
# Endpoints
@router.post("/embed")
def embed_texts(data: EmbedRequest):
if not qdrant.collection_exists(data.collection):
qdrant.recreate_collection(
collection_name=data.collection,
vectors_config=VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=Distance.COSINE
)
)
embeddings = model.encode([c.text for c in data.chunks]).tolist()
points = [PointStruct(id=str(uuid4()), vector=emb, payload=c.dict())
for emb, c in zip(embeddings, data.chunks)]
qdrant.upsert(collection_name=data.collection, points=points)
return {"status": "✅ embeddings saved", "count": len(points), "collection": data.collection}
@router.get("/search")
def search_text(query: str = Query(..., min_length=1), limit: int = Query(3, ge=1), collection: str = Query("default")):
vec = model.encode(query).tolist()
res = qdrant.search(collection_name=collection, query_vector=vec, limit=limit)
return [{"score": r.score, "text": r.payload.get("text", "")} for r in res]
@router.post("/prompt", response_model=PromptResponse)
def prompt(data: PromptRequest):
if not data.query.strip():
raise HTTPException(status_code=400, detail="'query' darf nicht leer sein.")
hits = qdrant.search(
collection_name=data.collection,
query_vector=model.encode(data.query).tolist(),
limit=data.context_limit
)
context = "\n".join(h.payload.get("text", "") for h in hits)
llm_url = os.getenv("OLLAMA_URL")
if not llm_url:
raise HTTPException(status_code=500, detail="LLM-Service-URL nicht konfiguriert.")
payload = {
"model": os.getenv("OLLAMA_MODEL"),
"prompt": f"Context:\n{context}\nQuestion: {data.query}",
"stream": False
}
try:
r = requests.post(llm_url, json=payload, timeout=30)
r.raise_for_status()
except Exception as e:
raise HTTPException(status_code=502, detail=f"LLM-Service-Fehler: {e}")
return PromptResponse(answer=r.json().get("response", ""), context=context, collection=data.collection)
@router.delete("/delete-source", response_model=DeleteResponse)
def delete_by_source(
collection: str = Query(...),
source: Optional[str] = Query(None),
type: Optional[str] = Query(None)
):
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
filt = []
if source:
filt.append({"key": "source", "match": {"value": source}})
if type:
filt.append({"key": "type", "match": {"value": type}})
if not filt:
raise HTTPException(status_code=400, detail="Mindestens ein Filterparameter muss angegeben werden.")
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter={"must": filt}, limit=10000)
ids = [str(p.id) for p in pts]
if not ids:
return DeleteResponse(status="🔍 Keine Einträge gefunden.", count=0, collection=collection)
qdrant.delete(collection_name=collection, points_selector=PointIdsList(points=ids))
return DeleteResponse(status="🗑️ gelöscht", count=len(ids), collection=collection)
# Delete entire collection
@router.delete("/delete-collection", response_model=DeleteResponse)
def delete_collection(
collection: str = Query(...)
):
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
qdrant.delete_collection(collection_name=collection)
return DeleteResponse(status="🗑️ gelöscht", count=0, collection=collection)