Trainer_LLM/llm-api/plan_router.py
Lars 36c82ac942
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 1s
llm-api/plan_router.py aktualisiert
2025-08-12 16:25:32 +02:00

263 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
plan_router.py v0.11.0 (WP-15)
Minimal-CRUD für Plan-Templates & Pläne (POST/GET) + Idempotenz via Fingerprint.
Erweiterungen:
- optionale Section-Felder ideal/supplement + materialisierte Facettenfelder
- Referenz-Validierung: template_id (pflicht, wenn gesetzt)
- Optionaler Strict-Mode: PLAN_STRICT_EXERCISES prüft exercise_external_id
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from uuid import uuid4
from datetime import datetime, timezone
import hashlib
import json
import os
from clients import model, qdrant
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance
router = APIRouter(tags=["plans"])
# -----------------
# Konfiguration
# -----------------
PLAN_COLLECTION = os.getenv("PLAN_COLLECTION") or os.getenv("QDRANT_COLLECTION_PLANS", "plans")
PLAN_TEMPLATE_COLLECTION = os.getenv("PLAN_TEMPLATE_COLLECTION", "plan_templates")
PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions")
EXERCISE_COLLECTION = os.getenv("EXERCISE_COLLECTION", "exercises")
# -----------------
# Modelle
# -----------------
class TemplateSection(BaseModel):
name: str
target_minutes: int
must_keywords: List[str] = []
ideal_keywords: List[str] = [] # optional
supplement_keywords: List[str] = [] # optional
forbid_keywords: List[str] = []
capability_targets: Dict[str, int] = {}
class PlanTemplate(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
discipline: str
age_group: str
target_group: str
total_minutes: int
sections: List[TemplateSection] = []
goals: List[str] = []
equipment_allowed: List[str] = []
created_by: str
version: str
class PlanItem(BaseModel):
exercise_external_id: str
duration: int
why: str
class PlanSection(BaseModel):
name: str
items: List[PlanItem] = []
minutes: int
class Plan(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
template_id: Optional[str] = None
title: str
discipline: str
age_group: str
target_group: str
total_minutes: int
sections: List[PlanSection] = []
goals: List[str] = []
capability_summary: Dict[str, int] = {}
novelty_against_last_n: Optional[float] = None
fingerprint: Optional[str] = None
created_by: str
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
source: str = "API"
# -----------------
# Helpers
# -----------------
def _ensure_collection(name: str):
# Legt Collection an, wenn sie fehlt (analog exercise_router)
if not qdrant.collection_exists(name):
qdrant.recreate_collection(
collection_name=name,
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
)
def _norm_list(xs: List[str]) -> List[str]:
seen, out = set(), []
for x in xs or []:
s = str(x).strip()
k = s.casefold()
if s and k not in seen:
seen.add(k)
out.append(s)
return sorted(out, key=str.casefold)
def _template_embed_text(tpl: PlanTemplate) -> str:
parts = [tpl.name, tpl.discipline, tpl.age_group, tpl.target_group]
parts += tpl.goals
parts += [s.name for s in tpl.sections]
return ". ".join([p for p in parts if p])
def _plan_embed_text(p: Plan) -> str:
parts = [p.title, p.discipline, p.age_group, p.target_group]
parts += p.goals
parts += [s.name for s in p.sections]
return ". ".join([p for p in parts if p])
def _embed(text: str):
return model.encode(text or "").tolist()
def _fingerprint_for_plan(p: Plan) -> str:
"""sha256(title, total_minutes, sections.items.exercise_external_id, sections.items.duration)"""
core = {
"title": p.title,
"total_minutes": int(p.total_minutes),
"items": [
{
"exercise_external_id": it.exercise_external_id,
"duration": int(it.duration),
}
for sec in p.sections
for it in (sec.items or [])
],
}
raw = json.dumps(core, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
if not pts:
return None
payload = dict(pts[0].payload or {})
payload.setdefault("id", str(pts[0].id))
return payload
def _as_model(model_cls, payload: Dict[str, Any]):
"""Filtert unbekannte Felder heraus (Pydantic v1/v2 kompatibel)."""
fields = getattr(model_cls, "model_fields", None) or getattr(model_cls, "__fields__", {})
allowed = set(fields.keys())
data = {k: payload[k] for k in payload.keys() if k in allowed}
return model_cls(**data)
def _truthy(val: Optional[str]) -> bool:
return str(val or "").strip().lower() in {"1", "true", "yes", "on"}
def _exists_in_collection(collection: str, key: str, value: Any) -> bool:
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=False)
return bool(pts)
# -----------------
# Endpoints
# -----------------
@router.post("/plan_templates", response_model=PlanTemplate)
def create_plan_template(t: PlanTemplate):
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
payload = t.model_dump()
# Normalisierung
payload["goals"] = _norm_list(payload.get("goals"))
sections = payload.get("sections", []) or []
for s in sections:
s["must_keywords"] = _norm_list(s.get("must_keywords") or [])
s["ideal_keywords"] = _norm_list(s.get("ideal_keywords") or [])
s["supplement_keywords"] = _norm_list(s.get("supplement_keywords") or [])
s["forbid_keywords"] = _norm_list(s.get("forbid_keywords") or [])
# Materialisierte Facettenfelder (stabile KEYWORD-Indizes in Qdrant)
payload["section_names"] = _norm_list([s.get("name", "") for s in sections])
payload["section_must_keywords"] = _norm_list([kw for s in sections for kw in (s.get("must_keywords") or [])])
payload["section_ideal_keywords"] = _norm_list([kw for s in sections for kw in (s.get("ideal_keywords") or [])])
payload["section_supplement_keywords"] = _norm_list([kw for s in sections for kw in (s.get("supplement_keywords") or [])])
payload["section_forbid_keywords"] = _norm_list([kw for s in sections for kw in (s.get("forbid_keywords") or [])])
vec = _embed(_template_embed_text(t))
qdrant.upsert(collection_name=PLAN_TEMPLATE_COLLECTION, points=[PointStruct(id=str(t.id), vector=vec, payload=payload)])
return t
@router.get("/plan_templates/{tpl_id}", response_model=PlanTemplate)
def get_plan_template(tpl_id: str):
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
found = _get_by_field(PLAN_TEMPLATE_COLLECTION, "id", tpl_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
return _as_model(PlanTemplate, found)
@router.post("/plan", response_model=Plan)
def create_plan(p: Plan):
_ensure_collection(PLAN_COLLECTION)
# 1) Template-Referenz prüfen (falls gesetzt)
if p.template_id:
if not _exists_in_collection(PLAN_TEMPLATE_COLLECTION, "id", p.template_id):
raise HTTPException(status_code=422, detail=f"Unknown template_id: {p.template_id}")
# 2) Optional: Strict-Mode Exercises prüfen
if _truthy(os.getenv("PLAN_STRICT_EXERCISES")):
missing: List[str] = []
for sec in p.sections or []:
for it in sec.items or []:
exid = (it.exercise_external_id or "").strip()
if exid and not _exists_in_collection(EXERCISE_COLLECTION, "external_id", exid):
missing.append(exid)
if missing:
raise HTTPException(status_code=422, detail={
"error": "unknown exercise_external_id",
"missing": sorted(set(missing))
})
# 3) Fingerprint
fp = _fingerprint_for_plan(p)
p.fingerprint = p.fingerprint or fp
# 4) Idempotenz
existing = _get_by_field(PLAN_COLLECTION, "fingerprint", p.fingerprint)
if existing:
return _as_model(Plan, existing)
# 5) Normalisieren & upsert
p.goals = _norm_list(p.goals)
payload = p.model_dump()
if isinstance(payload.get("created_at"), datetime):
payload["created_at"] = payload["created_at"].astimezone(timezone.utc).isoformat()
vec = _embed(_plan_embed_text(p))
qdrant.upsert(collection_name=PLAN_COLLECTION, points=[PointStruct(id=str(p.id), vector=vec, payload=payload)])
return p
@router.get("/plan/{plan_id}", response_model=Plan)
def get_plan(plan_id: str):
_ensure_collection(PLAN_COLLECTION)
found = _get_by_field(PLAN_COLLECTION, "id", plan_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
if isinstance(found.get("created_at"), str):
try:
found["created_at"] = datetime.fromisoformat(found["created_at"])
except Exception:
pass
return _as_model(Plan, found)