updated RAG Search functionality
This commit is contained in:
@@ -7,6 +7,8 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
import io
|
||||
import zipfile
|
||||
@@ -38,6 +40,7 @@ from db_queries import (
|
||||
list_rag_chunks,
|
||||
list_user_history,
|
||||
search_rag_chunks,
|
||||
search_rag_chunks_vector,
|
||||
update_audio_post,
|
||||
upload_storage_object,
|
||||
upsert_archive_metadata,
|
||||
@@ -61,6 +64,10 @@ WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8")
|
||||
ARCHIVE_BUCKET = os.getenv("SUPABASE_BUCKET", os.getenv("SUPABASE_ARCHIVE_BUCKET", "archives"))
|
||||
|
||||
_whisper_model: WhisperModel | None = None
|
||||
_openai_client = None
|
||||
EMBEDDING_DIM = 1536
|
||||
EMBEDDING_PROVIDER = (os.getenv("EMBEDDING_PROVIDER") or "local").strip().lower()
|
||||
OPENAI_EMBEDDING_MODEL = (os.getenv("OPENAI_EMBEDDING_MODEL") or "text-embedding-3-small").strip()
|
||||
|
||||
|
||||
def _model() -> WhisperModel:
|
||||
@@ -103,6 +110,62 @@ def _build_prompt(transcript_text: str, title: str) -> str:
|
||||
f"{transcript_text}\n\n"
|
||||
"Answer user questions grounded in this transcript."
|
||||
)
|
||||
|
||||
|
||||
def _local_embedding(text: str, dim: int = EMBEDDING_DIM) -> list[float]:
|
||||
"""
|
||||
Free fallback embedding: hashed bag-of-words + bi-grams, L2-normalized.
|
||||
This is weaker than model embeddings but keeps vector search functional.
|
||||
"""
|
||||
if not text:
|
||||
return [0.0] * dim
|
||||
|
||||
vec = [0.0] * dim
|
||||
tokens = re.findall(r"[a-z0-9]+", text.lower())
|
||||
if not tokens:
|
||||
return vec
|
||||
|
||||
for i, tok in enumerate(tokens):
|
||||
idx = int(hashlib.sha256(f"u:{tok}".encode("utf-8")).hexdigest(), 16) % dim
|
||||
vec[idx] += 1.0
|
||||
if i < len(tokens) - 1:
|
||||
bigram = f"{tok}_{tokens[i+1]}"
|
||||
bidx = int(hashlib.sha256(f"b:{bigram}".encode("utf-8")).hexdigest(), 16) % dim
|
||||
vec[bidx] += 0.5
|
||||
|
||||
norm = math.sqrt(sum(v * v for v in vec))
|
||||
if norm > 0:
|
||||
vec = [v / norm for v in vec]
|
||||
return vec
|
||||
|
||||
|
||||
def _openai_embedding(text: str) -> list[float] | None:
|
||||
global _openai_client
|
||||
api_key = (os.getenv("OPENAI_API_KEY") or "").strip()
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
if _openai_client is None:
|
||||
from openai import OpenAI
|
||||
_openai_client = OpenAI(api_key=api_key)
|
||||
response = _openai_client.embeddings.create(
|
||||
model=OPENAI_EMBEDDING_MODEL,
|
||||
input=text,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _embed_text(text: str) -> list[float]:
|
||||
if EMBEDDING_PROVIDER == "openai":
|
||||
emb = _openai_embedding(text)
|
||||
if emb:
|
||||
return emb
|
||||
return _local_embedding(text)
|
||||
|
||||
|
||||
def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Add signed audio URL to post if ready"""
|
||||
if post.get("status") == "ready":
|
||||
@@ -281,7 +344,7 @@ def api_upload_post():
|
||||
"end_sec": float(seg.end),
|
||||
"text": segment_text,
|
||||
"confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None,
|
||||
"embedding": None,
|
||||
"embedding": _embed_text(segment_text),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -376,8 +439,21 @@ def api_rag_search():
|
||||
return _error("'q' is required.", 400)
|
||||
|
||||
try:
|
||||
rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit)
|
||||
return jsonify({"results": rows, "page": page, "limit": min(max(1, limit), 100)})
|
||||
query_embedding = _embed_text(query_text)
|
||||
rows = search_rag_chunks_vector(user_id=user_id, query_embedding=query_embedding, limit=limit)
|
||||
|
||||
# Fallback in case vector path is unavailable or empty.
|
||||
mode = "vector"
|
||||
if not rows:
|
||||
rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit)
|
||||
mode = "text_fallback"
|
||||
|
||||
return jsonify({
|
||||
"results": rows,
|
||||
"page": page,
|
||||
"limit": min(max(1, limit), 100),
|
||||
"mode": mode
|
||||
})
|
||||
except Exception as e:
|
||||
return _error(str(e), 500)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user