updated RAG Search functionality

This commit is contained in:
Gaumit Kauts
2026-02-15 12:13:33 -07:00
parent d5f9f26643
commit 89534fb836
5 changed files with 182 additions and 12 deletions

View File

@@ -45,6 +45,9 @@ BACKEND_UPLOAD_DIR=uploads
WHISPER_MODEL=base WHISPER_MODEL=base
WHISPER_DEVICE=cpu WHISPER_DEVICE=cpu
WHISPER_COMPUTE_TYPE=int8 WHISPER_COMPUTE_TYPE=int8
EMBEDDING_PROVIDER=local
OPENAI_EMBEDDING_MODEL=text-embedding-3-small
# OPENAI_API_KEY=... # only needed if EMBEDDING_PROVIDER=openai
``` ```
Notes: Notes:

View File

@@ -7,6 +7,8 @@ import hashlib
import json import json
import os import os
import uuid import uuid
import math
import re
from pathlib import Path from pathlib import Path
import io import io
import zipfile import zipfile
@@ -38,6 +40,7 @@ from db_queries import (
list_rag_chunks, list_rag_chunks,
list_user_history, list_user_history,
search_rag_chunks, search_rag_chunks,
search_rag_chunks_vector,
update_audio_post, update_audio_post,
upload_storage_object, upload_storage_object,
upsert_archive_metadata, upsert_archive_metadata,
@@ -61,6 +64,10 @@ WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8")
ARCHIVE_BUCKET = os.getenv("SUPABASE_BUCKET", os.getenv("SUPABASE_ARCHIVE_BUCKET", "archives")) ARCHIVE_BUCKET = os.getenv("SUPABASE_BUCKET", os.getenv("SUPABASE_ARCHIVE_BUCKET", "archives"))
_whisper_model: WhisperModel | None = None _whisper_model: WhisperModel | None = None
_openai_client = None
EMBEDDING_DIM = 1536
EMBEDDING_PROVIDER = (os.getenv("EMBEDDING_PROVIDER") or "local").strip().lower()
OPENAI_EMBEDDING_MODEL = (os.getenv("OPENAI_EMBEDDING_MODEL") or "text-embedding-3-small").strip()
def _model() -> WhisperModel: def _model() -> WhisperModel:
@@ -103,6 +110,62 @@ def _build_prompt(transcript_text: str, title: str) -> str:
f"{transcript_text}\n\n" f"{transcript_text}\n\n"
"Answer user questions grounded in this transcript." "Answer user questions grounded in this transcript."
) )
def _local_embedding(text: str, dim: int = EMBEDDING_DIM) -> list[float]:
"""
Free fallback embedding: hashed bag-of-words + bi-grams, L2-normalized.
This is weaker than model embeddings but keeps vector search functional.
"""
if not text:
return [0.0] * dim
vec = [0.0] * dim
tokens = re.findall(r"[a-z0-9]+", text.lower())
if not tokens:
return vec
for i, tok in enumerate(tokens):
idx = int(hashlib.sha256(f"u:{tok}".encode("utf-8")).hexdigest(), 16) % dim
vec[idx] += 1.0
if i < len(tokens) - 1:
bigram = f"{tok}_{tokens[i+1]}"
bidx = int(hashlib.sha256(f"b:{bigram}".encode("utf-8")).hexdigest(), 16) % dim
vec[bidx] += 0.5
norm = math.sqrt(sum(v * v for v in vec))
if norm > 0:
vec = [v / norm for v in vec]
return vec
def _openai_embedding(text: str) -> list[float] | None:
global _openai_client
api_key = (os.getenv("OPENAI_API_KEY") or "").strip()
if not api_key:
return None
try:
if _openai_client is None:
from openai import OpenAI
_openai_client = OpenAI(api_key=api_key)
response = _openai_client.embeddings.create(
model=OPENAI_EMBEDDING_MODEL,
input=text,
)
return response.data[0].embedding
except Exception:
return None
def _embed_text(text: str) -> list[float]:
if EMBEDDING_PROVIDER == "openai":
emb = _openai_embedding(text)
if emb:
return emb
return _local_embedding(text)
def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]: def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]:
"""Add signed audio URL to post if ready""" """Add signed audio URL to post if ready"""
if post.get("status") == "ready": if post.get("status") == "ready":
@@ -281,7 +344,7 @@ def api_upload_post():
"end_sec": float(seg.end), "end_sec": float(seg.end),
"text": segment_text, "text": segment_text,
"confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None, "confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None,
"embedding": None, "embedding": _embed_text(segment_text),
} }
) )
@@ -376,8 +439,21 @@ def api_rag_search():
return _error("'q' is required.", 400) return _error("'q' is required.", 400)
try: try:
query_embedding = _embed_text(query_text)
rows = search_rag_chunks_vector(user_id=user_id, query_embedding=query_embedding, limit=limit)
# Fallback in case vector path is unavailable or empty.
mode = "vector"
if not rows:
rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit) rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit)
return jsonify({"results": rows, "page": page, "limit": min(max(1, limit), 100)}) mode = "text_fallback"
return jsonify({
"results": rows,
"page": page,
"limit": min(max(1, limit), 100),
"mode": mode
})
except Exception as e: except Exception as e:
return _error(str(e), 500) return _error(str(e), 500)

