diff --git a/scripts/backfill_capability_facets.py b/scripts/backfill_capability_facets.py index b8e89d6..291b185 100644 --- a/scripts/backfill_capability_facets.py +++ b/scripts/backfill_capability_facets.py @@ -1,62 +1,109 @@ #!/usr/bin/env python3 -# backfill_capability_facets.py -import os, math +# -*- coding: utf-8 -*- +""" +Idempotentes Backfill-Skript für Capability-Facetten in Qdrant. + +- Kompatibel mit qdrant-client 1.15.x: **kein** WithPayloadSelector-Import nötig +- Liest alle Punkte der Collection mit Payload (scroll, with_payload=True) +- Schreibt folgende Felder pro Point nach: + * capability_keys + * capability_ge1 .. capability_ge5 + * capability_eq1 .. capability_eq5 + +Hinweis: Das Skript setzt KEINE Vektoren neu, es aktualisiert nur Payload-Felder. +""" + +import os +from typing import Dict, Any, List, Tuple, Optional from qdrant_client import QdrantClient -from qdrant_client.models import Filter, WithPayloadSelector +from qdrant_client.models import Filter # nur für API-Kompatibilität; wird hier leer genutzt COLL = os.getenv("EXERCISE_COLLECTION", "exercises") -client = QdrantClient(host=os.getenv("QDRANT_HOST","localhost"), port=int(os.getenv("QDRANT_PORT","6333"))) +QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost") +QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) +BATCH = int(os.getenv("BACKFILL_BATCH", "256")) -def names_ge(caps, n): - out=[] - for k,v in (caps or {}).items(): - try: - if int(v) >= n: - out.append(k) - except Exception: - pass - return sorted(out) -scroll_filter = Filter(must=[]) # alles -offset = None -updated = 0 +def _facet_capabilities(caps: Dict[str, Any]) -> Dict[str, List[str]]: + caps = caps or {} -while True: - res = client.scroll( - collection_name=COLL, - scroll_filter=scroll_filter, - with_payload=WithPayloadSelector(enable=True), - limit=256, - offset=offset - ) - points, offset = res - if not points: - break + def names_where(pred) -> List[str]: + out = [] + for k, v in caps.items(): + try: + iv = int(v) + except Exception: + iv = 0 + if pred(iv): + s = str(k).strip() + if s: + out.append(s) + # stabil sortieren + return sorted({s for s in out}, key=str.casefold) - set_list = [] - for pt in points: - p = pt.payload or {} - caps = p.get("capabilities") or {} + all_keys = sorted({str(k).strip() for k in caps.keys() if str(k).strip()}, key=str.casefold) - cap_keys = sorted([k for k in caps.keys() if k]) - ge1 = names_ge(caps, 1) - ge2 = names_ge(caps, 2) - ge3 = names_ge(caps, 3) + return { + "capability_keys": all_keys, + # >= N + "capability_ge1": names_where(lambda lv: lv >= 1), + "capability_ge2": names_where(lambda lv: lv >= 2), + "capability_ge3": names_where(lambda lv: lv >= 3), + "capability_ge4": names_where(lambda lv: lv >= 4), + "capability_ge5": names_where(lambda lv: lv >= 5), + # == N + "capability_eq1": names_where(lambda lv: lv == 1), + "capability_eq2": names_where(lambda lv: lv == 2), + "capability_eq3": names_where(lambda lv: lv == 3), + "capability_eq4": names_where(lambda lv: lv == 4), + "capability_eq5": names_where(lambda lv: lv == 5), + } - # nur setzen, wenn fehlt oder abweicht - if p.get("capability_keys") != cap_keys or p.get("capability_ge1") != ge1 \ - or p.get("capability_ge2") != ge2 or p.get("capability_ge3") != ge3: - set_list.append((pt.id, { - "capability_keys": cap_keys, - "capability_ge1": ge1, - "capability_ge2": ge2, - "capability_ge3": ge3 - })) - if set_list: - client.set_payload(collection_name=COLL, payload={pid: pay for pid, pay in set_list}, points=[pid for pid,_ in set_list]) - updated += len(set_list) - print(f"[Backfill] updated {len(set_list)} points…") +def main() -> None: + client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) -print(f"[Backfill] done. total={updated}") + # Sanity: Collection muss existieren + info = client.get_collection(COLL) + print(f"[Backfill] Collection '{COLL}' ok – vectors={info.config.params.vectors}") + updated = 0 + offset = None + page = 0 + + while True: + page += 1 + points, offset = client.scroll( + collection_name=COLL, + scroll_filter=None, # alles + offset=offset, + limit=BATCH, + with_payload=True, + ) + if not points: + break + + for pt in points: + pld = pt.payload or {} + caps = pld.get("capabilities") or {} + facets = _facet_capabilities(caps) + + # Nur schreiben, wenn sich etwas ändert oder Felder fehlen + need = False + for k, v in facets.items(): + if pld.get(k) != v: + need = True + break + if not need: + continue + + # set_payload: pro Punkt separat (per-Point Payload) + client.set_payload(collection_name=COLL, points=[pt.id], payload=facets) + updated += 1 + print(f"[Backfill] page={page} processed={len(points)} updated_total={updated}") + + print(f"[Backfill] done. total_updated={updated}") + + +if __name__ == "__main__": + main()