From 5dbe887ce37c6838d05a6aec7fd154e88132128e Mon Sep 17 00:00:00 2001 From: Lars Date: Tue, 12 Aug 2025 10:24:36 +0200 Subject: [PATCH] =?UTF-8?q?llm-api/plan=5Frouter.py=20hinzugef=C3=BCgt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm-api/plan_router.py | 209 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 llm-api/plan_router.py diff --git a/llm-api/plan_router.py b/llm-api/plan_router.py new file mode 100644 index 0000000..0ad43c6 --- /dev/null +++ b/llm-api/plan_router.py @@ -0,0 +1,209 @@ + -*- coding: utf-8 -*- +""" +plan_router.py – v0.9.0 (WP-15) + +Minimal-CRUD für Plan-Templates & Pläne (POST/GET) + Idempotenz via Fingerprint. +Keine bestehenden API-Signaturen geändert. Qdrant-Client-Stil wie exercise_router. +""" +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any +from uuid import uuid4 +from datetime import datetime, timezone +import hashlib +import json +import os + +from clients import model, qdrant +from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance + +router = APIRouter(tags=["plans"]) + +# ----------------- +# Konfiguration +# ----------------- +PLAN_COLLECTION = os.getenv("PLAN_COLLECTION") or os.getenv("QDRANT_COLLECTION_PLANS", "plans") +PLAN_TEMPLATE_COLLECTION = os.getenv("PLAN_TEMPLATE_COLLECTION", "plan_templates") +PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions") + +# ----------------- +# Modelle +# ----------------- +class TemplateSection(BaseModel): + name: str + target_minutes: int + must_keywords: List[str] = [] + forbid_keywords: List[str] = [] + capability_targets: Dict[str, int] = {} + +class PlanTemplate(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + name: str + discipline: str + age_group: str + target_group: str + total_minutes: int + sections: List[TemplateSection] = [] + goals: List[str] = [] + equipment_allowed: List[str] = [] + created_by: str + version: str + +class PlanItem(BaseModel): + exercise_external_id: str + duration: int + why: str + +class PlanSection(BaseModel): + name: str + items: List[PlanItem] = [] + minutes: int + +class Plan(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + template_id: Optional[str] = None + title: str + discipline: str + age_group: str + target_group: str + total_minutes: int + sections: List[PlanSection] = [] + goals: List[str] = [] + capability_summary: Dict[str, int] = {} + novelty_against_last_n: Optional[float] = None + fingerprint: Optional[str] = None + created_by: str + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + source: str = "API" + +# ----------------- +# Helpers +# ----------------- + +def _ensure_collection(name: str): + # Legt Collection an, wenn sie fehlt (analog exercise_router) + if not qdrant.collection_exists(name): + qdrant.recreate_collection( + collection_name=name, + vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE), + ) + + +def _norm_list(xs: List[str]) -> List[str]: + seen, out = set(), [] + for x in xs or []: + s = str(x).strip() + k = s.casefold() + if s and k not in seen: + seen.add(k) + out.append(s) + return sorted(out, key=str.casefold) + + +def _template_embed_text(tpl: PlanTemplate) -> str: + parts = [tpl.name, tpl.discipline, tpl.age_group, tpl.target_group] + parts += tpl.goals + parts += [s.name for s in tpl.sections] + return ". ".join([p for p in parts if p]) + + +def _plan_embed_text(p: Plan) -> str: + parts = [p.title, p.discipline, p.age_group, p.target_group] + parts += p.goals + parts += [s.name for s in p.sections] + return ". ".join([p for p in parts if p]) + + +def _embed(text: str): + return model.encode(text or "").tolist() + + +def _fingerprint_for_plan(p: Plan) -> str: + """sha256(title, total_minutes, sections.items.exercise_external_id, sections.items.duration)""" + core = { + "title": p.title, + "total_minutes": int(p.total_minutes), + "items": [ + { + "exercise_external_id": it.exercise_external_id, + "duration": int(it.duration), + } + for sec in p.sections + for it in (sec.items or []) + ], + } + raw = json.dumps(core, sort_keys=True, ensure_ascii=False) + return hashlib.sha256(raw.encode("utf-8")).hexdigest() + + +def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]: + flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))]) + pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True) + if not pts: + return None + payload = dict(pts[0].payload or {}) + payload.setdefault("id", str(pts[0].id)) + return payload + +# ----------------- +# Endpoints +# ----------------- +@router.post("/plan_templates", response_model=PlanTemplate) +def create_plan_template(t: PlanTemplate): + _ensure_collection(PLAN_TEMPLATE_COLLECTION) + payload = t.model_dump() + payload["goals"] = _norm_list(payload.get("goals")) + for s in payload.get("sections", []) or []: + s["must_keywords"] = _norm_list(s.get("must_keywords") or []) + s["forbid_keywords"] = _norm_list(s.get("forbid_keywords") or []) + vec = _embed(_template_embed_text(t)) + qdrant.upsert(collection_name=PLAN_TEMPLATE_COLLECTION, points=[PointStruct(id=str(t.id), vector=vec, payload=payload)]) + return t + + +@router.get("/plan_templates/{tpl_id}", response_model=PlanTemplate) +def get_plan_template(tpl_id: str): + _ensure_collection(PLAN_TEMPLATE_COLLECTION) + found = _get_by_field(PLAN_TEMPLATE_COLLECTION, "id", tpl_id) + if not found: + raise HTTPException(status_code=404, detail="not found") + return PlanTemplate(**found) + + +@router.post("/plan", response_model=Plan) +def create_plan(p: Plan): + _ensure_collection(PLAN_COLLECTION) + # Fingerprint + fp = _fingerprint_for_plan(p) + p.fingerprint = p.fingerprint or fp + + # Idempotenz + existing = _get_by_field(PLAN_COLLECTION, "fingerprint", p.fingerprint) + if existing: + return Plan(**existing) + + # Normalisieren + p.goals = _norm_list(p.goals) + payload = p.model_dump() + # ISO8601 sicherstellen + if isinstance(payload.get("created_at"), datetime): + payload["created_at"] = payload["created_at"].astimezone(timezone.utc).isoformat() + + vec = _embed(_plan_embed_text(p)) + qdrant.upsert(collection_name=PLAN_COLLECTION, points=[PointStruct(id=str(p.id), vector=vec, payload=payload)]) + return p + + +@router.get("/plan/{plan_id}", response_model=Plan) +def get_plan(plan_id: str): + _ensure_collection(PLAN_COLLECTION) + found = _get_by_field(PLAN_COLLECTION, "id", plan_id) + if not found: + raise HTTPException(status_code=404, detail="not found") + # created_at zurück in datetime (ISO) + if isinstance(found.get("created_at"), str): + try: + found["created_at"] = datetime.fromisoformat(found["created_at"]) + except Exception: + pass + return Plan(**found) \ No newline at end of file