View File

@@ -3,6 +3,8 @@ Supabase data layer aligned with TitanForge/schema.sql.
""" """
import os import os
import math
import json
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -399,15 +401,104 @@ def search_rag_chunks_vector(user_id: int, query_embedding: List[float], limit:
Vector search via SQL RPC function `match_rag_chunks` (pgvector). Vector search via SQL RPC function `match_rag_chunks` (pgvector).
""" """
vector_text = "[" + ",".join(str(float(v)) for v in query_embedding) + "]" vector_text = "[" + ",".join(str(float(v)) for v in query_embedding) + "]"
safe_limit = min(max(1, limit), 100)
try:
response = supabase.rpc( response = supabase.rpc(
"match_rag_chunks", "match_rag_chunks",
{ {
"p_user_id": user_id, "p_user_id": user_id,
"p_query_embedding": vector_text, "p_query_embedding": vector_text,
"p_match_count": min(max(1, limit), 100), "p_match_count": safe_limit,
}, },
).execute() ).execute()
return _rows(response) rows = _rows(response)
if rows:
return rows
except Exception:
pass
# Fallback: pull candidate chunks and rank with cosine similarity in Python.
response = (
supabase.table("rag_chunks")
.select(
"chunk_id, post_id, start_sec, end_sec, text, confidence, created_at, embedding, "
"audio_posts!inner(post_id, user_id, title, visibility, created_at)"
)
.eq("audio_posts.user_id", user_id)
.limit(3000)
.execute()
)
candidates = _rows(response)
if not candidates:
return []
q = _normalize_vec(query_embedding)
ranked = []
for row in candidates:
emb = _parse_embedding(row.get("embedding"))
if not emb:
continue
score = _cosine_similarity(q, emb)
if score is None:
continue
out = dict(row)
out["similarity"] = score
out.pop("embedding", None)
ranked.append(out)
ranked.sort(key=lambda r: r.get("similarity", 0.0), reverse=True)
return ranked[:safe_limit]
def _parse_embedding(value: Any) -> Optional[List[float]]:
if value is None:
return None
if isinstance(value, list):
try:
return [float(v) for v in value]
except Exception:
return None
if isinstance(value, str):
text = value.strip()
if not text:
return None
try:
if text.startswith("[") and text.endswith("]"):
return [float(v) for v in text[1:-1].split(",") if v.strip()]
parsed = json.loads(text)
if isinstance(parsed, list):
return [float(v) for v in parsed]
except Exception:
return None
return None
def _normalize_vec(vec: List[float]) -> List[float]:
if not vec:
return []
norm = math.sqrt(sum(float(v) * float(v) for v in vec))
if norm <= 0:
return [0.0 for _ in vec]
return [float(v) / norm for v in vec]
def _cosine_similarity(a: List[float], b: List[float]) -> Optional[float]:
if not a or not b:
return None
n = min(len(a), len(b))
if n == 0:
return None
dot = 0.0
bnorm = 0.0
for i in range(n):
av = float(a[i])
bv = float(b[i])
dot += av * bv
bnorm += bv * bv
if bnorm <= 0:
return None
return dot / math.sqrt(bnorm)
# ==================== Audit Log ==================== # ==================== Audit Log ====================