llm-api/plan_router.py hinzugefügt
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s
This commit is contained in:
parent
4552e33cb3
commit
5dbe887ce3
209
llm-api/plan_router.py
Normal file
209
llm-api/plan_router.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
-*- coding: utf-8 -*-
|
||||
"""
|
||||
plan_router.py – v0.9.0 (WP-15)
|
||||
|
||||
Minimal-CRUD für Plan-Templates & Pläne (POST/GET) + Idempotenz via Fingerprint.
|
||||
Keine bestehenden API-Signaturen geändert. Qdrant-Client-Stil wie exercise_router.
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import uuid4
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
from clients import model, qdrant
|
||||
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance
|
||||
|
||||
router = APIRouter(tags=["plans"])
|
||||
|
||||
# -----------------
|
||||
# Konfiguration
|
||||
# -----------------
|
||||
PLAN_COLLECTION = os.getenv("PLAN_COLLECTION") or os.getenv("QDRANT_COLLECTION_PLANS", "plans")
|
||||
PLAN_TEMPLATE_COLLECTION = os.getenv("PLAN_TEMPLATE_COLLECTION", "plan_templates")
|
||||
PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions")
|
||||
|
||||
# -----------------
|
||||
# Modelle
|
||||
# -----------------
|
||||
class TemplateSection(BaseModel):
|
||||
name: str
|
||||
target_minutes: int
|
||||
must_keywords: List[str] = []
|
||||
forbid_keywords: List[str] = []
|
||||
capability_targets: Dict[str, int] = {}
|
||||
|
||||
class PlanTemplate(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str
|
||||
discipline: str
|
||||
age_group: str
|
||||
target_group: str
|
||||
total_minutes: int
|
||||
sections: List[TemplateSection] = []
|
||||
goals: List[str] = []
|
||||
equipment_allowed: List[str] = []
|
||||
created_by: str
|
||||
version: str
|
||||
|
||||
class PlanItem(BaseModel):
|
||||
exercise_external_id: str
|
||||
duration: int
|
||||
why: str
|
||||
|
||||
class PlanSection(BaseModel):
|
||||
name: str
|
||||
items: List[PlanItem] = []
|
||||
minutes: int
|
||||
|
||||
class Plan(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
template_id: Optional[str] = None
|
||||
title: str
|
||||
discipline: str
|
||||
age_group: str
|
||||
target_group: str
|
||||
total_minutes: int
|
||||
sections: List[PlanSection] = []
|
||||
goals: List[str] = []
|
||||
capability_summary: Dict[str, int] = {}
|
||||
novelty_against_last_n: Optional[float] = None
|
||||
fingerprint: Optional[str] = None
|
||||
created_by: str
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
source: str = "API"
|
||||
|
||||
# -----------------
|
||||
# Helpers
|
||||
# -----------------
|
||||
|
||||
def _ensure_collection(name: str):
|
||||
# Legt Collection an, wenn sie fehlt (analog exercise_router)
|
||||
if not qdrant.collection_exists(name):
|
||||
qdrant.recreate_collection(
|
||||
collection_name=name,
|
||||
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
|
||||
)
|
||||
|
||||
|
||||
def _norm_list(xs: List[str]) -> List[str]:
|
||||
seen, out = set(), []
|
||||
for x in xs or []:
|
||||
s = str(x).strip()
|
||||
k = s.casefold()
|
||||
if s and k not in seen:
|
||||
seen.add(k)
|
||||
out.append(s)
|
||||
return sorted(out, key=str.casefold)
|
||||
|
||||
|
||||
def _template_embed_text(tpl: PlanTemplate) -> str:
|
||||
parts = [tpl.name, tpl.discipline, tpl.age_group, tpl.target_group]
|
||||
parts += tpl.goals
|
||||
parts += [s.name for s in tpl.sections]
|
||||
return ". ".join([p for p in parts if p])
|
||||
|
||||
|
||||
def _plan_embed_text(p: Plan) -> str:
|
||||
parts = [p.title, p.discipline, p.age_group, p.target_group]
|
||||
parts += p.goals
|
||||
parts += [s.name for s in p.sections]
|
||||
return ". ".join([p for p in parts if p])
|
||||
|
||||
|
||||
def _embed(text: str):
|
||||
return model.encode(text or "").tolist()
|
||||
|
||||
|
||||
def _fingerprint_for_plan(p: Plan) -> str:
|
||||
"""sha256(title, total_minutes, sections.items.exercise_external_id, sections.items.duration)"""
|
||||
core = {
|
||||
"title": p.title,
|
||||
"total_minutes": int(p.total_minutes),
|
||||
"items": [
|
||||
{
|
||||
"exercise_external_id": it.exercise_external_id,
|
||||
"duration": int(it.duration),
|
||||
}
|
||||
for sec in p.sections
|
||||
for it in (sec.items or [])
|
||||
],
|
||||
}
|
||||
raw = json.dumps(core, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
|
||||
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
|
||||
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
|
||||
if not pts:
|
||||
return None
|
||||
payload = dict(pts[0].payload or {})
|
||||
payload.setdefault("id", str(pts[0].id))
|
||||
return payload
|
||||
|
||||
# -----------------
|
||||
# Endpoints
|
||||
# -----------------
|
||||
@router.post("/plan_templates", response_model=PlanTemplate)
|
||||
def create_plan_template(t: PlanTemplate):
|
||||
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
|
||||
payload = t.model_dump()
|
||||
payload["goals"] = _norm_list(payload.get("goals"))
|
||||
for s in payload.get("sections", []) or []:
|
||||
s["must_keywords"] = _norm_list(s.get("must_keywords") or [])
|
||||
s["forbid_keywords"] = _norm_list(s.get("forbid_keywords") or [])
|
||||
vec = _embed(_template_embed_text(t))
|
||||
qdrant.upsert(collection_name=PLAN_TEMPLATE_COLLECTION, points=[PointStruct(id=str(t.id), vector=vec, payload=payload)])
|
||||
return t
|
||||
|
||||
|
||||
@router.get("/plan_templates/{tpl_id}", response_model=PlanTemplate)
|
||||
def get_plan_template(tpl_id: str):
|
||||
_ensure_collection(PLAN_TEMPLATE_COLLECTION)
|
||||
found = _get_by_field(PLAN_TEMPLATE_COLLECTION, "id", tpl_id)
|
||||
if not found:
|
||||
raise HTTPException(status_code=404, detail="not found")
|
||||
return PlanTemplate(**found)
|
||||
|
||||
|
||||
@router.post("/plan", response_model=Plan)
|
||||
def create_plan(p: Plan):
|
||||
_ensure_collection(PLAN_COLLECTION)
|
||||
# Fingerprint
|
||||
fp = _fingerprint_for_plan(p)
|
||||
p.fingerprint = p.fingerprint or fp
|
||||
|
||||
# Idempotenz
|
||||
existing = _get_by_field(PLAN_COLLECTION, "fingerprint", p.fingerprint)
|
||||
if existing:
|
||||
return Plan(**existing)
|
||||
|
||||
# Normalisieren
|
||||
p.goals = _norm_list(p.goals)
|
||||
payload = p.model_dump()
|
||||
# ISO8601 sicherstellen
|
||||
if isinstance(payload.get("created_at"), datetime):
|
||||
payload["created_at"] = payload["created_at"].astimezone(timezone.utc).isoformat()
|
||||
|
||||
vec = _embed(_plan_embed_text(p))
|
||||
qdrant.upsert(collection_name=PLAN_COLLECTION, points=[PointStruct(id=str(p.id), vector=vec, payload=payload)])
|
||||
return p
|
||||
|
||||
|
||||
@router.get("/plan/{plan_id}", response_model=Plan)
|
||||
def get_plan(plan_id: str):
|
||||
_ensure_collection(PLAN_COLLECTION)
|
||||
found = _get_by_field(PLAN_COLLECTION, "id", plan_id)
|
||||
if not found:
|
||||
raise HTTPException(status_code=404, detail="not found")
|
||||
# created_at zurück in datetime (ISO)
|
||||
if isinstance(found.get("created_at"), str):
|
||||
try:
|
||||
found["created_at"] = datetime.fromisoformat(found["created_at"])
|
||||
except Exception:
|
||||
pass
|
||||
return Plan(**found)
|
||||
Loading…
Reference in New Issue
Block a user