Initial commit
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled

This commit is contained in:
harkon
2025-10-11 08:41:36 +01:00
commit b324ff09ef
276 changed files with 55220 additions and 0 deletions

13
libs/rag/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""Qdrant collections CRUD, hybrid search, rerank wrapper, de-identification utilities."""
from .collection_manager import QdrantCollectionManager
from .pii_detector import PIIDetector
from .retriever import RAGRetriever
from .utils import rag_search_for_citations
__all__ = [
"PIIDetector",
"QdrantCollectionManager",
"RAGRetriever",
"rag_search_for_citations",
]

View File

@@ -0,0 +1,233 @@
"""Manage Qdrant collections for RAG."""
from typing import Any
import structlog
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
Filter,
PointStruct,
SparseVector,
VectorParams,
)
from .pii_detector import PIIDetector
logger = structlog.get_logger()
class QdrantCollectionManager:
"""Manage Qdrant collections for RAG"""
def __init__(self, client: QdrantClient):
self.client = client
self.pii_detector = PIIDetector()
async def ensure_collection(
self,
collection_name: str,
vector_size: int = 384,
distance: Distance = Distance.COSINE,
sparse_vector_config: dict[str, Any] | None = None,
) -> bool:
"""Ensure collection exists with proper configuration"""
try:
# Check if collection exists
collections = self.client.get_collections().collections
if any(c.name == collection_name for c in collections):
logger.debug("Collection already exists", collection=collection_name)
return True
# Create collection with dense vectors
vector_config = VectorParams(size=vector_size, distance=distance)
# Add sparse vector configuration if provided
sparse_vectors_config = None
if sparse_vector_config:
sparse_vectors_config = {"sparse": sparse_vector_config}
self.client.create_collection(
collection_name=collection_name,
vectors_config=vector_config,
sparse_vectors_config=sparse_vectors_config, # type: ignore
)
logger.info("Created collection", collection=collection_name)
return True
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to create collection", collection=collection_name, error=str(e)
)
return False
async def upsert_points(
self, collection_name: str, points: list[PointStruct]
) -> bool:
"""Upsert points to collection"""
try:
# Validate all points are PII-free
for point in points:
if point.payload and not point.payload.get("pii_free", False):
logger.warning("Point not marked as PII-free", point_id=point.id)
return False
self.client.upsert(collection_name=collection_name, points=points)
logger.info(
"Upserted points", collection=collection_name, count=len(points)
)
return True
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to upsert points", collection=collection_name, error=str(e)
)
return False
async def search_dense( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
collection_name: str,
query_vector: list[float],
limit: int = 10,
filter_conditions: Filter | None = None,
score_threshold: float | None = None,
) -> list[dict[str, Any]]:
"""Search using dense vectors"""
try:
search_result = self.client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=filter_conditions,
limit=limit,
score_threshold=score_threshold,
with_payload=True,
with_vectors=False,
)
return [
{"id": hit.id, "score": hit.score, "payload": hit.payload}
for hit in search_result
]
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Dense search failed", collection=collection_name, error=str(e)
)
return []
async def search_sparse(
self,
collection_name: str,
query_vector: SparseVector,
limit: int = 10,
filter_conditions: Filter | None = None,
) -> list[dict[str, Any]]:
"""Search using sparse vectors"""
try:
search_result = self.client.search(
collection_name=collection_name,
query_vector=query_vector, # type: ignore
query_filter=filter_conditions,
limit=limit,
using="sparse",
with_payload=True,
with_vectors=False,
)
return [
{"id": hit.id, "score": hit.score, "payload": hit.payload}
for hit in search_result
]
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Sparse search failed", collection=collection_name, error=str(e)
)
return []
async def hybrid_search( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
collection_name: str,
dense_vector: list[float],
sparse_vector: SparseVector,
limit: int = 10,
alpha: float = 0.5,
filter_conditions: Filter | None = None,
) -> list[dict[str, Any]]:
"""Perform hybrid search combining dense and sparse results"""
# Get dense results
dense_results = await self.search_dense(
collection_name=collection_name,
query_vector=dense_vector,
limit=limit * 2, # Get more results for fusion
filter_conditions=filter_conditions,
)
# Get sparse results
sparse_results = await self.search_sparse(
collection_name=collection_name,
query_vector=sparse_vector,
limit=limit * 2,
filter_conditions=filter_conditions,
)
# Combine and re-rank results
return self._fuse_results(dense_results, sparse_results, alpha, limit)
def _fuse_results( # pylint: disable=too-many-locals
self,
dense_results: list[dict[str, Any]],
sparse_results: list[dict[str, Any]],
alpha: float,
limit: int,
) -> list[dict[str, Any]]:
"""Fuse dense and sparse search results"""
# Create score maps
dense_scores = {result["id"]: result["score"] for result in dense_results}
sparse_scores = {result["id"]: result["score"] for result in sparse_results}
# Get all unique IDs
all_ids = set(dense_scores.keys()) | set(sparse_scores.keys())
# Calculate hybrid scores
hybrid_results = []
for doc_id in all_ids:
dense_score = dense_scores.get(doc_id, 0.0)
sparse_score = sparse_scores.get(doc_id, 0.0)
# Normalize scores (simple min-max normalization)
if dense_results:
max_dense = max(dense_scores.values())
dense_score = dense_score / max_dense if max_dense > 0 else 0
if sparse_results:
max_sparse = max(sparse_scores.values())
sparse_score = sparse_score / max_sparse if max_sparse > 0 else 0
# Combine scores
hybrid_score = alpha * dense_score + (1 - alpha) * sparse_score
# Get payload from either result
payload = None
for result in dense_results + sparse_results:
if result["id"] == doc_id:
payload = result["payload"]
break
hybrid_results.append(
{
"id": doc_id,
"score": hybrid_score,
"dense_score": dense_score,
"sparse_score": sparse_score,
"payload": payload,
}
)
# Sort by hybrid score and return top results
hybrid_results.sort(key=lambda x: x["score"], reverse=True)
return hybrid_results[:limit]

507
libs/rag/indexer.py Normal file
View File

@@ -0,0 +1,507 @@
# FILE: retrieval/indexer.py
# De-identify -> embed dense/sparse -> upsert to Qdrant with payload
import json
import logging
import re
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
import numpy as np
import spacy
import torch
import yaml
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, SparseVector, VectorParams
from sentence_transformers import SentenceTransformer
from .chunker import DocumentChunker
from .pii_detector import PIIDetector, PIIRedactor
@dataclass
class IndexingResult:
collection_name: str
points_indexed: int
points_updated: int
points_failed: int
processing_time: float
errors: list[str]
class RAGIndexer:
def __init__(self, config_path: str, qdrant_url: str = "http://localhost:6333"):
with open(config_path) as f:
self.config = yaml.safe_load(f)
self.qdrant_client = QdrantClient(url=qdrant_url)
self.chunker = DocumentChunker(config_path)
self.pii_detector = PIIDetector()
self.pii_redactor = PIIRedactor()
# Initialize embedding models
self.dense_model = SentenceTransformer(
self.config.get("embedding_model", "bge-small-en-v1.5")
)
# Initialize sparse model (BM25/SPLADE)
self.sparse_model = self._init_sparse_model()
# Initialize NLP pipeline
self.nlp = spacy.load("en_core_web_sm")
self.logger = logging.getLogger(__name__)
def _init_sparse_model(self):
"""Initialize sparse embedding model (BM25 or SPLADE)"""
sparse_config = self.config.get("sparse_model", {})
model_type = sparse_config.get("type", "bm25")
if model_type == "bm25":
from rank_bm25 import BM25Okapi
return BM25Okapi
elif model_type == "splade":
from transformers import AutoModelForMaskedLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"naver/splade-cocondenser-ensembledistil"
)
model = AutoModelForMaskedLM.from_pretrained(
"naver/splade-cocondenser-ensembledistil"
)
return {"tokenizer": tokenizer, "model": model}
else:
raise ValueError(f"Unsupported sparse model type: {model_type}")
async def index_document(
self, document_path: str, collection_name: str, metadata: dict[str, Any]
) -> IndexingResult:
"""Index a single document into the specified collection"""
start_time = datetime.now()
errors = []
points_indexed = 0
points_updated = 0
points_failed = 0
try:
# Step 1: Chunk the document
chunks = await self.chunker.chunk_document(document_path, metadata)
# Step 2: Process each chunk
points = []
for chunk in chunks:
try:
point = await self._process_chunk(chunk, collection_name, metadata)
if point:
points.append(point)
except Exception as e:
self.logger.error(
f"Failed to process chunk {chunk.get('id', 'unknown')}: {str(e)}"
)
errors.append(f"Chunk processing error: {str(e)}")
points_failed += 1
# Step 3: Upsert to Qdrant
if points:
try:
operation_info = self.qdrant_client.upsert(
collection_name=collection_name, points=points, wait=True
)
points_indexed = len(points)
self.logger.info(
f"Indexed {points_indexed} points to {collection_name}"
)
except Exception as e:
self.logger.error(f"Failed to upsert to Qdrant: {str(e)}")
errors.append(f"Qdrant upsert error: {str(e)}")
points_failed += len(points)
points_indexed = 0
except Exception as e:
self.logger.error(f"Document indexing failed: {str(e)}")
errors.append(f"Document indexing error: {str(e)}")
processing_time = (datetime.now() - start_time).total_seconds()
return IndexingResult(
collection_name=collection_name,
points_indexed=points_indexed,
points_updated=points_updated,
points_failed=points_failed,
processing_time=processing_time,
errors=errors,
)
async def _process_chunk(
self, chunk: dict[str, Any], collection_name: str, base_metadata: dict[str, Any]
) -> PointStruct | None:
"""Process a single chunk: de-identify, embed, create point"""
# Step 1: De-identify PII
content = chunk["content"]
pii_detected = self.pii_detector.detect(content)
if pii_detected:
# Redact PII and create mapping
redacted_content, pii_mapping = self.pii_redactor.redact(
content, pii_detected
)
# Store PII mapping securely (not in vector DB)
await self._store_pii_mapping(chunk["id"], pii_mapping)
# Log PII detection for audit
self.logger.warning(
f"PII detected in chunk {chunk['id']}: {[p['type'] for p in pii_detected]}"
)
else:
redacted_content = content
# Verify no PII remains
if not self._verify_pii_free(redacted_content):
self.logger.error(f"PII verification failed for chunk {chunk['id']}")
return None
# Step 2: Generate embeddings
try:
dense_vector = await self._generate_dense_embedding(redacted_content)
sparse_vector = await self._generate_sparse_embedding(redacted_content)
except Exception as e:
self.logger.error(
f"Embedding generation failed for chunk {chunk['id']}: {str(e)}"
)
return None
# Step 3: Prepare metadata
payload = self._prepare_payload(chunk, base_metadata, redacted_content)
payload["pii_free"] = True # Verified above
# Step 4: Create point
point = PointStruct(
id=chunk["id"],
vector={"dense": dense_vector, "sparse": sparse_vector},
payload=payload,
)
return point
async def _generate_dense_embedding(self, text: str) -> list[float]:
"""Generate dense vector embedding"""
try:
# Use sentence transformer for dense embeddings
embedding = self.dense_model.encode(text, normalize_embeddings=True)
return embedding.tolist()
except Exception as e:
self.logger.error(f"Dense embedding generation failed: {str(e)}")
raise
async def _generate_sparse_embedding(self, text: str) -> SparseVector:
"""Generate sparse vector embedding (BM25 or SPLADE)"""
vector = SparseVector(indices=[], values=[])
try:
sparse_config = self.config.get("sparse_model", {})
model_type = sparse_config.get("type", "bm25")
if model_type == "bm25":
# Simple BM25-style sparse representation
doc = self.nlp(text)
tokens = [
token.lemma_.lower()
for token in doc
if not token.is_stop and not token.is_punct
]
# Create term frequency vector
term_freq = {}
for token in tokens:
term_freq[token] = term_freq.get(token, 0) + 1
# Convert to sparse vector format
vocab_size = sparse_config.get("vocab_size", 30000)
indices = []
values = []
for term, freq in term_freq.items():
# Simple hash-based vocabulary mapping
term_id = hash(term) % vocab_size
indices.append(term_id)
values.append(float(freq))
vector = SparseVector(indices=indices, values=values)
elif model_type == "splade":
# SPLADE sparse embeddings
tokenizer = self.sparse_model["tokenizer"]
model = self.sparse_model["model"]
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=512
)
outputs = model(**inputs)
# Extract sparse representation
logits = outputs.logits.squeeze()
sparse_rep = torch.relu(logits).detach().numpy()
# Convert to sparse format
indices = np.nonzero(sparse_rep)[0].tolist()
values = sparse_rep[indices].tolist()
vector = SparseVector(indices=indices, values=values)
return vector
except Exception as e:
self.logger.error(f"Sparse embedding generation failed: {str(e)}")
# Return empty sparse vector as fallback
return vector
def _prepare_payload(
self, chunk: dict[str, Any], base_metadata: dict[str, Any], content: str
) -> dict[str, Any]:
"""Prepare payload metadata for the chunk"""
# Start with base metadata
payload = base_metadata.copy()
# Add chunk-specific metadata
payload.update(
{
"document_id": chunk.get("document_id"),
"content": content, # De-identified content
"chunk_index": chunk.get("chunk_index", 0),
"total_chunks": chunk.get("total_chunks", 1),
"page_numbers": chunk.get("page_numbers", []),
"section_hierarchy": chunk.get("section_hierarchy", []),
"has_calculations": self._detect_calculations(content),
"has_forms": self._detect_form_references(content),
"confidence_score": chunk.get("confidence_score", 1.0),
"created_at": datetime.now().isoformat(),
"version": self.config.get("version", "1.0"),
}
)
# Extract and add topic tags
topic_tags = self._extract_topic_tags(content)
if topic_tags:
payload["topic_tags"] = topic_tags
# Add content analysis
payload.update(self._analyze_content(content))
return payload
def _detect_calculations(self, text: str) -> bool:
"""Detect if text contains calculations or formulas"""
calculation_patterns = [
r"\d+\s*[+\-*/]\s*\d+",
r"£\d+(?:,\d{3})*(?:\.\d{2})?",
r"\d+(?:\.\d+)?%",
r"total|sum|calculate|compute",
r"rate|threshold|allowance|relief",
]
for pattern in calculation_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def _detect_form_references(self, text: str) -> bool:
"""Detect references to tax forms"""
form_patterns = [
r"SA\d{3}",
r"P\d{2}",
r"CT\d{3}",
r"VAT\d{3}",
r"form\s+\w+",
r"schedule\s+\w+",
]
for pattern in form_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def _extract_topic_tags(self, text: str) -> list[str]:
"""Extract topic tags from content"""
topic_keywords = {
"employment": [
"PAYE",
"payslip",
"P60",
"employment",
"salary",
"wages",
"employer",
],
"self_employment": [
"self-employed",
"business",
"turnover",
"expenses",
"profit",
"loss",
],
"property": ["rental", "property", "landlord", "FHL", "mortgage", "rent"],
"dividends": ["dividend", "shares", "distribution", "corporation tax"],
"capital_gains": ["capital gains", "disposal", "acquisition", "CGT"],
"pensions": ["pension", "retirement", "SIPP", "occupational"],
"savings": ["interest", "savings", "ISA", "bonds"],
"inheritance": ["inheritance", "IHT", "estate", "probate"],
"vat": ["VAT", "value added tax", "registration", "return"],
}
tags = []
text_lower = text.lower()
for topic, keywords in topic_keywords.items():
for keyword in keywords:
if keyword.lower() in text_lower:
tags.append(topic)
break
return list(set(tags)) # Remove duplicates
def _analyze_content(self, text: str) -> dict[str, Any]:
"""Analyze content for additional metadata"""
doc = self.nlp(text)
return {
"word_count": len([token for token in doc if not token.is_space]),
"sentence_count": len(list(doc.sents)),
"entity_count": len(doc.ents),
"complexity_score": self._calculate_complexity(doc),
"language": doc.lang_ if hasattr(doc, "lang_") else "en",
}
def _calculate_complexity(self, doc: dict) -> float:
"""Calculate text complexity score"""
if not doc:
return 0.0
# Simple complexity based on sentence length and vocabulary
avg_sentence_length = sum(len(sent) for sent in doc.sents) / len(
list(doc.sents)
)
unique_words = len(set(token.lemma_.lower() for token in doc if token.is_alpha))
total_words = len([token for token in doc if token.is_alpha])
vocabulary_diversity = unique_words / total_words if total_words > 0 else 0
# Normalize to 0-1 scale
complexity = min(1.0, (avg_sentence_length / 20.0 + vocabulary_diversity) / 2.0)
return complexity
def _verify_pii_free(self, text: str) -> bool:
"""Verify that text contains no PII"""
# Quick verification using patterns
pii_patterns = [
r"\b[A-Z]{2}\d{6}[A-D]\b", # NI number
r"\b\d{10}\b", # UTR
r"\b[A-Z]{2}\d{2}[A-Z]{4}\d{14}\b", # IBAN
r"\b\d{2}-\d{2}-\d{2}\b", # Sort code
r"\b[A-Z]{1,2}\d[A-Z\d]?\s*\d[A-Z]{2}\b", # Postcode
r"\b[\w\.-]+@[\w\.-]+\.\w+\b", # Email
r"\b(?:\+44|0)\d{10,11}\b", # Phone
]
for pattern in pii_patterns:
if re.search(pattern, text):
return False
return True
async def _store_pii_mapping(
self, chunk_id: str, pii_mapping: dict[str, Any]
) -> None:
"""Store PII mapping in secure client data store (not in vector DB)"""
# This would integrate with the secure PostgreSQL client data store
# For now, just log the mapping securely
self.logger.info(
f"PII mapping stored for chunk {chunk_id}: {len(pii_mapping)} items"
)
async def create_collections(self) -> None:
"""Create all Qdrant collections based on configuration"""
collections_config_path = Path(__file__).parent / "qdrant_collections.json"
with open(collections_config_path) as f:
collections_config = json.load(f)
for collection_config in collections_config["collections"]:
collection_name = collection_config["name"]
try:
# Check if collection exists
try:
self.qdrant_client.get_collection(collection_name)
self.logger.info(f"Collection {collection_name} already exists")
continue
except:
pass # Collection doesn't exist, create it
# Create collection
vectors_config = {}
# Dense vector configuration
if "dense" in collection_config:
vectors_config["dense"] = VectorParams(
size=collection_config["dense"]["size"],
distance=Distance.COSINE,
)
# Sparse vector configuration
if collection_config.get("sparse", False):
vectors_config["sparse"] = VectorParams(
size=30000, # Vocabulary size for sparse vectors
distance=Distance.DOT,
on_disk=True,
)
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
**collection_config.get("indexing_config", {}),
)
self.logger.info(f"Created collection: {collection_name}")
except Exception as e:
self.logger.error(
f"Failed to create collection {collection_name}: {str(e)}"
)
raise
async def batch_index(
self, documents: list[dict[str, Any]], collection_name: str
) -> list[IndexingResult]:
"""Index multiple documents in batch"""
results = []
for doc_info in documents:
result = await self.index_document(
doc_info["path"], collection_name, doc_info["metadata"]
)
results.append(result)
return results
def get_collection_stats(self, collection_name: str) -> dict[str, Any]:
"""Get statistics for a collection"""
try:
collection_info = self.qdrant_client.get_collection(collection_name)
return {
"name": collection_name,
"vectors_count": collection_info.vectors_count,
"indexed_vectors_count": collection_info.indexed_vectors_count,
"points_count": collection_info.points_count,
"segments_count": collection_info.segments_count,
"status": collection_info.status,
}
except Exception as e:
self.logger.error(f"Failed to get stats for {collection_name}: {str(e)}")
return {"error": str(e)}

77
libs/rag/pii_detector.py Normal file
View File

@@ -0,0 +1,77 @@
"""PII detection and de-identification utilities."""
import hashlib
import re
from typing import Any
class PIIDetector:
"""PII detection and de-identification utilities"""
# Regex patterns for common PII
PII_PATTERNS = {
"uk_ni_number": r"\b[A-CEGHJ-PR-TW-Z]{2}\d{6}[A-D]\b",
"uk_utr": r"\b\d{10}\b",
"uk_postcode": r"\b[A-Z]{1,2}\d[A-Z0-9]?\s*\d[A-Z]{2}\b",
"uk_sort_code": r"\b\d{2}-\d{2}-\d{2}\b",
"uk_account_number": r"\b\d{8}\b",
"email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
"phone": r"\b(?:\+44|0)\d{10,11}\b",
"iban": r"\bGB\d{2}[A-Z]{4}\d{14}\b",
"amount": r"£\d{1,3}(?:,\d{3})*(?:\.\d{2})?",
"date": r"\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b",
}
def __init__(self) -> None:
self.compiled_patterns = {
name: re.compile(pattern, re.IGNORECASE)
for name, pattern in self.PII_PATTERNS.items()
}
def detect_pii(self, text: str) -> list[dict[str, Any]]:
"""Detect PII in text and return matches with positions"""
matches = []
for pii_type, pattern in self.compiled_patterns.items():
for match in pattern.finditer(text):
matches.append(
{
"type": pii_type,
"value": match.group(),
"start": match.start(),
"end": match.end(),
"placeholder": self._generate_placeholder(
pii_type, match.group()
),
}
)
return sorted(matches, key=lambda x: x["start"])
def de_identify_text(self, text: str) -> tuple[str, dict[str, str]]:
"""De-identify text by replacing PII with placeholders"""
pii_matches = self.detect_pii(text)
pii_mapping = {}
# Replace PII from end to start to maintain positions
de_identified = text
for match in reversed(pii_matches):
placeholder = match["placeholder"]
pii_mapping[placeholder] = match["value"]
de_identified = (
de_identified[: match["start"]]
+ placeholder
+ de_identified[match["end"] :]
)
return de_identified, pii_mapping
def _generate_placeholder(self, pii_type: str, value: str) -> str:
"""Generate consistent placeholder for PII value"""
# Create hash of the value for consistent placeholders
value_hash = hashlib.md5(value.encode()).hexdigest()[:8]
return f"[{pii_type.upper()}_{value_hash}]"
def has_pii(self, text: str) -> bool:
"""Check if text contains any PII"""
return len(self.detect_pii(text)) > 0

235
libs/rag/retriever.py Normal file
View File

@@ -0,0 +1,235 @@
"""High-level RAG retrieval with reranking and KG fusion."""
from typing import Any
import structlog
from qdrant_client import QdrantClient
from qdrant_client.models import (
FieldCondition,
Filter,
MatchValue,
SparseVector,
)
from .collection_manager import QdrantCollectionManager
logger = structlog.get_logger()
class RAGRetriever: # pylint: disable=too-few-public-methods
"""High-level RAG retrieval with reranking and KG fusion"""
def __init__(
self,
qdrant_client: QdrantClient,
neo4j_client: Any = None,
reranker_model: str | None = None,
) -> None:
self.collection_manager = QdrantCollectionManager(qdrant_client)
self.neo4j_client = neo4j_client
self.reranker_model = reranker_model
async def search( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
self,
query: str,
collections: list[str],
dense_vector: list[float],
sparse_vector: SparseVector,
k: int = 10,
alpha: float = 0.5,
beta: float = 0.3, # pylint: disable=unused-argument
gamma: float = 0.2, # pylint: disable=unused-argument
tax_year: str | None = None,
jurisdiction: str | None = None,
) -> dict[str, Any]:
"""Perform comprehensive RAG search with KG fusion"""
# Build filter conditions
filter_conditions = self._build_filter(tax_year, jurisdiction)
# Search each collection
all_chunks = []
for collection in collections:
chunks = await self.collection_manager.hybrid_search(
collection_name=collection,
dense_vector=dense_vector,
sparse_vector=sparse_vector,
limit=k,
alpha=alpha,
filter_conditions=filter_conditions,
)
# Add collection info to chunks
for chunk in chunks:
chunk["collection"] = collection
all_chunks.extend(chunks)
# Re-rank if reranker is available
if self.reranker_model and len(all_chunks) > k:
all_chunks = await self._rerank_chunks(query, all_chunks, k)
# Sort by score and take top k
all_chunks.sort(key=lambda x: x["score"], reverse=True)
top_chunks = all_chunks[:k]
# Get KG hints if Neo4j client is available
kg_hints = []
if self.neo4j_client:
kg_hints = await self._get_kg_hints(query, top_chunks)
# Extract citations
citations = self._extract_citations(top_chunks)
# Calculate calibrated confidence
calibrated_confidence = self._calculate_confidence(top_chunks)
return {
"chunks": top_chunks,
"citations": citations,
"kg_hints": kg_hints,
"calibrated_confidence": calibrated_confidence,
}
def _build_filter(
self, tax_year: str | None = None, jurisdiction: str | None = None
) -> Filter | None:
"""Build Qdrant filter conditions"""
conditions = []
if jurisdiction:
conditions.append(
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
)
if tax_year:
conditions.append(
FieldCondition(key="tax_years", match=MatchValue(value=tax_year))
)
# Always require PII-free content
conditions.append(FieldCondition(key="pii_free", match=MatchValue(value=True)))
if conditions:
return Filter(must=conditions) # type: ignore
return None
async def _rerank_chunks( # pylint: disable=unused-argument
self, query: str, chunks: list[dict[str, Any]], k: int
) -> list[dict[str, Any]]:
"""Rerank chunks using cross-encoder model"""
try:
# This would integrate with a reranking service
# For now, return original chunks
logger.debug("Reranking not implemented, returning original order")
return chunks
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Reranking failed, using original order", error=str(e))
return chunks
async def _get_kg_hints( # pylint: disable=unused-argument
self, query: str, chunks: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Get knowledge graph hints related to the query"""
try:
# Extract potential rule/formula references from chunks
hints = []
for chunk in chunks:
payload = chunk.get("payload", {})
topic_tags = payload.get("topic_tags", [])
# Look for tax rules related to the topics
if topic_tags and self.neo4j_client:
kg_query = """
MATCH (r:Rule)-[:APPLIES_TO]->(topic)
WHERE topic.name IN $topics
AND r.retracted_at IS NULL
RETURN r.rule_id as rule_id,
r.formula as formula_id,
collect(id(topic)) as node_ids
LIMIT 5
"""
kg_results = await self.neo4j_client.run_query(
kg_query, {"topics": topic_tags}
)
for result in kg_results:
hints.append(
{
"rule_id": result["rule_id"],
"formula_id": result["formula_id"],
"node_ids": result["node_ids"],
}
)
return hints[:5] # Limit to top 5 hints
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to get KG hints", error=str(e))
return []
def _extract_citations(self, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Extract citation information from chunks"""
citations = []
seen_docs = set()
for chunk in chunks:
payload = chunk.get("payload", {})
# Extract document reference
doc_id = payload.get("doc_id")
url = payload.get("url")
section_id = payload.get("section_id")
page = payload.get("page")
bbox = payload.get("bbox")
# Create citation key to avoid duplicates
citation_key = doc_id or url
if citation_key and citation_key not in seen_docs:
citation = {}
if doc_id:
citation["doc_id"] = doc_id
if url:
citation["url"] = url
if section_id:
citation["section_id"] = section_id
if page:
citation["page"] = page
if bbox:
citation["bbox"] = bbox
citations.append(citation)
seen_docs.add(citation_key)
return citations
def _calculate_confidence(self, chunks: list[dict[str, Any]]) -> float:
"""Calculate calibrated confidence score"""
if not chunks:
return 0.0
# Simple confidence calculation based on top scores
top_scores = [chunk["score"] for chunk in chunks[:3]]
if not top_scores:
return 0.0
# Average of top 3 scores with diminishing returns
weights = [0.5, 0.3, 0.2]
weighted_score = sum(
score * weight
for score, weight in zip(
top_scores, weights[: len(top_scores)], strict=False
)
)
# Apply calibration (simple temperature scaling)
# In production, this would use learned calibration parameters
temperature = 1.2
calibrated = weighted_score / temperature
return min(max(calibrated, 0.0), 1.0) # type: ignore

44
libs/rag/utils.py Normal file
View File

@@ -0,0 +1,44 @@
"""Coverage-specific RAG utility functions."""
from typing import Any
import structlog
from libs.schemas.coverage.evaluation import Citation
logger = structlog.get_logger()
async def rag_search_for_citations(
rag_client: Any, query: str, filters: dict[str, Any] | None = None
) -> list["Citation"]:
"""Search for citations using RAG with PII-free filtering"""
try:
# Ensure PII-free filter is always applied
search_filters = filters or {}
search_filters["pii_free"] = True
# This would integrate with the actual RAG retrieval system
# For now, return a placeholder implementation
logger.debug(
"RAG citation search called",
query=query,
filters=search_filters,
rag_client_available=rag_client is not None,
)
# Placeholder citations - in production this would call the RAG system
citations = [
Citation(
doc_id=f"RAG-{query.replace(' ', '-')[:20]}",
locator="Retrieved via RAG search",
url=f"https://guidance.example.com/search?q={query}",
)
]
return citations
except (ConnectionError, TimeoutError) as e:
logger.error("RAG citation search failed", query=query, error=str(e))
return []