Trainer_LLM/llm-api/archiv/llm_api_bk2.py

197 lines
5.1 KiB
Python

from fastapi import FastAPI, Query, HTTPException
from pydantic import BaseModel
from typing import List
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
from fastapi import HTTPException
from uuid import uuid4
import requests
from datetime import datetime
from qdrant_client.models import PointIdsList
app = FastAPI()
# Konfiguration
model = SentenceTransformer("all-MiniLM-L6-v2")
qdrant = QdrantClient(host="localhost", port=6333)
OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "mistral"
# Datenmodelle
from typing import List, Dict, Any
class ChunkInput(BaseModel):
text: str
source: str
source_type: str = "file"
title: str | None = None
version: str | None = None
related_to: str | None = None
tags: List[str] = []
owner: str | None = None
context_tag: str | None = None
imported_at: str | None = None
chunk_index: int | None = None
category: str | None = None
class EmbedRequest(BaseModel):
chunks: List[ChunkInput]
collection: str = "default"
class PromptRequest(BaseModel):
query: str
context_limit: int = 3
collection: str = "default"
@app.delete("/delete-source")
def delete_by_source(
collection: str = Query(...),
source: str = Query(...),
type: str = Query(None)
):
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
must = [{"key": "source", "match": {"value": source}}]
if type:
must.append({"key": "type", "match": {"value": type}})
result = qdrant.scroll(
collection_name=collection,
scroll_filter={"must": must},
limit=10000
)
points = result[0]
if not points:
return {"status": "🔍 Keine passenden Einträge gefunden."}
point_ids = []
for point in points:
pid = point.id
point_ids.append(str(pid)) # immer zu String casten
qdrant.delete(
collection_name=collection,
points_selector=PointIdsList(points=point_ids)
)
return {
"status": "🗑️ gelöscht",
"count": len(point_ids),
"collection": collection,
"source": source,
"type": type
}
@app.delete("/delete-collection")
def delete_collection(collection: str = Query(...)):
"""
Löscht eine gesamte Collection aus Qdrant.
"""
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
qdrant.delete_collection(collection_name=collection)
return {"status": "🗑️ gelöscht", "collection": collection}
@app.post("/embed")
def embed_texts(data: EmbedRequest):
collection_name = data.collection
if not qdrant.collection_exists(collection_name):
qdrant.recreate_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
)
embeddings = model.encode([chunk.text for chunk in data.chunks]).tolist()
points = []
for i, chunk in enumerate(data.chunks):
payload = {
"text": chunk.text,
"source": chunk.source,
"source_type": chunk.source_type,
"title": chunk.title,
"version": chunk.version,
"related_to": chunk.related_to,
"tags": chunk.tags,
"owner": chunk.owner,
"context_tag": chunk.context_tag,
"imported_at": chunk.imported_at or datetime.utcnow().isoformat(),
"chunk_index": chunk.chunk_index,
"category": chunk.category or data.collection
}
point = PointStruct(
id=str(uuid4()),
vector=embeddings[i],
payload=payload
)
points.append(point)
qdrant.upsert(collection_name=collection_name, points=points)
return {
"status": "✅ embeddings saved",
"count": len(points),
"collection": collection_name
}
@app.get("/search")
def search_text(query: str = Query(...), limit: int = 3, collection: str = Query(...)):
vec = model.encode(query).tolist()
results = qdrant.search(collection_name=collection, query_vector=vec, limit=limit)
return [{"score": r.score, "text": r.payload["text"]} for r in results]
@app.post("/prompt")
def generate_prompt(data: PromptRequest):
query_vec = model.encode(data.query).tolist()
results = qdrant.search(
collection_name=data.collection,
query_vector=query_vec,
limit=data.context_limit
)
context = "\n".join([r.payload["text"] for r in results])
full_prompt = f"""Beantworte die folgende Frage basierend auf dem Kontext:
Kontext:
{context}
Frage:
{data.query}
"""
ollama_payload = {
"model": OLLAMA_MODEL,
"prompt": full_prompt,
"stream": False
}
response = requests.post(OLLAMA_URL, json=ollama_payload)
response.raise_for_status()
answer = response.json()["response"]
return {
"answer": answer,
"context": context,
"collection": data.collection
}