Trainer_LLM/llm-api/plan_router.py
Lars 81473e20eb
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s
llm-api/plan_router.py aktualisiert
2025-08-12 12:46:48 +02:00

232 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
plan_router.py v0.10.0 (WP-15)
Minimal-CRUD für Plan-Templates & Pläne (POST/GET) + Idempotenz via Fingerprint.
Erweiterung ggü. v0.9.0: optionale Section-Felder ideal/supplement + materialisierte
Facettenfelder für Qdrant-Indizes; robustere Payload→Model-Konvertierung.
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from uuid import uuid4
from datetime import datetime, timezone
import hashlib
import json
import os
from clients import model, qdrant
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance
router = APIRouter(tags=["plans"])
# -----------------
# Konfiguration
# -----------------
PLAN_COLLECTION = os.getenv("PLAN_COLLECTION") or os.getenv("QDRANT_COLLECTION_PLANS", "plans")
PLAN_TEMPLATE_COLLECTION = os.getenv("PLAN_TEMPLATE_COLLECTION", "plan_templates")
PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions")
# -----------------
# Modelle
# -----------------
class TemplateSection(BaseModel):
name: str
target_minutes: int
must_keywords: List[str] = []
ideal_keywords: List[str] = [] # NEU
supplement_keywords: List[str] = [] # NEU
forbid_keywords: List[str] = []
capability_targets: Dict[str, int] = {}
class PlanTemplate(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
discipline: str
age_group: str
target_group: str
total_minutes: int
sections: List[TemplateSection] = []
goals: List[str] = []
equipment_allowed: List[str] = []
created_by: str
version: str
class PlanItem(BaseModel):
exercise_external_id: str
duration: int
why: str
class PlanSection(BaseModel):
name: str
items: List[PlanItem] = []
minutes: int
class Plan(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
template_id: Optional[str] = None
title: str
discipline: str
age_group: str
target_group: str
total_minutes: int
sections: List[PlanSection] = []
goals: List[str] = []
capability_summary: Dict[str, int] = {}
novelty_against_last_n: Optional[float] = None
fingerprint: Optional[str] = None
created_by: str
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
source: str = "API"
# -----------------
# Helpers
# -----------------
def _ensure_collection(name: str):
# Legt Collection an, wenn sie fehlt (analog exercise_router)
if not qdrant.collection_exists(name):
qdrant.recreate_collection(
collection_name=name,
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
)
def _norm_list(xs: List[str]) -> List[str]:
seen, out = set(), []
for x in xs or []:
s = str(x).strip()
k = s.casefold()
if s and k not in seen:
seen.add(k)
out.append(s)
return sorted(out, key=str.casefold)
def _template_embed_text(tpl: PlanTemplate) -> str:
parts = [tpl.name, tpl.discipline, tpl.age_group, tpl.target_group]
parts += tpl.goals
parts += [s.name for s in tpl.sections]
return ". ".join([p for p in parts if p])
def _plan_embed_text(p: Plan) -> str:
parts = [p.title, p.discipline, p.age_group, p.target_group]
parts += p.goals
parts += [s.name for s in p.sections]
return ". ".join([p for p in parts if p])
def _embed(text: str):
return model.encode(text or "").tolist()
def _fingerprint_for_plan(p: Plan) -> str:
"""sha256(title, total_minutes, sections.items.exercise_external_id, sections.items.duration)"""
core = {
"title": p.title,
"total_minutes": int(p.total_minutes),
"items": [
{
"exercise_external_id": it.exercise_external_id,
"duration": int(it.duration),
}
for sec in p.sections
for it in (sec.items or [])
],
}
raw = json.dumps(core, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
if not pts:
return None
payload = dict(pts[0].payload or {})
payload.setdefault("id", str(pts[0].id))
return payload
def _as_model(model_cls, payload: Dict[str, Any]):
"""Filtert unbekannte Felder heraus (Pydantic v1/v2 kompatibel)."""
fields = getattr(model_cls, "model_fields", None) or getattr(model_cls, "__fields__", {})
allowed = set(fields.keys())
data = {k: payload[k] for k in payload.keys() if k in allowed}
return model_cls(**data)
# -----------------
# Endpoints
# -----------------
@router.post("/plan_templates", response_model=PlanTemplate)
def create_plan_template(t: PlanTemplate):
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
payload = t.model_dump()
# Normalisierung
payload["goals"] = _norm_list(payload.get("goals"))
sections = payload.get("sections", []) or []
for s in sections:
s["must_keywords"] = _norm_list(s.get("must_keywords") or [])
s["ideal_keywords"] = _norm_list(s.get("ideal_keywords") or []) # NEU
s["supplement_keywords"] = _norm_list(s.get("supplement_keywords") or []) # NEU
s["forbid_keywords"] = _norm_list(s.get("forbid_keywords") or [])
# Materialisierte Facettenfelder (stabile KEYWORD-Indizes in Qdrant)
payload["section_names"] = _norm_list([s.get("name", "") for s in sections])
payload["section_must_keywords"] = _norm_list([kw for s in sections for kw in (s.get("must_keywords") or [])])
payload["section_ideal_keywords"] = _norm_list([kw for s in sections for kw in (s.get("ideal_keywords") or [])])
payload["section_supplement_keywords"] = _norm_list([kw for s in sections for kw in (s.get("supplement_keywords") or [])])
payload["section_forbid_keywords"] = _norm_list([kw for s in sections for kw in (s.get("forbid_keywords") or [])])
vec = _embed(_template_embed_text(t))
qdrant.upsert(collection_name=PLAN_TEMPLATE_COLLECTION, points=[PointStruct(id=str(t.id), vector=vec, payload=payload)])
return t
@router.get("/plan_templates/{tpl_id}", response_model=PlanTemplate)
def get_plan_template(tpl_id: str):
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
found = _get_by_field(PLAN_TEMPLATE_COLLECTION, "id", tpl_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
return _as_model(PlanTemplate, found)
@router.post("/plan", response_model=Plan)
def create_plan(p: Plan):
_ensure_collection(PLAN_COLLECTION)
# Fingerprint
fp = _fingerprint_for_plan(p)
p.fingerprint = p.fingerprint or fp
# Idempotenz
existing = _get_by_field(PLAN_COLLECTION, "fingerprint", p.fingerprint)
if existing:
return _as_model(Plan, existing)
# Normalisieren
p.goals = _norm_list(p.goals)
payload = p.model_dump()
# ISO8601 sicherstellen
if isinstance(payload.get("created_at"), datetime):
payload["created_at"] = payload["created_at"].astimezone(timezone.utc).isoformat()
vec = _embed(_plan_embed_text(p))
qdrant.upsert(collection_name=PLAN_COLLECTION, points=[PointStruct(id=str(p.id), vector=vec, payload=payload)])
return p
@router.get("/plan/{plan_id}", response_model=Plan)
def get_plan(plan_id: str):
_ensure_collection(PLAN_COLLECTION)
found = _get_by_field(PLAN_COLLECTION, "id", plan_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
# created_at zurück in datetime (ISO)
if isinstance(found.get("created_at"), str):
try:
found["created_at"] = datetime.fromisoformat(found["created_at"])
except Exception:
pass
return _as_model(Plan, found)