197 lines
5.1 KiB
Python
197 lines
5.1 KiB
Python
from fastapi import FastAPI, Query, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import List
|
|
from sentence_transformers import SentenceTransformer
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import VectorParams, Distance, PointStruct
|
|
from fastapi import HTTPException
|
|
from uuid import uuid4
|
|
import requests
|
|
from datetime import datetime
|
|
|
|
from qdrant_client.models import PointIdsList
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
# Konfiguration
|
|
model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
qdrant = QdrantClient(host="localhost", port=6333)
|
|
OLLAMA_URL = "http://localhost:11434/api/generate"
|
|
OLLAMA_MODEL = "mistral"
|
|
|
|
# Datenmodelle
|
|
|
|
from typing import List, Dict, Any
|
|
|
|
class ChunkInput(BaseModel):
|
|
text: str
|
|
source: str
|
|
source_type: str = "file"
|
|
title: str | None = None
|
|
version: str | None = None
|
|
related_to: str | None = None
|
|
tags: List[str] = []
|
|
owner: str | None = None
|
|
context_tag: str | None = None
|
|
imported_at: str | None = None
|
|
chunk_index: int | None = None
|
|
category: str | None = None
|
|
|
|
class EmbedRequest(BaseModel):
|
|
chunks: List[ChunkInput]
|
|
collection: str = "default"
|
|
|
|
class PromptRequest(BaseModel):
|
|
query: str
|
|
context_limit: int = 3
|
|
collection: str = "default"
|
|
|
|
|
|
|
|
|
|
|
|
@app.delete("/delete-source")
|
|
def delete_by_source(
|
|
collection: str = Query(...),
|
|
source: str = Query(...),
|
|
type: str = Query(None)
|
|
):
|
|
if not qdrant.collection_exists(collection):
|
|
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
|
|
|
|
must = [{"key": "source", "match": {"value": source}}]
|
|
if type:
|
|
must.append({"key": "type", "match": {"value": type}})
|
|
|
|
result = qdrant.scroll(
|
|
collection_name=collection,
|
|
scroll_filter={"must": must},
|
|
limit=10000
|
|
)
|
|
|
|
points = result[0]
|
|
if not points:
|
|
return {"status": "🔍 Keine passenden Einträge gefunden."}
|
|
|
|
point_ids = []
|
|
for point in points:
|
|
pid = point.id
|
|
point_ids.append(str(pid)) # immer zu String casten
|
|
|
|
qdrant.delete(
|
|
collection_name=collection,
|
|
points_selector=PointIdsList(points=point_ids)
|
|
|
|
)
|
|
|
|
return {
|
|
"status": "🗑️ gelöscht",
|
|
"count": len(point_ids),
|
|
"collection": collection,
|
|
"source": source,
|
|
"type": type
|
|
}
|
|
|
|
|
|
|
|
@app.delete("/delete-collection")
|
|
def delete_collection(collection: str = Query(...)):
|
|
"""
|
|
Löscht eine gesamte Collection aus Qdrant.
|
|
"""
|
|
if not qdrant.collection_exists(collection):
|
|
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
|
|
|
|
qdrant.delete_collection(collection_name=collection)
|
|
return {"status": "🗑️ gelöscht", "collection": collection}
|
|
|
|
|
|
@app.post("/embed")
|
|
def embed_texts(data: EmbedRequest):
|
|
collection_name = data.collection
|
|
|
|
if not qdrant.collection_exists(collection_name):
|
|
qdrant.recreate_collection(
|
|
collection_name=collection_name,
|
|
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
|
|
)
|
|
|
|
embeddings = model.encode([chunk.text for chunk in data.chunks]).tolist()
|
|
|
|
points = []
|
|
for i, chunk in enumerate(data.chunks):
|
|
payload = {
|
|
"text": chunk.text,
|
|
"source": chunk.source,
|
|
"source_type": chunk.source_type,
|
|
"title": chunk.title,
|
|
"version": chunk.version,
|
|
"related_to": chunk.related_to,
|
|
"tags": chunk.tags,
|
|
"owner": chunk.owner,
|
|
"context_tag": chunk.context_tag,
|
|
"imported_at": chunk.imported_at or datetime.utcnow().isoformat(),
|
|
"chunk_index": chunk.chunk_index,
|
|
"category": chunk.category or data.collection
|
|
}
|
|
|
|
point = PointStruct(
|
|
id=str(uuid4()),
|
|
vector=embeddings[i],
|
|
payload=payload
|
|
)
|
|
points.append(point)
|
|
|
|
qdrant.upsert(collection_name=collection_name, points=points)
|
|
|
|
return {
|
|
"status": "✅ embeddings saved",
|
|
"count": len(points),
|
|
"collection": collection_name
|
|
}
|
|
|
|
|
|
|
|
|
|
@app.get("/search")
|
|
def search_text(query: str = Query(...), limit: int = 3, collection: str = Query(...)):
|
|
vec = model.encode(query).tolist()
|
|
results = qdrant.search(collection_name=collection, query_vector=vec, limit=limit)
|
|
return [{"score": r.score, "text": r.payload["text"]} for r in results]
|
|
|
|
@app.post("/prompt")
|
|
def generate_prompt(data: PromptRequest):
|
|
query_vec = model.encode(data.query).tolist()
|
|
results = qdrant.search(
|
|
collection_name=data.collection,
|
|
query_vector=query_vec,
|
|
limit=data.context_limit
|
|
)
|
|
|
|
context = "\n".join([r.payload["text"] for r in results])
|
|
full_prompt = f"""Beantworte die folgende Frage basierend auf dem Kontext:
|
|
|
|
Kontext:
|
|
{context}
|
|
|
|
Frage:
|
|
{data.query}
|
|
"""
|
|
|
|
ollama_payload = {
|
|
"model": OLLAMA_MODEL,
|
|
"prompt": full_prompt,
|
|
"stream": False
|
|
}
|
|
|
|
response = requests.post(OLLAMA_URL, json=ollama_payload)
|
|
response.raise_for_status()
|
|
answer = response.json()["response"]
|
|
|
|
return {
|
|
"answer": answer,
|
|
"context": context,
|
|
"collection": data.collection
|
|
}
|