Trainer_LLM/llm-api/plan_session_router.py
Lars 40b1151023
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 1s
llm-api/plan_session_router.py hinzugefügt
2025-08-12 13:08:29 +02:00

112 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
plan_session_router.py v0.1.0 (WP-15)
CRUD-Minimum für Plan-Sessions (POST/GET).
Kompatibel zum Qdrant-Client-Stil der bestehenden Router.
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field, conint
from typing import List, Optional, Dict, Any
from uuid import uuid4
from datetime import datetime, timezone
import os
from clients import model, qdrant
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance
router = APIRouter(tags=["plan_sessions"])
# -----------------
# Konfiguration
# -----------------
PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions")
# -----------------
# Modelle
# -----------------
class Feedback(BaseModel):
rating: conint(ge=1, le=5)
notes: str
class PlanSession(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
plan_id: str
executed_at: datetime
location: str
coach: str
group_label: str
feedback: Feedback
used_equipment: List[str] = []
# -----------------
# Helpers
# -----------------
def _ensure_collection(name: str):
if not qdrant.collection_exists(name):
qdrant.recreate_collection(
collection_name=name,
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
)
def _norm_list(xs: List[str]) -> List[str]:
seen, out = set(), []
for x in xs or []:
s = str(x).strip()
k = s.casefold()
if s and k not in seen:
seen.add(k)
out.append(s)
return sorted(out, key=str.casefold)
def _session_embed_text(s: PlanSession) -> str:
parts = [s.plan_id, s.location, s.coach, s.group_label, s.feedback.notes]
return ". ".join([p for p in parts if p])
def _embed(text: str):
return model.encode(text or "").tolist()
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
if not pts:
return None
payload = dict(pts[0].payload or {})
payload.setdefault("id", str(pts[0].id))
return payload
# -----------------
# Endpoints
# -----------------
@router.post("/plan_sessions", response_model=PlanSession)
def create_plan_session(s: PlanSession):
_ensure_collection(PLAN_SESSION_COLLECTION)
# Normalisieren
s.used_equipment = _norm_list(s.used_equipment)
payload = s.model_dump()
# ISO8601 für executed_at sicherstellen
if isinstance(payload.get("executed_at"), datetime):
payload["executed_at"] = payload["executed_at"].astimezone(timezone.utc).isoformat()
vec = _embed(_session_embed_text(s))
qdrant.upsert(collection_name=PLAN_SESSION_COLLECTION, points=[PointStruct(id=str(s.id), vector=vec, payload=payload)])
return s
@router.get("/plan_sessions/{session_id}", response_model=PlanSession)
def get_plan_session(session_id: str):
_ensure_collection(PLAN_SESSION_COLLECTION)
found = _get_by_field(PLAN_SESSION_COLLECTION, "id", session_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
if isinstance(found.get("executed_at"), str):
try:
found["executed_at"] = datetime.fromisoformat(found["executed_at"])
except Exception:
pass
return PlanSession(**found)