Trainer_LLM/llm-api/exercise_router.py

184 lines
6.0 KiB
Python

# Test eines Kommentars, um die Funktion des gitea testen zu können
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from uuid import uuid4
from datetime import datetime, date
from clients import model, qdrant
from qdrant_client.models import PointStruct, VectorParams, Distance, PointIdsList
import os
router = APIRouter()
# ---- Models ----
class Exercise(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
title: str
summary: str
short_description: str
keywords: List[str] = []
link: Optional[str] = None
discipline: str
group: Optional[str] = None
age_group: str
target_group: str
min_participants: int
duration_minutes: int
capabilities: Dict[str, int] = {}
category: str
purpose: str
execution: str
notes: str
preparation: str
method: str
equipment: List[str] = []
class PhaseExercise(BaseModel):
exercise_id: str
cond_load: Dict[str, Any] = {}
coord_load: Dict[str, Any] = {}
instructions: str
class PlanPhase(BaseModel):
name: str
duration_minutes: int
method: str
method_notes: str
exercises: List[PhaseExercise]
class TrainingPlan(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
title: str
short_description: str
collection: str
discipline: str
group: Optional[str] = None
dojo: str
date: date
plan_duration_weeks: int
focus_areas: List[str] = []
predecessor_plan_id: Optional[str] = None
age_group: str
created_at: datetime = Field(default_factory=datetime.utcnow)
phases: List[PlanPhase]
class DeleteResponse(BaseModel):
status: str
count: int
collection: str
source: Optional[str] = None
type: Optional[str] = None
# ---- CRUD Endpoints for Exercise ----
@router.post("/exercise", response_model=Exercise)
def create_exercise(ex: Exercise):
# Ensure Exercise collection exists
if not qdrant.collection_exists("exercises"):
qdrant.recreate_collection(
collection_name="exercises",
vectors_config=VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=Distance.COSINE
)
)
vec = model.encode(f"{ex.title}. {ex.summary}").tolist()
point = PointStruct(id=ex.id, vector=vec, payload=ex.dict())
qdrant.upsert(collection_name="exercises", points=[point])
return ex
@router.get("/exercise", response_model=List[Exercise])
def list_exercises(
discipline: Optional[str] = Query(None),
group: Optional[str] = Query(None),
tags: Optional[str] = Query(None)
):
filters = []
if discipline:
filters.append({"key": "discipline", "match": {"value": discipline}})
if group:
filters.append({"key": "group", "match": {"value": group}})
if tags:
for t in tags.split(","):
filters.append({"key": "keywords", "match": {"value": t.strip()}})
pts, _ = qdrant.scroll(
collection_name="exercises",
scroll_filter={"must": filters} if filters else None,
limit=10000
)
return [Exercise(**pt.payload) for pt in pts]
# ---- CRUD Endpoints for TrainingPlan ----
@router.post("/plan", response_model=TrainingPlan)
def create_plan(plan: TrainingPlan):
# Ensure TrainingPlan collection exists
if not qdrant.collection_exists("training_plans"):
qdrant.recreate_collection(
collection_name="training_plans",
vectors_config=VectorParams(
size=model.get_sentence_embedding_dimension(),
distance=Distance.COSINE
)
)
vec = model.encode(f"{plan.title}. {plan.short_description}").tolist()
point = PointStruct(id=plan.id, vector=vec, payload=plan.dict())
qdrant.upsert(collection_name="training_plans", points=[point])
return plan
@router.get("/plan", response_model=List[TrainingPlan])
def list_plans(
collection: str = Query("training_plans"),
discipline: Optional[str] = Query(None),
group: Optional[str] = Query(None),
dojo: Optional[str] = Query(None)
):
if not qdrant.collection_exists(collection):
return []
pts, _ = qdrant.scroll(collection_name=collection, limit=10000)
result = []
for pt in pts:
pl = TrainingPlan(**pt.payload)
if discipline and pl.discipline != discipline:
continue
if group and pl.group != group:
continue
if dojo and pl.dojo != dojo:
continue
result.append(pl)
return result
# ---- Delete Endpoints ----
@router.delete("/delete-source", response_model=DeleteResponse)
def delete_by_source(
collection: str = Query(...),
source: Optional[str] = Query(None),
type: Optional[str] = Query(None)
):
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
filt = []
if source:
filt.append({"key": "source", "match": {"value": source}})
if type:
filt.append({"key": "type", "match": {"value": type}})
if not filt:
raise HTTPException(status_code=400, detail="Mindestens ein Filterparameter muss angegeben werden.")
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter={"must": filt}, limit=10000)
ids = [str(p.id) for p in pts]
if not ids:
return DeleteResponse(status="🔍 Keine Einträge gefunden.", count=0, collection=collection)
qdrant.delete(collection_name=collection, points_selector=PointIdsList(points=ids))
return DeleteResponse(status="🗑️ gelöscht", count=len(ids), collection=collection)
@router.delete("/delete-collection", response_model=DeleteResponse)
def delete_collection(
collection: str = Query(...)
):
if not qdrant.collection_exists(collection):
raise HTTPException(status_code=404, detail=f"Collection '{collection}' nicht gefunden.")
qdrant.delete_collection(collection_name=collection)
return DeleteResponse(status="🗑️ gelöscht", count=0, collection=collection)