""" Mehrstufiges Retrieval für Planungs-Übungssuche (S1b). Stufen: S1b-0 Kandidaten-Pool (Profil-Signale, Volltext, Progressions-Nachfolger) S1b-1 Profil-Vorselektion → Top-K vor teurem Hybrid-Score S1b-2 Hybrid-Score (Volltext, Graph, Skills, Plan, Profil, Wiederholung) """ from __future__ import annotations from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple from planning_exercise_profiles import ( PlanningTargetProfile, load_exercise_match_profiles_bulk, score_exercise_against_target, ) _RAW_POOL_LIMIT = 500 _PROFILE_PRESELECT_LIMIT = 160 def _skill_jaccard(a: Set[int], b: Set[int]) -> float: if not a or not b: return 0.0 inter = len(a & b) union = len(a | b) return inter / union if union else 0.0 def _top_weight_keys(weights: Mapping[int, float], limit: int) -> List[int]: if not weights: return [] return [ int(k) for k, _ in sorted(weights.items(), key=lambda x: -float(x[1]))[:limit] if int(k) > 0 ] def _target_profile_signals(target: PlanningTargetProfile) -> Tuple[List[int], List[int], List[int]]: skill_ids = _top_weight_keys(target.skill_weights, 8) for sid in _top_weight_keys(target.skill_gap_weights, 6): if sid not in skill_ids: skill_ids.append(sid) focus_ids = _top_weight_keys(target.focus_area_ids, 6) style_ids = _top_weight_keys(target.style_direction_ids, 4) return skill_ids[:12], focus_ids, style_ids def fetch_retrieval_candidate_rows( cur, *, vis_sql: str, vis_params: Sequence[Any], query: str, exercise_kind_any: Optional[List[str]], target: PlanningTargetProfile, progression_successor_ids: Set[int], anchor_skill_ids: Set[int], raw_pool_limit: int = _RAW_POOL_LIMIT, ) -> List[Dict[str, Any]]: """S1b-0: Profil-geführter Kandidaten-Pool.""" where = [vis_sql, "COALESCE(e.status, '') <> %s"] params: List[Any] = [] if query: ft_select = "ts_rank_cd(e.search_vector, plainto_tsquery('german', %s)) AS ft_rank" # SELECT-Platzhalter steht im SQL vor WHERE — Query zuerst binden. params.append(query) else: ft_select = "0.0::float AS ft_rank" params.extend(vis_params) params.append("archived") ek_filtered: List[str] = [] if exercise_kind_any: for raw in exercise_kind_any: s = str(raw or "").strip().lower() if s in ("simple", "combination") and s not in ek_filtered: ek_filtered.append(s) if ek_filtered: ph = ",".join(["%s"] * len(ek_filtered)) where.append(f"(LOWER(TRIM(COALESCE(e.exercise_kind::text,''))) IN ({ph}))") params.extend(ek_filtered) skill_ids, focus_ids, style_ids = _target_profile_signals(target) if not skill_ids and anchor_skill_ids: skill_ids = sorted(anchor_skill_ids)[:10] profile_clauses: List[str] = [] if skill_ids: ph = ",".join(["%s"] * len(skill_ids)) profile_clauses.append( f"EXISTS (SELECT 1 FROM exercise_skills es WHERE es.exercise_id = e.id AND es.skill_id IN ({ph}))" ) params.extend(skill_ids) if focus_ids: ph = ",".join(["%s"] * len(focus_ids)) profile_clauses.append( f"EXISTS (SELECT 1 FROM exercise_focus_areas efa WHERE efa.exercise_id = e.id AND efa.focus_area_id IN ({ph}))" ) params.extend(focus_ids) if style_ids: ph = ",".join(["%s"] * len(style_ids)) profile_clauses.append( f"EXISTS (SELECT 1 FROM exercise_style_directions esd WHERE esd.exercise_id = e.id AND esd.style_direction_id IN ({ph}))" ) params.extend(style_ids) if progression_successor_ids: ph = ",".join(["%s"] * len(progression_successor_ids)) profile_clauses.append(f"e.id IN ({ph})") params.extend(sorted(progression_successor_ids)) if query: profile_clauses.append("e.search_vector @@ plainto_tsquery('german', %s)") params.append(query) use_profile_pool = bool(profile_clauses) if use_profile_pool: where.append(f"({' OR '.join(profile_clauses)})") order_by = "e.updated_at DESC, e.id DESC" if query: order_by = "ft_rank DESC NULLS LAST, e.updated_at DESC, e.id DESC" sql = f""" SELECT e.id, e.title, e.summary, ( SELECT fa.name FROM exercise_focus_areas efa JOIN focus_areas fa ON fa.id = efa.focus_area_id WHERE efa.exercise_id = e.id ORDER BY efa.is_primary DESC NULLS LAST, fa.name ASC LIMIT 1 ) AS primary_focus_name, {ft_select} FROM exercises e WHERE {' AND '.join(where)} ORDER BY {order_by} LIMIT %s """ params.append(int(raw_pool_limit)) cur.execute(sql, params) rows = [dict(r) for r in cur.fetchall()] if rows or not use_profile_pool: return rows return _fetch_broad_fallback_pool( cur, vis_sql=vis_sql, vis_params=vis_params, query=query, ek_filtered=ek_filtered, raw_pool_limit=raw_pool_limit, ) def _fetch_broad_fallback_pool( cur, *, vis_sql: str, vis_params: Sequence[Any], query: str, ek_filtered: List[str], raw_pool_limit: int, ) -> List[Dict[str, Any]]: fallback_where = [vis_sql, "COALESCE(e.status, '') <> %s"] fallback_params: List[Any] = list(vis_params) fallback_params.append("archived") if ek_filtered: ph = ",".join(["%s"] * len(ek_filtered)) fallback_where.append(f"(LOWER(TRIM(COALESCE(e.exercise_kind::text,''))) IN ({ph}))") fallback_params.extend(ek_filtered) if query: ft_fb = "ts_rank_cd(e.search_vector, plainto_tsquery('german', %s)) AS ft_rank" fb_order = "ft_rank DESC NULLS LAST, e.updated_at DESC, e.id DESC" fallback_params.insert(0, query) else: ft_fb = "0.0::float AS ft_rank" fb_order = "e.updated_at DESC, e.id DESC" fb_sql = f""" SELECT e.id, e.title, e.summary, ( SELECT fa.name FROM exercise_focus_areas efa JOIN focus_areas fa ON fa.id = efa.focus_area_id WHERE efa.exercise_id = e.id ORDER BY efa.is_primary DESC NULLS LAST, fa.name ASC LIMIT 1 ) AS primary_focus_name, {ft_fb} FROM exercises e WHERE {' AND '.join(fallback_where)} ORDER BY {fb_order} LIMIT %s """ fallback_params.append(int(raw_pool_limit)) cur.execute(fb_sql, fallback_params) return [dict(r) for r in cur.fetchall()] def profile_preselect_rows( cur, rows: Sequence[Dict[str, Any]], *, target: PlanningTargetProfile, intent: str, progression_successor_ids: Set[int], query: str, preselect_limit: int = _PROFILE_PRESELECT_LIMIT, ) -> Tuple[List[Dict[str, Any]], bool]: """S1b-1: Profil-Score auf Pool, Top-K für Hybrid.""" if len(rows) <= preselect_limit: return list(rows), False cand_ids = [int(r["id"]) for r in rows] match_profiles = load_exercise_match_profiles_bulk(cur, cand_ids) scored: List[Tuple[float, Dict[str, Any]]] = [] row_by_id = {int(r["id"]): r for r in rows} must_keep: Set[int] = set(int(x) for x in progression_successor_ids) if query: max_ft = max(float(r.get("ft_rank") or 0.0) for r in rows) or 0.0 if max_ft > 0: for r in rows: if float(r.get("ft_rank") or 0.0) / max_ft >= 0.5: must_keep.add(int(r["id"])) for eid in cand_ids: emp = match_profiles.get(eid) profile_score = 0.0 if emp: profile_score, _ = score_exercise_against_target(emp, target, intent=intent) scored.append((profile_score, row_by_id[eid])) scored.sort(key=lambda x: (-x[0], str(x[1].get("title") or ""))) selected: List[Dict[str, Any]] = [] seen: Set[int] = set() for _, row in scored: eid = int(row["id"]) if eid in seen: continue seen.add(eid) selected.append(row) if len(selected) >= preselect_limit: break for eid in must_keep: if eid in seen: continue row = row_by_id.get(eid) if row: selected.append(row) seen.add(eid) return selected, True def hybrid_score_planning_hits( cur, rows: Sequence[Dict[str, Any]], *, query: str, intent: str, intent_weights: Mapping[str, float], target: PlanningTargetProfile, pack: Mapping[str, Any], ) -> Tuple[List[Dict[str, Any]], Dict[int, Set[int]]]: """S1b-2: Hybrid-Score auf vorselektiertem Pool.""" planned_set = set(pack.get("planned_exercise_ids") or []) group_recent_set = set(pack.get("group_recent_exercise_ids") or []) progression_set = set(pack.get("progression_successor_ids") or []) anchor_skills = set(pack.get("anchor_skill_ids") or []) anchor_id = pack.get("anchor_exercise_id") progression_notes = pack.get("progression_edge_notes") or {} last_planned_skills: Set[int] = set() planned_ids = pack.get("planned_exercise_ids") or [] if planned_ids: cur.execute( "SELECT skill_id FROM exercise_skills WHERE exercise_id = %s", (int(planned_ids[-1]),), ) last_planned_skills = {int(r["skill_id"]) for r in cur.fetchall() if r.get("skill_id")} cand_ids = [int(r["id"]) for r in rows] skills_by_ex: Dict[int, Set[int]] = {cid: set() for cid in cand_ids} match_profiles = load_exercise_match_profiles_bulk(cur, cand_ids) if cand_ids: ph = ",".join(["%s"] * len(cand_ids)) cur.execute( f"SELECT exercise_id, skill_id FROM exercise_skills WHERE exercise_id IN ({ph})", cand_ids, ) for r in cur.fetchall(): skills_by_ex.setdefault(int(r["exercise_id"]), set()).add(int(r["skill_id"])) max_ft = 0.0 scored_items: List[Dict[str, Any]] = [] for row in rows: eid = int(row["id"]) if anchor_id and eid == int(anchor_id): continue ft = float(row.get("ft_rank") or 0.0) if ft > max_ft: max_ft = ft scored_items.append( { "row": row, "eid": eid, "ft": ft, "skills": skills_by_ex.get(eid, set()), } ) weights = dict(intent_weights) hits: List[Dict[str, Any]] = [] for item in scored_items: eid = item["eid"] row = item["row"] ft_norm = (item["ft"] / max_ft) if max_ft > 0 else 0.0 prog_hit = 1.0 if eid in progression_set else 0.0 skill_sim = _skill_jaccard(anchor_skills, item["skills"]) if anchor_skills else 0.0 plan_aff = 0.0 if last_planned_skills and item["skills"]: plan_aff = _skill_jaccard(last_planned_skills, item["skills"]) repeat_unit = 1.0 if eid in planned_set else 0.0 repeat_group = 1.0 if eid in group_recent_set else 0.0 profile_score = 0.0 profile_reasons: List[str] = [] emp = match_profiles.get(eid) if emp: profile_score, profile_reasons = score_exercise_against_target( emp, target, intent=intent ) score = ( weights["fulltext"] * ft_norm + weights["progression"] * prog_hit + weights["skill"] * skill_sim + weights["plan"] * plan_aff + weights["profile"] * profile_score + weights["repeat_unit"] * repeat_unit + weights["repeat_group"] * repeat_group ) reasons: List[str] = [] if query and ft_norm >= 0.35: reasons.append("Volltext-Treffer") if prog_hit > 0: note = progression_notes.get(eid) reasons.append( f"Nachfolger im Progressionsgraph{f': {note}' if note else ''}" ) if skill_sim >= 0.2 and anchor_id: reasons.append("Fähigkeiten passen zur Anker-Übung") if plan_aff >= 0.25: reasons.append("Schließt an Skills der letzten geplanten Übung an") if repeat_unit > 0: reasons.append("Bereits in dieser Einheit eingeplant") if repeat_group > 0 and repeat_unit <= 0: reasons.append("Kürzlich in der Gruppe verwendet") for pr in profile_reasons: if pr not in reasons: reasons.append(pr) if score <= 0 and not reasons and not query: if prog_hit or skill_sim or plan_aff or profile_score: score = 0.05 + prog_hit * 0.3 + skill_sim * 0.2 + profile_score * 0.25 hits.append( { "id": eid, "title": row.get("title"), "summary": row.get("summary"), "focus_area": row.get("primary_focus_name"), "score": round(max(0.0, min(1.0, score)), 4), "reasons": reasons, } ) hits.sort(key=lambda h: (-h["score"], h.get("title") or "")) return hits, skills_by_ex def run_multistage_planning_retrieval( cur, *, vis_sql: str, vis_params: Sequence[Any], query: str, exercise_kind_any: Optional[List[str]], target: PlanningTargetProfile, intent: str, intent_weights: Mapping[str, float], pack: Mapping[str, Any], ) -> Tuple[List[Dict[str, Any]], Dict[int, Set[int]], bool]: """Orchestriert S1b-0 → S1b-1 → S1b-2.""" progression_set = set(pack.get("progression_successor_ids") or []) anchor_skills = set(pack.get("anchor_skill_ids") or []) rows = fetch_retrieval_candidate_rows( cur, vis_sql=vis_sql, vis_params=vis_params, query=query, exercise_kind_any=exercise_kind_any, target=target, progression_successor_ids=progression_set, anchor_skill_ids=anchor_skills, ) rows, preselect_applied = profile_preselect_rows( cur, rows, target=target, intent=intent, progression_successor_ids=progression_set, query=query, ) hits, skills_by_ex = hybrid_score_planning_hits( cur, rows, query=query, intent=intent, intent_weights=intent_weights, target=target, pack=pack, ) return hits, skills_by_ex, preselect_applied __all__ = [ "fetch_retrieval_candidate_rows", "hybrid_score_planning_hits", "profile_preselect_rows", "run_multistage_planning_retrieval", ]