Files
ai-tax-agent/libs/rag/collection_manager.py
harkon b324ff09ef
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Initial commit
2025-10-11 08:41:36 +01:00

234 lines
7.9 KiB
Python

"""Manage Qdrant collections for RAG."""
from typing import Any
import structlog
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
Filter,
PointStruct,
SparseVector,
VectorParams,
)
from .pii_detector import PIIDetector
logger = structlog.get_logger()
class QdrantCollectionManager:
"""Manage Qdrant collections for RAG"""
def __init__(self, client: QdrantClient):
self.client = client
self.pii_detector = PIIDetector()
async def ensure_collection(
self,
collection_name: str,
vector_size: int = 384,
distance: Distance = Distance.COSINE,
sparse_vector_config: dict[str, Any] | None = None,
) -> bool:
"""Ensure collection exists with proper configuration"""
try:
# Check if collection exists
collections = self.client.get_collections().collections
if any(c.name == collection_name for c in collections):
logger.debug("Collection already exists", collection=collection_name)
return True
# Create collection with dense vectors
vector_config = VectorParams(size=vector_size, distance=distance)
# Add sparse vector configuration if provided
sparse_vectors_config = None
if sparse_vector_config:
sparse_vectors_config = {"sparse": sparse_vector_config}
self.client.create_collection(
collection_name=collection_name,
vectors_config=vector_config,
sparse_vectors_config=sparse_vectors_config, # type: ignore
)
logger.info("Created collection", collection=collection_name)
return True
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to create collection", collection=collection_name, error=str(e)
)
return False
async def upsert_points(
self, collection_name: str, points: list[PointStruct]
) -> bool:
"""Upsert points to collection"""
try:
# Validate all points are PII-free
for point in points:
if point.payload and not point.payload.get("pii_free", False):
logger.warning("Point not marked as PII-free", point_id=point.id)
return False
self.client.upsert(collection_name=collection_name, points=points)
logger.info(
"Upserted points", collection=collection_name, count=len(points)
)
return True
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to upsert points", collection=collection_name, error=str(e)
)
return False
async def search_dense( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
collection_name: str,
query_vector: list[float],
limit: int = 10,
filter_conditions: Filter | None = None,
score_threshold: float | None = None,
) -> list[dict[str, Any]]:
"""Search using dense vectors"""
try:
search_result = self.client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=filter_conditions,
limit=limit,
score_threshold=score_threshold,
with_payload=True,
with_vectors=False,
)
return [
{"id": hit.id, "score": hit.score, "payload": hit.payload}
for hit in search_result
]
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Dense search failed", collection=collection_name, error=str(e)
)
return []
async def search_sparse(
self,
collection_name: str,
query_vector: SparseVector,
limit: int = 10,
filter_conditions: Filter | None = None,
) -> list[dict[str, Any]]:
"""Search using sparse vectors"""
try:
search_result = self.client.search(
collection_name=collection_name,
query_vector=query_vector, # type: ignore
query_filter=filter_conditions,
limit=limit,
using="sparse",
with_payload=True,
with_vectors=False,
)
return [
{"id": hit.id, "score": hit.score, "payload": hit.payload}
for hit in search_result
]
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Sparse search failed", collection=collection_name, error=str(e)
)
return []
async def hybrid_search( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
collection_name: str,
dense_vector: list[float],
sparse_vector: SparseVector,
limit: int = 10,
alpha: float = 0.5,
filter_conditions: Filter | None = None,
) -> list[dict[str, Any]]:
"""Perform hybrid search combining dense and sparse results"""
# Get dense results
dense_results = await self.search_dense(
collection_name=collection_name,
query_vector=dense_vector,
limit=limit * 2, # Get more results for fusion
filter_conditions=filter_conditions,
)
# Get sparse results
sparse_results = await self.search_sparse(
collection_name=collection_name,
query_vector=sparse_vector,
limit=limit * 2,
filter_conditions=filter_conditions,
)
# Combine and re-rank results
return self._fuse_results(dense_results, sparse_results, alpha, limit)
def _fuse_results( # pylint: disable=too-many-locals
self,
dense_results: list[dict[str, Any]],
sparse_results: list[dict[str, Any]],
alpha: float,
limit: int,
) -> list[dict[str, Any]]:
"""Fuse dense and sparse search results"""
# Create score maps
dense_scores = {result["id"]: result["score"] for result in dense_results}
sparse_scores = {result["id"]: result["score"] for result in sparse_results}
# Get all unique IDs
all_ids = set(dense_scores.keys()) | set(sparse_scores.keys())
# Calculate hybrid scores
hybrid_results = []
for doc_id in all_ids:
dense_score = dense_scores.get(doc_id, 0.0)
sparse_score = sparse_scores.get(doc_id, 0.0)
# Normalize scores (simple min-max normalization)
if dense_results:
max_dense = max(dense_scores.values())
dense_score = dense_score / max_dense if max_dense > 0 else 0
if sparse_results:
max_sparse = max(sparse_scores.values())
sparse_score = sparse_score / max_sparse if max_sparse > 0 else 0
# Combine scores
hybrid_score = alpha * dense_score + (1 - alpha) * sparse_score
# Get payload from either result
payload = None
for result in dense_results + sparse_results:
if result["id"] == doc_id:
payload = result["payload"]
break
hybrid_results.append(
{
"id": doc_id,
"score": hybrid_score,
"dense_score": dense_score,
"sparse_score": sparse_score,
"payload": payload,
}
)
# Sort by hybrid score and return top results
hybrid_results.sort(key=lambda x: x["score"], reverse=True)
return hybrid_results[:limit]