rag vector embeddings

This commit is contained in:
Gaumit Kauts
2026-02-15 01:21:05 -07:00
parent 6e5b4850b9
commit 5f471c21be
5 changed files with 65 additions and 6 deletions

View File

@@ -7,7 +7,9 @@ import hashlib
import json
import os
import uuid
import re
from pathlib import Path
from typing import Any, List
from dotenv import load_dotenv
from faster_whisper import WhisperModel
@@ -34,6 +36,7 @@ from db_queries import (
list_rag_chunks,
list_user_history,
search_rag_chunks,
search_rag_chunks_vector,
update_audio_post,
upload_storage_object,
upsert_archive_metadata,
@@ -96,7 +99,7 @@ def _build_prompt(transcript_text: str, title: str) -> str:
f"{transcript_text}\n\n"
"Answer user questions grounded in this transcript."
)
def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]:
def _add_audio_url(post: dict[str, Any]) -> dict[str, Any]:
"""Add signed audio URL to post if ready"""
if post.get("status") == "ready":
try:
@@ -107,6 +110,29 @@ def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]:
return post
def _local_embedding(text: str, dimensions: int = 1536) -> List[float]:
"""
Free deterministic embedding fallback (offline).
Replace with model-based embeddings later if needed.
"""
vector = [0.0] * dimensions
tokens = re.findall(r"[A-Za-z0-9']+", text.lower())
if not tokens:
return vector
for token in tokens:
digest = hashlib.sha256(token.encode("utf-8")).digest()
idx = int.from_bytes(digest[:4], "big") % dimensions
sign = 1.0 if (digest[4] & 1) == 0 else -1.0
weight = 1.0 + (digest[5] / 255.0) * 0.25
vector[idx] += sign * weight
norm = sum(v * v for v in vector) ** 0.5
if norm > 0:
vector = [v / norm for v in vector]
return vector
@api.get("/health")
def health():
@@ -274,7 +300,7 @@ def api_upload_post():
"end_sec": float(seg.end),
"text": segment_text,
"confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None,
"embedding": None,
"embedding": _local_embedding(segment_text),
}
)
@@ -360,17 +386,31 @@ def api_user_history(user_id: int):
def api_rag_search():
query_text = (request.args.get("q") or "").strip()
user_id = request.args.get("user_id", type=int)
query_embedding_raw = request.args.get("query_embedding")
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=30, type=int)
if not user_id:
return _error("'user_id' is required.", 400)
if not query_text:
return _error("'q' is required.", 400)
try:
if query_embedding_raw:
try:
parsed = json.loads(query_embedding_raw)
if not isinstance(parsed, list):
return _error("'query_embedding' must be a JSON array.", 400)
query_embedding = [float(v) for v in parsed]
except Exception:
return _error("Invalid 'query_embedding'. Example: [0.1,0.2,...]", 400)
rows = search_rag_chunks_vector(user_id=user_id, query_embedding=query_embedding, limit=limit)
return jsonify({"results": rows, "mode": "vector", "limit": min(max(1, limit), 100)})
if not query_text:
return _error("'q' is required when 'query_embedding' is not provided.", 400)
rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit)
return jsonify({"results": rows, "page": page, "limit": min(max(1, limit), 100)})
return jsonify({"results": rows, "mode": "text", "page": page, "limit": min(max(1, limit), 100)})
except Exception as e:
return _error(str(e), 500)

View File

@@ -322,6 +322,9 @@ def add_rag_chunks(post_id: int, chunks: List[Dict[str, Any]]) -> List[Dict[str,
rows = []
for c in chunks:
embedding = c.get("embedding")
if isinstance(embedding, list):
embedding = "[" + ",".join(str(float(v)) for v in embedding) + "]"
rows.append(
{
"post_id": post_id,
@@ -329,7 +332,7 @@ def add_rag_chunks(post_id: int, chunks: List[Dict[str, Any]]) -> List[Dict[str,
"end_sec": c.get("end_sec"),
"text": c.get("text"),
"confidence": c.get("confidence"),
"embedding": c.get("embedding"),
"embedding": embedding,
}
)
@@ -367,6 +370,22 @@ def search_rag_chunks(user_id: int, query_text: str, page: int = 1, limit: int =
return _rows(response)
def search_rag_chunks_vector(user_id: int, query_embedding: List[float], limit: int = 30) -> List[Dict[str, Any]]:
"""
Vector search via SQL RPC function `match_rag_chunks` (pgvector).
"""
vector_text = "[" + ",".join(str(float(v)) for v in query_embedding) + "]"
response = supabase.rpc(
"match_rag_chunks",
{
"p_user_id": user_id,
"p_query_embedding": vector_text,
"p_match_count": min(max(1, limit), 100),
},
).execute()
return _rows(response)
# ==================== Audit Log ====================
def add_audit_log(payload: Dict[str, Any]) -> Dict[str, Any]: