updated RAG Search functionality
This commit is contained in:
@@ -45,6 +45,9 @@ BACKEND_UPLOAD_DIR=uploads
|
|||||||
WHISPER_MODEL=base
|
WHISPER_MODEL=base
|
||||||
WHISPER_DEVICE=cpu
|
WHISPER_DEVICE=cpu
|
||||||
WHISPER_COMPUTE_TYPE=int8
|
WHISPER_COMPUTE_TYPE=int8
|
||||||
|
EMBEDDING_PROVIDER=local
|
||||||
|
OPENAI_EMBEDDING_MODEL=text-embedding-3-small
|
||||||
|
# OPENAI_API_KEY=... # only needed if EMBEDDING_PROVIDER=openai
|
||||||
```
|
```
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
import math
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import io
|
import io
|
||||||
import zipfile
|
import zipfile
|
||||||
@@ -38,6 +40,7 @@ from db_queries import (
|
|||||||
list_rag_chunks,
|
list_rag_chunks,
|
||||||
list_user_history,
|
list_user_history,
|
||||||
search_rag_chunks,
|
search_rag_chunks,
|
||||||
|
search_rag_chunks_vector,
|
||||||
update_audio_post,
|
update_audio_post,
|
||||||
upload_storage_object,
|
upload_storage_object,
|
||||||
upsert_archive_metadata,
|
upsert_archive_metadata,
|
||||||
@@ -61,6 +64,10 @@ WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8")
|
|||||||
ARCHIVE_BUCKET = os.getenv("SUPABASE_BUCKET", os.getenv("SUPABASE_ARCHIVE_BUCKET", "archives"))
|
ARCHIVE_BUCKET = os.getenv("SUPABASE_BUCKET", os.getenv("SUPABASE_ARCHIVE_BUCKET", "archives"))
|
||||||
|
|
||||||
_whisper_model: WhisperModel | None = None
|
_whisper_model: WhisperModel | None = None
|
||||||
|
_openai_client = None
|
||||||
|
EMBEDDING_DIM = 1536
|
||||||
|
EMBEDDING_PROVIDER = (os.getenv("EMBEDDING_PROVIDER") or "local").strip().lower()
|
||||||
|
OPENAI_EMBEDDING_MODEL = (os.getenv("OPENAI_EMBEDDING_MODEL") or "text-embedding-3-small").strip()
|
||||||
|
|
||||||
|
|
||||||
def _model() -> WhisperModel:
|
def _model() -> WhisperModel:
|
||||||
@@ -103,6 +110,62 @@ def _build_prompt(transcript_text: str, title: str) -> str:
|
|||||||
f"{transcript_text}\n\n"
|
f"{transcript_text}\n\n"
|
||||||
"Answer user questions grounded in this transcript."
|
"Answer user questions grounded in this transcript."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _local_embedding(text: str, dim: int = EMBEDDING_DIM) -> list[float]:
|
||||||
|
"""
|
||||||
|
Free fallback embedding: hashed bag-of-words + bi-grams, L2-normalized.
|
||||||
|
This is weaker than model embeddings but keeps vector search functional.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return [0.0] * dim
|
||||||
|
|
||||||
|
vec = [0.0] * dim
|
||||||
|
tokens = re.findall(r"[a-z0-9]+", text.lower())
|
||||||
|
if not tokens:
|
||||||
|
return vec
|
||||||
|
|
||||||
|
for i, tok in enumerate(tokens):
|
||||||
|
idx = int(hashlib.sha256(f"u:{tok}".encode("utf-8")).hexdigest(), 16) % dim
|
||||||
|
vec[idx] += 1.0
|
||||||
|
if i < len(tokens) - 1:
|
||||||
|
bigram = f"{tok}_{tokens[i+1]}"
|
||||||
|
bidx = int(hashlib.sha256(f"b:{bigram}".encode("utf-8")).hexdigest(), 16) % dim
|
||||||
|
vec[bidx] += 0.5
|
||||||
|
|
||||||
|
norm = math.sqrt(sum(v * v for v in vec))
|
||||||
|
if norm > 0:
|
||||||
|
vec = [v / norm for v in vec]
|
||||||
|
return vec
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_embedding(text: str) -> list[float] | None:
|
||||||
|
global _openai_client
|
||||||
|
api_key = (os.getenv("OPENAI_API_KEY") or "").strip()
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if _openai_client is None:
|
||||||
|
from openai import OpenAI
|
||||||
|
_openai_client = OpenAI(api_key=api_key)
|
||||||
|
response = _openai_client.embeddings.create(
|
||||||
|
model=OPENAI_EMBEDDING_MODEL,
|
||||||
|
input=text,
|
||||||
|
)
|
||||||
|
return response.data[0].embedding
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_text(text: str) -> list[float]:
|
||||||
|
if EMBEDDING_PROVIDER == "openai":
|
||||||
|
emb = _openai_embedding(text)
|
||||||
|
if emb:
|
||||||
|
return emb
|
||||||
|
return _local_embedding(text)
|
||||||
|
|
||||||
|
|
||||||
def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]:
|
def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Add signed audio URL to post if ready"""
|
"""Add signed audio URL to post if ready"""
|
||||||
if post.get("status") == "ready":
|
if post.get("status") == "ready":
|
||||||
@@ -281,7 +344,7 @@ def api_upload_post():
|
|||||||
"end_sec": float(seg.end),
|
"end_sec": float(seg.end),
|
||||||
"text": segment_text,
|
"text": segment_text,
|
||||||
"confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None,
|
"confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None,
|
||||||
"embedding": None,
|
"embedding": _embed_text(segment_text),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -376,8 +439,21 @@ def api_rag_search():
|
|||||||
return _error("'q' is required.", 400)
|
return _error("'q' is required.", 400)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit)
|
query_embedding = _embed_text(query_text)
|
||||||
return jsonify({"results": rows, "page": page, "limit": min(max(1, limit), 100)})
|
rows = search_rag_chunks_vector(user_id=user_id, query_embedding=query_embedding, limit=limit)
|
||||||
|
|
||||||
|
# Fallback in case vector path is unavailable or empty.
|
||||||
|
mode = "vector"
|
||||||
|
if not rows:
|
||||||
|
rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit)
|
||||||
|
mode = "text_fallback"
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"results": rows,
|
||||||
|
"page": page,
|
||||||
|
"limit": min(max(1, limit), 100),
|
||||||
|
"mode": mode
|
||||||
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _error(str(e), 500)
|
return _error(str(e), 500)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ Supabase data layer aligned with TitanForge/schema.sql.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
|
import json
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -399,15 +401,104 @@ def search_rag_chunks_vector(user_id: int, query_embedding: List[float], limit:
|
|||||||
Vector search via SQL RPC function `match_rag_chunks` (pgvector).
|
Vector search via SQL RPC function `match_rag_chunks` (pgvector).
|
||||||
"""
|
"""
|
||||||
vector_text = "[" + ",".join(str(float(v)) for v in query_embedding) + "]"
|
vector_text = "[" + ",".join(str(float(v)) for v in query_embedding) + "]"
|
||||||
response = supabase.rpc(
|
safe_limit = min(max(1, limit), 100)
|
||||||
"match_rag_chunks",
|
|
||||||
{
|
try:
|
||||||
"p_user_id": user_id,
|
response = supabase.rpc(
|
||||||
"p_query_embedding": vector_text,
|
"match_rag_chunks",
|
||||||
"p_match_count": min(max(1, limit), 100),
|
{
|
||||||
},
|
"p_user_id": user_id,
|
||||||
).execute()
|
"p_query_embedding": vector_text,
|
||||||
return _rows(response)
|
"p_match_count": safe_limit,
|
||||||
|
},
|
||||||
|
).execute()
|
||||||
|
rows = _rows(response)
|
||||||
|
if rows:
|
||||||
|
return rows
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback: pull candidate chunks and rank with cosine similarity in Python.
|
||||||
|
response = (
|
||||||
|
supabase.table("rag_chunks")
|
||||||
|
.select(
|
||||||
|
"chunk_id, post_id, start_sec, end_sec, text, confidence, created_at, embedding, "
|
||||||
|
"audio_posts!inner(post_id, user_id, title, visibility, created_at)"
|
||||||
|
)
|
||||||
|
.eq("audio_posts.user_id", user_id)
|
||||||
|
.limit(3000)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
candidates = _rows(response)
|
||||||
|
if not candidates:
|
||||||
|
return []
|
||||||
|
|
||||||
|
q = _normalize_vec(query_embedding)
|
||||||
|
ranked = []
|
||||||
|
for row in candidates:
|
||||||
|
emb = _parse_embedding(row.get("embedding"))
|
||||||
|
if not emb:
|
||||||
|
continue
|
||||||
|
score = _cosine_similarity(q, emb)
|
||||||
|
if score is None:
|
||||||
|
continue
|
||||||
|
out = dict(row)
|
||||||
|
out["similarity"] = score
|
||||||
|
out.pop("embedding", None)
|
||||||
|
ranked.append(out)
|
||||||
|
|
||||||
|
ranked.sort(key=lambda r: r.get("similarity", 0.0), reverse=True)
|
||||||
|
return ranked[:safe_limit]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_embedding(value: Any) -> Optional[List[float]]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, list):
|
||||||
|
try:
|
||||||
|
return [float(v) for v in value]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if isinstance(value, str):
|
||||||
|
text = value.strip()
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
if text.startswith("[") and text.endswith("]"):
|
||||||
|
return [float(v) for v in text[1:-1].split(",") if v.strip()]
|
||||||
|
parsed = json.loads(text)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
return [float(v) for v in parsed]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_vec(vec: List[float]) -> List[float]:
|
||||||
|
if not vec:
|
||||||
|
return []
|
||||||
|
norm = math.sqrt(sum(float(v) * float(v) for v in vec))
|
||||||
|
if norm <= 0:
|
||||||
|
return [0.0 for _ in vec]
|
||||||
|
return [float(v) / norm for v in vec]
|
||||||
|
|
||||||
|
|
||||||
|
def _cosine_similarity(a: List[float], b: List[float]) -> Optional[float]:
|
||||||
|
if not a or not b:
|
||||||
|
return None
|
||||||
|
n = min(len(a), len(b))
|
||||||
|
if n == 0:
|
||||||
|
return None
|
||||||
|
dot = 0.0
|
||||||
|
bnorm = 0.0
|
||||||
|
for i in range(n):
|
||||||
|
av = float(a[i])
|
||||||
|
bv = float(b[i])
|
||||||
|
dot += av * bv
|
||||||
|
bnorm += bv * bv
|
||||||
|
if bnorm <= 0:
|
||||||
|
return None
|
||||||
|
return dot / math.sqrt(bnorm)
|
||||||
|
|
||||||
|
|
||||||
# ==================== Audit Log ====================
|
# ==================== Audit Log ====================
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user