Update db_queries.py
This commit is contained in:
@@ -379,8 +379,28 @@ def list_rag_chunks(post_id: int, page: int = 1, limit: int = 200) -> List[Dict[
|
|||||||
return _rows(response)
|
return _rows(response)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_post_row(row: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Normalize joined post payload from Supabase. It can come back as:
|
||||||
|
- row["audio_posts"] = {...}
|
||||||
|
- row["audio_posts"] = [{...}]
|
||||||
|
- flat fields on row from RPC
|
||||||
|
"""
|
||||||
|
post = row.get("audio_posts")
|
||||||
|
if isinstance(post, list):
|
||||||
|
return post[0] if post else {}
|
||||||
|
if isinstance(post, dict):
|
||||||
|
return post
|
||||||
|
return {
|
||||||
|
"user_id": row.get("user_id"),
|
||||||
|
"visibility": row.get("visibility"),
|
||||||
|
"title": row.get("title"),
|
||||||
|
"created_at": row.get("created_at"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _can_access_post(row: Dict[str, Any], requester_user_id: int) -> bool:
|
def _can_access_post(row: Dict[str, Any], requester_user_id: int) -> bool:
|
||||||
post = row.get("audio_posts") or {}
|
post = _extract_post_row(row)
|
||||||
visibility = post.get("visibility")
|
visibility = post.get("visibility")
|
||||||
owner_id = post.get("user_id")
|
owner_id = post.get("user_id")
|
||||||
return visibility == "public" or owner_id == requester_user_id
|
return visibility == "public" or owner_id == requester_user_id
|
||||||
@@ -399,7 +419,11 @@ def search_rag_chunks(user_id: int, query_text: str, page: int = 1, limit: int =
|
|||||||
.range(0, min(2000, end + 500))
|
.range(0, min(2000, end + 500))
|
||||||
.execute()
|
.execute()
|
||||||
)
|
)
|
||||||
rows = [r for r in _rows(response) if _can_access_post(r, user_id)]
|
rows = []
|
||||||
|
for row in _rows(response):
|
||||||
|
if _can_access_post(row, user_id):
|
||||||
|
row["audio_posts"] = _extract_post_row(row)
|
||||||
|
rows.append(row)
|
||||||
return rows[start:end + 1]
|
return rows[start:end + 1]
|
||||||
|
|
||||||
|
|
||||||
@@ -432,7 +456,6 @@ def search_rag_chunks_vector(user_id: int, query_embedding: List[float], limit:
|
|||||||
"chunk_id, post_id, start_sec, end_sec, text, confidence, created_at, embedding, "
|
"chunk_id, post_id, start_sec, end_sec, text, confidence, created_at, embedding, "
|
||||||
"audio_posts!inner(post_id, user_id, title, visibility, created_at)"
|
"audio_posts!inner(post_id, user_id, title, visibility, created_at)"
|
||||||
)
|
)
|
||||||
.eq("audio_posts.user_id", user_id)
|
|
||||||
.limit(3000)
|
.limit(3000)
|
||||||
.execute()
|
.execute()
|
||||||
)
|
)
|
||||||
@@ -453,6 +476,7 @@ def search_rag_chunks_vector(user_id: int, query_embedding: List[float], limit:
|
|||||||
if score is None:
|
if score is None:
|
||||||
continue
|
continue
|
||||||
out = dict(row)
|
out = dict(row)
|
||||||
|
out["audio_posts"] = _extract_post_row(row)
|
||||||
out["similarity"] = score
|
out["similarity"] = score
|
||||||
out.pop("embedding", None)
|
out.pop("embedding", None)
|
||||||
ranked.append(out)
|
ranked.append(out)
|
||||||
|
|||||||
Reference in New Issue
Block a user