Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
236 lines
7.9 KiB
Python
236 lines
7.9 KiB
Python
"""High-level RAG retrieval with reranking and KG fusion."""
|
|
|
|
from typing import Any
|
|
|
|
import structlog
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import (
|
|
FieldCondition,
|
|
Filter,
|
|
MatchValue,
|
|
SparseVector,
|
|
)
|
|
|
|
from .collection_manager import QdrantCollectionManager
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class RAGRetriever: # pylint: disable=too-few-public-methods
|
|
"""High-level RAG retrieval with reranking and KG fusion"""
|
|
|
|
def __init__(
|
|
self,
|
|
qdrant_client: QdrantClient,
|
|
neo4j_client: Any = None,
|
|
reranker_model: str | None = None,
|
|
) -> None:
|
|
self.collection_manager = QdrantCollectionManager(qdrant_client)
|
|
self.neo4j_client = neo4j_client
|
|
self.reranker_model = reranker_model
|
|
|
|
async def search( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
|
|
self,
|
|
query: str,
|
|
collections: list[str],
|
|
dense_vector: list[float],
|
|
sparse_vector: SparseVector,
|
|
k: int = 10,
|
|
alpha: float = 0.5,
|
|
beta: float = 0.3, # pylint: disable=unused-argument
|
|
gamma: float = 0.2, # pylint: disable=unused-argument
|
|
tax_year: str | None = None,
|
|
jurisdiction: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Perform comprehensive RAG search with KG fusion"""
|
|
|
|
# Build filter conditions
|
|
filter_conditions = self._build_filter(tax_year, jurisdiction)
|
|
|
|
# Search each collection
|
|
all_chunks = []
|
|
for collection in collections:
|
|
chunks = await self.collection_manager.hybrid_search(
|
|
collection_name=collection,
|
|
dense_vector=dense_vector,
|
|
sparse_vector=sparse_vector,
|
|
limit=k,
|
|
alpha=alpha,
|
|
filter_conditions=filter_conditions,
|
|
)
|
|
|
|
# Add collection info to chunks
|
|
for chunk in chunks:
|
|
chunk["collection"] = collection
|
|
|
|
all_chunks.extend(chunks)
|
|
|
|
# Re-rank if reranker is available
|
|
if self.reranker_model and len(all_chunks) > k:
|
|
all_chunks = await self._rerank_chunks(query, all_chunks, k)
|
|
|
|
# Sort by score and take top k
|
|
all_chunks.sort(key=lambda x: x["score"], reverse=True)
|
|
top_chunks = all_chunks[:k]
|
|
|
|
# Get KG hints if Neo4j client is available
|
|
kg_hints = []
|
|
if self.neo4j_client:
|
|
kg_hints = await self._get_kg_hints(query, top_chunks)
|
|
|
|
# Extract citations
|
|
citations = self._extract_citations(top_chunks)
|
|
|
|
# Calculate calibrated confidence
|
|
calibrated_confidence = self._calculate_confidence(top_chunks)
|
|
|
|
return {
|
|
"chunks": top_chunks,
|
|
"citations": citations,
|
|
"kg_hints": kg_hints,
|
|
"calibrated_confidence": calibrated_confidence,
|
|
}
|
|
|
|
def _build_filter(
|
|
self, tax_year: str | None = None, jurisdiction: str | None = None
|
|
) -> Filter | None:
|
|
"""Build Qdrant filter conditions"""
|
|
conditions = []
|
|
|
|
if jurisdiction:
|
|
conditions.append(
|
|
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
|
|
)
|
|
|
|
if tax_year:
|
|
conditions.append(
|
|
FieldCondition(key="tax_years", match=MatchValue(value=tax_year))
|
|
)
|
|
|
|
# Always require PII-free content
|
|
conditions.append(FieldCondition(key="pii_free", match=MatchValue(value=True)))
|
|
|
|
if conditions:
|
|
return Filter(must=conditions) # type: ignore
|
|
return None
|
|
|
|
async def _rerank_chunks( # pylint: disable=unused-argument
|
|
self, query: str, chunks: list[dict[str, Any]], k: int
|
|
) -> list[dict[str, Any]]:
|
|
"""Rerank chunks using cross-encoder model"""
|
|
try:
|
|
# This would integrate with a reranking service
|
|
# For now, return original chunks
|
|
logger.debug("Reranking not implemented, returning original order")
|
|
return chunks
|
|
|
|
except Exception as e: # pylint: disable=broad-exception-caught
|
|
logger.warning("Reranking failed, using original order", error=str(e))
|
|
return chunks
|
|
|
|
async def _get_kg_hints( # pylint: disable=unused-argument
|
|
self, query: str, chunks: list[dict[str, Any]]
|
|
) -> list[dict[str, Any]]:
|
|
"""Get knowledge graph hints related to the query"""
|
|
try:
|
|
# Extract potential rule/formula references from chunks
|
|
hints = []
|
|
|
|
for chunk in chunks:
|
|
payload = chunk.get("payload", {})
|
|
topic_tags = payload.get("topic_tags", [])
|
|
|
|
# Look for tax rules related to the topics
|
|
if topic_tags and self.neo4j_client:
|
|
kg_query = """
|
|
MATCH (r:Rule)-[:APPLIES_TO]->(topic)
|
|
WHERE topic.name IN $topics
|
|
AND r.retracted_at IS NULL
|
|
RETURN r.rule_id as rule_id,
|
|
r.formula as formula_id,
|
|
collect(id(topic)) as node_ids
|
|
LIMIT 5
|
|
"""
|
|
|
|
kg_results = await self.neo4j_client.run_query(
|
|
kg_query, {"topics": topic_tags}
|
|
)
|
|
|
|
for result in kg_results:
|
|
hints.append(
|
|
{
|
|
"rule_id": result["rule_id"],
|
|
"formula_id": result["formula_id"],
|
|
"node_ids": result["node_ids"],
|
|
}
|
|
)
|
|
|
|
return hints[:5] # Limit to top 5 hints
|
|
|
|
except Exception as e: # pylint: disable=broad-exception-caught
|
|
logger.warning("Failed to get KG hints", error=str(e))
|
|
return []
|
|
|
|
def _extract_citations(self, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
"""Extract citation information from chunks"""
|
|
citations = []
|
|
seen_docs = set()
|
|
|
|
for chunk in chunks:
|
|
payload = chunk.get("payload", {})
|
|
|
|
# Extract document reference
|
|
doc_id = payload.get("doc_id")
|
|
url = payload.get("url")
|
|
section_id = payload.get("section_id")
|
|
page = payload.get("page")
|
|
bbox = payload.get("bbox")
|
|
|
|
# Create citation key to avoid duplicates
|
|
citation_key = doc_id or url
|
|
if citation_key and citation_key not in seen_docs:
|
|
citation = {}
|
|
|
|
if doc_id:
|
|
citation["doc_id"] = doc_id
|
|
if url:
|
|
citation["url"] = url
|
|
if section_id:
|
|
citation["section_id"] = section_id
|
|
if page:
|
|
citation["page"] = page
|
|
if bbox:
|
|
citation["bbox"] = bbox
|
|
|
|
citations.append(citation)
|
|
seen_docs.add(citation_key)
|
|
|
|
return citations
|
|
|
|
def _calculate_confidence(self, chunks: list[dict[str, Any]]) -> float:
|
|
"""Calculate calibrated confidence score"""
|
|
if not chunks:
|
|
return 0.0
|
|
|
|
# Simple confidence calculation based on top scores
|
|
top_scores = [chunk["score"] for chunk in chunks[:3]]
|
|
|
|
if not top_scores:
|
|
return 0.0
|
|
|
|
# Average of top 3 scores with diminishing returns
|
|
weights = [0.5, 0.3, 0.2]
|
|
weighted_score = sum(
|
|
score * weight
|
|
for score, weight in zip(
|
|
top_scores, weights[: len(top_scores)], strict=False
|
|
)
|
|
)
|
|
|
|
# Apply calibration (simple temperature scaling)
|
|
# In production, this would use learned calibration parameters
|
|
temperature = 1.2
|
|
calibrated = weighted_score / temperature
|
|
|
|
return min(max(calibrated, 0.0), 1.0) # type: ignore
|