All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s
479 lines
19 KiB
Python
479 lines
19 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
plan_router.py – v0.13.1 (WP-15)
|
||
|
||
Minimal-CRUD + List/Filter für Templates & Pläne.
|
||
|
||
Änderungen ggü. v0.13.0
|
||
- Serverseitiger Zeitfenster-Filter über `created_at_ts` (FLOAT) bleibt erhalten.
|
||
- Lokaler Fallback-Zeitfilter wird DEAKTIVIERT, sobald ein serverseitiger Range aktiv ist
|
||
(verhindert false negatives).
|
||
- `plan_section_names` wird beim POST /plan materialisiert und für Filters genutzt.
|
||
"""
|
||
from fastapi import APIRouter, HTTPException, Query
|
||
from pydantic import BaseModel, Field
|
||
from typing import List, Optional, Dict, Any
|
||
from uuid import uuid4
|
||
from datetime import datetime, timezone
|
||
import hashlib
|
||
import json
|
||
import os
|
||
|
||
from clients import model, qdrant
|
||
from qdrant_client.models import (
|
||
PointStruct, Filter, FieldCondition, MatchValue,
|
||
VectorParams, Distance, Range
|
||
)
|
||
|
||
router = APIRouter(tags=["plans"])
|
||
|
||
# -----------------
|
||
# Konfiguration
|
||
# -----------------
|
||
PLAN_COLLECTION = os.getenv("PLAN_COLLECTION") or os.getenv("QDRANT_COLLECTION_PLANS", "plans")
|
||
PLAN_TEMPLATE_COLLECTION = os.getenv("PLAN_TEMPLATE_COLLECTION", "plan_templates")
|
||
PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions")
|
||
EXERCISE_COLLECTION = os.getenv("EXERCISE_COLLECTION", "exercises")
|
||
|
||
# -----------------
|
||
# Modelle
|
||
# -----------------
|
||
class TemplateSection(BaseModel):
|
||
name: str
|
||
target_minutes: int
|
||
must_keywords: List[str] = []
|
||
ideal_keywords: List[str] = [] # wünschenswert
|
||
supplement_keywords: List[str] = [] # ergänzend
|
||
forbid_keywords: List[str] = []
|
||
capability_targets: Dict[str, int] = {}
|
||
|
||
class PlanTemplate(BaseModel):
|
||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||
name: str
|
||
discipline: str
|
||
age_group: str
|
||
target_group: str
|
||
total_minutes: int
|
||
sections: List[TemplateSection] = []
|
||
goals: List[str] = []
|
||
equipment_allowed: List[str] = []
|
||
created_by: str
|
||
version: str
|
||
|
||
class PlanItem(BaseModel):
|
||
exercise_external_id: str
|
||
duration: int
|
||
why: str
|
||
|
||
class PlanSection(BaseModel):
|
||
name: str
|
||
items: List[PlanItem] = []
|
||
minutes: int
|
||
|
||
class Plan(BaseModel):
|
||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||
template_id: Optional[str] = None
|
||
title: str
|
||
discipline: str
|
||
age_group: str
|
||
target_group: str
|
||
total_minutes: int
|
||
sections: List[PlanSection] = []
|
||
goals: List[str] = []
|
||
capability_summary: Dict[str, int] = {}
|
||
novelty_against_last_n: Optional[float] = None
|
||
fingerprint: Optional[str] = None
|
||
created_by: str
|
||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||
source: str = "API"
|
||
|
||
class PlanTemplateList(BaseModel):
|
||
items: List[PlanTemplate]
|
||
limit: int
|
||
offset: int
|
||
count: int
|
||
|
||
class PlanList(BaseModel):
|
||
items: List[Plan]
|
||
limit: int
|
||
offset: int
|
||
count: int
|
||
|
||
# -----------------
|
||
# Helpers
|
||
# -----------------
|
||
|
||
def _ensure_collection(name: str):
|
||
"""Falls Collection fehlt, analog exercise_router anlegen."""
|
||
if not qdrant.collection_exists(name):
|
||
qdrant.recreate_collection(
|
||
collection_name=name,
|
||
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
|
||
)
|
||
|
||
def _norm_list(xs: List[str]) -> List[str]:
|
||
"""Trimmen, casefolded deduplizieren, stabil sortieren."""
|
||
seen, out = set(), []
|
||
for x in xs or []:
|
||
s = str(x).strip()
|
||
k = s.casefold()
|
||
if s and k not in seen:
|
||
seen.add(k)
|
||
out.append(s)
|
||
return sorted(out, key=str.casefold)
|
||
|
||
def _template_embed_text(tpl: PlanTemplate) -> str:
|
||
parts = [tpl.name, tpl.discipline, tpl.age_group, tpl.target_group]
|
||
parts += tpl.goals
|
||
parts += [s.name for s in tpl.sections]
|
||
return ". ".join([p for p in parts if p])
|
||
|
||
def _plan_embed_text(p: Plan) -> str:
|
||
parts = [p.title, p.discipline, p.age_group, p.target_group]
|
||
parts += p.goals
|
||
parts += [s.name for s in p.sections]
|
||
return ". ".join([p for p in parts if p])
|
||
|
||
def _embed(text: str):
|
||
return model.encode(text or "").tolist()
|
||
|
||
def _fingerprint_for_plan(p: Plan) -> str:
|
||
"""sha256(title, total_minutes, sections.items.exercise_external_id, sections.items.duration)"""
|
||
core = {
|
||
"title": p.title,
|
||
"total_minutes": int(p.total_minutes),
|
||
"items": [
|
||
{"exercise_external_id": it.exercise_external_id, "duration": int(it.duration)}
|
||
for sec in p.sections
|
||
for it in (sec.items or [])
|
||
],
|
||
}
|
||
raw = json.dumps(core, sort_keys=True, ensure_ascii=False)
|
||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||
|
||
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
|
||
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
|
||
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
|
||
if not pts:
|
||
return None
|
||
payload = dict(pts[0].payload or {})
|
||
payload.setdefault("id", str(pts[0].id))
|
||
return payload
|
||
|
||
def _as_model(model_cls, payload: Dict[str, Any]):
|
||
"""Unbekannte Payload-Felder herausfiltern (Pydantic v1/v2 kompatibel)."""
|
||
fields = getattr(model_cls, "model_fields", None) or getattr(model_cls, "__fields__", {})
|
||
allowed = set(fields.keys())
|
||
data = {k: payload[k] for k in payload.keys() if k in allowed}
|
||
return model_cls(**data)
|
||
|
||
def _truthy(val: Optional[str]) -> bool:
|
||
return str(val or "").strip().lower() in {"1", "true", "yes", "on"}
|
||
|
||
def _exists_in_collection(collection: str, key: str, value: Any) -> bool:
|
||
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
|
||
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=False)
|
||
return bool(pts)
|
||
|
||
# -----------------
|
||
# Endpoints: Templates
|
||
# -----------------
|
||
@router.post(
|
||
"/plan_templates",
|
||
response_model=PlanTemplate,
|
||
summary="Create a plan template",
|
||
description=(
|
||
"Erstellt ein Plan-Template (Strukturplanung).\n\n"
|
||
"• Mehrere Sections erlaubt.\n"
|
||
"• Section-Felder: must/ideal/supplement/forbid keywords + capability_targets.\n"
|
||
"• Materialisierte Facettenfelder (section_*) werden intern geschrieben, um Qdrant-Filter zu beschleunigen."
|
||
),
|
||
)
|
||
def create_plan_template(t: PlanTemplate):
|
||
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
|
||
payload = t.model_dump()
|
||
payload["goals"] = _norm_list(payload.get("goals"))
|
||
sections = payload.get("sections", []) or []
|
||
for s in sections:
|
||
s["must_keywords"] = _norm_list(s.get("must_keywords") or [])
|
||
s["ideal_keywords"] = _norm_list(s.get("ideal_keywords") or [])
|
||
s["supplement_keywords"] = _norm_list(s.get("supplement_keywords") or [])
|
||
s["forbid_keywords"] = _norm_list(s.get("forbid_keywords") or [])
|
||
|
||
# Materialisierte Facetten (KEYWORD-Indizes)
|
||
payload["section_names"] = _norm_list([s.get("name", "") for s in sections])
|
||
payload["section_must_keywords"] = _norm_list([kw for s in sections for kw in (s.get("must_keywords") or [])])
|
||
payload["section_ideal_keywords"] = _norm_list([kw for s in sections for kw in (s.get("ideal_keywords") or [])])
|
||
payload["section_supplement_keywords"] = _norm_list([kw for s in sections for kw in (s.get("supplement_keywords") or [])])
|
||
payload["section_forbid_keywords"] = _norm_list([kw for s in sections for kw in (s.get("forbid_keywords") or [])])
|
||
|
||
vec = _embed(_template_embed_text(t))
|
||
qdrant.upsert(collection_name=PLAN_TEMPLATE_COLLECTION, points=[PointStruct(id=str(t.id), vector=vec, payload=payload)])
|
||
return t
|
||
|
||
@router.get(
|
||
"/plan_templates/{tpl_id}",
|
||
response_model=PlanTemplate,
|
||
summary="Read a plan template by id",
|
||
description="Liest ein Template anhand seiner ID und gibt nur die Schemafelder zurück (zusätzliche Payload wird herausgefiltert).",
|
||
)
|
||
def get_plan_template(tpl_id: str):
|
||
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
|
||
found = _get_by_field(PLAN_TEMPLATE_COLLECTION, "id", tpl_id)
|
||
if not found:
|
||
raise HTTPException(status_code=404, detail="not found")
|
||
return _as_model(PlanTemplate, found)
|
||
|
||
@router.get(
|
||
"/plan_templates",
|
||
response_model=PlanTemplateList,
|
||
summary="List plan templates (filterable)",
|
||
description=(
|
||
"Listet Plan-Templates mit Filtern.\n\n"
|
||
"**Filter** (exakte Matches, KEYWORD-Felder):\n"
|
||
"- discipline, age_group, target_group\n"
|
||
"- section: Section-Name (nutzt materialisierte `section_names`)\n"
|
||
"- goal: Ziel (nutzt `goals`)\n"
|
||
"- keyword: trifft auf beliebige Section-Keyword-Felder (must/ideal/supplement/forbid).\n\n"
|
||
"**Pagination:** limit/offset. Feld `count` entspricht der Anzahl zurückgegebener Items (keine Gesamtsumme)."
|
||
),
|
||
)
|
||
def list_plan_templates(
|
||
discipline: Optional[str] = Query(None, description="Filter: Disziplin (exaktes KEYWORD-Match)", example="Karate"),
|
||
age_group: Optional[str] = Query(None, description="Filter: Altersgruppe", example="Teenager"),
|
||
target_group: Optional[str] = Query(None, description="Filter: Zielgruppe", example="Breitensport"),
|
||
section: Optional[str] = Query(None, description="Filter: Section-Name (materialisiert)", example="Warmup"),
|
||
goal: Optional[str] = Query(None, description="Filter: Trainingsziel", example="Technik"),
|
||
keyword: Optional[str] = Query(None, description="Filter: Keyword in must/ideal/supplement/forbid", example="Koordination"),
|
||
limit: int = Query(20, ge=1, le=200, description="Max. Anzahl Items"),
|
||
offset: int = Query(0, ge=0, description="Start-Offset für Paging"),
|
||
):
|
||
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
|
||
must: List[Any] = []
|
||
should: List[Any] = []
|
||
if discipline:
|
||
must.append(FieldCondition(key="discipline", match=MatchValue(value=discipline)))
|
||
if age_group:
|
||
must.append(FieldCondition(key="age_group", match=MatchValue(value=age_group)))
|
||
if target_group:
|
||
must.append(FieldCondition(key="target_group", match=MatchValue(value=target_group)))
|
||
if section:
|
||
must.append(FieldCondition(key="section_names", match=MatchValue(value=section)))
|
||
if goal:
|
||
must.append(FieldCondition(key="goals", match=MatchValue(value=goal)))
|
||
if keyword:
|
||
for k in (
|
||
"section_must_keywords",
|
||
"section_ideal_keywords",
|
||
"section_supplement_keywords",
|
||
"section_forbid_keywords",
|
||
):
|
||
should.append(FieldCondition(key=k, match=MatchValue(value=keyword)))
|
||
|
||
flt = None
|
||
if must or should:
|
||
flt = Filter(must=must or None, should=should or None)
|
||
|
||
fetch_n = max(offset + limit, 1)
|
||
pts, _ = qdrant.scroll(collection_name=PLAN_TEMPLATE_COLLECTION, scroll_filter=flt, limit=fetch_n, with_payload=True)
|
||
items: List[PlanTemplate] = []
|
||
for p in pts[offset:offset+limit]:
|
||
payload = dict(p.payload or {})
|
||
payload.setdefault("id", str(p.id))
|
||
items.append(_as_model(PlanTemplate, payload))
|
||
return PlanTemplateList(items=items, limit=limit, offset=offset, count=len(items))
|
||
|
||
# -----------------
|
||
# Endpoints: Pläne
|
||
# -----------------
|
||
@router.post(
|
||
"/plan",
|
||
response_model=Plan,
|
||
summary="Create a concrete training plan",
|
||
description=(
|
||
"Erstellt einen konkreten Trainingsplan.\n\n"
|
||
"Idempotenz: gleicher Fingerprint (title + items) → gleicher Plan (kein Duplikat).\n"
|
||
"Optional: Validierung von template_id und Exercises (Strict-Mode)."
|
||
),
|
||
)
|
||
def create_plan(p: Plan):
|
||
_ensure_collection(PLAN_COLLECTION)
|
||
|
||
# Template-Referenz prüfen (falls gesetzt)
|
||
if p.template_id:
|
||
if not _exists_in_collection(PLAN_TEMPLATE_COLLECTION, "id", p.template_id):
|
||
raise HTTPException(status_code=422, detail=f"Unknown template_id: {p.template_id}")
|
||
|
||
# Optional: Strict-Mode – Exercises gegen EXERCISE_COLLECTION prüfen
|
||
if _truthy(os.getenv("PLAN_STRICT_EXERCISES")):
|
||
missing: List[str] = []
|
||
for sec in p.sections or []:
|
||
for it in sec.items or []:
|
||
exid = (it.exercise_external_id or "").strip()
|
||
if exid and not _exists_in_collection(EXERCISE_COLLECTION, "external_id", exid):
|
||
missing.append(exid)
|
||
if missing:
|
||
raise HTTPException(status_code=422, detail={"error": "unknown exercise_external_id", "missing": sorted(set(missing))})
|
||
|
||
# Fingerprint + Idempotenz
|
||
fp = _fingerprint_for_plan(p)
|
||
p.fingerprint = p.fingerprint or fp
|
||
existing = _get_by_field(PLAN_COLLECTION, "fingerprint", p.fingerprint)
|
||
if existing:
|
||
return _as_model(Plan, existing)
|
||
|
||
# Normalisieren + Materialisierung
|
||
p.goals = _norm_list(p.goals)
|
||
payload = p.model_dump()
|
||
|
||
# created_at → ISO + numerischer Zeitstempel (FLOAT)
|
||
dt = payload.get("created_at")
|
||
if isinstance(dt, datetime):
|
||
dt = dt.astimezone(timezone.utc).isoformat()
|
||
elif isinstance(dt, str):
|
||
try:
|
||
_ = datetime.fromisoformat(dt.replace("Z", "+00:00"))
|
||
except Exception:
|
||
dt = datetime.now(timezone.utc).isoformat()
|
||
else:
|
||
dt = datetime.now(timezone.utc).isoformat()
|
||
payload["created_at"] = dt
|
||
try:
|
||
ts = datetime.fromisoformat(dt.replace("Z", "+00:00")).timestamp()
|
||
except Exception:
|
||
ts = datetime.now(timezone.utc).timestamp()
|
||
payload["created_at_ts"] = float(ts)
|
||
|
||
# Materialisierte Section-Namen für robuste Filter/Indizes
|
||
try:
|
||
payload["plan_section_names"] = _norm_list([
|
||
(s.get("name") or "").strip() for s in (payload.get("sections") or []) if isinstance(s, dict)
|
||
])
|
||
except Exception:
|
||
payload["plan_section_names"] = _norm_list([
|
||
(getattr(s, "name", "") or "").strip() for s in (p.sections or [])
|
||
])
|
||
|
||
vec = _embed(_plan_embed_text(p))
|
||
qdrant.upsert(collection_name=PLAN_COLLECTION, points=[PointStruct(id=str(p.id), vector=vec, payload=payload)])
|
||
return p
|
||
|
||
@router.get(
|
||
"/plan/{plan_id}",
|
||
response_model=Plan,
|
||
summary="Read a training plan by id",
|
||
description="Liest einen Plan anhand seiner ID. `created_at` wird (falls ISO-String) zu `datetime` geparst.",
|
||
)
|
||
def get_plan(plan_id: str):
|
||
_ensure_collection(PLAN_COLLECTION)
|
||
found = _get_by_field(PLAN_COLLECTION, "id", plan_id)
|
||
if not found:
|
||
raise HTTPException(status_code=404, detail="not found")
|
||
if isinstance(found.get("created_at"), str):
|
||
try:
|
||
found["created_at"] = datetime.fromisoformat(found["created_at"])
|
||
except Exception:
|
||
pass
|
||
return _as_model(Plan, found)
|
||
|
||
@router.get(
|
||
"/plans",
|
||
response_model=PlanList,
|
||
summary="List training plans (filterable)",
|
||
description=(
|
||
"Listet Trainingspläne mit Filtern.\n\n"
|
||
"**Filter** (exakte Matches, KEYWORD-Felder):\n"
|
||
"- created_by, discipline, age_group, target_group, goal\n"
|
||
"- section: Section-Name (nutzt materialisiertes `plan_section_names`)\n"
|
||
"- created_from / created_to: ISO-8601 Zeitfenster → serverseitiger Range-Filter über `created_at_ts` (FLOAT).\n\n"
|
||
"**Pagination:** limit/offset. Feld `count` entspricht der Anzahl zurückgegebener Items (keine Gesamtsumme)."
|
||
),
|
||
)
|
||
def list_plans(
|
||
created_by: Optional[str] = Query(None, description="Filter: Ersteller", example="tester"),
|
||
discipline: Optional[str] = Query(None, description="Filter: Disziplin", example="Karate"),
|
||
age_group: Optional[str] = Query(None, description="Filter: Altersgruppe", example="Teenager"),
|
||
target_group: Optional[str] = Query(None, description="Filter: Zielgruppe", example="Breitensport"),
|
||
goal: Optional[str] = Query(None, description="Filter: Trainingsziel", example="Technik"),
|
||
section: Optional[str] = Query(None, description="Filter: Section-Name", example="Warmup"),
|
||
created_from: Optional[str] = Query(None, description="Ab-Zeitpunkt (ISO 8601, z. B. 2025-08-12T00:00:00Z)", example="2025-08-12T00:00:00Z"),
|
||
created_to: Optional[str] = Query(None, description="Bis-Zeitpunkt (ISO 8601)", example="2025-08-13T00:00:00Z"),
|
||
limit: int = Query(20, ge=1, le=200, description="Max. Anzahl Items"),
|
||
offset: int = Query(0, ge=0, description="Start-Offset für Paging"),
|
||
):
|
||
_ensure_collection(PLAN_COLLECTION)
|
||
must: List[Any] = []
|
||
if created_by:
|
||
must.append(FieldCondition(key="created_by", match=MatchValue(value=created_by)))
|
||
if discipline:
|
||
must.append(FieldCondition(key="discipline", match=MatchValue(value=discipline)))
|
||
if age_group:
|
||
must.append(FieldCondition(key="age_group", match=MatchValue(value=age_group)))
|
||
if target_group:
|
||
must.append(FieldCondition(key="target_group", match=MatchValue(value=target_group)))
|
||
if goal:
|
||
must.append(FieldCondition(key="goals", match=MatchValue(value=goal)))
|
||
if section:
|
||
must.append(FieldCondition(key="plan_section_names", match=MatchValue(value=section)))
|
||
|
||
# Range-Filter über numerisches Feld (FLOAT)
|
||
range_args: Dict[str, float] = {}
|
||
try:
|
||
if created_from:
|
||
range_args["gte"] = float(datetime.fromisoformat(created_from.replace("Z", "+00:00")).timestamp())
|
||
if created_to:
|
||
range_args["lte"] = float(datetime.fromisoformat(created_to.replace("Z", "+00:00")).timestamp())
|
||
except Exception:
|
||
range_args = {}
|
||
|
||
applied_server_range = bool(range_args)
|
||
if applied_server_range:
|
||
must.append(FieldCondition(key="created_at_ts", range=Range(**range_args)))
|
||
|
||
flt = Filter(must=must or None) if must else None
|
||
|
||
fetch_n = max(offset + limit, 1)
|
||
pts, _ = qdrant.scroll(collection_name=PLAN_COLLECTION, scroll_filter=flt, limit=fetch_n, with_payload=True)
|
||
|
||
# Fallback: nur wenn KEIN serverseitiger Range aktiv war (Alt-Daten ohne created_at_ts)
|
||
def _in_window(py: Dict[str, Any]) -> bool:
|
||
if applied_server_range:
|
||
return True
|
||
if not (created_from or created_to):
|
||
return True
|
||
ts = py.get("created_at")
|
||
if isinstance(ts, dict) and ts.get("$date"):
|
||
ts = ts["$date"]
|
||
if isinstance(ts, str):
|
||
try:
|
||
dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
|
||
except Exception:
|
||
return False
|
||
elif isinstance(ts, datetime):
|
||
dt = ts
|
||
else:
|
||
return False
|
||
ok = True
|
||
if created_from:
|
||
try:
|
||
ok = ok and dt >= datetime.fromisoformat(created_from.replace("Z", "+00:00"))
|
||
except Exception:
|
||
pass
|
||
if created_to:
|
||
try:
|
||
ok = ok and dt <= datetime.fromisoformat(created_to.replace("Z", "+00:00"))
|
||
except Exception:
|
||
pass
|
||
return ok
|
||
|
||
payloads: List[Dict[str, Any]] = []
|
||
for p in pts:
|
||
py = dict(p.payload or {})
|
||
py.setdefault("id", str(p.id))
|
||
if _in_window(py):
|
||
payloads.append(py)
|
||
|
||
sliced = payloads[offset:offset+limit]
|
||
items = [_as_model(Plan, x) for x in sliced]
|
||
return PlanList(items=items, limit=limit, offset=offset, count=len(items))
|