rag vector embeddings
This commit is contained in:
@@ -7,7 +7,9 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from faster_whisper import WhisperModel
|
||||
@@ -34,6 +36,7 @@ from db_queries import (
|
||||
list_rag_chunks,
|
||||
list_user_history,
|
||||
search_rag_chunks,
|
||||
search_rag_chunks_vector,
|
||||
update_audio_post,
|
||||
upload_storage_object,
|
||||
upsert_archive_metadata,
|
||||
@@ -96,7 +99,7 @@ def _build_prompt(transcript_text: str, title: str) -> str:
|
||||
f"{transcript_text}\n\n"
|
||||
"Answer user questions grounded in this transcript."
|
||||
)
|
||||
def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _add_audio_url(post: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Add signed audio URL to post if ready"""
|
||||
if post.get("status") == "ready":
|
||||
try:
|
||||
@@ -107,6 +110,29 @@ def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return post
|
||||
|
||||
|
||||
def _local_embedding(text: str, dimensions: int = 1536) -> List[float]:
|
||||
"""
|
||||
Free deterministic embedding fallback (offline).
|
||||
Replace with model-based embeddings later if needed.
|
||||
"""
|
||||
vector = [0.0] * dimensions
|
||||
tokens = re.findall(r"[A-Za-z0-9']+", text.lower())
|
||||
if not tokens:
|
||||
return vector
|
||||
|
||||
for token in tokens:
|
||||
digest = hashlib.sha256(token.encode("utf-8")).digest()
|
||||
idx = int.from_bytes(digest[:4], "big") % dimensions
|
||||
sign = 1.0 if (digest[4] & 1) == 0 else -1.0
|
||||
weight = 1.0 + (digest[5] / 255.0) * 0.25
|
||||
vector[idx] += sign * weight
|
||||
|
||||
norm = sum(v * v for v in vector) ** 0.5
|
||||
if norm > 0:
|
||||
vector = [v / norm for v in vector]
|
||||
return vector
|
||||
|
||||
|
||||
|
||||
@api.get("/health")
|
||||
def health():
|
||||
@@ -274,7 +300,7 @@ def api_upload_post():
|
||||
"end_sec": float(seg.end),
|
||||
"text": segment_text,
|
||||
"confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None,
|
||||
"embedding": None,
|
||||
"embedding": _local_embedding(segment_text),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -360,17 +386,31 @@ def api_user_history(user_id: int):
|
||||
def api_rag_search():
|
||||
query_text = (request.args.get("q") or "").strip()
|
||||
user_id = request.args.get("user_id", type=int)
|
||||
query_embedding_raw = request.args.get("query_embedding")
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=30, type=int)
|
||||
|
||||
if not user_id:
|
||||
return _error("'user_id' is required.", 400)
|
||||
if not query_text:
|
||||
return _error("'q' is required.", 400)
|
||||
|
||||
try:
|
||||
if query_embedding_raw:
|
||||
try:
|
||||
parsed = json.loads(query_embedding_raw)
|
||||
if not isinstance(parsed, list):
|
||||
return _error("'query_embedding' must be a JSON array.", 400)
|
||||
query_embedding = [float(v) for v in parsed]
|
||||
except Exception:
|
||||
return _error("Invalid 'query_embedding'. Example: [0.1,0.2,...]", 400)
|
||||
|
||||
rows = search_rag_chunks_vector(user_id=user_id, query_embedding=query_embedding, limit=limit)
|
||||
return jsonify({"results": rows, "mode": "vector", "limit": min(max(1, limit), 100)})
|
||||
|
||||
if not query_text:
|
||||
return _error("'q' is required when 'query_embedding' is not provided.", 400)
|
||||
|
||||
rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit)
|
||||
return jsonify({"results": rows, "page": page, "limit": min(max(1, limit), 100)})
|
||||
return jsonify({"results": rows, "mode": "text", "page": page, "limit": min(max(1, limit), 100)})
|
||||
except Exception as e:
|
||||
return _error(str(e), 500)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user