llm-api/plan_router.py hinzugefügt
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s

This commit is contained in:
Lars 2025-08-12 10:24:36 +02:00
parent 4552e33cb3
commit 5dbe887ce3

209
llm-api/plan_router.py Normal file
View File

@ -0,0 +1,209 @@
-*- coding: utf-8 -*-
"""
plan_router.py v0.9.0 (WP-15)
Minimal-CRUD für Plan-Templates & Pläne (POST/GET) + Idempotenz via Fingerprint.
Keine bestehenden API-Signaturen geändert. Qdrant-Client-Stil wie exercise_router.
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from uuid import uuid4
from datetime import datetime, timezone
import hashlib
import json
import os
from clients import model, qdrant
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance
router = APIRouter(tags=["plans"])
# -----------------
# Konfiguration
# -----------------
PLAN_COLLECTION = os.getenv("PLAN_COLLECTION") or os.getenv("QDRANT_COLLECTION_PLANS", "plans")
PLAN_TEMPLATE_COLLECTION = os.getenv("PLAN_TEMPLATE_COLLECTION", "plan_templates")
PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions")
# -----------------
# Modelle
# -----------------
class TemplateSection(BaseModel):
name: str
target_minutes: int
must_keywords: List[str] = []
forbid_keywords: List[str] = []
capability_targets: Dict[str, int] = {}
class PlanTemplate(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
discipline: str
age_group: str
target_group: str
total_minutes: int
sections: List[TemplateSection] = []
goals: List[str] = []
equipment_allowed: List[str] = []
created_by: str
version: str
class PlanItem(BaseModel):
exercise_external_id: str
duration: int
why: str
class PlanSection(BaseModel):
name: str
items: List[PlanItem] = []
minutes: int
class Plan(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
template_id: Optional[str] = None
title: str
discipline: str
age_group: str
target_group: str
total_minutes: int
sections: List[PlanSection] = []
goals: List[str] = []
capability_summary: Dict[str, int] = {}
novelty_against_last_n: Optional[float] = None
fingerprint: Optional[str] = None
created_by: str
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
source: str = "API"
# -----------------
# Helpers
# -----------------
def _ensure_collection(name: str):
# Legt Collection an, wenn sie fehlt (analog exercise_router)
if not qdrant.collection_exists(name):
qdrant.recreate_collection(
collection_name=name,
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
)
def _norm_list(xs: List[str]) -> List[str]:
seen, out = set(), []
for x in xs or []:
s = str(x).strip()
k = s.casefold()
if s and k not in seen:
seen.add(k)
out.append(s)
return sorted(out, key=str.casefold)
def _template_embed_text(tpl: PlanTemplate) -> str:
parts = [tpl.name, tpl.discipline, tpl.age_group, tpl.target_group]
parts += tpl.goals
parts += [s.name for s in tpl.sections]
return ". ".join([p for p in parts if p])
def _plan_embed_text(p: Plan) -> str:
parts = [p.title, p.discipline, p.age_group, p.target_group]
parts += p.goals
parts += [s.name for s in p.sections]
return ". ".join([p for p in parts if p])
def _embed(text: str):
return model.encode(text or "").tolist()
def _fingerprint_for_plan(p: Plan) -> str:
"""sha256(title, total_minutes, sections.items.exercise_external_id, sections.items.duration)"""
core = {
"title": p.title,
"total_minutes": int(p.total_minutes),
"items": [
{
"exercise_external_id": it.exercise_external_id,
"duration": int(it.duration),
}
for sec in p.sections
for it in (sec.items or [])
],
}
raw = json.dumps(core, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
if not pts:
return None
payload = dict(pts[0].payload or {})
payload.setdefault("id", str(pts[0].id))
return payload
# -----------------
# Endpoints
# -----------------
@router.post("/plan_templates", response_model=PlanTemplate)
def create_plan_template(t: PlanTemplate):
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
payload = t.model_dump()
payload["goals"] = _norm_list(payload.get("goals"))
for s in payload.get("sections", []) or []:
s["must_keywords"] = _norm_list(s.get("must_keywords") or [])
s["forbid_keywords"] = _norm_list(s.get("forbid_keywords") or [])
vec = _embed(_template_embed_text(t))
qdrant.upsert(collection_name=PLAN_TEMPLATE_COLLECTION, points=[PointStruct(id=str(t.id), vector=vec, payload=payload)])
return t
@router.get("/plan_templates/{tpl_id}", response_model=PlanTemplate)
def get_plan_template(tpl_id: str):
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
found = _get_by_field(PLAN_TEMPLATE_COLLECTION, "id", tpl_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
return PlanTemplate(**found)
@router.post("/plan", response_model=Plan)
def create_plan(p: Plan):
_ensure_collection(PLAN_COLLECTION)
# Fingerprint
fp = _fingerprint_for_plan(p)
p.fingerprint = p.fingerprint or fp
# Idempotenz
existing = _get_by_field(PLAN_COLLECTION, "fingerprint", p.fingerprint)
if existing:
return Plan(**existing)
# Normalisieren
p.goals = _norm_list(p.goals)
payload = p.model_dump()
# ISO8601 sicherstellen
if isinstance(payload.get("created_at"), datetime):
payload["created_at"] = payload["created_at"].astimezone(timezone.utc).isoformat()
vec = _embed(_plan_embed_text(p))
qdrant.upsert(collection_name=PLAN_COLLECTION, points=[PointStruct(id=str(p.id), vector=vec, payload=payload)])
return p
@router.get("/plan/{plan_id}", response_model=Plan)
def get_plan(plan_id: str):
_ensure_collection(PLAN_COLLECTION)
found = _get_by_field(PLAN_COLLECTION, "id", plan_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
# created_at zurück in datetime (ISO)
if isinstance(found.get("created_at"), str):
try:
found["created_at"] = datetime.fromisoformat(found["created_at"])
except Exception:
pass
return Plan(**found)