updated RAG Search functionality

This commit is contained in:
Gaumit Kauts
2026-02-15 12:13:33 -07:00
parent d5f9f26643
commit 89534fb836
5 changed files with 182 additions and 12 deletions

View File

@@ -7,6 +7,8 @@ import hashlib
import json
import os
import uuid
import math
import re
from pathlib import Path
import io
import zipfile
@@ -38,6 +40,7 @@ from db_queries import (
list_rag_chunks,
list_user_history,
search_rag_chunks,
search_rag_chunks_vector,
update_audio_post,
upload_storage_object,
upsert_archive_metadata,
@@ -61,6 +64,10 @@ WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8")
ARCHIVE_BUCKET = os.getenv("SUPABASE_BUCKET", os.getenv("SUPABASE_ARCHIVE_BUCKET", "archives"))
_whisper_model: WhisperModel | None = None
_openai_client = None
EMBEDDING_DIM = 1536
EMBEDDING_PROVIDER = (os.getenv("EMBEDDING_PROVIDER") or "local").strip().lower()
OPENAI_EMBEDDING_MODEL = (os.getenv("OPENAI_EMBEDDING_MODEL") or "text-embedding-3-small").strip()
def _model() -> WhisperModel:
@@ -103,6 +110,62 @@ def _build_prompt(transcript_text: str, title: str) -> str:
f"{transcript_text}\n\n"
"Answer user questions grounded in this transcript."
)
def _local_embedding(text: str, dim: int = EMBEDDING_DIM) -> list[float]:
"""
Free fallback embedding: hashed bag-of-words + bi-grams, L2-normalized.
This is weaker than model embeddings but keeps vector search functional.
"""
if not text:
return [0.0] * dim
vec = [0.0] * dim
tokens = re.findall(r"[a-z0-9]+", text.lower())
if not tokens:
return vec
for i, tok in enumerate(tokens):
idx = int(hashlib.sha256(f"u:{tok}".encode("utf-8")).hexdigest(), 16) % dim
vec[idx] += 1.0
if i < len(tokens) - 1:
bigram = f"{tok}_{tokens[i+1]}"
bidx = int(hashlib.sha256(f"b:{bigram}".encode("utf-8")).hexdigest(), 16) % dim
vec[bidx] += 0.5
norm = math.sqrt(sum(v * v for v in vec))
if norm > 0:
vec = [v / norm for v in vec]
return vec
def _openai_embedding(text: str) -> list[float] | None:
global _openai_client
api_key = (os.getenv("OPENAI_API_KEY") or "").strip()
if not api_key:
return None
try:
if _openai_client is None:
from openai import OpenAI
_openai_client = OpenAI(api_key=api_key)
response = _openai_client.embeddings.create(
model=OPENAI_EMBEDDING_MODEL,
input=text,
)
return response.data[0].embedding
except Exception:
return None
def _embed_text(text: str) -> list[float]:
if EMBEDDING_PROVIDER == "openai":
emb = _openai_embedding(text)
if emb:
return emb
return _local_embedding(text)
def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]:
"""Add signed audio URL to post if ready"""
if post.get("status") == "ready":
@@ -281,7 +344,7 @@ def api_upload_post():
"end_sec": float(seg.end),
"text": segment_text,
"confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None,
"embedding": None,
"embedding": _embed_text(segment_text),
}
)
@@ -376,8 +439,21 @@ def api_rag_search():
return _error("'q' is required.", 400)
try:
rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit)
return jsonify({"results": rows, "page": page, "limit": min(max(1, limit), 100)})
query_embedding = _embed_text(query_text)
rows = search_rag_chunks_vector(user_id=user_id, query_embedding=query_embedding, limit=limit)
# Fallback in case vector path is unavailable or empty.
mode = "vector"
if not rows:
rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit)
mode = "text_fallback"
return jsonify({
"results": rows,
"page": page,
"limit": min(max(1, limit), 100),
"mode": mode
})
except Exception as e:
return _error(str(e), 500)