Update db_queries.py

This commit is contained in:
Gaumit Kauts
2026-02-15 12:20:05 -07:00
parent 08e277ccc0
commit 9635025aed

View File

@@ -379,8 +379,28 @@ def list_rag_chunks(post_id: int, page: int = 1, limit: int = 200) -> List[Dict[
return _rows(response)
def _extract_post_row(row: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalize joined post payload from Supabase. It can come back as:
- row["audio_posts"] = {...}
- row["audio_posts"] = [{...}]
- flat fields on row from RPC
"""
post = row.get("audio_posts")
if isinstance(post, list):
return post[0] if post else {}
if isinstance(post, dict):
return post
return {
"user_id": row.get("user_id"),
"visibility": row.get("visibility"),
"title": row.get("title"),
"created_at": row.get("created_at"),
}
def _can_access_post(row: Dict[str, Any], requester_user_id: int) -> bool:
post = row.get("audio_posts") or {}
post = _extract_post_row(row)
visibility = post.get("visibility")
owner_id = post.get("user_id")
return visibility == "public" or owner_id == requester_user_id
@@ -399,7 +419,11 @@ def search_rag_chunks(user_id: int, query_text: str, page: int = 1, limit: int =
.range(0, min(2000, end + 500))
.execute()
)
rows = [r for r in _rows(response) if _can_access_post(r, user_id)]
rows = []
for row in _rows(response):
if _can_access_post(row, user_id):
row["audio_posts"] = _extract_post_row(row)
rows.append(row)
return rows[start:end + 1]
@@ -432,7 +456,6 @@ def search_rag_chunks_vector(user_id: int, query_embedding: List[float], limit:
"chunk_id, post_id, start_sec, end_sec, text, confidence, created_at, embedding, "
"audio_posts!inner(post_id, user_id, title, visibility, created_at)"
)
.eq("audio_posts.user_id", user_id)
.limit(3000)
.execute()
)
@@ -453,6 +476,7 @@ def search_rag_chunks_vector(user_id: int, query_embedding: List[float], limit:
if score is None:
continue
out = dict(row)
out["audio_posts"] = _extract_post_row(row)
out["similarity"] = score
out.pop("embedding", None)
ranked.append(out)