diff --git a/README.md b/README.md index 31d6dcb..6443df4 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,9 @@ BACKEND_UPLOAD_DIR=uploads WHISPER_MODEL=base WHISPER_DEVICE=cpu WHISPER_COMPUTE_TYPE=int8 +EMBEDDING_PROVIDER=local +OPENAI_EMBEDDING_MODEL=text-embedding-3-small +# OPENAI_API_KEY=... # only needed if EMBEDDING_PROVIDER=openai ``` Notes: diff --git a/backend/api_routes.py b/backend/api_routes.py index 7ad8841..a7475cd 100644 --- a/backend/api_routes.py +++ b/backend/api_routes.py @@ -7,6 +7,8 @@ import hashlib import json import os import uuid +import math +import re from pathlib import Path import io import zipfile @@ -38,6 +40,7 @@ from db_queries import ( list_rag_chunks, list_user_history, search_rag_chunks, + search_rag_chunks_vector, update_audio_post, upload_storage_object, upsert_archive_metadata, @@ -61,6 +64,10 @@ WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8") ARCHIVE_BUCKET = os.getenv("SUPABASE_BUCKET", os.getenv("SUPABASE_ARCHIVE_BUCKET", "archives")) _whisper_model: WhisperModel | None = None +_openai_client = None +EMBEDDING_DIM = 1536 +EMBEDDING_PROVIDER = (os.getenv("EMBEDDING_PROVIDER") or "local").strip().lower() +OPENAI_EMBEDDING_MODEL = (os.getenv("OPENAI_EMBEDDING_MODEL") or "text-embedding-3-small").strip() def _model() -> WhisperModel: @@ -103,6 +110,62 @@ def _build_prompt(transcript_text: str, title: str) -> str: f"{transcript_text}\n\n" "Answer user questions grounded in this transcript." ) + + +def _local_embedding(text: str, dim: int = EMBEDDING_DIM) -> list[float]: + """ + Free fallback embedding: hashed bag-of-words + bi-grams, L2-normalized. + This is weaker than model embeddings but keeps vector search functional. + """ + if not text: + return [0.0] * dim + + vec = [0.0] * dim + tokens = re.findall(r"[a-z0-9]+", text.lower()) + if not tokens: + return vec + + for i, tok in enumerate(tokens): + idx = int(hashlib.sha256(f"u:{tok}".encode("utf-8")).hexdigest(), 16) % dim + vec[idx] += 1.0 + if i < len(tokens) - 1: + bigram = f"{tok}_{tokens[i+1]}" + bidx = int(hashlib.sha256(f"b:{bigram}".encode("utf-8")).hexdigest(), 16) % dim + vec[bidx] += 0.5 + + norm = math.sqrt(sum(v * v for v in vec)) + if norm > 0: + vec = [v / norm for v in vec] + return vec + + +def _openai_embedding(text: str) -> list[float] | None: + global _openai_client + api_key = (os.getenv("OPENAI_API_KEY") or "").strip() + if not api_key: + return None + + try: + if _openai_client is None: + from openai import OpenAI + _openai_client = OpenAI(api_key=api_key) + response = _openai_client.embeddings.create( + model=OPENAI_EMBEDDING_MODEL, + input=text, + ) + return response.data[0].embedding + except Exception: + return None + + +def _embed_text(text: str) -> list[float]: + if EMBEDDING_PROVIDER == "openai": + emb = _openai_embedding(text) + if emb: + return emb + return _local_embedding(text) + + def _add_audio_url(post: Dict[str, Any]) -> Dict[str, Any]: """Add signed audio URL to post if ready""" if post.get("status") == "ready": @@ -281,7 +344,7 @@ def api_upload_post(): "end_sec": float(seg.end), "text": segment_text, "confidence": float(seg.avg_logprob) if seg.avg_logprob is not None else None, - "embedding": None, + "embedding": _embed_text(segment_text), } ) @@ -376,8 +439,21 @@ def api_rag_search(): return _error("'q' is required.", 400) try: - rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit) - return jsonify({"results": rows, "page": page, "limit": min(max(1, limit), 100)}) + query_embedding = _embed_text(query_text) + rows = search_rag_chunks_vector(user_id=user_id, query_embedding=query_embedding, limit=limit) + + # Fallback in case vector path is unavailable or empty. + mode = "vector" + if not rows: + rows = search_rag_chunks(user_id=user_id, query_text=query_text, page=page, limit=limit) + mode = "text_fallback" + + return jsonify({ + "results": rows, + "page": page, + "limit": min(max(1, limit), 100), + "mode": mode + }) except Exception as e: return _error(str(e), 500) diff --git a/backend/db_queries.py b/backend/db_queries.py index 5dd7edd..f68a84b 100644 --- a/backend/db_queries.py +++ b/backend/db_queries.py @@ -3,6 +3,8 @@ Supabase data layer aligned with TitanForge/schema.sql. """ import os +import math +import json from typing import Any, Dict, List, Optional, Tuple from dotenv import load_dotenv @@ -399,15 +401,104 @@ def search_rag_chunks_vector(user_id: int, query_embedding: List[float], limit: Vector search via SQL RPC function `match_rag_chunks` (pgvector). """ vector_text = "[" + ",".join(str(float(v)) for v in query_embedding) + "]" - response = supabase.rpc( - "match_rag_chunks", - { - "p_user_id": user_id, - "p_query_embedding": vector_text, - "p_match_count": min(max(1, limit), 100), - }, - ).execute() - return _rows(response) + safe_limit = min(max(1, limit), 100) + + try: + response = supabase.rpc( + "match_rag_chunks", + { + "p_user_id": user_id, + "p_query_embedding": vector_text, + "p_match_count": safe_limit, + }, + ).execute() + rows = _rows(response) + if rows: + return rows + except Exception: + pass + + # Fallback: pull candidate chunks and rank with cosine similarity in Python. + response = ( + supabase.table("rag_chunks") + .select( + "chunk_id, post_id, start_sec, end_sec, text, confidence, created_at, embedding, " + "audio_posts!inner(post_id, user_id, title, visibility, created_at)" + ) + .eq("audio_posts.user_id", user_id) + .limit(3000) + .execute() + ) + candidates = _rows(response) + if not candidates: + return [] + + q = _normalize_vec(query_embedding) + ranked = [] + for row in candidates: + emb = _parse_embedding(row.get("embedding")) + if not emb: + continue + score = _cosine_similarity(q, emb) + if score is None: + continue + out = dict(row) + out["similarity"] = score + out.pop("embedding", None) + ranked.append(out) + + ranked.sort(key=lambda r: r.get("similarity", 0.0), reverse=True) + return ranked[:safe_limit] + + +def _parse_embedding(value: Any) -> Optional[List[float]]: + if value is None: + return None + if isinstance(value, list): + try: + return [float(v) for v in value] + except Exception: + return None + if isinstance(value, str): + text = value.strip() + if not text: + return None + try: + if text.startswith("[") and text.endswith("]"): + return [float(v) for v in text[1:-1].split(",") if v.strip()] + parsed = json.loads(text) + if isinstance(parsed, list): + return [float(v) for v in parsed] + except Exception: + return None + return None + + +def _normalize_vec(vec: List[float]) -> List[float]: + if not vec: + return [] + norm = math.sqrt(sum(float(v) * float(v) for v in vec)) + if norm <= 0: + return [0.0 for _ in vec] + return [float(v) / norm for v in vec] + + +def _cosine_similarity(a: List[float], b: List[float]) -> Optional[float]: + if not a or not b: + return None + n = min(len(a), len(b)) + if n == 0: + return None + dot = 0.0 + bnorm = 0.0 + for i in range(n): + av = float(a[i]) + bv = float(b[i]) + dot += av * bv + bnorm += bv * bv + if bnorm <= 0: + return None + return dot / math.sqrt(bnorm) # ==================== Audit Log ==================== diff --git a/backend/uploads/0c2f5f8d-f1e0-4a99-bdcf-49398ce524e2_Clean_Energy.m4a b/backend/uploads/0c2f5f8d-f1e0-4a99-bdcf-49398ce524e2_Clean_Energy.m4a new file mode 100644 index 0000000..5cebdeb Binary files /dev/null and b/backend/uploads/0c2f5f8d-f1e0-4a99-bdcf-49398ce524e2_Clean_Energy.m4a differ diff --git a/backend/uploads/4a7d7a65-22a0-46cf-b69e-f9095aa00f6f_AI-Speech.m4a b/backend/uploads/4a7d7a65-22a0-46cf-b69e-f9095aa00f6f_AI-Speech.m4a new file mode 100644 index 0000000..6f7b9bf Binary files /dev/null and b/backend/uploads/4a7d7a65-22a0-46cf-b69e-f9095aa00f6f_AI-Speech.m4a differ