mindnet/app/frontend/ui_graph_service.py

221 lines
9.3 KiB
Python

from qdrant_client import QdrantClient, models
from streamlit_agraph import Node, Edge
from ui_config import GRAPH_COLORS, get_edge_color, SYSTEM_EDGES
class GraphExplorerService:
def __init__(self, url, api_key=None, prefix="mindnet"):
self.client = QdrantClient(url=url, api_key=api_key)
self.prefix = prefix
self.notes_col = f"{prefix}_notes"
self.chunks_col = f"{prefix}_chunks"
self.edges_col = f"{prefix}_edges"
self._note_cache = {}
def get_ego_graph(self, center_note_id: str, depth=2):
nodes_dict = {}
unique_edges = {}
# --- LEVEL 1: Center & direkte Nachbarn ---
# 1. Center Note
center_note = self._fetch_note_cached(center_note_id)
if not center_note: return [], []
self._add_node_to_dict(nodes_dict, center_note, level=0)
# Wir sammeln IDs für Level 2 Suche
level_1_ids = {center_note_id}
# Suche Kanten für Center
l1_edges = self._find_connected_edges([center_note_id], center_note.get("title"))
# Verarbeite L1 Kanten
for edge_data in l1_edges:
src_id, tgt_id = self._process_edge(edge_data, nodes_dict, unique_edges, current_depth=1)
if src_id: level_1_ids.add(src_id)
if tgt_id: level_1_ids.add(tgt_id)
# --- LEVEL 2: Nachbarn der Nachbarn ---
if depth > 1 and level_1_ids:
# Wir suchen Kanten, bei denen Source oder Target einer der L1 Nodes ist
# Wichtig: Wir filtern System-Edges schon in der Query oder Python, um Traffic zu sparen
# Um die Performance zu wahren, limitieren wir die L2 Suche auf die IDs, die wir schon haben (als Source)
# Das ist ein "Ego-Network" Ansatz.
# Wir nehmen alle IDs aus Level 1 (außer Center, das haben wir schon)
l1_subset = list(level_1_ids - {center_note_id})
if l1_subset:
l2_edges = self._find_connected_edges_batch(l1_subset)
for edge_data in l2_edges:
self._process_edge(edge_data, nodes_dict, unique_edges, current_depth=2)
# --- GRAPH CONSTRUCTION ---
final_edges = []
for (src, tgt), data in unique_edges.items():
kind = data['kind']
prov = data['provenance']
# Dynamische Farbe holen
color = get_edge_color(kind)
is_smart = (prov != "explicit" and prov != "rule")
final_edges.append(Edge(
source=src, target=tgt, label=kind, color=color, dashes=is_smart,
title=f"Provenance: {prov}\nType: {kind}"
))
return list(nodes_dict.values()), final_edges
def _find_connected_edges(self, note_ids, note_title=None):
"""Findet In- und Outgoing Edges für eine Liste von Note-IDs."""
# 1. Chunks zu diesen Notes finden
scroll_filter = models.Filter(
must=[models.FieldCondition(key="note_id", match=models.MatchAny(any=note_ids))]
)
chunks, _ = self.client.scroll(
collection_name=self.chunks_col, scroll_filter=scroll_filter, limit=200
)
chunk_ids = [c.id for c in chunks]
results = []
# Outgoing (Source is Chunk)
if chunk_ids:
out_f = models.Filter(must=[
models.FieldCondition(key="source_id", match=models.MatchAny(any=chunk_ids)),
# Filter System Edges
models.FieldCondition(key="kind", match=models.MatchExcept(except_=SYSTEM_EDGES))
])
res_out, _ = self.client.scroll(self.edges_col, scroll_filter=out_f, limit=100, with_payload=True)
results.extend(res_out)
# Incoming (Target is Chunk OR Title OR NoteID)
shoulds = []
if chunk_ids: shoulds.append(models.FieldCondition(key="target_id", match=models.MatchAny(any=chunk_ids)))
if note_title: shoulds.append(models.FieldCondition(key="target_id", match=models.MatchValue(value=note_title)))
shoulds.append(models.FieldCondition(key="target_id", match=models.MatchAny(any=note_ids)))
if shoulds:
in_f = models.Filter(
must=[models.FieldCondition(key="kind", match=models.MatchExcept(except_=SYSTEM_EDGES))],
should=shoulds
)
res_in, _ = self.client.scroll(self.edges_col, scroll_filter=in_f, limit=100, with_payload=True)
results.extend(res_in)
return results
def _find_connected_edges_batch(self, note_ids):
"""Batch-Suche für Level 2 (nur ausgehend und eingehend auf Note-Ebene, keine Title-Suche für Performance)."""
# Vereinfachte Suche: Wir suchen Kanten, die direkt mit den note_ids (oder deren Chunks) zu tun haben
# Um Performance zu sparen, machen wir hier einen simpleren Lookup, wenn möglich.
return self._find_connected_edges(note_ids)
def _process_edge(self, record, nodes_dict, unique_edges, current_depth):
payload = record.payload
src_ref = payload.get("source_id")
tgt_ref = payload.get("target_id")
kind = payload.get("kind")
provenance = payload.get("provenance", "explicit")
# Resolve
src_note = self._resolve_note_from_ref(src_ref)
tgt_note = self._resolve_note_from_ref(tgt_ref)
if src_note and tgt_note:
src_id = src_note['note_id']
tgt_id = tgt_note['note_id']
if src_id != tgt_id:
# Add Nodes
self._add_node_to_dict(nodes_dict, src_note, level=current_depth)
self._add_node_to_dict(nodes_dict, tgt_note, level=current_depth)
# Add Edge (Deduplication Logic)
key = (src_id, tgt_id)
existing = unique_edges.get(key)
# Update logic: Explicit > Smart
should_update = True
is_current_explicit = (provenance in ["explicit", "rule"])
if existing:
is_existing_explicit = (existing['provenance'] in ["explicit", "rule"])
if is_existing_explicit and not is_current_explicit:
should_update = False
if should_update:
unique_edges[key] = {
"source": src_id, "target": tgt_id, "kind": kind, "provenance": provenance
}
return src_id, tgt_id
return None, None
def _fetch_note_cached(self, note_id):
if note_id in self._note_cache: return self._note_cache[note_id]
res, _ = self.client.scroll(
collection_name=self.notes_col,
scroll_filter=models.Filter(must=[models.FieldCondition(key="note_id", match=models.MatchValue(value=note_id))]),
limit=1, with_payload=True
)
if res:
self._note_cache[note_id] = res[0].payload
return res[0].payload
return None
def _resolve_note_from_ref(self, ref_str):
if not ref_str: return None
# ... (Logik identisch zu vorher, hier gekürzt für Übersicht)
# Fall A: Chunk ID / Section
if "#" in ref_str:
try:
res = self.client.retrieve(self.chunks_col, ids=[ref_str], with_payload=True)
if res: return self._fetch_note_cached(res[0].payload.get("note_id"))
except: pass
possible_note_id = ref_str.split("#")[0]
if self._fetch_note_cached(possible_note_id): return self._fetch_note_cached(possible_note_id)
# Fall B: Note ID
if self._fetch_note_cached(ref_str): return self._fetch_note_cached(ref_str)
# Fall C: Titel
res, _ = self.client.scroll(
collection_name=self.notes_col,
scroll_filter=models.Filter(must=[models.FieldCondition(key="title", match=models.MatchValue(value=ref_str))]),
limit=1, with_payload=True
)
if res:
self._note_cache[res[0].payload['note_id']] = res[0].payload
return res[0].payload
return None
def _add_node_to_dict(self, node_dict, note_payload, level=1):
nid = note_payload.get("note_id")
# Wenn Node schon da ist, aber wir finden ihn auf einem "höheren" Level (näher am Zentrum), updaten wir ihn nicht zwingend,
# außer wir wollen visuelle Eigenschaften ändern.
if nid in node_dict: return
ntype = note_payload.get("type", "default")
color = GRAPH_COLORS.get(ntype, GRAPH_COLORS["default"])
# Größe & Label basierend auf Level
if level == 0:
size = 40
label_prefix = ""
elif level == 1:
size = 25
label_prefix = ""
else:
size = 15 # Level 2 kleiner
label_prefix = ""
node_dict[nid] = Node(
id=nid,
label=f"{label_prefix}{note_payload.get('title', nid)}",
size=size,
color=color,
shape="dot" if level > 0 else "diamond",
title=f"Type: {ntype}\nLevel: {level}\nTags: {note_payload.get('tags')}",
font={'color': 'black', 'face': 'arial', 'size': 14 if level < 2 else 10}
)