Trainer_LLM/llm-api/archiv/app.py

111 lines
3.0 KiB
Python

from fastapi import FastAPI, Query
from pydantic import BaseModel
from typing import List
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
import requests
app = FastAPI()
# Initialisierung
model = SentenceTransformer("all-MiniLM-L6-v2")
qdrant = QdrantClient(host="localhost", port=6333)
# COLLECTION = "karate-doku"
OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "mistral" # kann später auch geändert werden
# Embedding-Input
class EmbedRequest(BaseModel):
texts: List[str]
collection: str = "default"
class PromptRequest(BaseModel):
query: str
context_limit: int = 3
collection: str = "default"
@app.post("/embed")
def embed_texts(data: EmbedRequest):
collection_name = data.collection
if not qdrant.collection_exists(collection_name):
qdrant.recreate_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
)
embeddings = model.encode(data.texts).tolist()
points = [
PointStruct(id=i, vector=vec, payload={"text": data.texts[i]})
for i, vec in enumerate(embeddings)
]
qdrant.upsert(collection_name=collection_name, points=points)
return {"status": "✅ embeddings saved", "count": len(points), "collection": collection_name}
@app.get("/search")
def search_text(query: str = Query(...), limit: int = 3, collection: str = Query(...)):
vec = model.encode(query).tolist()
results = qdrant.search(collection_name=collection, query_vector=vec, limit=limit)
return [{"score": r.score, "text": r.payload["text"]} for r in results]
@app.post("/prompt")
def generate_prompt(data: PromptRequest):
query_vec = model.encode(data.query).tolist()
# Suche relevante Einträge aus der angegebenen Collection
results = qdrant.search(
collection_name=data.collection,
query_vector=query_vec,
limit=data.context_limit
)
# Kontext für den Prompt aus den gefundenen Texten zusammenbauen
context = "\n".join([r.payload["text"] for r in results])
full_prompt = f"""Beantworte die folgende Frage basierend auf dem Kontext:
Kontext:
{context}
Frage:
{data.query}
"""
# Anfrage an Ollama stellen
ollama_payload = {
"model": OLLAMA_MODEL,
"prompt": full_prompt,
"stream": False
}
response = requests.post(OLLAMA_URL, json=ollama_payload)
response.raise_for_status()
answer = response.json()["response"]
return {
"answer": answer,
"context": context,
"collection": data.collection
}
Kontext:
{context}
Frage:
{data.query}
"""
ollama_payload = {
"model": OLLAMA_MODEL,
"prompt": full_prompt,
"stream": False
}
response = requests.post(OLLAMA_URL, json=ollama_payload)
response.raise_for_status()
answer = response.json()["response"]
return {"answer": answer, "context": context}