"""Manage Qdrant collections for RAG.""" from typing import Any import structlog from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, Filter, PointStruct, SparseVector, VectorParams, ) from .pii_detector import PIIDetector logger = structlog.get_logger() class QdrantCollectionManager: """Manage Qdrant collections for RAG""" def __init__(self, client: QdrantClient): self.client = client self.pii_detector = PIIDetector() async def ensure_collection( self, collection_name: str, vector_size: int = 384, distance: Distance = Distance.COSINE, sparse_vector_config: dict[str, Any] | None = None, ) -> bool: """Ensure collection exists with proper configuration""" try: # Check if collection exists collections = self.client.get_collections().collections if any(c.name == collection_name for c in collections): logger.debug("Collection already exists", collection=collection_name) return True # Create collection with dense vectors vector_config = VectorParams(size=vector_size, distance=distance) # Add sparse vector configuration if provided sparse_vectors_config = None if sparse_vector_config: sparse_vectors_config = {"sparse": sparse_vector_config} self.client.create_collection( collection_name=collection_name, vectors_config=vector_config, sparse_vectors_config=sparse_vectors_config, # type: ignore ) logger.info("Created collection", collection=collection_name) return True except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Failed to create collection", collection=collection_name, error=str(e) ) return False async def upsert_points( self, collection_name: str, points: list[PointStruct] ) -> bool: """Upsert points to collection""" try: # Validate all points are PII-free for point in points: if point.payload and not point.payload.get("pii_free", False): logger.warning("Point not marked as PII-free", point_id=point.id) return False self.client.upsert(collection_name=collection_name, points=points) logger.info( "Upserted points", collection=collection_name, count=len(points) ) return True except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Failed to upsert points", collection=collection_name, error=str(e) ) return False async def search_dense( # pylint: disable=too-many-arguments,too-many-positional-arguments self, collection_name: str, query_vector: list[float], limit: int = 10, filter_conditions: Filter | None = None, score_threshold: float | None = None, ) -> list[dict[str, Any]]: """Search using dense vectors""" try: search_result = self.client.search( collection_name=collection_name, query_vector=query_vector, query_filter=filter_conditions, limit=limit, score_threshold=score_threshold, with_payload=True, with_vectors=False, ) return [ {"id": hit.id, "score": hit.score, "payload": hit.payload} for hit in search_result ] except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Dense search failed", collection=collection_name, error=str(e) ) return [] async def search_sparse( self, collection_name: str, query_vector: SparseVector, limit: int = 10, filter_conditions: Filter | None = None, ) -> list[dict[str, Any]]: """Search using sparse vectors""" try: search_result = self.client.search( collection_name=collection_name, query_vector=query_vector, # type: ignore query_filter=filter_conditions, limit=limit, using="sparse", with_payload=True, with_vectors=False, ) return [ {"id": hit.id, "score": hit.score, "payload": hit.payload} for hit in search_result ] except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Sparse search failed", collection=collection_name, error=str(e) ) return [] async def hybrid_search( # pylint: disable=too-many-arguments,too-many-positional-arguments self, collection_name: str, dense_vector: list[float], sparse_vector: SparseVector, limit: int = 10, alpha: float = 0.5, filter_conditions: Filter | None = None, ) -> list[dict[str, Any]]: """Perform hybrid search combining dense and sparse results""" # Get dense results dense_results = await self.search_dense( collection_name=collection_name, query_vector=dense_vector, limit=limit * 2, # Get more results for fusion filter_conditions=filter_conditions, ) # Get sparse results sparse_results = await self.search_sparse( collection_name=collection_name, query_vector=sparse_vector, limit=limit * 2, filter_conditions=filter_conditions, ) # Combine and re-rank results return self._fuse_results(dense_results, sparse_results, alpha, limit) def _fuse_results( # pylint: disable=too-many-locals self, dense_results: list[dict[str, Any]], sparse_results: list[dict[str, Any]], alpha: float, limit: int, ) -> list[dict[str, Any]]: """Fuse dense and sparse search results""" # Create score maps dense_scores = {result["id"]: result["score"] for result in dense_results} sparse_scores = {result["id"]: result["score"] for result in sparse_results} # Get all unique IDs all_ids = set(dense_scores.keys()) | set(sparse_scores.keys()) # Calculate hybrid scores hybrid_results = [] for doc_id in all_ids: dense_score = dense_scores.get(doc_id, 0.0) sparse_score = sparse_scores.get(doc_id, 0.0) # Normalize scores (simple min-max normalization) if dense_results: max_dense = max(dense_scores.values()) dense_score = dense_score / max_dense if max_dense > 0 else 0 if sparse_results: max_sparse = max(sparse_scores.values()) sparse_score = sparse_score / max_sparse if max_sparse > 0 else 0 # Combine scores hybrid_score = alpha * dense_score + (1 - alpha) * sparse_score # Get payload from either result payload = None for result in dense_results + sparse_results: if result["id"] == doc_id: payload = result["payload"] break hybrid_results.append( { "id": doc_id, "score": hybrid_score, "dense_score": dense_score, "sparse_score": sparse_score, "payload": payload, } ) # Sort by hybrid score and return top results hybrid_results.sort(key=lambda x: x["score"], reverse=True) return hybrid_results[:limit]