Files
ai-tax-agent/libs/rag/retriever.py
harkon b324ff09ef
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Initial commit
2025-10-11 08:41:36 +01:00

236 lines
7.9 KiB
Python

"""High-level RAG retrieval with reranking and KG fusion."""
from typing import Any
import structlog
from qdrant_client import QdrantClient
from qdrant_client.models import (
FieldCondition,
Filter,
MatchValue,
SparseVector,
)
from .collection_manager import QdrantCollectionManager
logger = structlog.get_logger()
class RAGRetriever: # pylint: disable=too-few-public-methods
"""High-level RAG retrieval with reranking and KG fusion"""
def __init__(
self,
qdrant_client: QdrantClient,
neo4j_client: Any = None,
reranker_model: str | None = None,
) -> None:
self.collection_manager = QdrantCollectionManager(qdrant_client)
self.neo4j_client = neo4j_client
self.reranker_model = reranker_model
async def search( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
self,
query: str,
collections: list[str],
dense_vector: list[float],
sparse_vector: SparseVector,
k: int = 10,
alpha: float = 0.5,
beta: float = 0.3, # pylint: disable=unused-argument
gamma: float = 0.2, # pylint: disable=unused-argument
tax_year: str | None = None,
jurisdiction: str | None = None,
) -> dict[str, Any]:
"""Perform comprehensive RAG search with KG fusion"""
# Build filter conditions
filter_conditions = self._build_filter(tax_year, jurisdiction)
# Search each collection
all_chunks = []
for collection in collections:
chunks = await self.collection_manager.hybrid_search(
collection_name=collection,
dense_vector=dense_vector,
sparse_vector=sparse_vector,
limit=k,
alpha=alpha,
filter_conditions=filter_conditions,
)
# Add collection info to chunks
for chunk in chunks:
chunk["collection"] = collection
all_chunks.extend(chunks)
# Re-rank if reranker is available
if self.reranker_model and len(all_chunks) > k:
all_chunks = await self._rerank_chunks(query, all_chunks, k)
# Sort by score and take top k
all_chunks.sort(key=lambda x: x["score"], reverse=True)
top_chunks = all_chunks[:k]
# Get KG hints if Neo4j client is available
kg_hints = []
if self.neo4j_client:
kg_hints = await self._get_kg_hints(query, top_chunks)
# Extract citations
citations = self._extract_citations(top_chunks)
# Calculate calibrated confidence
calibrated_confidence = self._calculate_confidence(top_chunks)
return {
"chunks": top_chunks,
"citations": citations,
"kg_hints": kg_hints,
"calibrated_confidence": calibrated_confidence,
}
def _build_filter(
self, tax_year: str | None = None, jurisdiction: str | None = None
) -> Filter | None:
"""Build Qdrant filter conditions"""
conditions = []
if jurisdiction:
conditions.append(
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
)
if tax_year:
conditions.append(
FieldCondition(key="tax_years", match=MatchValue(value=tax_year))
)
# Always require PII-free content
conditions.append(FieldCondition(key="pii_free", match=MatchValue(value=True)))
if conditions:
return Filter(must=conditions) # type: ignore
return None
async def _rerank_chunks( # pylint: disable=unused-argument
self, query: str, chunks: list[dict[str, Any]], k: int
) -> list[dict[str, Any]]:
"""Rerank chunks using cross-encoder model"""
try:
# This would integrate with a reranking service
# For now, return original chunks
logger.debug("Reranking not implemented, returning original order")
return chunks
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Reranking failed, using original order", error=str(e))
return chunks
async def _get_kg_hints( # pylint: disable=unused-argument
self, query: str, chunks: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Get knowledge graph hints related to the query"""
try:
# Extract potential rule/formula references from chunks
hints = []
for chunk in chunks:
payload = chunk.get("payload", {})
topic_tags = payload.get("topic_tags", [])
# Look for tax rules related to the topics
if topic_tags and self.neo4j_client:
kg_query = """
MATCH (r:Rule)-[:APPLIES_TO]->(topic)
WHERE topic.name IN $topics
AND r.retracted_at IS NULL
RETURN r.rule_id as rule_id,
r.formula as formula_id,
collect(id(topic)) as node_ids
LIMIT 5
"""
kg_results = await self.neo4j_client.run_query(
kg_query, {"topics": topic_tags}
)
for result in kg_results:
hints.append(
{
"rule_id": result["rule_id"],
"formula_id": result["formula_id"],
"node_ids": result["node_ids"],
}
)
return hints[:5] # Limit to top 5 hints
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to get KG hints", error=str(e))
return []
def _extract_citations(self, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Extract citation information from chunks"""
citations = []
seen_docs = set()
for chunk in chunks:
payload = chunk.get("payload", {})
# Extract document reference
doc_id = payload.get("doc_id")
url = payload.get("url")
section_id = payload.get("section_id")
page = payload.get("page")
bbox = payload.get("bbox")
# Create citation key to avoid duplicates
citation_key = doc_id or url
if citation_key and citation_key not in seen_docs:
citation = {}
if doc_id:
citation["doc_id"] = doc_id
if url:
citation["url"] = url
if section_id:
citation["section_id"] = section_id
if page:
citation["page"] = page
if bbox:
citation["bbox"] = bbox
citations.append(citation)
seen_docs.add(citation_key)
return citations
def _calculate_confidence(self, chunks: list[dict[str, Any]]) -> float:
"""Calculate calibrated confidence score"""
if not chunks:
return 0.0
# Simple confidence calculation based on top scores
top_scores = [chunk["score"] for chunk in chunks[:3]]
if not top_scores:
return 0.0
# Average of top 3 scores with diminishing returns
weights = [0.5, 0.3, 0.2]
weighted_score = sum(
score * weight
for score, weight in zip(
top_scores, weights[: len(top_scores)], strict=False
)
)
# Apply calibration (simple temperature scaling)
# In production, this would use learned calibration parameters
temperature = 1.2
calibrated = weighted_score / temperature
return min(max(calibrated, 0.0), 1.0) # type: ignore