# -*- coding: utf-8 -*- """ plan_router.py – v0.13.4 (WP-15) Änderungen ggü. v0.13.3 - Idempotenter POST /plan: Wenn ein Plan mit gleichem Fingerprint existiert und die neue Anfrage ein späteres `created_at` trägt, wird der gespeicherte Plan mit dem neueren `created_at` und `created_at_ts` aktualisiert (kein Duplikat, aber zeitlich „frisch“). - /plans: Mehrseitiges Scrollen bleibt aktiv; Zeitfenster-Filter robust (serverseitig + Fallback). """ from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel, Field from typing import List, Optional, Dict, Any from uuid import uuid4 from datetime import datetime, timezone import hashlib import json import os from clients import model, qdrant from qdrant_client.models import ( PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance, Range ) router = APIRouter(tags=["plans"]) # ----------------- # Konfiguration # ----------------- PLAN_COLLECTION = os.getenv("PLAN_COLLECTION") or os.getenv("QDRANT_COLLECTION_PLANS", "plans") PLAN_TEMPLATE_COLLECTION = os.getenv("PLAN_TEMPLATE_COLLECTION", "plan_templates") PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions") EXERCISE_COLLECTION = os.getenv("EXERCISE_COLLECTION", "exercises") # ----------------- # Modelle # ----------------- class TemplateSection(BaseModel): name: str target_minutes: int must_keywords: List[str] = [] ideal_keywords: List[str] = [] supplement_keywords: List[str] = [] forbid_keywords: List[str] = [] capability_targets: Dict[str, int] = {} class PlanTemplate(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) name: str discipline: str age_group: str target_group: str total_minutes: int sections: List[TemplateSection] = [] goals: List[str] = [] equipment_allowed: List[str] = [] created_by: str version: str class PlanItem(BaseModel): exercise_external_id: str duration: int why: str class PlanSection(BaseModel): name: str items: List[PlanItem] = [] minutes: int class Plan(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) template_id: Optional[str] = None title: str discipline: str age_group: str target_group: str total_minutes: int sections: List[PlanSection] = [] goals: List[str] = [] capability_summary: Dict[str, int] = {} novelty_against_last_n: Optional[float] = None fingerprint: Optional[str] = None created_by: str created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) source: str = "API" class PlanTemplateList(BaseModel): items: List[PlanTemplate] limit: int offset: int count: int class PlanList(BaseModel): items: List[Plan] limit: int offset: int count: int # ----------------- # Helpers # ----------------- def _ensure_collection(name: str): if not qdrant.collection_exists(name): qdrant.recreate_collection( collection_name=name, vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE), ) def _norm_list(xs: List[str]) -> List[str]: seen, out = set(), [] for x in xs or []: s = str(x).strip() k = s.casefold() if s and k not in seen: seen.add(k) out.append(s) return sorted(out, key=str.casefold) def _template_embed_text(tpl: PlanTemplate) -> str: parts = [tpl.name, tpl.discipline, tpl.age_group, tpl.target_group] parts += tpl.goals parts += [s.name for s in tpl.sections] return ". ".join([p for p in parts if p]) def _plan_embed_text(p: Plan) -> str: parts = [p.title, p.discipline, p.age_group, p.target_group] parts += p.goals parts += [s.name for s in p.sections] return ". ".join([p for p in parts if p]) def _embed(text: str): return model.encode(text or "").tolist() def _fingerprint_for_plan(p: Plan) -> str: """sha256(title, total_minutes, sections.items.exercise_external_id, sections.items.duration)""" core = { "title": p.title, "total_minutes": int(p.total_minutes), "items": [ {"exercise_external_id": it.exercise_external_id, "duration": int(it.duration)} for sec in p.sections for it in (sec.items or []) ], } raw = json.dumps(core, sort_keys=True, ensure_ascii=False) return hashlib.sha256(raw.encode("utf-8")).hexdigest() def _get_by_field(collection: str, key: str, value: Any): flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))]) pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True) if not pts: return None point = pts[0] payload = dict(point.payload or {}) payload.setdefault("id", str(point.id)) return {"id": point.id, "payload": payload} def _as_model(model_cls, payload: Dict[str, Any]): fields = getattr(model_cls, "model_fields", None) or getattr(model_cls, "__fields__", {}) allowed = set(fields.keys()) data = {k: payload[k] for k in payload.keys() if k in allowed} return model_cls(**data) def _truthy(val: Optional[str]) -> bool: return str(val or "").strip().lower() in {"1", "true", "yes", "on"} def _exists_in_collection(collection: str, key: str, value: Any) -> bool: flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))]) pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=False) return bool(pts) def _parse_iso_to_ts(iso_str: str) -> float: try: return float(datetime.fromisoformat(iso_str.replace("Z", "+00:00")).timestamp()) except Exception: return float(datetime.now(timezone.utc).timestamp()) def _scroll_collect(collection: str, flt: Optional[Filter], need: int, page: int = 256): out = [] offset = None page = max(1, min(page, 1024)) while len(out) < need: pts, offset = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=min(page, need - len(out)), with_payload=True, offset=offset) if not pts: break out.extend(pts) if offset is None: break return out # ----------------- # Endpoints: Templates # ----------------- @router.post( "/plan_templates", response_model=PlanTemplate, summary="Create a plan template", description=( "Erstellt ein Plan-Template (Strukturplanung).\n\n" "• Mehrere Sections erlaubt.\n" "• Section-Felder: must/ideal/supplement/forbid keywords + capability_targets.\n" "• Materialisierte Facettenfelder (section_*) werden intern geschrieben, um Qdrant-Filter zu beschleunigen." ), ) def create_plan_template(t: PlanTemplate): _ensure_collection(PLAN_TEMPLATE_COLLECTION) payload = t.model_dump() payload["goals"] = _norm_list(payload.get("goals")) sections = payload.get("sections", []) or [] for s in sections: s["must_keywords"] = _norm_list(s.get("must_keywords") or []) s["ideal_keywords"] = _norm_list(s.get("ideal_keywords") or []) s["supplement_keywords"] = _norm_list(s.get("supplement_keywords") or []) s["forbid_keywords"] = _norm_list(s.get("forbid_keywords") or []) # Materialisierte Facetten (KEYWORD-Indizes) payload["section_names"] = _norm_list([s.get("name", "") for s in sections]) payload["section_must_keywords"] = _norm_list([kw for s in sections for kw in (s.get("must_keywords") or [])]) payload["section_ideal_keywords"] = _norm_list([kw for s in sections for kw in (s.get("ideal_keywords") or [])]) payload["section_supplement_keywords"] = _norm_list([kw for s in sections for kw in (s.get("supplement_keywords") or [])]) payload["section_forbid_keywords"] = _norm_list([kw for s in sections for kw in (s.get("forbid_keywords") or [])]) vec = _embed(_template_embed_text(t)) qdrant.upsert(collection_name=PLAN_TEMPLATE_COLLECTION, points=[PointStruct(id=str(t.id), vector=vec, payload=payload)]) return t @router.get( "/plan_templates/{tpl_id}", response_model=PlanTemplate, summary="Read a plan template by id", description="Liest ein Template anhand seiner ID und gibt nur die Schemafelder zurück (zusätzliche Payload wird herausgefiltert).", ) def get_plan_template(tpl_id: str): _ensure_collection(PLAN_TEMPLATE_COLLECTION) found = _get_by_field(PLAN_TEMPLATE_COLLECTION, "id", tpl_id) if not found: raise HTTPException(status_code=404, detail="not found") return _as_model(PlanTemplate, found["payload"]) @router.get( "/plan_templates", response_model=PlanTemplateList, summary="List plan templates (filterable)", description=( "Listet Plan-Templates mit Filtern.\n\n" "**Filter** (exakte Matches, KEYWORD-Felder):\n" "- discipline, age_group, target_group\n" "- section: Section-Name (nutzt materialisierte `section_names`)\n" "- goal: Ziel (nutzt `goals`)\n" "- keyword: trifft auf beliebige Section-Keyword-Felder (must/ideal/supplement/forbid).\n\n" "**Pagination:** limit/offset. Feld `count` entspricht der Anzahl zurückgegebener Items (keine Gesamtsumme)." ), ) def list_plan_templates( discipline: Optional[str] = Query(None, description="Filter: Disziplin (exaktes KEYWORD-Match)", example="Karate"), age_group: Optional[str] = Query(None, description="Filter: Altersgruppe", example="Teenager"), target_group: Optional[str] = Query(None, description="Filter: Zielgruppe", example="Breitensport"), section: Optional[str] = Query(None, description="Filter: Section-Name (materialisiert)", example="Warmup"), goal: Optional[str] = Query(None, description="Filter: Trainingsziel", example="Technik"), keyword: Optional[str] = Query(None, description="Filter: Keyword in must/ideal/supplement/forbid", example="Koordination"), limit: int = Query(20, ge=1, le=200, description="Max. Anzahl Items"), offset: int = Query(0, ge=0, description="Start-Offset für Paging"), ): _ensure_collection(PLAN_TEMPLATE_COLLECTION) must: List[Any] = [] should: List[Any] = [] if discipline: must.append(FieldCondition(key="discipline", match=MatchValue(value=discipline))) if age_group: must.append(FieldCondition(key="age_group", match=MatchValue(value=age_group))) if target_group: must.append(FieldCondition(key="target_group", match=MatchValue(value=target_group))) if section: must.append(FieldCondition(key="section_names", match=MatchValue(value=section))) if goal: must.append(FieldCondition(key="goals", match=MatchValue(value=goal))) if keyword: for k in ("section_must_keywords","section_ideal_keywords","section_supplement_keywords","section_forbid_keywords"): should.append(FieldCondition(key=k, match=MatchValue(value=keyword))) flt = Filter(must=must or None, should=should or None) if (must or should) else None need = max(offset + limit, 1) pts = _scroll_collect(PLAN_TEMPLATE_COLLECTION, flt, need) items: List[PlanTemplate] = [] for p in pts[offset:offset+limit]: payload = dict(p.payload or {}) payload.setdefault("id", str(p.id)) items.append(_as_model(PlanTemplate, payload)) return PlanTemplateList(items=items, limit=limit, offset=offset, count=len(items)) # ----------------- # Endpoints: Pläne # ----------------- @router.post( "/plan", response_model=Plan, summary="Create a concrete training plan", description=( "Erstellt einen konkreten Trainingsplan.\n\n" "Idempotenz: gleicher Fingerprint (title + items) → gleicher Plan (kein Duplikat).\n" "Bei erneutem POST mit späterem `created_at` wird `created_at`/`created_at_ts` des bestehenden Plans aktualisiert." ), ) def create_plan(p: Plan): _ensure_collection(PLAN_COLLECTION) # Template-Referenz prüfen (falls gesetzt) if p.template_id: if not _exists_in_collection(PLAN_TEMPLATE_COLLECTION, "id", p.template_id): raise HTTPException(status_code=422, detail=f"Unknown template_id: {p.template_id}") # Optional: Strict-Mode – Exercises gegen EXERCISE_COLLECTION prüfen if _truthy(os.getenv("PLAN_STRICT_EXERCISES")): missing: List[str] = [] for sec in p.sections or []: for it in sec.items or []: exid = (it.exercise_external_id or "").strip() if exid and not _exists_in_collection(EXERCISE_COLLECTION, "external_id", exid): missing.append(exid) if missing: raise HTTPException(status_code=422, detail={"error": "unknown exercise_external_id", "missing": sorted(set(missing))}) # Fingerprint fp = _fingerprint_for_plan(p) p.fingerprint = p.fingerprint or fp # Ziel-ISO + TS aus Request berechnen (auch wenn Duplikat) req_payload = p.model_dump() dt = req_payload.get("created_at") if isinstance(dt, datetime): dt = dt.astimezone(timezone.utc).isoformat() elif isinstance(dt, str): try: _ = datetime.fromisoformat(dt.replace("Z", "+00:00")) except Exception: dt = datetime.now(timezone.utc).isoformat() else: dt = datetime.now(timezone.utc).isoformat() req_payload["created_at"] = dt req_ts = _parse_iso_to_ts(dt) req_payload["created_at_ts"] = float(req_ts) # Dup-Check existing = _get_by_field(PLAN_COLLECTION, "fingerprint", p.fingerprint) if existing: # Falls neues created_at später ist → gespeicherten Plan aktualisieren cur = existing["payload"] cur_ts = cur.get("created_at_ts") if cur_ts is None: cur_ts = _parse_iso_to_ts(str(cur.get("created_at", dt))) if req_ts > float(cur_ts): try: qdrant.set_payload( collection_name=PLAN_COLLECTION, payload={"created_at": req_payload["created_at"], "created_at_ts": req_payload["created_at_ts"]}, points=[existing["id"]], ) # Antwort-Objekt aktualisieren cur["created_at"] = req_payload["created_at"] cur["created_at_ts"] = req_payload["created_at_ts"] except Exception: pass return _as_model(Plan, cur) # Neu anlegen p.goals = _norm_list(p.goals) payload = req_payload # enthält bereits korrektes created_at + created_at_ts payload.update({ "id": p.id, "template_id": p.template_id, "title": p.title, "discipline": p.discipline, "age_group": p.age_group, "target_group": p.target_group, "total_minutes": p.total_minutes, "sections": [s.model_dump() for s in p.sections], "goals": _norm_list(p.goals), "capability_summary": p.capability_summary, "novelty_against_last_n": p.novelty_against_last_n, "fingerprint": p.fingerprint, "created_by": p.created_by, "source": p.source, }) # Section-Namen materialisieren payload["plan_section_names"] = _norm_list([ (s.get("name") or "").strip() for s in (payload.get("sections") or []) if isinstance(s, dict) ]) vec = _embed(_plan_embed_text(p)) qdrant.upsert(collection_name=PLAN_COLLECTION, points=[PointStruct(id=str(p.id), vector=vec, payload=payload)]) return p @router.get( "/plan/{plan_id}", response_model=Plan, summary="Read a training plan by id", description="Liest einen Plan anhand seiner ID. `created_at` wird (falls ISO-String) zu `datetime` geparst.", ) def get_plan(plan_id: str): _ensure_collection(PLAN_COLLECTION) found = _get_by_field(PLAN_COLLECTION, "id", plan_id) if not found: raise HTTPException(status_code=404, detail="not found") payload = found["payload"] if isinstance(payload.get("created_at"), str): try: payload["created_at"] = datetime.fromisoformat(payload["created_at"]) except Exception: pass return _as_model(Plan, payload) @router.get( "/plans", response_model=PlanList, summary="List training plans (filterable)", description=( "Listet Trainingspläne mit Filtern.\n\n" "**Filter** (exakte Matches, KEYWORD-Felder):\n" "- created_by, discipline, age_group, target_group, goal\n" "- section: Section-Name (nutzt materialisiertes `plan_section_names`)\n" "- created_from / created_to: ISO-8601 Zeitfenster → serverseitiger Range-Filter über `created_at_ts` (FLOAT). " "Falls 0 Treffer: zweiter Durchlauf ohne Zeit-Range + lokale Zeitprüfung.\n\n" "**Pagination:** limit/offset. Feld `count` entspricht der Anzahl zurückgegebener Items (keine Gesamtsumme)." ), ) def list_plans( created_by: Optional[str] = Query(None, description="Filter: Ersteller", example="tester"), discipline: Optional[str] = Query(None, description="Filter: Disziplin", example="Karate"), age_group: Optional[str] = Query(None, description="Filter: Altersgruppe", example="Teenager"), target_group: Optional[str] = Query(None, description="Filter: Zielgruppe", example="Breitensport"), goal: Optional[str] = Query(None, description="Filter: Trainingsziel", example="Technik"), section: Optional[str] = Query(None, description="Filter: Section-Name", example="Warmup"), created_from: Optional[str] = Query(None, description="Ab-Zeitpunkt (ISO 8601, z. B. 2025-08-12T00:00:00Z)", example="2025-08-12T00:00:00Z"), created_to: Optional[str] = Query(None, description="Bis-Zeitpunkt (ISO 8601)", example="2025-08-13T00:00:00Z"), limit: int = Query(20, ge=1, le=200, description="Max. Anzahl Items"), offset: int = Query(0, ge=0, description="Start-Offset für Paging"), ): _ensure_collection(PLAN_COLLECTION) # Grundfilter (ohne Zeit) base_must: List[Any] = [] if created_by: base_must.append(FieldCondition(key="created_by", match=MatchValue(value=created_by))) if discipline: base_must.append(FieldCondition(key="discipline", match=MatchValue(value=discipline))) if age_group: base_must.append(FieldCondition(key="age_group", match=MatchValue(value=age_group))) if target_group: base_must.append(FieldCondition(key="target_group", match=MatchValue(value=target_group))) if goal: base_must.append(FieldCondition(key="goals", match=MatchValue(value=goal))) if section: base_must.append(FieldCondition(key="plan_section_names", match=MatchValue(value=section))) # Serverseitiger Zeitbereich range_args: Dict[str, float] = {} try: if created_from: range_args["gte"] = float(datetime.fromisoformat(created_from.replace("Z", "+00:00")).timestamp()) if created_to: range_args["lte"] = float(datetime.fromisoformat(created_to.replace("Z", "+00:00")).timestamp()) except Exception: range_args = {} applied_server_range = bool(range_args) must_with_time = list(base_must) if applied_server_range: must_with_time.append(FieldCondition(key="created_at_ts", range=Range(**range_args))) need = max(offset + limit, 1) # 1) Scroll mit Zeit-Range (falls vorhanden) pts = _scroll_collect(PLAN_COLLECTION, Filter(must=must_with_time or None) if must_with_time else None, need) # 2) Fallback: 0 Treffer → ohne Zeit-Range scrollen und lokal filtern fallback_local_time_check = False if applied_server_range and not pts: pts = _scroll_collect(PLAN_COLLECTION, Filter(must=base_must or None) if base_must else None, need) fallback_local_time_check = True def _in_window(py: Dict[str, Any]) -> bool: if not (created_from or created_to): return True if applied_server_range and not fallback_local_time_check: return True # serverseitig bereits gefiltert ts = py.get("created_at") if isinstance(ts, dict) and ts.get("$date"): ts = ts["$date"] if isinstance(py.get("created_at_ts"), (int, float)): dt = datetime.fromtimestamp(float(py["created_at_ts"]), tz=timezone.utc) elif isinstance(ts, str): try: dt = datetime.fromisoformat(ts.replace("Z", "+00:00")) except Exception: return False elif isinstance(ts, datetime): dt = ts else: return False ok = True if created_from: try: ok = ok and dt >= datetime.fromisoformat(created_from.replace("Z", "+00:00")) except Exception: pass if created_to: try: ok = ok and dt <= datetime.fromisoformat(created_to.replace("Z", "+00:00")) except Exception: pass return ok payloads: List[Dict[str, Any]] = [] for p in pts: py = dict(p.payload or {}) py.setdefault("id", str(p.id)) if _in_window(py): payloads.append(py) sliced = payloads[offset:offset+limit] items = [_as_model(Plan, x) for x in sliced] return PlanList(items=items, limit=limit, offset=offset, count=len(items))