llm-api/plan_router.py aktualisiert
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s
This commit is contained in:
parent
93cdde13a7
commit
5e2591fb56
|
|
@ -1,9 +1,14 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
plan_router.py – v0.13.0 (WP-15)
|
plan_router.py – v0.13.1 (WP-15)
|
||||||
|
|
||||||
Minimal-CRUD + List/Filter für Templates & Pläne.
|
Minimal-CRUD + List/Filter für Templates & Pläne.
|
||||||
Fix: Zeitfenster-Filter per Qdrant-Range über `created_at_ts` (FLOAT).
|
|
||||||
|
Änderungen ggü. v0.13.0
|
||||||
|
- Serverseitiger Zeitfenster-Filter über `created_at_ts` (FLOAT) bleibt erhalten.
|
||||||
|
- Lokaler Fallback-Zeitfilter wird DEAKTIVIERT, sobald ein serverseitiger Range aktiv ist
|
||||||
|
(verhindert false negatives).
|
||||||
|
- `plan_section_names` wird beim POST /plan materialisiert und für Filters genutzt.
|
||||||
"""
|
"""
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
@ -99,14 +104,15 @@ class PlanList(BaseModel):
|
||||||
# -----------------
|
# -----------------
|
||||||
|
|
||||||
def _ensure_collection(name: str):
|
def _ensure_collection(name: str):
|
||||||
|
"""Falls Collection fehlt, analog exercise_router anlegen."""
|
||||||
if not qdrant.collection_exists(name):
|
if not qdrant.collection_exists(name):
|
||||||
qdrant.recreate_collection(
|
qdrant.recreate_collection(
|
||||||
collection_name=name,
|
collection_name=name,
|
||||||
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
|
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _norm_list(xs: List[str]) -> List[str]:
|
def _norm_list(xs: List[str]) -> List[str]:
|
||||||
|
"""Trimmen, casefolded deduplizieren, stabil sortieren."""
|
||||||
seen, out = set(), []
|
seen, out = set(), []
|
||||||
for x in xs or []:
|
for x in xs or []:
|
||||||
s = str(x).strip()
|
s = str(x).strip()
|
||||||
|
|
@ -116,26 +122,23 @@ def _norm_list(xs: List[str]) -> List[str]:
|
||||||
out.append(s)
|
out.append(s)
|
||||||
return sorted(out, key=str.casefold)
|
return sorted(out, key=str.casefold)
|
||||||
|
|
||||||
|
|
||||||
def _template_embed_text(tpl: PlanTemplate) -> str:
|
def _template_embed_text(tpl: PlanTemplate) -> str:
|
||||||
parts = [tpl.name, tpl.discipline, tpl.age_group, tpl.target_group]
|
parts = [tpl.name, tpl.discipline, tpl.age_group, tpl.target_group]
|
||||||
parts += tpl.goals
|
parts += tpl.goals
|
||||||
parts += [s.name for s in tpl.sections]
|
parts += [s.name for s in tpl.sections]
|
||||||
return ". ".join([p for p in parts if p])
|
return ". ".join([p for p in parts if p])
|
||||||
|
|
||||||
|
|
||||||
def _plan_embed_text(p: Plan) -> str:
|
def _plan_embed_text(p: Plan) -> str:
|
||||||
parts = [p.title, p.discipline, p.age_group, p.target_group]
|
parts = [p.title, p.discipline, p.age_group, p.target_group]
|
||||||
parts += p.goals
|
parts += p.goals
|
||||||
parts += [s.name for s in p.sections]
|
parts += [s.name for s in p.sections]
|
||||||
return ". ".join([p for p in parts if p])
|
return ". ".join([p for p in parts if p])
|
||||||
|
|
||||||
|
|
||||||
def _embed(text: str):
|
def _embed(text: str):
|
||||||
return model.encode(text or "").tolist()
|
return model.encode(text or "").tolist()
|
||||||
|
|
||||||
|
|
||||||
def _fingerprint_for_plan(p: Plan) -> str:
|
def _fingerprint_for_plan(p: Plan) -> str:
|
||||||
|
"""sha256(title, total_minutes, sections.items.exercise_external_id, sections.items.duration)"""
|
||||||
core = {
|
core = {
|
||||||
"title": p.title,
|
"title": p.title,
|
||||||
"total_minutes": int(p.total_minutes),
|
"total_minutes": int(p.total_minutes),
|
||||||
|
|
@ -148,7 +151,6 @@ def _fingerprint_for_plan(p: Plan) -> str:
|
||||||
raw = json.dumps(core, sort_keys=True, ensure_ascii=False)
|
raw = json.dumps(core, sort_keys=True, ensure_ascii=False)
|
||||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
|
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
|
||||||
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
|
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
|
||||||
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
|
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
|
||||||
|
|
@ -158,18 +160,16 @@ def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, A
|
||||||
payload.setdefault("id", str(pts[0].id))
|
payload.setdefault("id", str(pts[0].id))
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
def _as_model(model_cls, payload: Dict[str, Any]):
|
def _as_model(model_cls, payload: Dict[str, Any]):
|
||||||
|
"""Unbekannte Payload-Felder herausfiltern (Pydantic v1/v2 kompatibel)."""
|
||||||
fields = getattr(model_cls, "model_fields", None) or getattr(model_cls, "__fields__", {})
|
fields = getattr(model_cls, "model_fields", None) or getattr(model_cls, "__fields__", {})
|
||||||
allowed = set(fields.keys())
|
allowed = set(fields.keys())
|
||||||
data = {k: payload[k] for k in payload.keys() if k in allowed}
|
data = {k: payload[k] for k in payload.keys() if k in allowed}
|
||||||
return model_cls(**data)
|
return model_cls(**data)
|
||||||
|
|
||||||
|
|
||||||
def _truthy(val: Optional[str]) -> bool:
|
def _truthy(val: Optional[str]) -> bool:
|
||||||
return str(val or "").strip().lower() in {"1", "true", "yes", "on"}
|
return str(val or "").strip().lower() in {"1", "true", "yes", "on"}
|
||||||
|
|
||||||
|
|
||||||
def _exists_in_collection(collection: str, key: str, value: Any) -> bool:
|
def _exists_in_collection(collection: str, key: str, value: Any) -> bool:
|
||||||
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
|
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
|
||||||
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=False)
|
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=False)
|
||||||
|
|
@ -211,7 +211,6 @@ def create_plan_template(t: PlanTemplate):
|
||||||
qdrant.upsert(collection_name=PLAN_TEMPLATE_COLLECTION, points=[PointStruct(id=str(t.id), vector=vec, payload=payload)])
|
qdrant.upsert(collection_name=PLAN_TEMPLATE_COLLECTION, points=[PointStruct(id=str(t.id), vector=vec, payload=payload)])
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/plan_templates/{tpl_id}",
|
"/plan_templates/{tpl_id}",
|
||||||
response_model=PlanTemplate,
|
response_model=PlanTemplate,
|
||||||
|
|
@ -225,7 +224,6 @@ def get_plan_template(tpl_id: str):
|
||||||
raise HTTPException(status_code=404, detail="not found")
|
raise HTTPException(status_code=404, detail="not found")
|
||||||
return _as_model(PlanTemplate, found)
|
return _as_model(PlanTemplate, found)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/plan_templates",
|
"/plan_templates",
|
||||||
response_model=PlanTemplateList,
|
response_model=PlanTemplateList,
|
||||||
|
|
@ -333,7 +331,6 @@ def create_plan(p: Plan):
|
||||||
if isinstance(dt, datetime):
|
if isinstance(dt, datetime):
|
||||||
dt = dt.astimezone(timezone.utc).isoformat()
|
dt = dt.astimezone(timezone.utc).isoformat()
|
||||||
elif isinstance(dt, str):
|
elif isinstance(dt, str):
|
||||||
# sicherheitshalber nach UTC normalisieren
|
|
||||||
try:
|
try:
|
||||||
_ = datetime.fromisoformat(dt.replace("Z", "+00:00"))
|
_ = datetime.fromisoformat(dt.replace("Z", "+00:00"))
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -361,7 +358,6 @@ def create_plan(p: Plan):
|
||||||
qdrant.upsert(collection_name=PLAN_COLLECTION, points=[PointStruct(id=str(p.id), vector=vec, payload=payload)])
|
qdrant.upsert(collection_name=PLAN_COLLECTION, points=[PointStruct(id=str(p.id), vector=vec, payload=payload)])
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/plan/{plan_id}",
|
"/plan/{plan_id}",
|
||||||
response_model=Plan,
|
response_model=Plan,
|
||||||
|
|
@ -380,7 +376,6 @@ def get_plan(plan_id: str):
|
||||||
pass
|
pass
|
||||||
return _as_model(Plan, found)
|
return _as_model(Plan, found)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/plans",
|
"/plans",
|
||||||
response_model=PlanList,
|
response_model=PlanList,
|
||||||
|
|
@ -430,7 +425,9 @@ def list_plans(
|
||||||
range_args["lte"] = float(datetime.fromisoformat(created_to.replace("Z", "+00:00")).timestamp())
|
range_args["lte"] = float(datetime.fromisoformat(created_to.replace("Z", "+00:00")).timestamp())
|
||||||
except Exception:
|
except Exception:
|
||||||
range_args = {}
|
range_args = {}
|
||||||
if range_args:
|
|
||||||
|
applied_server_range = bool(range_args)
|
||||||
|
if applied_server_range:
|
||||||
must.append(FieldCondition(key="created_at_ts", range=Range(**range_args)))
|
must.append(FieldCondition(key="created_at_ts", range=Range(**range_args)))
|
||||||
|
|
||||||
flt = Filter(must=must or None) if must else None
|
flt = Filter(must=must or None) if must else None
|
||||||
|
|
@ -438,8 +435,10 @@ def list_plans(
|
||||||
fetch_n = max(offset + limit, 1)
|
fetch_n = max(offset + limit, 1)
|
||||||
pts, _ = qdrant.scroll(collection_name=PLAN_COLLECTION, scroll_filter=flt, limit=fetch_n, with_payload=True)
|
pts, _ = qdrant.scroll(collection_name=PLAN_COLLECTION, scroll_filter=flt, limit=fetch_n, with_payload=True)
|
||||||
|
|
||||||
# Fallback: lokaler Zeitfilter (für Alt-Daten ohne created_at_ts)
|
# Fallback: nur wenn KEIN serverseitiger Range aktiv war (Alt-Daten ohne created_at_ts)
|
||||||
def _in_window(py: Dict[str, Any]) -> bool:
|
def _in_window(py: Dict[str, Any]) -> bool:
|
||||||
|
if applied_server_range:
|
||||||
|
return True
|
||||||
if not (created_from or created_to):
|
if not (created_from or created_to):
|
||||||
return True
|
return True
|
||||||
ts = py.get("created_at")
|
ts = py.get("created_at")
|
||||||
|
|
@ -476,4 +475,4 @@ def list_plans(
|
||||||
|
|
||||||
sliced = payloads[offset:offset+limit]
|
sliced = payloads[offset:offset+limit]
|
||||||
items = [_as_model(Plan, x) for x in sliced]
|
items = [_as_model(Plan, x) for x in sliced]
|
||||||
return PlanList(items=items, limit=limit, offset=offset, count=len(items))
|
return PlanList(items=items, limit=limit, offset=offset, count=len(items))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user