From 40b115102373d9fbbf775ca91ca45972d55b279a Mon Sep 17 00:00:00 2001 From: Lars Date: Tue, 12 Aug 2025 13:08:29 +0200 Subject: [PATCH] =?UTF-8?q?llm-api/plan=5Fsession=5Frouter.py=20hinzugef?= =?UTF-8?q?=C3=BCgt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm-api/plan_session_router.py | 112 +++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 llm-api/plan_session_router.py diff --git a/llm-api/plan_session_router.py b/llm-api/plan_session_router.py new file mode 100644 index 0000000..ddcfee6 --- /dev/null +++ b/llm-api/plan_session_router.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +""" +plan_session_router.py – v0.1.0 (WP-15) + +CRUD-Minimum für Plan-Sessions (POST/GET). +Kompatibel zum Qdrant-Client-Stil der bestehenden Router. +""" +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field, conint +from typing import List, Optional, Dict, Any +from uuid import uuid4 +from datetime import datetime, timezone +import os + +from clients import model, qdrant +from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance + +router = APIRouter(tags=["plan_sessions"]) + +# ----------------- +# Konfiguration +# ----------------- +PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions") + +# ----------------- +# Modelle +# ----------------- +class Feedback(BaseModel): + rating: conint(ge=1, le=5) + notes: str + +class PlanSession(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + plan_id: str + executed_at: datetime + location: str + coach: str + group_label: str + feedback: Feedback + used_equipment: List[str] = [] + +# ----------------- +# Helpers +# ----------------- + +def _ensure_collection(name: str): + if not qdrant.collection_exists(name): + qdrant.recreate_collection( + collection_name=name, + vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE), + ) + + +def _norm_list(xs: List[str]) -> List[str]: + seen, out = set(), [] + for x in xs or []: + s = str(x).strip() + k = s.casefold() + if s and k not in seen: + seen.add(k) + out.append(s) + return sorted(out, key=str.casefold) + + +def _session_embed_text(s: PlanSession) -> str: + parts = [s.plan_id, s.location, s.coach, s.group_label, s.feedback.notes] + return ". ".join([p for p in parts if p]) + + +def _embed(text: str): + return model.encode(text or "").tolist() + + +def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]: + flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))]) + pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True) + if not pts: + return None + payload = dict(pts[0].payload or {}) + payload.setdefault("id", str(pts[0].id)) + return payload + +# ----------------- +# Endpoints +# ----------------- +@router.post("/plan_sessions", response_model=PlanSession) +def create_plan_session(s: PlanSession): + _ensure_collection(PLAN_SESSION_COLLECTION) + # Normalisieren + s.used_equipment = _norm_list(s.used_equipment) + payload = s.model_dump() + # ISO8601 für executed_at sicherstellen + if isinstance(payload.get("executed_at"), datetime): + payload["executed_at"] = payload["executed_at"].astimezone(timezone.utc).isoformat() + + vec = _embed(_session_embed_text(s)) + qdrant.upsert(collection_name=PLAN_SESSION_COLLECTION, points=[PointStruct(id=str(s.id), vector=vec, payload=payload)]) + return s + + +@router.get("/plan_sessions/{session_id}", response_model=PlanSession) +def get_plan_session(session_id: str): + _ensure_collection(PLAN_SESSION_COLLECTION) + found = _get_by_field(PLAN_SESSION_COLLECTION, "id", session_id) + if not found: + raise HTTPException(status_code=404, detail="not found") + if isinstance(found.get("executed_at"), str): + try: + found["executed_at"] = datetime.fromisoformat(found["executed_at"]) + except Exception: + pass + return PlanSession(**found) \ No newline at end of file