rag vector embeddings

This commit is contained in:
Gaumit Kauts
2026-02-15 01:21:05 -07:00
parent 6e5b4850b9
commit 5f471c21be
5 changed files with 65 additions and 6 deletions

View File

@@ -7,7 +7,9 @@ import hashlib
import json import json
import os import os
import uuid import uuid
import re
from pathlib import Path from pathlib import Path
from typing import Any, List
from dotenv import load_dotenv from dotenv import load_dotenv
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
@@ -34,6 +36,7 @@ from db_queries import (
list_rag_chunks, list_rag_chunks,
list_user_history, list_user_history,
search_rag_chunks, search_rag_chunks,
search_rag_chunks_vector,
update_audio_post, update_audio_post,
upload_storage_object, upload_storage_object,
upsert_archive_metadata, upsert_archive_metadata,
@@ -96,7 +99,7 @@ def _build_prompt(transcript_text: str, title: str) -> str:
f"{transcript_text}\n\n" f"{transcript_text}\n\n"
"Answer user questions grounded in this transcript." "Answer user questions grounded in this transcript."
) )
def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]: def _add_audio_url(post: dict[str, Any]) -> dict[str, Any]:
"""Add signed audio URL to post if ready""" """Add signed audio URL to post if ready"""
if post.get("status") == "ready": if post.get("status") == "ready":
try: try:
@@ -107,6 +110,29 @@ def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]:
return post return post
def _local_embedding(text: str, dimensions: int = 1536) -> List[float]:
"""
Free deterministic embedding fallback (offline).
Replace with model-based embeddings later if needed.
"""
vector = [0.0] * dimensions
tokens = re.findall(r"[A-Za-z0-9']+", text.lower())
if not tokens:
return vector
for token in tokens:
digest = hashlib.sha256(token.encode("utf-8")).digest()
idx = int.from_bytes(digest[:4], "big") % dimensions
sign = 1.0 if (digest[4] & 1) == 0 else -1.0
weight = 1.0 + (digest[5] / 255.0) * 0.25
vector[idx] += sign * weight
norm = sum(v * v for v in vector) ** 0.5
if norm > 0:
vector = [v / norm for v in vector]
return vector
@api.get("/health") @api.get("/health")
def health(): def health():
@@ -274,7 +300,7 @@ def api_upload_post():
"end_sec": float(seg.end), "end_sec": float(seg.end),
"text": segment_text, "text": segment_text,
"confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None, "confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None,
"embedding": None, "embedding": _local_embedding(segment_text),
} }
) )
@@ -360,17 +386,31 @@ def api_user_history(user_id: int):
def api_rag_search(): def api_rag_search():
query_text = (request.args.get("q") or "").strip() query_text = (request.args.get("q") or "").strip()
user_id = request.args.get("user_id", type=int) user_id = request.args.get("user_id", type=int)
query_embedding_raw = request.args.get("query_embedding")
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=30, type=int) limit = request.args.get("limit", default=30, type=int)
if not user_id: if not user_id:
return _error("'user_id' is required.", 400) return _error("'user_id' is required.", 400)
if not query_text:
return _error("'q' is required.", 400)
try: try:
if query_embedding_raw:
try:
parsed = json.loads(query_embedding_raw)
if not isinstance(parsed, list):
return _error("'query_embedding' must be a JSON array.", 400)
query_embedding = [float(v) for v in parsed]
except Exception:
return _error("Invalid 'query_embedding'. Example: [0.1,0.2,...]", 400)
rows = search_rag_chunks_vector(user_id=user_id, query_embedding=query_embedding, limit=limit)
return jsonify({"results": rows, "mode": "vector", "limit": min(max(1, limit), 100)})
if not query_text:
return _error("'q' is required when 'query_embedding' is not provided.", 400)
rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit) rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit)
return jsonify({"results": rows, "page": page, "limit": min(max(1, limit), 100)}) return jsonify({"results": rows, "mode": "text", "page": page, "limit": min(max(1, limit), 100)})
except Exception as e: except Exception as e:
return _error(str(e), 500) return _error(str(e), 500)

View File

@@ -322,6 +322,9 @@ def add_rag_chunks(post_id: int, chunks: List[Dict[str, Any]]) -> List[Dict[str,
rows = [] rows = []
for c in chunks: for c in chunks:
embedding = c.get("embedding")
if isinstance(embedding, list):
embedding = "[" + ",".join(str(float(v)) for v in embedding) + "]"
rows.append( rows.append(
{ {
"post_id": post_id, "post_id": post_id,
@@ -329,7 +332,7 @@ def add_rag_chunks(post_id: int, chunks: List[Dict[str, Any]]) -> List[Dict[str,
"end_sec": c.get("end_sec"), "end_sec": c.get("end_sec"),
"text": c.get("text"), "text": c.get("text"),
"confidence": c.get("confidence"), "confidence": c.get("confidence"),
"embedding": c.get("embedding"), "embedding": embedding,
} }
) )
@@ -367,6 +370,22 @@ def search_rag_chunks(user_id: int, query_text: str, page: int = 1, limit: int =
return _rows(response) return _rows(response)
def search_rag_chunks_vector(user_id: int, query_embedding: List[float], limit: int = 30) -> List[Dict[str, Any]]:
"""
Vector search via SQL RPC function `match_rag_chunks` (pgvector).
"""
vector_text = "[" + ",".join(str(float(v)) for v in query_embedding) + "]"
response = supabase.rpc(
"match_rag_chunks",
{
"p_user_id": user_id,
"p_query_embedding": vector_text,
"p_match_count": min(max(1, limit), 100),
},
).execute()
return _rows(response)
# ==================== Audit Log ==================== # ==================== Audit Log ====================
def add_audit_log(payload: Dict[str, Any]) -> Dict[str, Any]: def add_audit_log(payload: Dict[str, Any]) -> Dict[str, Any]: