"""Hybrid retrieval service combining BM25 + vector search.""" import logging from typing import List, Optional from datetime import datetime import numpy as np from sqlalchemy import text, func from sqlalchemy.orm import Session from sqlalchemy.dialects.postgresql import array from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams, PointStruct from app.config import settings from app.models import Document, Entity, QueryLog, Relation logger = logging.getLogger(__name__) class RetrievalService: """Hybrid BM25 + vector retrieval with RRF fusion.""" def __init__(self, session: Session): self.session = session self.weights = settings.HYBRID_RETRIEVAL_WEIGHTS self.embedding_model = SentenceTransformer(settings.EMBEDDING_MODEL) self.qdrant_client = QdrantClient(url=settings.QDRANT_URL) self.vector_size = 384 # bge-m3 dimension async def hybrid_query( self, query_text: str, domain: str, top_k: int = 5, min_relevance: float = 0.5, extract_entities: bool = True ) -> dict: """ Perform hybrid query combining BM25 and vector search. Uses Reciprocal Rank Fusion (RRF) to merge results: score = Σ (weight_i * 1/(k + rank_i)) """ start_time = datetime.utcnow() # TODO: Implement BM25 search using PostgreSQL FTS bm25_results = await self._bm25_search(query_text, domain, top_k * 2) # TODO: Implement vector search using Qdrant vector_results = await self._vector_search(query_text, domain, top_k * 2) # Merge with RRF merged = self._rrf_merge(bm25_results, vector_results) final_results = merged[:top_k] # Extract entities from results entities = [] relations = [] if extract_entities: entities, relations = await self._extract_entities_from_results( final_results, domain ) # Log query for evaluation await self._log_query(query_text, domain, final_results) latency_ms = (datetime.utcnow() - start_time).total_seconds() * 1000 return { "query": query_text, "domain": domain, "results": final_results, "entities": entities, "relations": relations, "total_results": len(final_results), "latency_ms": latency_ms } async def _bm25_search( self, query: str, domain: str, limit: int ) -> List[dict]: """BM25 full-text search using PostgreSQL FTS.""" try: # PostgreSQL full-text search with ts_rank for scoring sql = text(""" SELECT d.id, d.title, d.content, d.source, ts_rank(to_tsvector('english', d.content), plainto_tsquery('english', :query)) as relevance_score, 'bm25' as retrieval_method FROM document d WHERE d.domain = :domain AND to_tsvector('english', d.content) @@ plainto_tsquery('english', :query) ORDER BY relevance_score DESC LIMIT :limit """) result = self.session.execute( sql, { "query": query, "domain": domain, "limit": limit } ) rows = result.fetchall() return [ { "id": row.id, "title": row.title, "content": row.content, "source": row.source, "relevance_score": float(row.relevance_score), "retrieval_method": "bm25" } for row in rows ] except Exception as e: logger.error(f"BM25 search error: {e}") return [] async def _vector_search( self, query: str, domain: str, limit: int ) -> List[dict]: """Vector similarity search using Qdrant with bge-m3 embeddings.""" try: # Embed query using bge-m3 query_embedding = self.embedding_model.encode(query, convert_to_numpy=True) # Search Qdrant collection collection_name = f"documents_{domain}" search_result = self.qdrant_client.search( collection_name=collection_name, query_vector=query_embedding.tolist(), limit=limit, with_payload=True ) # Convert results to standard format results = [] for point in search_result: payload = point.payload results.append({ "id": payload.get("doc_id"), "title": payload.get("title", ""), "content": payload.get("content", ""), "source": payload.get("source", ""), "relevance_score": float(point.score), "retrieval_method": "vector" }) return results except Exception as e: logger.error(f"Vector search error: {e}") return [] def _rrf_merge(self, bm25_results: List[dict], vector_results: List[dict]) -> List[dict]: """Merge BM25 and vector results using Reciprocal Rank Fusion.""" k = 60 # Standard RRF parameter # Create position dicts positions = {} scores = {} for i, result in enumerate(bm25_results): doc_id = result["id"] positions[doc_id] = i + 1 scores[doc_id] = 0 for i, result in enumerate(vector_results): doc_id = result["id"] positions[doc_id] = i + 1 if doc_id not in scores: scores[doc_id] = 0 # Calculate RRF scores for doc_id in scores: w_bm25 = self.weights.get("bm25", 0.4) w_vector = self.weights.get("vector", 0.6) bm25_pos = positions.get(doc_id, float('inf')) vector_pos = positions.get(doc_id, float('inf')) bm25_score = w_bm25 * (1 / (k + bm25_pos)) if bm25_pos != float('inf') else 0 vector_score = w_vector * (1 / (k + vector_pos)) if vector_pos != float('inf') else 0 scores[doc_id] = bm25_score + vector_score # Sort by RRF score sorted_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True) # Reconstruct result objects merged = [] for doc_id, score in sorted_docs: # Find original result for result in bm25_results + vector_results: if result["id"] == doc_id and result not in merged: result["relevance_score"] = min(1.0, score) merged.append(result) break return merged async def _extract_entities_from_results( self, results: List[dict], domain: str ) -> tuple: """Extract entities and relations from retrieved documents.""" try: entities = [] relations = [] entity_ids_set = set() # Collect entity IDs from documents for result in results: doc_id = result.get("id") doc = self.session.query(Document).filter( Document.id == doc_id, Document.domain == domain ).first() if doc and doc.entity_ids: entity_ids_set.update(doc.entity_ids) # Fetch entities from database if entity_ids_set: fetched_entities = self.session.query(Entity).filter( Entity.id.in_(list(entity_ids_set)), Entity.domain == domain ).all() entities = [ { "entity_id": str(e.id), "name": e.name, "entity_type": e.entity_type, "confidence": float(e.confidence) } for e in fetched_entities ] # Fetch relations between these entities relation_list = self.session.query(Relation).filter( (Relation.source_id.in_(list(entity_ids_set))) | (Relation.target_id.in_(list(entity_ids_set))) ).all() relations = [ { "source_id": str(r.source_id), "relation_type": r.relation_type, "target_id": str(r.target_id), "strength": float(r.strength) } for r in relation_list ] return entities, relations except Exception as e: logger.error(f"Entity extraction error: {e}") return [], [] async def _log_query( self, query_text: str, domain: str, results: List[dict] ): """Log query for evaluation dataset building.""" try: retrieved_doc_ids = [result.get("id") for result in results] relevance_scores = [result.get("relevance_score", 0) for result in results] query_log = QueryLog( query_text=query_text, domain=domain, retrieved_doc_ids=retrieved_doc_ids, relevance_scores=relevance_scores ) self.session.add(query_log) self.session.commit() except Exception as e: logger.error(f"Query logging error: {e}") self.session.rollback()