Initial commit
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
This commit is contained in:
235
libs/rag/retriever.py
Normal file
235
libs/rag/retriever.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""High-level RAG retrieval with reranking and KG fusion."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
SparseVector,
|
||||
)
|
||||
|
||||
from .collection_manager import QdrantCollectionManager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RAGRetriever: # pylint: disable=too-few-public-methods
|
||||
"""High-level RAG retrieval with reranking and KG fusion"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_client: QdrantClient,
|
||||
neo4j_client: Any = None,
|
||||
reranker_model: str | None = None,
|
||||
) -> None:
|
||||
self.collection_manager = QdrantCollectionManager(qdrant_client)
|
||||
self.neo4j_client = neo4j_client
|
||||
self.reranker_model = reranker_model
|
||||
|
||||
async def search( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
|
||||
self,
|
||||
query: str,
|
||||
collections: list[str],
|
||||
dense_vector: list[float],
|
||||
sparse_vector: SparseVector,
|
||||
k: int = 10,
|
||||
alpha: float = 0.5,
|
||||
beta: float = 0.3, # pylint: disable=unused-argument
|
||||
gamma: float = 0.2, # pylint: disable=unused-argument
|
||||
tax_year: str | None = None,
|
||||
jurisdiction: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Perform comprehensive RAG search with KG fusion"""
|
||||
|
||||
# Build filter conditions
|
||||
filter_conditions = self._build_filter(tax_year, jurisdiction)
|
||||
|
||||
# Search each collection
|
||||
all_chunks = []
|
||||
for collection in collections:
|
||||
chunks = await self.collection_manager.hybrid_search(
|
||||
collection_name=collection,
|
||||
dense_vector=dense_vector,
|
||||
sparse_vector=sparse_vector,
|
||||
limit=k,
|
||||
alpha=alpha,
|
||||
filter_conditions=filter_conditions,
|
||||
)
|
||||
|
||||
# Add collection info to chunks
|
||||
for chunk in chunks:
|
||||
chunk["collection"] = collection
|
||||
|
||||
all_chunks.extend(chunks)
|
||||
|
||||
# Re-rank if reranker is available
|
||||
if self.reranker_model and len(all_chunks) > k:
|
||||
all_chunks = await self._rerank_chunks(query, all_chunks, k)
|
||||
|
||||
# Sort by score and take top k
|
||||
all_chunks.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_chunks = all_chunks[:k]
|
||||
|
||||
# Get KG hints if Neo4j client is available
|
||||
kg_hints = []
|
||||
if self.neo4j_client:
|
||||
kg_hints = await self._get_kg_hints(query, top_chunks)
|
||||
|
||||
# Extract citations
|
||||
citations = self._extract_citations(top_chunks)
|
||||
|
||||
# Calculate calibrated confidence
|
||||
calibrated_confidence = self._calculate_confidence(top_chunks)
|
||||
|
||||
return {
|
||||
"chunks": top_chunks,
|
||||
"citations": citations,
|
||||
"kg_hints": kg_hints,
|
||||
"calibrated_confidence": calibrated_confidence,
|
||||
}
|
||||
|
||||
def _build_filter(
|
||||
self, tax_year: str | None = None, jurisdiction: str | None = None
|
||||
) -> Filter | None:
|
||||
"""Build Qdrant filter conditions"""
|
||||
conditions = []
|
||||
|
||||
if jurisdiction:
|
||||
conditions.append(
|
||||
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
|
||||
)
|
||||
|
||||
if tax_year:
|
||||
conditions.append(
|
||||
FieldCondition(key="tax_years", match=MatchValue(value=tax_year))
|
||||
)
|
||||
|
||||
# Always require PII-free content
|
||||
conditions.append(FieldCondition(key="pii_free", match=MatchValue(value=True)))
|
||||
|
||||
if conditions:
|
||||
return Filter(must=conditions) # type: ignore
|
||||
return None
|
||||
|
||||
async def _rerank_chunks( # pylint: disable=unused-argument
|
||||
self, query: str, chunks: list[dict[str, Any]], k: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Rerank chunks using cross-encoder model"""
|
||||
try:
|
||||
# This would integrate with a reranking service
|
||||
# For now, return original chunks
|
||||
logger.debug("Reranking not implemented, returning original order")
|
||||
return chunks
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.warning("Reranking failed, using original order", error=str(e))
|
||||
return chunks
|
||||
|
||||
async def _get_kg_hints( # pylint: disable=unused-argument
|
||||
self, query: str, chunks: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get knowledge graph hints related to the query"""
|
||||
try:
|
||||
# Extract potential rule/formula references from chunks
|
||||
hints = []
|
||||
|
||||
for chunk in chunks:
|
||||
payload = chunk.get("payload", {})
|
||||
topic_tags = payload.get("topic_tags", [])
|
||||
|
||||
# Look for tax rules related to the topics
|
||||
if topic_tags and self.neo4j_client:
|
||||
kg_query = """
|
||||
MATCH (r:Rule)-[:APPLIES_TO]->(topic)
|
||||
WHERE topic.name IN $topics
|
||||
AND r.retracted_at IS NULL
|
||||
RETURN r.rule_id as rule_id,
|
||||
r.formula as formula_id,
|
||||
collect(id(topic)) as node_ids
|
||||
LIMIT 5
|
||||
"""
|
||||
|
||||
kg_results = await self.neo4j_client.run_query(
|
||||
kg_query, {"topics": topic_tags}
|
||||
)
|
||||
|
||||
for result in kg_results:
|
||||
hints.append(
|
||||
{
|
||||
"rule_id": result["rule_id"],
|
||||
"formula_id": result["formula_id"],
|
||||
"node_ids": result["node_ids"],
|
||||
}
|
||||
)
|
||||
|
||||
return hints[:5] # Limit to top 5 hints
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.warning("Failed to get KG hints", error=str(e))
|
||||
return []
|
||||
|
||||
def _extract_citations(self, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Extract citation information from chunks"""
|
||||
citations = []
|
||||
seen_docs = set()
|
||||
|
||||
for chunk in chunks:
|
||||
payload = chunk.get("payload", {})
|
||||
|
||||
# Extract document reference
|
||||
doc_id = payload.get("doc_id")
|
||||
url = payload.get("url")
|
||||
section_id = payload.get("section_id")
|
||||
page = payload.get("page")
|
||||
bbox = payload.get("bbox")
|
||||
|
||||
# Create citation key to avoid duplicates
|
||||
citation_key = doc_id or url
|
||||
if citation_key and citation_key not in seen_docs:
|
||||
citation = {}
|
||||
|
||||
if doc_id:
|
||||
citation["doc_id"] = doc_id
|
||||
if url:
|
||||
citation["url"] = url
|
||||
if section_id:
|
||||
citation["section_id"] = section_id
|
||||
if page:
|
||||
citation["page"] = page
|
||||
if bbox:
|
||||
citation["bbox"] = bbox
|
||||
|
||||
citations.append(citation)
|
||||
seen_docs.add(citation_key)
|
||||
|
||||
return citations
|
||||
|
||||
def _calculate_confidence(self, chunks: list[dict[str, Any]]) -> float:
|
||||
"""Calculate calibrated confidence score"""
|
||||
if not chunks:
|
||||
return 0.0
|
||||
|
||||
# Simple confidence calculation based on top scores
|
||||
top_scores = [chunk["score"] for chunk in chunks[:3]]
|
||||
|
||||
if not top_scores:
|
||||
return 0.0
|
||||
|
||||
# Average of top 3 scores with diminishing returns
|
||||
weights = [0.5, 0.3, 0.2]
|
||||
weighted_score = sum(
|
||||
score * weight
|
||||
for score, weight in zip(
|
||||
top_scores, weights[: len(top_scores)], strict=False
|
||||
)
|
||||
)
|
||||
|
||||
# Apply calibration (simple temperature scaling)
|
||||
# In production, this would use learned calibration parameters
|
||||
temperature = 1.2
|
||||
calibrated = weighted_score / temperature
|
||||
|
||||
return min(max(calibrated, 0.0), 1.0) # type: ignore
|
||||
Reference in New Issue
Block a user