Trainer_LLM/scripts/backfill_capability_facets.py
Lars a6d68134cd
All checks were successful
Deploy Trainer_LLM to llm-node / deploy (push) Successful in 2s
scripts/backfill_capability_facets.py aktualisiert
2025-08-11 19:24:17 +02:00

96 lines
3.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Backfill Capability-Facetten in Qdrant v1.2
Fix: beendet korrekt, wenn `next_page_offset` (offset) None ist.
"""
import os
from typing import Dict, Any, List
from qdrant_client import QdrantClient
COLL = os.getenv("EXERCISE_COLLECTION", "exercises")
QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
BATCH = int(os.getenv("BACKFILL_BATCH", "256"))
def _facet_capabilities(caps: Dict[str, Any]) -> Dict[str, List[str]]:
caps = caps or {}
def names_where(pred) -> List[str]:
out = []
for k, v in caps.items():
try:
iv = int(v)
except Exception:
iv = 0
if pred(iv):
s = str(k).strip()
if s:
out.append(s)
return sorted({s for s in out}, key=str.casefold)
all_keys = sorted({str(k).strip() for k in caps.keys() if str(k).strip()}, key=str.casefold)
return {
"capability_keys": all_keys,
"capability_ge1": names_where(lambda lv: lv >= 1),
"capability_ge2": names_where(lambda lv: lv >= 2),
"capability_ge3": names_where(lambda lv: lv >= 3),
"capability_ge4": names_where(lambda lv: lv >= 4),
"capability_ge5": names_where(lambda lv: lv >= 5),
"capability_eq1": names_where(lambda lv: lv == 1),
"capability_eq2": names_where(lambda lv: lv == 2),
"capability_eq3": names_where(lambda lv: lv == 3),
"capability_eq4": names_where(lambda lv: lv == 4),
"capability_eq5": names_where(lambda lv: lv == 5),
}
def main() -> None:
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
info = client.get_collection(COLL)
print(f"[Backfill] Collection '{COLL}' ok vectors={info.config.params.vectors}")
updated_total = 0
offset = None
page = 0
while True:
points, next_offset = client.scroll(
collection_name=COLL,
scroll_filter=None,
offset=offset,
limit=BATCH,
with_payload=True,
)
page += 1
if not points:
print("[Backfill] no more points done")
break
updated_page = 0
for pt in points:
pld = pt.payload or {}
caps = pld.get("capabilities") or {}
facets = _facet_capabilities(caps)
# nur setzen, wenn sich etwas ändert
need = any(pld.get(k) != v for k, v in facets.items())
if not need:
continue
client.set_payload(collection_name=COLL, points=[pt.id], payload=facets)
updated_total += 1
updated_page += 1
print(f"[Backfill] page={page} processed={len(points)} updated_page={updated_page} updated_total={updated_total}")
# Ende erreicht? Dann nach dieser Seite aussteigen.
if next_offset is None:
break
offset = next_offset
print(f"[Backfill] done. total_updated={updated_total}")
if __name__ == "__main__":
main()