"""High-level RAG retrieval with reranking and KG fusion.""" from typing import Any import structlog from qdrant_client import QdrantClient from qdrant_client.models import ( FieldCondition, Filter, MatchValue, SparseVector, ) from .collection_manager import QdrantCollectionManager logger = structlog.get_logger() class RAGRetriever: # pylint: disable=too-few-public-methods """High-level RAG retrieval with reranking and KG fusion""" def __init__( self, qdrant_client: QdrantClient, neo4j_client: Any = None, reranker_model: str | None = None, ) -> None: self.collection_manager = QdrantCollectionManager(qdrant_client) self.neo4j_client = neo4j_client self.reranker_model = reranker_model async def search( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals self, query: str, collections: list[str], dense_vector: list[float], sparse_vector: SparseVector, k: int = 10, alpha: float = 0.5, beta: float = 0.3, # pylint: disable=unused-argument gamma: float = 0.2, # pylint: disable=unused-argument tax_year: str | None = None, jurisdiction: str | None = None, ) -> dict[str, Any]: """Perform comprehensive RAG search with KG fusion""" # Build filter conditions filter_conditions = self._build_filter(tax_year, jurisdiction) # Search each collection all_chunks = [] for collection in collections: chunks = await self.collection_manager.hybrid_search( collection_name=collection, dense_vector=dense_vector, sparse_vector=sparse_vector, limit=k, alpha=alpha, filter_conditions=filter_conditions, ) # Add collection info to chunks for chunk in chunks: chunk["collection"] = collection all_chunks.extend(chunks) # Re-rank if reranker is available if self.reranker_model and len(all_chunks) > k: all_chunks = await self._rerank_chunks(query, all_chunks, k) # Sort by score and take top k all_chunks.sort(key=lambda x: x["score"], reverse=True) top_chunks = all_chunks[:k] # Get KG hints if Neo4j client is available kg_hints = [] if self.neo4j_client: kg_hints = await self._get_kg_hints(query, top_chunks) # Extract citations citations = self._extract_citations(top_chunks) # Calculate calibrated confidence calibrated_confidence = self._calculate_confidence(top_chunks) return { "chunks": top_chunks, "citations": citations, "kg_hints": kg_hints, "calibrated_confidence": calibrated_confidence, } def _build_filter( self, tax_year: str | None = None, jurisdiction: str | None = None ) -> Filter | None: """Build Qdrant filter conditions""" conditions = [] if jurisdiction: conditions.append( FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction)) ) if tax_year: conditions.append( FieldCondition(key="tax_years", match=MatchValue(value=tax_year)) ) # Always require PII-free content conditions.append(FieldCondition(key="pii_free", match=MatchValue(value=True))) if conditions: return Filter(must=conditions) # type: ignore return None async def _rerank_chunks( # pylint: disable=unused-argument self, query: str, chunks: list[dict[str, Any]], k: int ) -> list[dict[str, Any]]: """Rerank chunks using cross-encoder model""" try: # This would integrate with a reranking service # For now, return original chunks logger.debug("Reranking not implemented, returning original order") return chunks except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Reranking failed, using original order", error=str(e)) return chunks async def _get_kg_hints( # pylint: disable=unused-argument self, query: str, chunks: list[dict[str, Any]] ) -> list[dict[str, Any]]: """Get knowledge graph hints related to the query""" try: # Extract potential rule/formula references from chunks hints = [] for chunk in chunks: payload = chunk.get("payload", {}) topic_tags = payload.get("topic_tags", []) # Look for tax rules related to the topics if topic_tags and self.neo4j_client: kg_query = """ MATCH (r:Rule)-[:APPLIES_TO]->(topic) WHERE topic.name IN $topics AND r.retracted_at IS NULL RETURN r.rule_id as rule_id, r.formula as formula_id, collect(id(topic)) as node_ids LIMIT 5 """ kg_results = await self.neo4j_client.run_query( kg_query, {"topics": topic_tags} ) for result in kg_results: hints.append( { "rule_id": result["rule_id"], "formula_id": result["formula_id"], "node_ids": result["node_ids"], } ) return hints[:5] # Limit to top 5 hints except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Failed to get KG hints", error=str(e)) return [] def _extract_citations(self, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]: """Extract citation information from chunks""" citations = [] seen_docs = set() for chunk in chunks: payload = chunk.get("payload", {}) # Extract document reference doc_id = payload.get("doc_id") url = payload.get("url") section_id = payload.get("section_id") page = payload.get("page") bbox = payload.get("bbox") # Create citation key to avoid duplicates citation_key = doc_id or url if citation_key and citation_key not in seen_docs: citation = {} if doc_id: citation["doc_id"] = doc_id if url: citation["url"] = url if section_id: citation["section_id"] = section_id if page: citation["page"] = page if bbox: citation["bbox"] = bbox citations.append(citation) seen_docs.add(citation_key) return citations def _calculate_confidence(self, chunks: list[dict[str, Any]]) -> float: """Calculate calibrated confidence score""" if not chunks: return 0.0 # Simple confidence calculation based on top scores top_scores = [chunk["score"] for chunk in chunks[:3]] if not top_scores: return 0.0 # Average of top 3 scores with diminishing returns weights = [0.5, 0.3, 0.2] weighted_score = sum( score * weight for score, weight in zip( top_scores, weights[: len(top_scores)], strict=False ) ) # Apply calibration (simple temperature scaling) # In production, this would use learned calibration parameters temperature = 1.2 calibrated = weighted_score / temperature return min(max(calibrated, 0.0), 1.0) # type: ignore