updated RAG Search functionality
This commit is contained in:
@@ -3,6 +3,8 @@ Supabase data layer aligned with TitanForge/schema.sql.
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -399,15 +401,104 @@ def search_rag_chunks_vector(user_id: int, query_embedding: List[float], limit:
|
||||
Vector search via SQL RPC function `match_rag_chunks` (pgvector).
|
||||
"""
|
||||
vector_text = "[" + ",".join(str(float(v)) for v in query_embedding) + "]"
|
||||
response = supabase.rpc(
|
||||
"match_rag_chunks",
|
||||
{
|
||||
"p_user_id": user_id,
|
||||
"p_query_embedding": vector_text,
|
||||
"p_match_count": min(max(1, limit), 100),
|
||||
},
|
||||
).execute()
|
||||
return _rows(response)
|
||||
safe_limit = min(max(1, limit), 100)
|
||||
|
||||
try:
|
||||
response = supabase.rpc(
|
||||
"match_rag_chunks",
|
||||
{
|
||||
"p_user_id": user_id,
|
||||
"p_query_embedding": vector_text,
|
||||
"p_match_count": safe_limit,
|
||||
},
|
||||
).execute()
|
||||
rows = _rows(response)
|
||||
if rows:
|
||||
return rows
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: pull candidate chunks and rank with cosine similarity in Python.
|
||||
response = (
|
||||
supabase.table("rag_chunks")
|
||||
.select(
|
||||
"chunk_id, post_id, start_sec, end_sec, text, confidence, created_at, embedding, "
|
||||
"audio_posts!inner(post_id, user_id, title, visibility, created_at)"
|
||||
)
|
||||
.eq("audio_posts.user_id", user_id)
|
||||
.limit(3000)
|
||||
.execute()
|
||||
)
|
||||
candidates = _rows(response)
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
q = _normalize_vec(query_embedding)
|
||||
ranked = []
|
||||
for row in candidates:
|
||||
emb = _parse_embedding(row.get("embedding"))
|
||||
if not emb:
|
||||
continue
|
||||
score = _cosine_similarity(q, emb)
|
||||
if score is None:
|
||||
continue
|
||||
out = dict(row)
|
||||
out["similarity"] = score
|
||||
out.pop("embedding", None)
|
||||
ranked.append(out)
|
||||
|
||||
ranked.sort(key=lambda r: r.get("similarity", 0.0), reverse=True)
|
||||
return ranked[:safe_limit]
|
||||
|
||||
|
||||
def _parse_embedding(value: Any) -> Optional[List[float]]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
try:
|
||||
return [float(v) for v in value]
|
||||
except Exception:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return None
|
||||
try:
|
||||
if text.startswith("[") and text.endswith("]"):
|
||||
return [float(v) for v in text[1:-1].split(",") if v.strip()]
|
||||
parsed = json.loads(text)
|
||||
if isinstance(parsed, list):
|
||||
return [float(v) for v in parsed]
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_vec(vec: List[float]) -> List[float]:
|
||||
if not vec:
|
||||
return []
|
||||
norm = math.sqrt(sum(float(v) * float(v) for v in vec))
|
||||
if norm <= 0:
|
||||
return [0.0 for _ in vec]
|
||||
return [float(v) / norm for v in vec]
|
||||
|
||||
|
||||
def _cosine_similarity(a: List[float], b: List[float]) -> Optional[float]:
|
||||
if not a or not b:
|
||||
return None
|
||||
n = min(len(a), len(b))
|
||||
if n == 0:
|
||||
return None
|
||||
dot = 0.0
|
||||
bnorm = 0.0
|
||||
for i in range(n):
|
||||
av = float(a[i])
|
||||
bv = float(b[i])
|
||||
dot += av * bv
|
||||
bnorm += bv * bv
|
||||
if bnorm <= 0:
|
||||
return None
|
||||
return dot / math.sqrt(bnorm)
|
||||
|
||||
|
||||
# ==================== Audit Log ====================
|
||||
|
||||
Reference in New Issue
Block a user