Trainer_LLM/llm-api/exercise_router.py
Lars 32577a7fda
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s
llm-api/exercise_router.py aktualisiert
2025-08-11 19:35:28 +02:00

416 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
exercise_router.py v1.7.0
Neu:
- Endpoint **POST /exercise/search**: kombinierbare Filter (discipline, duration, equipment any/all, keywords any/all,
capability_geN / capability_eqN + names) + optionaler Vektor-Query (query-Text). Ausgabe inkl. Score.
- Facetten erweitert: neben capability_ge1..ge5 jetzt auch capability_eq1..eq5.
- Idempotenz-Fix & Payload-Scroll (aus v1.6.2) beibehalten.
- API-Signaturen bestehender Routen unverändert.
Hinweis: Die „eq/ge“-Felder werden beim Upsert gesetzt; für Alt-Punkte einmal das Backfill laufen lassen.
"""
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from uuid import uuid4
from datetime import datetime
from clients import model, qdrant
from qdrant_client.models import (
PointStruct,
VectorParams,
Distance,
PointIdsList,
Filter,
FieldCondition,
MatchValue,
)
import os
router = APIRouter()
# =========================
# Models
# =========================
class Exercise(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
# Upsert-Metadaten
external_id: Optional[str] = None
fingerprint: Optional[str] = None
source: Optional[str] = None
imported_at: Optional[datetime] = None
# Domain-Felder
title: str
summary: str
short_description: str
keywords: List[str] = []
link: Optional[str] = None
discipline: str
group: Optional[str] = None
age_group: str
target_group: str
min_participants: int
duration_minutes: int
capabilities: Dict[str, int] = {}
category: str
purpose: str
execution: str
notes: str
preparation: str
method: str
equipment: List[str] = []
class DeleteResponse(BaseModel):
status: str
count: int
collection: str
class ExerciseSearchRequest(BaseModel):
# Optionaler Semantik-Query (Vektor)
query: Optional[str] = None
limit: int = Field(default=20, ge=1, le=200)
offset: int = Field(default=0, ge=0)
# Einfache Filter
discipline: Optional[str] = None
target_group: Optional[str] = None
age_group: Optional[str] = None
max_duration: Optional[int] = Field(default=None, ge=0)
# Listen-Filter
equipment_any: Optional[List[str]] = None # mindestens eins muss passen
equipment_all: Optional[List[str]] = None # alle müssen passen
keywords_any: Optional[List[str]] = None
keywords_all: Optional[List[str]] = None
# Capabilities (Namen + Level-Operator)
capability_names: Optional[List[str]] = None
capability_ge_level: Optional[int] = Field(default=None, ge=1, le=5)
capability_eq_level: Optional[int] = Field(default=None, ge=1, le=5)
class ExerciseSearchHit(BaseModel):
id: str
score: Optional[float] = None
payload: Exercise
class ExerciseSearchResponse(BaseModel):
hits: List[ExerciseSearchHit]
# =========================
# Helpers
# =========================
COLLECTION = os.getenv("EXERCISE_COLLECTION", "exercises")
def _ensure_collection():
if not qdrant.collection_exists(COLLECTION):
qdrant.recreate_collection(
collection_name=COLLECTION,
vectors_config=VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=Distance.COSINE,
),
)
def _lookup_by_external_id(external_id: str) -> Optional[Dict[str, Any]]:
_ensure_collection()
flt = Filter(must=[FieldCondition(key="external_id", match=MatchValue(value=external_id))])
pts, _ = qdrant.scroll(
collection_name=COLLECTION,
scroll_filter=flt,
limit=1,
with_payload=True,
)
if not pts:
return None
doc = dict(pts[0].payload or {})
doc.setdefault("id", str(pts[0].id))
return doc
_DEF_EMBED_FIELDS = ("title", "summary", "short_description", "purpose", "execution", "notes")
def _make_vector_from_exercise(ex: Exercise) -> List[float]:
text = ". ".join([getattr(ex, f, "") for f in _DEF_EMBED_FIELDS if getattr(ex, f, None)])
return model.encode(text).tolist()
def _make_vector_from_query(query: str) -> List[float]:
return model.encode(query).tolist()
def _norm_list(xs: List[Any]) -> List[str]:
out = []
seen = set()
for x in xs or []:
s = str(x).strip()
if not s:
continue
key = s.casefold()
if key in seen:
continue
seen.add(key)
out.append(s)
return sorted(out, key=str.casefold)
def _facet_capabilities(caps: Dict[str, Any]) -> Dict[str, List[str]]:
caps = caps or {}
def names_where(pred) -> List[str]:
out = []
for k, v in caps.items():
try:
iv = int(v)
except Exception:
iv = 0
if pred(iv):
t = str(k).strip()
if t:
out.append(t)
return sorted({t for t in out}, key=str.casefold)
all_keys = sorted({str(k).strip() for k in caps.keys() if str(k).strip()}, key=str.casefold)
return {
"capability_keys": all_keys,
# >= N
"capability_ge1": names_where(lambda lv: lv >= 1),
"capability_ge2": names_where(lambda lv: lv >= 2),
"capability_ge3": names_where(lambda lv: lv >= 3),
"capability_ge4": names_where(lambda lv: lv >= 4),
"capability_ge5": names_where(lambda lv: lv >= 5),
# == N
"capability_eq1": names_where(lambda lv: lv == 1),
"capability_eq2": names_where(lambda lv: lv == 2),
"capability_eq3": names_where(lambda lv: lv == 3),
"capability_eq4": names_where(lambda lv: lv == 4),
"capability_eq5": names_where(lambda lv: lv == 5),
}
def _response_strip_extras(payload: Dict[str, Any]) -> Dict[str, Any]:
allowed = set(Exercise.model_fields.keys())
return {k: v for k, v in payload.items() if k in allowed}
def _build_filter(req: ExerciseSearchRequest) -> Filter:
must: List[Any] = []
should: List[Any] = []
if req.discipline:
must.append(FieldCondition(key="discipline", match=MatchValue(value=req.discipline)))
if req.target_group:
must.append(FieldCondition(key="target_group", match=MatchValue(value=req.target_group)))
if req.age_group:
must.append(FieldCondition(key="age_group", match=MatchValue(value=req.age_group)))
if req.max_duration is not None:
# Range ohne Import zusätzlicher Modelle: Qdrant akzeptiert auch {'range': {'lte': n}} per JSON;
# über Client-Modell tun wir es hier nicht, da wir Filter primär für Keyword-Felder nutzen.
must.append({"key": "duration_minutes", "range": {"lte": int(req.max_duration)}})
# equipment
if req.equipment_all:
for it in req.equipment_all:
must.append(FieldCondition(key="equipment", match=MatchValue(value=it)))
if req.equipment_any:
# OR: über 'should' Liste
for it in req.equipment_any:
should.append(FieldCondition(key="equipment", match=MatchValue(value=it)))
# keywords
if req.keywords_all:
for it in req.keywords_all:
must.append(FieldCondition(key="keywords", match=MatchValue(value=it)))
if req.keywords_any:
for it in req.keywords_any:
should.append(FieldCondition(key="keywords", match=MatchValue(value=it)))
# capabilities (ge/eq)
if req.capability_names:
names = [s for s in req.capability_names if s and s.strip()]
if req.capability_eq_level:
key = f"capability_eq{int(req.capability_eq_level)}"
for n in names:
must.append(FieldCondition(key=key, match=MatchValue(value=n)))
elif req.capability_ge_level:
key = f"capability_ge{int(req.capability_ge_level)}"
for n in names:
must.append(FieldCondition(key=key, match=MatchValue(value=n)))
else:
# Default: Level >=1 (alle vorhanden)
for n in names:
must.append(FieldCondition(key="capability_ge1", match=MatchValue(value=n)))
flt = Filter(must=must)
if should:
# qdrant: 'should' mit implizitem minimum_should_match=1
flt.should = should
return flt
# =========================
# Endpoints
# =========================
@router.get("/exercise/by-external-id")
def get_exercise_by_external_id(external_id: str = Query(..., min_length=3)):
found = _lookup_by_external_id(external_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
return found
@router.post("/exercise", response_model=Exercise)
def create_or_update_exercise(ex: Exercise):
_ensure_collection()
point_id = ex.id
if ex.external_id:
prior = _lookup_by_external_id(ex.external_id)
if prior:
point_id = prior.get("id", point_id)
vector = _make_vector_from_exercise(ex)
payload: Dict[str, Any] = ex.model_dump()
payload["id"] = str(point_id)
payload["keywords"] = _norm_list(payload.get("keywords") or [])
payload["equipment"] = _norm_list(payload.get("equipment") or [])
payload.update(_facet_capabilities(payload.get("capabilities") or {}))
qdrant.upsert(
collection_name=COLLECTION,
points=[PointStruct(id=str(point_id), vector=vector, payload=payload)],
)
return Exercise(**_response_strip_extras(payload))
@router.get("/exercise/{exercise_id}", response_model=Exercise)
def get_exercise(exercise_id: str):
_ensure_collection()
pts, _ = qdrant.scroll(
collection_name=COLLECTION,
scroll_filter=Filter(must=[FieldCondition(key="id", match=MatchValue(value=exercise_id))]),
limit=1,
with_payload=True,
)
if not pts:
raise HTTPException(status_code=404, detail="not found")
payload = dict(pts[0].payload or {})
payload.setdefault("id", str(pts[0].id))
return Exercise(**_response_strip_extras(payload))
@router.post("/exercise/search", response_model=ExerciseSearchResponse)
def search_exercises(req: ExerciseSearchRequest) -> ExerciseSearchResponse:
_ensure_collection()
flt = _build_filter(req)
hits: List[ExerciseSearchHit] = []
if req.query:
vec = _make_vector_from_query(req.query)
# qdrant_client.search unterstützt offset/limit
res = qdrant.search(
collection_name=COLLECTION,
query_vector=vec,
limit=req.limit,
offset=req.offset,
query_filter=flt,
)
for h in res:
payload = dict(h.payload or {})
payload.setdefault("id", str(h.id))
hits.append(ExerciseSearchHit(id=str(h.id), score=float(h.score or 0.0), payload=Exercise(**_response_strip_extras(payload))))
else:
# Filter-only: per Scroll (ohne Score); einfache Paginierung via offset/limit
# Hole offset+limit Punkte und simuliere Score=None
collected = 0
skipped = 0
next_offset = None
while collected < req.limit:
page, next_offset = qdrant.scroll(
collection_name=COLLECTION,
scroll_filter=flt,
offset=next_offset,
limit=max(1, min(256, req.limit - collected + req.offset - skipped)),
with_payload=True,
)
if not page:
break
for pt in page:
if skipped < req.offset:
skipped += 1
continue
payload = dict(pt.payload or {})
payload.setdefault("id", str(pt.id))
hits.append(ExerciseSearchHit(id=str(pt.id), score=None, payload=Exercise(**_response_strip_extras(payload))))
collected += 1
if collected >= req.limit:
break
if next_offset is None:
break
return ExerciseSearchResponse(hits=hits)
@router.delete("/exercise/delete-by-external-id", response_model=DeleteResponse)
def delete_by_external_id(external_id: str = Query(...)):
_ensure_collection()
flt = Filter(must=[FieldCondition(key="external_id", match=MatchValue(value=external_id))])
pts, _ = qdrant.scroll(collection_name=COLLECTION, scroll_filter=flt, limit=10000, with_payload=False)
ids = [str(p.id) for p in pts]
if not ids:
return DeleteResponse(status="🔍 Keine Einträge gefunden.", count=0, collection=COLLECTION)
qdrant.delete(collection_name=COLLECTION, points_selector=PointIdsList(points=ids))
return DeleteResponse(status="🗑️ gelöscht", count=len(ids), collection=COLLECTION)
@router.delete("/exercise/delete-collection", response_model=DeleteResponse)
def delete_collection(collection: str = Query(default=COLLECTION)):
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
qdrant.delete_collection(collection_name=collection)
return DeleteResponse(status="🗑️ gelöscht", count=0, collection=collection)
# ---------------------------
# OPTIONAL: einfacher Selbsttest (kannst du auch separat als Script verwenden)
# ---------------------------
TEST_DOC = """
Speicher als tests/test_exercise_search.py und mit pytest laufen lassen.
import os, requests
BASE = os.getenv("API_BASE", "http://localhost:8000")
# 1) Filter-only
r = requests.post(f"{BASE}/exercise/search", json={
"discipline": "Karate",
"max_duration": 12,
"equipment_any": ["Bälle"],
"capability_names": ["Reaktionsfähigkeit"],
"capability_ge_level": 2,
"limit": 5
})
r.raise_for_status()
js = r.json()
assert "hits" in js
for h in js["hits"]:
p = h["payload"]
assert p["discipline"] == "Karate"
assert p["duration_minutes"] <= 12
# 2) Vector + Filter
r = requests.post(f"{BASE}/exercise/search", json={
"query": "Aufwärmen 10min, Reaktionsfähigkeit, Teenager, Bälle",
"discipline": "Karate",
"limit": 3
})
r.raise_for_status()
js = r.json(); assert len(js["hits"]) <= 3
"""