Trainer_LLM/scripts/backfill_capability_facets.py
Lars fa8a92208a
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s
scripts/backfill_capability_facets.py aktualisiert
2025-08-11 19:21:56 +02:00

110 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Idempotentes Backfill-Skript für Capability-Facetten in Qdrant.
- Kompatibel mit qdrant-client 1.15.x: **kein** WithPayloadSelector-Import nötig
- Liest alle Punkte der Collection mit Payload (scroll, with_payload=True)
- Schreibt folgende Felder pro Point nach:
* capability_keys
* capability_ge1 .. capability_ge5
* capability_eq1 .. capability_eq5
Hinweis: Das Skript setzt KEINE Vektoren neu, es aktualisiert nur Payload-Felder.
"""
import os
from typing import Dict, Any, List, Tuple, Optional
from qdrant_client import QdrantClient
from qdrant_client.models import Filter # nur für API-Kompatibilität; wird hier leer genutzt
COLL = os.getenv("EXERCISE_COLLECTION", "exercises")
QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
BATCH = int(os.getenv("BACKFILL_BATCH", "256"))
def _facet_capabilities(caps: Dict[str, Any]) -> Dict[str, List[str]]:
caps = caps or {}
def names_where(pred) -> List[str]:
out = []
for k, v in caps.items():
try:
iv = int(v)
except Exception:
iv = 0
if pred(iv):
s = str(k).strip()
if s:
out.append(s)
# stabil sortieren
return sorted({s for s in out}, key=str.casefold)
all_keys = sorted({str(k).strip() for k in caps.keys() if str(k).strip()}, key=str.casefold)
return {
"capability_keys": all_keys,
# >= N
"capability_ge1": names_where(lambda lv: lv >= 1),
"capability_ge2": names_where(lambda lv: lv >= 2),
"capability_ge3": names_where(lambda lv: lv >= 3),
"capability_ge4": names_where(lambda lv: lv >= 4),
"capability_ge5": names_where(lambda lv: lv >= 5),
# == N
"capability_eq1": names_where(lambda lv: lv == 1),
"capability_eq2": names_where(lambda lv: lv == 2),
"capability_eq3": names_where(lambda lv: lv == 3),
"capability_eq4": names_where(lambda lv: lv == 4),
"capability_eq5": names_where(lambda lv: lv == 5),
}
def main() -> None:
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
# Sanity: Collection muss existieren
info = client.get_collection(COLL)
print(f"[Backfill] Collection '{COLL}' ok vectors={info.config.params.vectors}")
updated = 0
offset = None
page = 0
while True:
page += 1
points, offset = client.scroll(
collection_name=COLL,
scroll_filter=None, # alles
offset=offset,
limit=BATCH,
with_payload=True,
)
if not points:
break
for pt in points:
pld = pt.payload or {}
caps = pld.get("capabilities") or {}
facets = _facet_capabilities(caps)
# Nur schreiben, wenn sich etwas ändert oder Felder fehlen
need = False
for k, v in facets.items():
if pld.get(k) != v:
need = True
break
if not need:
continue
# set_payload: pro Punkt separat (per-Point Payload)
client.set_payload(collection_name=COLL, points=[pt.id], payload=facets)
updated += 1
print(f"[Backfill] page={page} processed={len(points)} updated_total={updated}")
print(f"[Backfill] done. total_updated={updated}")
if __name__ == "__main__":
main()