llm-api/plan_session_router.py hinzugefügt
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 1s
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 1s
This commit is contained in:
parent
5c51d3bc4f
commit
40b1151023
112
llm-api/plan_session_router.py
Normal file
112
llm-api/plan_session_router.py
Normal file
|
|
@ -0,0 +1,112 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
plan_session_router.py – v0.1.0 (WP-15)
|
||||||
|
|
||||||
|
CRUD-Minimum für Plan-Sessions (POST/GET).
|
||||||
|
Kompatibel zum Qdrant-Client-Stil der bestehenden Router.
|
||||||
|
"""
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel, Field, conint
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from uuid import uuid4
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import os
|
||||||
|
|
||||||
|
from clients import model, qdrant
|
||||||
|
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue, VectorParams, Distance
|
||||||
|
|
||||||
|
router = APIRouter(tags=["plan_sessions"])
|
||||||
|
|
||||||
|
# -----------------
|
||||||
|
# Konfiguration
|
||||||
|
# -----------------
|
||||||
|
PLAN_SESSION_COLLECTION = os.getenv("PLAN_SESSION_COLLECTION", "plan_sessions")
|
||||||
|
|
||||||
|
# -----------------
|
||||||
|
# Modelle
|
||||||
|
# -----------------
|
||||||
|
class Feedback(BaseModel):
|
||||||
|
rating: conint(ge=1, le=5)
|
||||||
|
notes: str
|
||||||
|
|
||||||
|
class PlanSession(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
|
plan_id: str
|
||||||
|
executed_at: datetime
|
||||||
|
location: str
|
||||||
|
coach: str
|
||||||
|
group_label: str
|
||||||
|
feedback: Feedback
|
||||||
|
used_equipment: List[str] = []
|
||||||
|
|
||||||
|
# -----------------
|
||||||
|
# Helpers
|
||||||
|
# -----------------
|
||||||
|
|
||||||
|
def _ensure_collection(name: str):
|
||||||
|
if not qdrant.collection_exists(name):
|
||||||
|
qdrant.recreate_collection(
|
||||||
|
collection_name=name,
|
||||||
|
vectors_config=VectorParams(size=model.get_sentence_embedding_dimension(), distance=Distance.COSINE),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _norm_list(xs: List[str]) -> List[str]:
|
||||||
|
seen, out = set(), []
|
||||||
|
for x in xs or []:
|
||||||
|
s = str(x).strip()
|
||||||
|
k = s.casefold()
|
||||||
|
if s and k not in seen:
|
||||||
|
seen.add(k)
|
||||||
|
out.append(s)
|
||||||
|
return sorted(out, key=str.casefold)
|
||||||
|
|
||||||
|
|
||||||
|
def _session_embed_text(s: PlanSession) -> str:
|
||||||
|
parts = [s.plan_id, s.location, s.coach, s.group_label, s.feedback.notes]
|
||||||
|
return ". ".join([p for p in parts if p])
|
||||||
|
|
||||||
|
|
||||||
|
def _embed(text: str):
|
||||||
|
return model.encode(text or "").tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_by_field(collection: str, key: str, value: Any) -> Optional[Dict[str, Any]]:
|
||||||
|
flt = Filter(must=[FieldCondition(key=key, match=MatchValue(value=value))])
|
||||||
|
pts, _ = qdrant.scroll(collection_name=collection, scroll_filter=flt, limit=1, with_payload=True)
|
||||||
|
if not pts:
|
||||||
|
return None
|
||||||
|
payload = dict(pts[0].payload or {})
|
||||||
|
payload.setdefault("id", str(pts[0].id))
|
||||||
|
return payload
|
||||||
|
|
||||||
|
# -----------------
|
||||||
|
# Endpoints
|
||||||
|
# -----------------
|
||||||
|
@router.post("/plan_sessions", response_model=PlanSession)
|
||||||
|
def create_plan_session(s: PlanSession):
|
||||||
|
_ensure_collection(PLAN_SESSION_COLLECTION)
|
||||||
|
# Normalisieren
|
||||||
|
s.used_equipment = _norm_list(s.used_equipment)
|
||||||
|
payload = s.model_dump()
|
||||||
|
# ISO8601 für executed_at sicherstellen
|
||||||
|
if isinstance(payload.get("executed_at"), datetime):
|
||||||
|
payload["executed_at"] = payload["executed_at"].astimezone(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
vec = _embed(_session_embed_text(s))
|
||||||
|
qdrant.upsert(collection_name=PLAN_SESSION_COLLECTION, points=[PointStruct(id=str(s.id), vector=vec, payload=payload)])
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/plan_sessions/{session_id}", response_model=PlanSession)
|
||||||
|
def get_plan_session(session_id: str):
|
||||||
|
_ensure_collection(PLAN_SESSION_COLLECTION)
|
||||||
|
found = _get_by_field(PLAN_SESSION_COLLECTION, "id", session_id)
|
||||||
|
if not found:
|
||||||
|
raise HTTPException(status_code=404, detail="not found")
|
||||||
|
if isinstance(found.get("executed_at"), str):
|
||||||
|
try:
|
||||||
|
found["executed_at"] = datetime.fromisoformat(found["executed_at"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return PlanSession(**found)
|
||||||
Loading…
Reference in New Issue
Block a user