llm-api/plan_session_router.py hinzugefügt
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 1s

This commit is contained in:
Lars 2025-08-12 13:08:29 +02:00
parent 5c51d3bc4f
commit 40b1151023

View File

@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
"""
plan_session_router.py v0.1.0 (WP-15)
CRUD-Minimum für Plan-Sessions (POST/GET).
Kompatibel zum Qdrant-Client-Stil der bestehenden Router.
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field, conint
from typing import List, Optional, Dict, Any
from uuid import uuid4
from datetime import datetime, timezone
import os
from clients import model, qdrant
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance
router = APIRouter(tags=["plan_sessions"])
# -----------------
# Konfiguration
# -----------------
PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions")
# -----------------
# Modelle
# -----------------
class Feedback(BaseModel):
rating: conint(ge=1, le=5)
notes: str
class PlanSession(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
plan_id: str
executed_at: datetime
location: str
coach: str
group_label: str
feedback: Feedback
used_equipment: List[str] = []
# -----------------
# Helpers
# -----------------
def _ensure_collection(name: str):
if not qdrant.collection_exists(name):
qdrant.recreate_collection(
collection_name=name,
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
)
def _norm_list(xs: List[str]) -> List[str]:
seen, out = set(), []
for x in xs or []:
s = str(x).strip()
k = s.casefold()
if s and k not in seen:
seen.add(k)
out.append(s)
return sorted(out, key=str.casefold)
def _session_embed_text(s: PlanSession) -> str:
parts = [s.plan_id, s.location, s.coach, s.group_label, s.feedback.notes]
return ". ".join([p for p in parts if p])
def _embed(text: str):
return model.encode(text or "").tolist()
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
if not pts:
return None
payload = dict(pts[0].payload or {})
payload.setdefault("id", str(pts[0].id))
return payload
# -----------------
# Endpoints
# -----------------
@router.post("/plan_sessions", response_model=PlanSession)
def create_plan_session(s: PlanSession):
_ensure_collection(PLAN_SESSION_COLLECTION)
# Normalisieren
s.used_equipment = _norm_list(s.used_equipment)
payload = s.model_dump()
# ISO8601 für executed_at sicherstellen
if isinstance(payload.get("executed_at"), datetime):
payload["executed_at"] = payload["executed_at"].astimezone(timezone.utc).isoformat()
vec = _embed(_session_embed_text(s))
qdrant.upsert(collection_name=PLAN_SESSION_COLLECTION, points=[PointStruct(id=str(s.id), vector=vec, payload=payload)])
return s
@router.get("/plan_sessions/{session_id}", response_model=PlanSession)
def get_plan_session(session_id: str):
_ensure_collection(PLAN_SESSION_COLLECTION)
found = _get_by_field(PLAN_SESSION_COLLECTION, "id", session_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
if isinstance(found.get("executed_at"), str):
try:
found["executed_at"] = datetime.fromisoformat(found["executed_at"])
except Exception:
pass
return PlanSession(**found)