Trainer_LLM/llm-api/plan_router.py
Lars 5e2591fb56
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s
llm-api/plan_router.py aktualisiert
2025-08-13 11:17:36 +02:00

479 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
plan_router.py v0.13.1 (WP-15)
Minimal-CRUD + List/Filter für Templates & Pläne.
Änderungen ggü. v0.13.0
- Serverseitiger Zeitfenster-Filter über `created_at_ts` (FLOAT) bleibt erhalten.
- Lokaler Fallback-Zeitfilter wird DEAKTIVIERT, sobald ein serverseitiger Range aktiv ist
(verhindert false negatives).
- `plan_section_names` wird beim POST /plan materialisiert und für Filters genutzt.
"""
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from uuid import uuid4
from datetime import datetime, timezone
import hashlib
import json
import os
from clients import model, qdrant
from qdrant_client.models import (
PointStruct, Filter, FieldCondition, MatchValue,
VectorParams, Distance, Range
)
router = APIRouter(tags=["plans"])
# -----------------
# Konfiguration
# -----------------
PLAN_COLLECTION = os.getenv("PLAN_COLLECTION") or os.getenv("QDRANT_COLLECTION_PLANS", "plans")
PLAN_TEMPLATE_COLLECTION = os.getenv("PLAN_TEMPLATE_COLLECTION", "plan_templates")
PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions")
EXERCISE_COLLECTION = os.getenv("EXERCISE_COLLECTION", "exercises")
# -----------------
# Modelle
# -----------------
class TemplateSection(BaseModel):
name: str
target_minutes: int
must_keywords: List[str] = []
ideal_keywords: List[str] = [] # wünschenswert
supplement_keywords: List[str] = [] # ergänzend
forbid_keywords: List[str] = []
capability_targets: Dict[str, int] = {}
class PlanTemplate(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
discipline: str
age_group: str
target_group: str
total_minutes: int
sections: List[TemplateSection] = []
goals: List[str] = []
equipment_allowed: List[str] = []
created_by: str
version: str
class PlanItem(BaseModel):
exercise_external_id: str
duration: int
why: str
class PlanSection(BaseModel):
name: str
items: List[PlanItem] = []
minutes: int
class Plan(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
template_id: Optional[str] = None
title: str
discipline: str
age_group: str
target_group: str
total_minutes: int
sections: List[PlanSection] = []
goals: List[str] = []
capability_summary: Dict[str, int] = {}
novelty_against_last_n: Optional[float] = None
fingerprint: Optional[str] = None
created_by: str
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
source: str = "API"
class PlanTemplateList(BaseModel):
items: List[PlanTemplate]
limit: int
offset: int
count: int
class PlanList(BaseModel):
items: List[Plan]
limit: int
offset: int
count: int
# -----------------
# Helpers
# -----------------
def _ensure_collection(name: str):
"""Falls Collection fehlt, analog exercise_router anlegen."""
if not qdrant.collection_exists(name):
qdrant.recreate_collection(
collection_name=name,
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
)
def _norm_list(xs: List[str]) -> List[str]:
"""Trimmen, casefolded deduplizieren, stabil sortieren."""
seen, out = set(), []
for x in xs or []:
s = str(x).strip()
k = s.casefold()
if s and k not in seen:
seen.add(k)
out.append(s)
return sorted(out, key=str.casefold)
def _template_embed_text(tpl: PlanTemplate) -> str:
parts = [tpl.name, tpl.discipline, tpl.age_group, tpl.target_group]
parts += tpl.goals
parts += [s.name for s in tpl.sections]
return ". ".join([p for p in parts if p])
def _plan_embed_text(p: Plan) -> str:
parts = [p.title, p.discipline, p.age_group, p.target_group]
parts += p.goals
parts += [s.name for s in p.sections]
return ". ".join([p for p in parts if p])
def _embed(text: str):
return model.encode(text or "").tolist()
def _fingerprint_for_plan(p: Plan) -> str:
"""sha256(title, total_minutes, sections.items.exercise_external_id, sections.items.duration)"""
core = {
"title": p.title,
"total_minutes": int(p.total_minutes),
"items": [
{"exercise_external_id": it.exercise_external_id, "duration": int(it.duration)}
for sec in p.sections
for it in (sec.items or [])
],
}
raw = json.dumps(core, sort_keys=True, ensure_ascii=False)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
if not pts:
return None
payload = dict(pts[0].payload or {})
payload.setdefault("id", str(pts[0].id))
return payload
def _as_model(model_cls, payload: Dict[str, Any]):
"""Unbekannte Payload-Felder herausfiltern (Pydantic v1/v2 kompatibel)."""
fields = getattr(model_cls, "model_fields", None) or getattr(model_cls, "__fields__", {})
allowed = set(fields.keys())
data = {k: payload[k] for k in payload.keys() if k in allowed}
return model_cls(**data)
def _truthy(val: Optional[str]) -> bool:
return str(val or "").strip().lower() in {"1", "true", "yes", "on"}
def _exists_in_collection(collection: str, key: str, value: Any) -> bool:
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=False)
return bool(pts)
# -----------------
# Endpoints: Templates
# -----------------
@router.post(
"/plan_templates",
response_model=PlanTemplate,
summary="Create a plan template",
description=(
"Erstellt ein Plan-Template (Strukturplanung).\n\n"
"• Mehrere Sections erlaubt.\n"
"• Section-Felder: must/ideal/supplement/forbid keywords + capability_targets.\n"
"• Materialisierte Facettenfelder (section_*) werden intern geschrieben, um Qdrant-Filter zu beschleunigen."
),
)
def create_plan_template(t: PlanTemplate):
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
payload = t.model_dump()
payload["goals"] = _norm_list(payload.get("goals"))
sections = payload.get("sections", []) or []
for s in sections:
s["must_keywords"] = _norm_list(s.get("must_keywords") or [])
s["ideal_keywords"] = _norm_list(s.get("ideal_keywords") or [])
s["supplement_keywords"] = _norm_list(s.get("supplement_keywords") or [])
s["forbid_keywords"] = _norm_list(s.get("forbid_keywords") or [])
# Materialisierte Facetten (KEYWORD-Indizes)
payload["section_names"] = _norm_list([s.get("name", "") for s in sections])
payload["section_must_keywords"] = _norm_list([kw for s in sections for kw in (s.get("must_keywords") or [])])
payload["section_ideal_keywords"] = _norm_list([kw for s in sections for kw in (s.get("ideal_keywords") or [])])
payload["section_supplement_keywords"] = _norm_list([kw for s in sections for kw in (s.get("supplement_keywords") or [])])
payload["section_forbid_keywords"] = _norm_list([kw for s in sections for kw in (s.get("forbid_keywords") or [])])
vec = _embed(_template_embed_text(t))
qdrant.upsert(collection_name=PLAN_TEMPLATE_COLLECTION, points=[PointStruct(id=str(t.id), vector=vec, payload=payload)])
return t
@router.get(
"/plan_templates/{tpl_id}",
response_model=PlanTemplate,
summary="Read a plan template by id",
description="Liest ein Template anhand seiner ID und gibt nur die Schemafelder zurück (zusätzliche Payload wird herausgefiltert).",
)
def get_plan_template(tpl_id: str):
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
found = _get_by_field(PLAN_TEMPLATE_COLLECTION, "id", tpl_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
return _as_model(PlanTemplate, found)
@router.get(
"/plan_templates",
response_model=PlanTemplateList,
summary="List plan templates (filterable)",
description=(
"Listet Plan-Templates mit Filtern.\n\n"
"**Filter** (exakte Matches, KEYWORD-Felder):\n"
"- discipline, age_group, target_group\n"
"- section: Section-Name (nutzt materialisierte `section_names`)\n"
"- goal: Ziel (nutzt `goals`)\n"
"- keyword: trifft auf beliebige Section-Keyword-Felder (must/ideal/supplement/forbid).\n\n"
"**Pagination:** limit/offset. Feld `count` entspricht der Anzahl zurückgegebener Items (keine Gesamtsumme)."
),
)
def list_plan_templates(
discipline: Optional[str] = Query(None, description="Filter: Disziplin (exaktes KEYWORD-Match)", example="Karate"),
age_group: Optional[str] = Query(None, description="Filter: Altersgruppe", example="Teenager"),
target_group: Optional[str] = Query(None, description="Filter: Zielgruppe", example="Breitensport"),
section: Optional[str] = Query(None, description="Filter: Section-Name (materialisiert)", example="Warmup"),
goal: Optional[str] = Query(None, description="Filter: Trainingsziel", example="Technik"),
keyword: Optional[str] = Query(None, description="Filter: Keyword in must/ideal/supplement/forbid", example="Koordination"),
limit: int = Query(20, ge=1, le=200, description="Max. Anzahl Items"),
offset: int = Query(0, ge=0, description="Start-Offset für Paging"),
):
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
must: List[Any] = []
should: List[Any] = []
if discipline:
must.append(FieldCondition(key="discipline", match=MatchValue(value=discipline)))
if age_group:
must.append(FieldCondition(key="age_group", match=MatchValue(value=age_group)))
if target_group:
must.append(FieldCondition(key="target_group", match=MatchValue(value=target_group)))
if section:
must.append(FieldCondition(key="section_names", match=MatchValue(value=section)))
if goal:
must.append(FieldCondition(key="goals", match=MatchValue(value=goal)))
if keyword:
for k in (
"section_must_keywords",
"section_ideal_keywords",
"section_supplement_keywords",
"section_forbid_keywords",
):
should.append(FieldCondition(key=k, match=MatchValue(value=keyword)))
flt = None
if must or should:
flt = Filter(must=must or None, should=should or None)
fetch_n = max(offset + limit, 1)
pts, _ = qdrant.scroll(collection_name=PLAN_TEMPLATE_COLLECTION, scroll_filter=flt, limit=fetch_n, with_payload=True)
items: List[PlanTemplate] = []
for p in pts[offset:offset+limit]:
payload = dict(p.payload or {})
payload.setdefault("id", str(p.id))
items.append(_as_model(PlanTemplate, payload))
return PlanTemplateList(items=items, limit=limit, offset=offset, count=len(items))
# -----------------
# Endpoints: Pläne
# -----------------
@router.post(
"/plan",
response_model=Plan,
summary="Create a concrete training plan",
description=(
"Erstellt einen konkreten Trainingsplan.\n\n"
"Idempotenz: gleicher Fingerprint (title + items) → gleicher Plan (kein Duplikat).\n"
"Optional: Validierung von template_id und Exercises (Strict-Mode)."
),
)
def create_plan(p: Plan):
_ensure_collection(PLAN_COLLECTION)
# Template-Referenz prüfen (falls gesetzt)
if p.template_id:
if not _exists_in_collection(PLAN_TEMPLATE_COLLECTION, "id", p.template_id):
raise HTTPException(status_code=422, detail=f"Unknown template_id: {p.template_id}")
# Optional: Strict-Mode Exercises gegen EXERCISE_COLLECTION prüfen
if _truthy(os.getenv("PLAN_STRICT_EXERCISES")):
missing: List[str] = []
for sec in p.sections or []:
for it in sec.items or []:
exid = (it.exercise_external_id or "").strip()
if exid and not _exists_in_collection(EXERCISE_COLLECTION, "external_id", exid):
missing.append(exid)
if missing:
raise HTTPException(status_code=422, detail={"error": "unknown exercise_external_id", "missing": sorted(set(missing))})
# Fingerprint + Idempotenz
fp = _fingerprint_for_plan(p)
p.fingerprint = p.fingerprint or fp
existing = _get_by_field(PLAN_COLLECTION, "fingerprint", p.fingerprint)
if existing:
return _as_model(Plan, existing)
# Normalisieren + Materialisierung
p.goals = _norm_list(p.goals)
payload = p.model_dump()
# created_at → ISO + numerischer Zeitstempel (FLOAT)
dt = payload.get("created_at")
if isinstance(dt, datetime):
dt = dt.astimezone(timezone.utc).isoformat()
elif isinstance(dt, str):
try:
_ = datetime.fromisoformat(dt.replace("Z", "+00:00"))
except Exception:
dt = datetime.now(timezone.utc).isoformat()
else:
dt = datetime.now(timezone.utc).isoformat()
payload["created_at"] = dt
try:
ts = datetime.fromisoformat(dt.replace("Z", "+00:00")).timestamp()
except Exception:
ts = datetime.now(timezone.utc).timestamp()
payload["created_at_ts"] = float(ts)
# Materialisierte Section-Namen für robuste Filter/Indizes
try:
payload["plan_section_names"] = _norm_list([
(s.get("name") or "").strip() for s in (payload.get("sections") or []) if isinstance(s, dict)
])
except Exception:
payload["plan_section_names"] = _norm_list([
(getattr(s, "name", "") or "").strip() for s in (p.sections or [])
])
vec = _embed(_plan_embed_text(p))
qdrant.upsert(collection_name=PLAN_COLLECTION, points=[PointStruct(id=str(p.id), vector=vec, payload=payload)])
return p
@router.get(
"/plan/{plan_id}",
response_model=Plan,
summary="Read a training plan by id",
description="Liest einen Plan anhand seiner ID. `created_at` wird (falls ISO-String) zu `datetime` geparst.",
)
def get_plan(plan_id: str):
_ensure_collection(PLAN_COLLECTION)
found = _get_by_field(PLAN_COLLECTION, "id", plan_id)
if not found:
raise HTTPException(status_code=404, detail="not found")
if isinstance(found.get("created_at"), str):
try:
found["created_at"] = datetime.fromisoformat(found["created_at"])
except Exception:
pass
return _as_model(Plan, found)
@router.get(
"/plans",
response_model=PlanList,
summary="List training plans (filterable)",
description=(
"Listet Trainingspläne mit Filtern.\n\n"
"**Filter** (exakte Matches, KEYWORD-Felder):\n"
"- created_by, discipline, age_group, target_group, goal\n"
"- section: Section-Name (nutzt materialisiertes `plan_section_names`)\n"
"- created_from / created_to: ISO-8601 Zeitfenster → serverseitiger Range-Filter über `created_at_ts` (FLOAT).\n\n"
"**Pagination:** limit/offset. Feld `count` entspricht der Anzahl zurückgegebener Items (keine Gesamtsumme)."
),
)
def list_plans(
created_by: Optional[str] = Query(None, description="Filter: Ersteller", example="tester"),
discipline: Optional[str] = Query(None, description="Filter: Disziplin", example="Karate"),
age_group: Optional[str] = Query(None, description="Filter: Altersgruppe", example="Teenager"),
target_group: Optional[str] = Query(None, description="Filter: Zielgruppe", example="Breitensport"),
goal: Optional[str] = Query(None, description="Filter: Trainingsziel", example="Technik"),
section: Optional[str] = Query(None, description="Filter: Section-Name", example="Warmup"),
created_from: Optional[str] = Query(None, description="Ab-Zeitpunkt (ISO 8601, z. B. 2025-08-12T00:00:00Z)", example="2025-08-12T00:00:00Z"),
created_to: Optional[str] = Query(None, description="Bis-Zeitpunkt (ISO 8601)", example="2025-08-13T00:00:00Z"),
limit: int = Query(20, ge=1, le=200, description="Max. Anzahl Items"),
offset: int = Query(0, ge=0, description="Start-Offset für Paging"),
):
_ensure_collection(PLAN_COLLECTION)
must: List[Any] = []
if created_by:
must.append(FieldCondition(key="created_by", match=MatchValue(value=created_by)))
if discipline:
must.append(FieldCondition(key="discipline", match=MatchValue(value=discipline)))
if age_group:
must.append(FieldCondition(key="age_group", match=MatchValue(value=age_group)))
if target_group:
must.append(FieldCondition(key="target_group", match=MatchValue(value=target_group)))
if goal:
must.append(FieldCondition(key="goals", match=MatchValue(value=goal)))
if section:
must.append(FieldCondition(key="plan_section_names", match=MatchValue(value=section)))
# Range-Filter über numerisches Feld (FLOAT)
range_args: Dict[str, float] = {}
try:
if created_from:
range_args["gte"] = float(datetime.fromisoformat(created_from.replace("Z", "+00:00")).timestamp())
if created_to:
range_args["lte"] = float(datetime.fromisoformat(created_to.replace("Z", "+00:00")).timestamp())
except Exception:
range_args = {}
applied_server_range = bool(range_args)
if applied_server_range:
must.append(FieldCondition(key="created_at_ts", range=Range(**range_args)))
flt = Filter(must=must or None) if must else None
fetch_n = max(offset + limit, 1)
pts, _ = qdrant.scroll(collection_name=PLAN_COLLECTION, scroll_filter=flt, limit=fetch_n, with_payload=True)
# Fallback: nur wenn KEIN serverseitiger Range aktiv war (Alt-Daten ohne created_at_ts)
def _in_window(py: Dict[str, Any]) -> bool:
if applied_server_range:
return True
if not (created_from or created_to):
return True
ts = py.get("created_at")
if isinstance(ts, dict) and ts.get("$date"):
ts = ts["$date"]
if isinstance(ts, str):
try:
dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
except Exception:
return False
elif isinstance(ts, datetime):
dt = ts
else:
return False
ok = True
if created_from:
try:
ok = ok and dt >= datetime.fromisoformat(created_from.replace("Z", "+00:00"))
except Exception:
pass
if created_to:
try:
ok = ok and dt <= datetime.fromisoformat(created_to.replace("Z", "+00:00"))
except Exception:
pass
return ok
payloads: List[Dict[str, Any]] = []
for p in pts:
py = dict(p.payload or {})
py.setdefault("id", str(p.id))
if _in_window(py):
payloads.append(py)
sliced = payloads[offset:offset+limit]
items = [_as_model(Plan, x) for x in sliced]
return PlanList(items=items, limit=limit, offset=offset, count=len(items))