128 lines
4.9 KiB
Python
128 lines
4.9 KiB
Python
# Kommentarzeile 2
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Optional
|
|
from uuid import uuid4
|
|
from clients import model, qdrant
|
|
from qdrant_client.models import PointStruct, VectorParams, Distance, PointIdsList
|
|
import requests, os
|
|
|
|
router = APIRouter()
|
|
|
|
# Models
|
|
class ChunkInput(BaseModel):
|
|
text: str
|
|
source: str
|
|
source_type: str = ""
|
|
title: str = ""
|
|
version: str = ""
|
|
related_to: str = ""
|
|
tags: List[str] = []
|
|
owner: str = ""
|
|
context_tag: Optional[str] = None
|
|
imported_at: Optional[str] = None
|
|
chunk_index: Optional[int] = None
|
|
category: Optional[str] = None
|
|
|
|
class EmbedRequest(BaseModel):
|
|
chunks: List[ChunkInput]
|
|
collection: str = "default"
|
|
|
|
class PromptRequest(BaseModel):
|
|
query: str = Field(..., description="Suchanfrage")
|
|
context_limit: int = Field(default=3, ge=1, le=10, description="Anzahl Kontext-Dokumente")
|
|
collection: str = Field(default="default", description="Qdrant-Collection")
|
|
|
|
class PromptResponse(BaseModel):
|
|
answer: str
|
|
context: str
|
|
collection: str
|
|
|
|
class DeleteResponse(BaseModel):
|
|
status: str
|
|
count: int
|
|
collection: str
|
|
source: Optional[str] = None
|
|
type: Optional[str] = None
|
|
|
|
# Endpoints
|
|
@router.post("/embed")
|
|
def embed_texts(data: EmbedRequest):
|
|
if not qdrant.collection_exists(data.collection):
|
|
qdrant.recreate_collection(
|
|
collection_name=data.collection,
|
|
vectors_config=VectorParams(
|
|
size=model.get_sentence_embedding_dimension(),
|
|
distance=Distance.COSINE
|
|
)
|
|
)
|
|
embeddings = model.encode([c.text for c in data.chunks]).tolist()
|
|
points = [PointStruct(id=str(uuid4()), vector=emb, payload=c.dict())
|
|
for emb, c in zip(embeddings, data.chunks)]
|
|
qdrant.upsert(collection_name=data.collection, points=points)
|
|
return {"status": "✅ embeddings saved", "count": len(points), "collection": data.collection}
|
|
|
|
@router.get("/search")
|
|
def search_text(query: str = Query(..., min_length=1), limit: int = Query(3, ge=1), collection: str = Query("default")):
|
|
vec = model.encode(query).tolist()
|
|
res = qdrant.search(collection_name=collection, query_vector=vec, limit=limit)
|
|
return [{"score": r.score, "text": r.payload.get("text", "")} for r in res]
|
|
|
|
@router.post("/prompt", response_model=PromptResponse)
|
|
def prompt(data: PromptRequest):
|
|
if not data.query.strip():
|
|
raise HTTPException(status_code=400, detail="'query' darf nicht leer sein.")
|
|
hits = qdrant.search(
|
|
collection_name=data.collection,
|
|
query_vector=model.encode(data.query).tolist(),
|
|
limit=data.context_limit
|
|
)
|
|
context = "\n".join(h.payload.get("text", "") for h in hits)
|
|
llm_url = os.getenv("OLLAMA_URL")
|
|
if not llm_url:
|
|
raise HTTPException(status_code=500, detail="LLM-Service-URL nicht konfiguriert.")
|
|
payload = {
|
|
"model": os.getenv("OLLAMA_MODEL"),
|
|
"prompt": f"Context:\n{context}\nQuestion: {data.query}",
|
|
"stream": False
|
|
}
|
|
try:
|
|
r = requests.post(llm_url, json=payload, timeout=30)
|
|
r.raise_for_status()
|
|
except Exception as e:
|
|
raise HTTPException(status_code=502, detail=f"LLM-Service-Fehler: {e}")
|
|
return PromptResponse(answer=r.json().get("response", ""), context=context, collection=data.collection)
|
|
|
|
@router.delete("/delete-source", response_model=DeleteResponse)
|
|
def delete_by_source(
|
|
collection: str = Query(...),
|
|
source: Optional[str] = Query(None),
|
|
type: Optional[str] = Query(None)
|
|
):
|
|
if not qdrant.collection_exists(collection):
|
|
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
|
|
filt = []
|
|
if source:
|
|
filt.append({"key": "source", "match": {"value": source}})
|
|
if type:
|
|
filt.append({"key": "type", "match": {"value": type}})
|
|
if not filt:
|
|
raise HTTPException(status_code=400, detail="Mindestens ein Filterparameter muss angegeben werden.")
|
|
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter={"must": filt}, limit=10000)
|
|
ids = [str(p.id) for p in pts]
|
|
if not ids:
|
|
return DeleteResponse(status="🔍 Keine Einträge gefunden.", count=0, collection=collection)
|
|
qdrant.delete(collection_name=collection, points_selector=PointIdsList(points=ids))
|
|
return DeleteResponse(status="🗑️ gelöscht", count=len(ids), collection=collection)
|
|
|
|
# Delete entire collection
|
|
@router.delete("/delete-collection", response_model=DeleteResponse)
|
|
def delete_collection(
|
|
collection: str = Query(...)
|
|
):
|
|
if not qdrant.collection_exists(collection):
|
|
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
|
|
qdrant.delete_collection(collection_name=collection)
|
|
return DeleteResponse(status="🗑️ gelöscht", count=0, collection=collection)
|