# -*- coding: utf-8 -*- """ plan_session_router.py – v0.1.0 (WP-15) CRUD-Minimum für Plan-Sessions (POST/GET). Kompatibel zum Qdrant-Client-Stil der bestehenden Router. """ from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field, conint from typing import List, Optional, Dict, Any from uuid import uuid4 from datetime import datetime, timezone import os from clients import model, qdrant from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance router = APIRouter(tags=["plan_sessions"]) # ----------------- # Konfiguration # ----------------- PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions") # ----------------- # Modelle # ----------------- class Feedback(BaseModel): rating: conint(ge=1, le=5) notes: str class PlanSession(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) plan_id: str executed_at: datetime location: str coach: str group_label: str feedback: Feedback used_equipment: List[str] = [] # ----------------- # Helpers # ----------------- def _ensure_collection(name: str): if not qdrant.collection_exists(name): qdrant.recreate_collection( collection_name=name, vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE), ) def _norm_list(xs: List[str]) -> List[str]: seen, out = set(), [] for x in xs or []: s = str(x).strip() k = s.casefold() if s and k not in seen: seen.add(k) out.append(s) return sorted(out, key=str.casefold) def _session_embed_text(s: PlanSession) -> str: parts = [s.plan_id, s.location, s.coach, s.group_label, s.feedback.notes] return ". ".join([p for p in parts if p]) def _embed(text: str): return model.encode(text or "").tolist() def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]: flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))]) pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True) if not pts: return None payload = dict(pts[0].payload or {}) payload.setdefault("id", str(pts[0].id)) return payload # ----------------- # Endpoints # ----------------- @router.post("/plan_sessions", response_model=PlanSession) def create_plan_session(s: PlanSession): _ensure_collection(PLAN_SESSION_COLLECTION) # Normalisieren s.used_equipment = _norm_list(s.used_equipment) payload = s.model_dump() # ISO8601 für executed_at sicherstellen if isinstance(payload.get("executed_at"), datetime): payload["executed_at"] = payload["executed_at"].astimezone(timezone.utc).isoformat() vec = _embed(_session_embed_text(s)) qdrant.upsert(collection_name=PLAN_SESSION_COLLECTION, points=[PointStruct(id=str(s.id), vector=vec, payload=payload)]) return s @router.get("/plan_sessions/{session_id}", response_model=PlanSession) def get_plan_session(session_id: str): _ensure_collection(PLAN_SESSION_COLLECTION) found = _get_by_field(PLAN_SESSION_COLLECTION, "id", session_id) if not found: raise HTTPException(status_code=404, detail="not found") if isinstance(found.get("executed_at"), str): try: found["executed_at"] = datetime.fromisoformat(found["executed_at"]) except Exception: pass return PlanSession(**found)