Initial commit
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
This commit is contained in:
13
libs/rag/__init__.py
Normal file
13
libs/rag/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Qdrant collections CRUD, hybrid search, rerank wrapper, de-identification utilities."""
|
||||
|
||||
from .collection_manager import QdrantCollectionManager
|
||||
from .pii_detector import PIIDetector
|
||||
from .retriever import RAGRetriever
|
||||
from .utils import rag_search_for_citations
|
||||
|
||||
__all__ = [
|
||||
"PIIDetector",
|
||||
"QdrantCollectionManager",
|
||||
"RAGRetriever",
|
||||
"rag_search_for_citations",
|
||||
]
|
||||
233
libs/rag/collection_manager.py
Normal file
233
libs/rag/collection_manager.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Manage Qdrant collections for RAG."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
Filter,
|
||||
PointStruct,
|
||||
SparseVector,
|
||||
VectorParams,
|
||||
)
|
||||
|
||||
from .pii_detector import PIIDetector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class QdrantCollectionManager:
|
||||
"""Manage Qdrant collections for RAG"""
|
||||
|
||||
def __init__(self, client: QdrantClient):
|
||||
self.client = client
|
||||
self.pii_detector = PIIDetector()
|
||||
|
||||
async def ensure_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vector_size: int = 384,
|
||||
distance: Distance = Distance.COSINE,
|
||||
sparse_vector_config: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Ensure collection exists with proper configuration"""
|
||||
try:
|
||||
# Check if collection exists
|
||||
collections = self.client.get_collections().collections
|
||||
if any(c.name == collection_name for c in collections):
|
||||
logger.debug("Collection already exists", collection=collection_name)
|
||||
return True
|
||||
|
||||
# Create collection with dense vectors
|
||||
vector_config = VectorParams(size=vector_size, distance=distance)
|
||||
|
||||
# Add sparse vector configuration if provided
|
||||
sparse_vectors_config = None
|
||||
if sparse_vector_config:
|
||||
sparse_vectors_config = {"sparse": sparse_vector_config}
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vector_config,
|
||||
sparse_vectors_config=sparse_vectors_config, # type: ignore
|
||||
)
|
||||
|
||||
logger.info("Created collection", collection=collection_name)
|
||||
return True
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Failed to create collection", collection=collection_name, error=str(e)
|
||||
)
|
||||
return False
|
||||
|
||||
async def upsert_points(
|
||||
self, collection_name: str, points: list[PointStruct]
|
||||
) -> bool:
|
||||
"""Upsert points to collection"""
|
||||
try:
|
||||
# Validate all points are PII-free
|
||||
for point in points:
|
||||
if point.payload and not point.payload.get("pii_free", False):
|
||||
logger.warning("Point not marked as PII-free", point_id=point.id)
|
||||
return False
|
||||
|
||||
self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
logger.info(
|
||||
"Upserted points", collection=collection_name, count=len(points)
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Failed to upsert points", collection=collection_name, error=str(e)
|
||||
)
|
||||
return False
|
||||
|
||||
async def search_dense( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
self,
|
||||
collection_name: str,
|
||||
query_vector: list[float],
|
||||
limit: int = 10,
|
||||
filter_conditions: Filter | None = None,
|
||||
score_threshold: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search using dense vectors"""
|
||||
try:
|
||||
search_result = self.client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=filter_conditions,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
return [
|
||||
{"id": hit.id, "score": hit.score, "payload": hit.payload}
|
||||
for hit in search_result
|
||||
]
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Dense search failed", collection=collection_name, error=str(e)
|
||||
)
|
||||
return []
|
||||
|
||||
async def search_sparse(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_vector: SparseVector,
|
||||
limit: int = 10,
|
||||
filter_conditions: Filter | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search using sparse vectors"""
|
||||
try:
|
||||
search_result = self.client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector, # type: ignore
|
||||
query_filter=filter_conditions,
|
||||
limit=limit,
|
||||
using="sparse",
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
return [
|
||||
{"id": hit.id, "score": hit.score, "payload": hit.payload}
|
||||
for hit in search_result
|
||||
]
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Sparse search failed", collection=collection_name, error=str(e)
|
||||
)
|
||||
return []
|
||||
|
||||
async def hybrid_search( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
self,
|
||||
collection_name: str,
|
||||
dense_vector: list[float],
|
||||
sparse_vector: SparseVector,
|
||||
limit: int = 10,
|
||||
alpha: float = 0.5,
|
||||
filter_conditions: Filter | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Perform hybrid search combining dense and sparse results"""
|
||||
|
||||
# Get dense results
|
||||
dense_results = await self.search_dense(
|
||||
collection_name=collection_name,
|
||||
query_vector=dense_vector,
|
||||
limit=limit * 2, # Get more results for fusion
|
||||
filter_conditions=filter_conditions,
|
||||
)
|
||||
|
||||
# Get sparse results
|
||||
sparse_results = await self.search_sparse(
|
||||
collection_name=collection_name,
|
||||
query_vector=sparse_vector,
|
||||
limit=limit * 2,
|
||||
filter_conditions=filter_conditions,
|
||||
)
|
||||
|
||||
# Combine and re-rank results
|
||||
return self._fuse_results(dense_results, sparse_results, alpha, limit)
|
||||
|
||||
def _fuse_results( # pylint: disable=too-many-locals
|
||||
self,
|
||||
dense_results: list[dict[str, Any]],
|
||||
sparse_results: list[dict[str, Any]],
|
||||
alpha: float,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fuse dense and sparse search results"""
|
||||
|
||||
# Create score maps
|
||||
dense_scores = {result["id"]: result["score"] for result in dense_results}
|
||||
sparse_scores = {result["id"]: result["score"] for result in sparse_results}
|
||||
|
||||
# Get all unique IDs
|
||||
all_ids = set(dense_scores.keys()) | set(sparse_scores.keys())
|
||||
|
||||
# Calculate hybrid scores
|
||||
hybrid_results = []
|
||||
for doc_id in all_ids:
|
||||
dense_score = dense_scores.get(doc_id, 0.0)
|
||||
sparse_score = sparse_scores.get(doc_id, 0.0)
|
||||
|
||||
# Normalize scores (simple min-max normalization)
|
||||
if dense_results:
|
||||
max_dense = max(dense_scores.values())
|
||||
dense_score = dense_score / max_dense if max_dense > 0 else 0
|
||||
|
||||
if sparse_results:
|
||||
max_sparse = max(sparse_scores.values())
|
||||
sparse_score = sparse_score / max_sparse if max_sparse > 0 else 0
|
||||
|
||||
# Combine scores
|
||||
hybrid_score = alpha * dense_score + (1 - alpha) * sparse_score
|
||||
|
||||
# Get payload from either result
|
||||
payload = None
|
||||
for result in dense_results + sparse_results:
|
||||
if result["id"] == doc_id:
|
||||
payload = result["payload"]
|
||||
break
|
||||
|
||||
hybrid_results.append(
|
||||
{
|
||||
"id": doc_id,
|
||||
"score": hybrid_score,
|
||||
"dense_score": dense_score,
|
||||
"sparse_score": sparse_score,
|
||||
"payload": payload,
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by hybrid score and return top results
|
||||
hybrid_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return hybrid_results[:limit]
|
||||
507
libs/rag/indexer.py
Normal file
507
libs/rag/indexer.py
Normal file
@@ -0,0 +1,507 @@
|
||||
# FILE: retrieval/indexer.py
|
||||
# De-identify -> embed dense/sparse -> upsert to Qdrant with payload
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import spacy
|
||||
import torch
|
||||
import yaml
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Distance, PointStruct, SparseVector, VectorParams
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from .chunker import DocumentChunker
|
||||
from .pii_detector import PIIDetector, PIIRedactor
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexingResult:
|
||||
collection_name: str
|
||||
points_indexed: int
|
||||
points_updated: int
|
||||
points_failed: int
|
||||
processing_time: float
|
||||
errors: list[str]
|
||||
|
||||
|
||||
class RAGIndexer:
|
||||
def __init__(self, config_path: str, qdrant_url: str = "http://localhost:6333"):
|
||||
with open(config_path) as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
|
||||
self.qdrant_client = QdrantClient(url=qdrant_url)
|
||||
self.chunker = DocumentChunker(config_path)
|
||||
self.pii_detector = PIIDetector()
|
||||
self.pii_redactor = PIIRedactor()
|
||||
|
||||
# Initialize embedding models
|
||||
self.dense_model = SentenceTransformer(
|
||||
self.config.get("embedding_model", "bge-small-en-v1.5")
|
||||
)
|
||||
|
||||
# Initialize sparse model (BM25/SPLADE)
|
||||
self.sparse_model = self._init_sparse_model()
|
||||
|
||||
# Initialize NLP pipeline
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _init_sparse_model(self):
|
||||
"""Initialize sparse embedding model (BM25 or SPLADE)"""
|
||||
sparse_config = self.config.get("sparse_model", {})
|
||||
model_type = sparse_config.get("type", "bm25")
|
||||
|
||||
if model_type == "bm25":
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
return BM25Okapi
|
||||
elif model_type == "splade":
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"naver/splade-cocondenser-ensembledistil"
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
"naver/splade-cocondenser-ensembledistil"
|
||||
)
|
||||
return {"tokenizer": tokenizer, "model": model}
|
||||
else:
|
||||
raise ValueError(f"Unsupported sparse model type: {model_type}")
|
||||
|
||||
async def index_document(
|
||||
self, document_path: str, collection_name: str, metadata: dict[str, Any]
|
||||
) -> IndexingResult:
|
||||
"""Index a single document into the specified collection"""
|
||||
start_time = datetime.now()
|
||||
errors = []
|
||||
points_indexed = 0
|
||||
points_updated = 0
|
||||
points_failed = 0
|
||||
|
||||
try:
|
||||
# Step 1: Chunk the document
|
||||
chunks = await self.chunker.chunk_document(document_path, metadata)
|
||||
|
||||
# Step 2: Process each chunk
|
||||
points = []
|
||||
for chunk in chunks:
|
||||
try:
|
||||
point = await self._process_chunk(chunk, collection_name, metadata)
|
||||
if point:
|
||||
points.append(point)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Failed to process chunk {chunk.get('id', 'unknown')}: {str(e)}"
|
||||
)
|
||||
errors.append(f"Chunk processing error: {str(e)}")
|
||||
points_failed += 1
|
||||
|
||||
# Step 3: Upsert to Qdrant
|
||||
if points:
|
||||
try:
|
||||
operation_info = self.qdrant_client.upsert(
|
||||
collection_name=collection_name, points=points, wait=True
|
||||
)
|
||||
points_indexed = len(points)
|
||||
self.logger.info(
|
||||
f"Indexed {points_indexed} points to {collection_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to upsert to Qdrant: {str(e)}")
|
||||
errors.append(f"Qdrant upsert error: {str(e)}")
|
||||
points_failed += len(points)
|
||||
points_indexed = 0
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Document indexing failed: {str(e)}")
|
||||
errors.append(f"Document indexing error: {str(e)}")
|
||||
|
||||
processing_time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
return IndexingResult(
|
||||
collection_name=collection_name,
|
||||
points_indexed=points_indexed,
|
||||
points_updated=points_updated,
|
||||
points_failed=points_failed,
|
||||
processing_time=processing_time,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def _process_chunk(
|
||||
self, chunk: dict[str, Any], collection_name: str, base_metadata: dict[str, Any]
|
||||
) -> PointStruct | None:
|
||||
"""Process a single chunk: de-identify, embed, create point"""
|
||||
|
||||
# Step 1: De-identify PII
|
||||
content = chunk["content"]
|
||||
pii_detected = self.pii_detector.detect(content)
|
||||
|
||||
if pii_detected:
|
||||
# Redact PII and create mapping
|
||||
redacted_content, pii_mapping = self.pii_redactor.redact(
|
||||
content, pii_detected
|
||||
)
|
||||
|
||||
# Store PII mapping securely (not in vector DB)
|
||||
await self._store_pii_mapping(chunk["id"], pii_mapping)
|
||||
|
||||
# Log PII detection for audit
|
||||
self.logger.warning(
|
||||
f"PII detected in chunk {chunk['id']}: {[p['type'] for p in pii_detected]}"
|
||||
)
|
||||
else:
|
||||
redacted_content = content
|
||||
|
||||
# Verify no PII remains
|
||||
if not self._verify_pii_free(redacted_content):
|
||||
self.logger.error(f"PII verification failed for chunk {chunk['id']}")
|
||||
return None
|
||||
|
||||
# Step 2: Generate embeddings
|
||||
try:
|
||||
dense_vector = await self._generate_dense_embedding(redacted_content)
|
||||
sparse_vector = await self._generate_sparse_embedding(redacted_content)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Embedding generation failed for chunk {chunk['id']}: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Step 3: Prepare metadata
|
||||
payload = self._prepare_payload(chunk, base_metadata, redacted_content)
|
||||
payload["pii_free"] = True # Verified above
|
||||
|
||||
# Step 4: Create point
|
||||
point = PointStruct(
|
||||
id=chunk["id"],
|
||||
vector={"dense": dense_vector, "sparse": sparse_vector},
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
return point
|
||||
|
||||
async def _generate_dense_embedding(self, text: str) -> list[float]:
|
||||
"""Generate dense vector embedding"""
|
||||
try:
|
||||
# Use sentence transformer for dense embeddings
|
||||
embedding = self.dense_model.encode(text, normalize_embeddings=True)
|
||||
return embedding.tolist()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Dense embedding generation failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _generate_sparse_embedding(self, text: str) -> SparseVector:
|
||||
"""Generate sparse vector embedding (BM25 or SPLADE)"""
|
||||
vector = SparseVector(indices=[], values=[])
|
||||
|
||||
try:
|
||||
sparse_config = self.config.get("sparse_model", {})
|
||||
model_type = sparse_config.get("type", "bm25")
|
||||
|
||||
if model_type == "bm25":
|
||||
# Simple BM25-style sparse representation
|
||||
doc = self.nlp(text)
|
||||
tokens = [
|
||||
token.lemma_.lower()
|
||||
for token in doc
|
||||
if not token.is_stop and not token.is_punct
|
||||
]
|
||||
|
||||
# Create term frequency vector
|
||||
term_freq = {}
|
||||
for token in tokens:
|
||||
term_freq[token] = term_freq.get(token, 0) + 1
|
||||
|
||||
# Convert to sparse vector format
|
||||
vocab_size = sparse_config.get("vocab_size", 30000)
|
||||
indices = []
|
||||
values = []
|
||||
|
||||
for term, freq in term_freq.items():
|
||||
# Simple hash-based vocabulary mapping
|
||||
term_id = hash(term) % vocab_size
|
||||
indices.append(term_id)
|
||||
values.append(float(freq))
|
||||
|
||||
vector = SparseVector(indices=indices, values=values)
|
||||
|
||||
elif model_type == "splade":
|
||||
# SPLADE sparse embeddings
|
||||
tokenizer = self.sparse_model["tokenizer"]
|
||||
model = self.sparse_model["model"]
|
||||
|
||||
inputs = tokenizer(
|
||||
text, return_tensors="pt", truncation=True, max_length=512
|
||||
)
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Extract sparse representation
|
||||
logits = outputs.logits.squeeze()
|
||||
sparse_rep = torch.relu(logits).detach().numpy()
|
||||
|
||||
# Convert to sparse format
|
||||
indices = np.nonzero(sparse_rep)[0].tolist()
|
||||
values = sparse_rep[indices].tolist()
|
||||
|
||||
vector = SparseVector(indices=indices, values=values)
|
||||
|
||||
return vector
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Sparse embedding generation failed: {str(e)}")
|
||||
# Return empty sparse vector as fallback
|
||||
return vector
|
||||
|
||||
def _prepare_payload(
|
||||
self, chunk: dict[str, Any], base_metadata: dict[str, Any], content: str
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare payload metadata for the chunk"""
|
||||
|
||||
# Start with base metadata
|
||||
payload = base_metadata.copy()
|
||||
|
||||
# Add chunk-specific metadata
|
||||
payload.update(
|
||||
{
|
||||
"document_id": chunk.get("document_id"),
|
||||
"content": content, # De-identified content
|
||||
"chunk_index": chunk.get("chunk_index", 0),
|
||||
"total_chunks": chunk.get("total_chunks", 1),
|
||||
"page_numbers": chunk.get("page_numbers", []),
|
||||
"section_hierarchy": chunk.get("section_hierarchy", []),
|
||||
"has_calculations": self._detect_calculations(content),
|
||||
"has_forms": self._detect_form_references(content),
|
||||
"confidence_score": chunk.get("confidence_score", 1.0),
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"version": self.config.get("version", "1.0"),
|
||||
}
|
||||
)
|
||||
|
||||
# Extract and add topic tags
|
||||
topic_tags = self._extract_topic_tags(content)
|
||||
if topic_tags:
|
||||
payload["topic_tags"] = topic_tags
|
||||
|
||||
# Add content analysis
|
||||
payload.update(self._analyze_content(content))
|
||||
|
||||
return payload
|
||||
|
||||
def _detect_calculations(self, text: str) -> bool:
|
||||
"""Detect if text contains calculations or formulas"""
|
||||
calculation_patterns = [
|
||||
r"\d+\s*[+\-*/]\s*\d+",
|
||||
r"£\d+(?:,\d{3})*(?:\.\d{2})?",
|
||||
r"\d+(?:\.\d+)?%",
|
||||
r"total|sum|calculate|compute",
|
||||
r"rate|threshold|allowance|relief",
|
||||
]
|
||||
|
||||
for pattern in calculation_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _detect_form_references(self, text: str) -> bool:
|
||||
"""Detect references to tax forms"""
|
||||
form_patterns = [
|
||||
r"SA\d{3}",
|
||||
r"P\d{2}",
|
||||
r"CT\d{3}",
|
||||
r"VAT\d{3}",
|
||||
r"form\s+\w+",
|
||||
r"schedule\s+\w+",
|
||||
]
|
||||
|
||||
for pattern in form_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_topic_tags(self, text: str) -> list[str]:
|
||||
"""Extract topic tags from content"""
|
||||
topic_keywords = {
|
||||
"employment": [
|
||||
"PAYE",
|
||||
"payslip",
|
||||
"P60",
|
||||
"employment",
|
||||
"salary",
|
||||
"wages",
|
||||
"employer",
|
||||
],
|
||||
"self_employment": [
|
||||
"self-employed",
|
||||
"business",
|
||||
"turnover",
|
||||
"expenses",
|
||||
"profit",
|
||||
"loss",
|
||||
],
|
||||
"property": ["rental", "property", "landlord", "FHL", "mortgage", "rent"],
|
||||
"dividends": ["dividend", "shares", "distribution", "corporation tax"],
|
||||
"capital_gains": ["capital gains", "disposal", "acquisition", "CGT"],
|
||||
"pensions": ["pension", "retirement", "SIPP", "occupational"],
|
||||
"savings": ["interest", "savings", "ISA", "bonds"],
|
||||
"inheritance": ["inheritance", "IHT", "estate", "probate"],
|
||||
"vat": ["VAT", "value added tax", "registration", "return"],
|
||||
}
|
||||
|
||||
tags = []
|
||||
text_lower = text.lower()
|
||||
|
||||
for topic, keywords in topic_keywords.items():
|
||||
for keyword in keywords:
|
||||
if keyword.lower() in text_lower:
|
||||
tags.append(topic)
|
||||
break
|
||||
|
||||
return list(set(tags)) # Remove duplicates
|
||||
|
||||
def _analyze_content(self, text: str) -> dict[str, Any]:
|
||||
"""Analyze content for additional metadata"""
|
||||
doc = self.nlp(text)
|
||||
|
||||
return {
|
||||
"word_count": len([token for token in doc if not token.is_space]),
|
||||
"sentence_count": len(list(doc.sents)),
|
||||
"entity_count": len(doc.ents),
|
||||
"complexity_score": self._calculate_complexity(doc),
|
||||
"language": doc.lang_ if hasattr(doc, "lang_") else "en",
|
||||
}
|
||||
|
||||
def _calculate_complexity(self, doc: dict) -> float:
|
||||
"""Calculate text complexity score"""
|
||||
if not doc:
|
||||
return 0.0
|
||||
|
||||
# Simple complexity based on sentence length and vocabulary
|
||||
avg_sentence_length = sum(len(sent) for sent in doc.sents) / len(
|
||||
list(doc.sents)
|
||||
)
|
||||
unique_words = len(set(token.lemma_.lower() for token in doc if token.is_alpha))
|
||||
total_words = len([token for token in doc if token.is_alpha])
|
||||
|
||||
vocabulary_diversity = unique_words / total_words if total_words > 0 else 0
|
||||
|
||||
# Normalize to 0-1 scale
|
||||
complexity = min(1.0, (avg_sentence_length / 20.0 + vocabulary_diversity) / 2.0)
|
||||
return complexity
|
||||
|
||||
def _verify_pii_free(self, text: str) -> bool:
|
||||
"""Verify that text contains no PII"""
|
||||
# Quick verification using patterns
|
||||
pii_patterns = [
|
||||
r"\b[A-Z]{2}\d{6}[A-D]\b", # NI number
|
||||
r"\b\d{10}\b", # UTR
|
||||
r"\b[A-Z]{2}\d{2}[A-Z]{4}\d{14}\b", # IBAN
|
||||
r"\b\d{2}-\d{2}-\d{2}\b", # Sort code
|
||||
r"\b[A-Z]{1,2}\d[A-Z\d]?\s*\d[A-Z]{2}\b", # Postcode
|
||||
r"\b[\w\.-]+@[\w\.-]+\.\w+\b", # Email
|
||||
r"\b(?:\+44|0)\d{10,11}\b", # Phone
|
||||
]
|
||||
|
||||
for pattern in pii_patterns:
|
||||
if re.search(pattern, text):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _store_pii_mapping(
|
||||
self, chunk_id: str, pii_mapping: dict[str, Any]
|
||||
) -> None:
|
||||
"""Store PII mapping in secure client data store (not in vector DB)"""
|
||||
# This would integrate with the secure PostgreSQL client data store
|
||||
# For now, just log the mapping securely
|
||||
self.logger.info(
|
||||
f"PII mapping stored for chunk {chunk_id}: {len(pii_mapping)} items"
|
||||
)
|
||||
|
||||
async def create_collections(self) -> None:
|
||||
"""Create all Qdrant collections based on configuration"""
|
||||
collections_config_path = Path(__file__).parent / "qdrant_collections.json"
|
||||
|
||||
with open(collections_config_path) as f:
|
||||
collections_config = json.load(f)
|
||||
|
||||
for collection_config in collections_config["collections"]:
|
||||
collection_name = collection_config["name"]
|
||||
|
||||
try:
|
||||
# Check if collection exists
|
||||
try:
|
||||
self.qdrant_client.get_collection(collection_name)
|
||||
self.logger.info(f"Collection {collection_name} already exists")
|
||||
continue
|
||||
except:
|
||||
pass # Collection doesn't exist, create it
|
||||
|
||||
# Create collection
|
||||
vectors_config = {}
|
||||
|
||||
# Dense vector configuration
|
||||
if "dense" in collection_config:
|
||||
vectors_config["dense"] = VectorParams(
|
||||
size=collection_config["dense"]["size"],
|
||||
distance=Distance.COSINE,
|
||||
)
|
||||
|
||||
# Sparse vector configuration
|
||||
if collection_config.get("sparse", False):
|
||||
vectors_config["sparse"] = VectorParams(
|
||||
size=30000, # Vocabulary size for sparse vectors
|
||||
distance=Distance.DOT,
|
||||
on_disk=True,
|
||||
)
|
||||
|
||||
self.qdrant_client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
**collection_config.get("indexing_config", {}),
|
||||
)
|
||||
|
||||
self.logger.info(f"Created collection: {collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Failed to create collection {collection_name}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def batch_index(
|
||||
self, documents: list[dict[str, Any]], collection_name: str
|
||||
) -> list[IndexingResult]:
|
||||
"""Index multiple documents in batch"""
|
||||
results = []
|
||||
|
||||
for doc_info in documents:
|
||||
result = await self.index_document(
|
||||
doc_info["path"], collection_name, doc_info["metadata"]
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def get_collection_stats(self, collection_name: str) -> dict[str, Any]:
|
||||
"""Get statistics for a collection"""
|
||||
try:
|
||||
collection_info = self.qdrant_client.get_collection(collection_name)
|
||||
return {
|
||||
"name": collection_name,
|
||||
"vectors_count": collection_info.vectors_count,
|
||||
"indexed_vectors_count": collection_info.indexed_vectors_count,
|
||||
"points_count": collection_info.points_count,
|
||||
"segments_count": collection_info.segments_count,
|
||||
"status": collection_info.status,
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get stats for {collection_name}: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
77
libs/rag/pii_detector.py
Normal file
77
libs/rag/pii_detector.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""PII detection and de-identification utilities."""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PIIDetector:
|
||||
"""PII detection and de-identification utilities"""
|
||||
|
||||
# Regex patterns for common PII
|
||||
PII_PATTERNS = {
|
||||
"uk_ni_number": r"\b[A-CEGHJ-PR-TW-Z]{2}\d{6}[A-D]\b",
|
||||
"uk_utr": r"\b\d{10}\b",
|
||||
"uk_postcode": r"\b[A-Z]{1,2}\d[A-Z0-9]?\s*\d[A-Z]{2}\b",
|
||||
"uk_sort_code": r"\b\d{2}-\d{2}-\d{2}\b",
|
||||
"uk_account_number": r"\b\d{8}\b",
|
||||
"email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
||||
"phone": r"\b(?:\+44|0)\d{10,11}\b",
|
||||
"iban": r"\bGB\d{2}[A-Z]{4}\d{14}\b",
|
||||
"amount": r"£\d{1,3}(?:,\d{3})*(?:\.\d{2})?",
|
||||
"date": r"\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.compiled_patterns = {
|
||||
name: re.compile(pattern, re.IGNORECASE)
|
||||
for name, pattern in self.PII_PATTERNS.items()
|
||||
}
|
||||
|
||||
def detect_pii(self, text: str) -> list[dict[str, Any]]:
|
||||
"""Detect PII in text and return matches with positions"""
|
||||
matches = []
|
||||
|
||||
for pii_type, pattern in self.compiled_patterns.items():
|
||||
for match in pattern.finditer(text):
|
||||
matches.append(
|
||||
{
|
||||
"type": pii_type,
|
||||
"value": match.group(),
|
||||
"start": match.start(),
|
||||
"end": match.end(),
|
||||
"placeholder": self._generate_placeholder(
|
||||
pii_type, match.group()
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return sorted(matches, key=lambda x: x["start"])
|
||||
|
||||
def de_identify_text(self, text: str) -> tuple[str, dict[str, str]]:
|
||||
"""De-identify text by replacing PII with placeholders"""
|
||||
pii_matches = self.detect_pii(text)
|
||||
pii_mapping = {}
|
||||
|
||||
# Replace PII from end to start to maintain positions
|
||||
de_identified = text
|
||||
for match in reversed(pii_matches):
|
||||
placeholder = match["placeholder"]
|
||||
pii_mapping[placeholder] = match["value"]
|
||||
de_identified = (
|
||||
de_identified[: match["start"]]
|
||||
+ placeholder
|
||||
+ de_identified[match["end"] :]
|
||||
)
|
||||
|
||||
return de_identified, pii_mapping
|
||||
|
||||
def _generate_placeholder(self, pii_type: str, value: str) -> str:
|
||||
"""Generate consistent placeholder for PII value"""
|
||||
# Create hash of the value for consistent placeholders
|
||||
value_hash = hashlib.md5(value.encode()).hexdigest()[:8]
|
||||
return f"[{pii_type.upper()}_{value_hash}]"
|
||||
|
||||
def has_pii(self, text: str) -> bool:
|
||||
"""Check if text contains any PII"""
|
||||
return len(self.detect_pii(text)) > 0
|
||||
235
libs/rag/retriever.py
Normal file
235
libs/rag/retriever.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""High-level RAG retrieval with reranking and KG fusion."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
SparseVector,
|
||||
)
|
||||
|
||||
from .collection_manager import QdrantCollectionManager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RAGRetriever: # pylint: disable=too-few-public-methods
|
||||
"""High-level RAG retrieval with reranking and KG fusion"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_client: QdrantClient,
|
||||
neo4j_client: Any = None,
|
||||
reranker_model: str | None = None,
|
||||
) -> None:
|
||||
self.collection_manager = QdrantCollectionManager(qdrant_client)
|
||||
self.neo4j_client = neo4j_client
|
||||
self.reranker_model = reranker_model
|
||||
|
||||
async def search( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
|
||||
self,
|
||||
query: str,
|
||||
collections: list[str],
|
||||
dense_vector: list[float],
|
||||
sparse_vector: SparseVector,
|
||||
k: int = 10,
|
||||
alpha: float = 0.5,
|
||||
beta: float = 0.3, # pylint: disable=unused-argument
|
||||
gamma: float = 0.2, # pylint: disable=unused-argument
|
||||
tax_year: str | None = None,
|
||||
jurisdiction: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Perform comprehensive RAG search with KG fusion"""
|
||||
|
||||
# Build filter conditions
|
||||
filter_conditions = self._build_filter(tax_year, jurisdiction)
|
||||
|
||||
# Search each collection
|
||||
all_chunks = []
|
||||
for collection in collections:
|
||||
chunks = await self.collection_manager.hybrid_search(
|
||||
collection_name=collection,
|
||||
dense_vector=dense_vector,
|
||||
sparse_vector=sparse_vector,
|
||||
limit=k,
|
||||
alpha=alpha,
|
||||
filter_conditions=filter_conditions,
|
||||
)
|
||||
|
||||
# Add collection info to chunks
|
||||
for chunk in chunks:
|
||||
chunk["collection"] = collection
|
||||
|
||||
all_chunks.extend(chunks)
|
||||
|
||||
# Re-rank if reranker is available
|
||||
if self.reranker_model and len(all_chunks) > k:
|
||||
all_chunks = await self._rerank_chunks(query, all_chunks, k)
|
||||
|
||||
# Sort by score and take top k
|
||||
all_chunks.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_chunks = all_chunks[:k]
|
||||
|
||||
# Get KG hints if Neo4j client is available
|
||||
kg_hints = []
|
||||
if self.neo4j_client:
|
||||
kg_hints = await self._get_kg_hints(query, top_chunks)
|
||||
|
||||
# Extract citations
|
||||
citations = self._extract_citations(top_chunks)
|
||||
|
||||
# Calculate calibrated confidence
|
||||
calibrated_confidence = self._calculate_confidence(top_chunks)
|
||||
|
||||
return {
|
||||
"chunks": top_chunks,
|
||||
"citations": citations,
|
||||
"kg_hints": kg_hints,
|
||||
"calibrated_confidence": calibrated_confidence,
|
||||
}
|
||||
|
||||
def _build_filter(
|
||||
self, tax_year: str | None = None, jurisdiction: str | None = None
|
||||
) -> Filter | None:
|
||||
"""Build Qdrant filter conditions"""
|
||||
conditions = []
|
||||
|
||||
if jurisdiction:
|
||||
conditions.append(
|
||||
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
|
||||
)
|
||||
|
||||
if tax_year:
|
||||
conditions.append(
|
||||
FieldCondition(key="tax_years", match=MatchValue(value=tax_year))
|
||||
)
|
||||
|
||||
# Always require PII-free content
|
||||
conditions.append(FieldCondition(key="pii_free", match=MatchValue(value=True)))
|
||||
|
||||
if conditions:
|
||||
return Filter(must=conditions) # type: ignore
|
||||
return None
|
||||
|
||||
async def _rerank_chunks( # pylint: disable=unused-argument
|
||||
self, query: str, chunks: list[dict[str, Any]], k: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Rerank chunks using cross-encoder model"""
|
||||
try:
|
||||
# This would integrate with a reranking service
|
||||
# For now, return original chunks
|
||||
logger.debug("Reranking not implemented, returning original order")
|
||||
return chunks
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.warning("Reranking failed, using original order", error=str(e))
|
||||
return chunks
|
||||
|
||||
async def _get_kg_hints( # pylint: disable=unused-argument
|
||||
self, query: str, chunks: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get knowledge graph hints related to the query"""
|
||||
try:
|
||||
# Extract potential rule/formula references from chunks
|
||||
hints = []
|
||||
|
||||
for chunk in chunks:
|
||||
payload = chunk.get("payload", {})
|
||||
topic_tags = payload.get("topic_tags", [])
|
||||
|
||||
# Look for tax rules related to the topics
|
||||
if topic_tags and self.neo4j_client:
|
||||
kg_query = """
|
||||
MATCH (r:Rule)-[:APPLIES_TO]->(topic)
|
||||
WHERE topic.name IN $topics
|
||||
AND r.retracted_at IS NULL
|
||||
RETURN r.rule_id as rule_id,
|
||||
r.formula as formula_id,
|
||||
collect(id(topic)) as node_ids
|
||||
LIMIT 5
|
||||
"""
|
||||
|
||||
kg_results = await self.neo4j_client.run_query(
|
||||
kg_query, {"topics": topic_tags}
|
||||
)
|
||||
|
||||
for result in kg_results:
|
||||
hints.append(
|
||||
{
|
||||
"rule_id": result["rule_id"],
|
||||
"formula_id": result["formula_id"],
|
||||
"node_ids": result["node_ids"],
|
||||
}
|
||||
)
|
||||
|
||||
return hints[:5] # Limit to top 5 hints
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.warning("Failed to get KG hints", error=str(e))
|
||||
return []
|
||||
|
||||
def _extract_citations(self, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Extract citation information from chunks"""
|
||||
citations = []
|
||||
seen_docs = set()
|
||||
|
||||
for chunk in chunks:
|
||||
payload = chunk.get("payload", {})
|
||||
|
||||
# Extract document reference
|
||||
doc_id = payload.get("doc_id")
|
||||
url = payload.get("url")
|
||||
section_id = payload.get("section_id")
|
||||
page = payload.get("page")
|
||||
bbox = payload.get("bbox")
|
||||
|
||||
# Create citation key to avoid duplicates
|
||||
citation_key = doc_id or url
|
||||
if citation_key and citation_key not in seen_docs:
|
||||
citation = {}
|
||||
|
||||
if doc_id:
|
||||
citation["doc_id"] = doc_id
|
||||
if url:
|
||||
citation["url"] = url
|
||||
if section_id:
|
||||
citation["section_id"] = section_id
|
||||
if page:
|
||||
citation["page"] = page
|
||||
if bbox:
|
||||
citation["bbox"] = bbox
|
||||
|
||||
citations.append(citation)
|
||||
seen_docs.add(citation_key)
|
||||
|
||||
return citations
|
||||
|
||||
def _calculate_confidence(self, chunks: list[dict[str, Any]]) -> float:
|
||||
"""Calculate calibrated confidence score"""
|
||||
if not chunks:
|
||||
return 0.0
|
||||
|
||||
# Simple confidence calculation based on top scores
|
||||
top_scores = [chunk["score"] for chunk in chunks[:3]]
|
||||
|
||||
if not top_scores:
|
||||
return 0.0
|
||||
|
||||
# Average of top 3 scores with diminishing returns
|
||||
weights = [0.5, 0.3, 0.2]
|
||||
weighted_score = sum(
|
||||
score * weight
|
||||
for score, weight in zip(
|
||||
top_scores, weights[: len(top_scores)], strict=False
|
||||
)
|
||||
)
|
||||
|
||||
# Apply calibration (simple temperature scaling)
|
||||
# In production, this would use learned calibration parameters
|
||||
temperature = 1.2
|
||||
calibrated = weighted_score / temperature
|
||||
|
||||
return min(max(calibrated, 0.0), 1.0) # type: ignore
|
||||
44
libs/rag/utils.py
Normal file
44
libs/rag/utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Coverage-specific RAG utility functions."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
from libs.schemas.coverage.evaluation import Citation
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def rag_search_for_citations(
|
||||
rag_client: Any, query: str, filters: dict[str, Any] | None = None
|
||||
) -> list["Citation"]:
|
||||
"""Search for citations using RAG with PII-free filtering"""
|
||||
|
||||
try:
|
||||
# Ensure PII-free filter is always applied
|
||||
search_filters = filters or {}
|
||||
search_filters["pii_free"] = True
|
||||
|
||||
# This would integrate with the actual RAG retrieval system
|
||||
# For now, return a placeholder implementation
|
||||
logger.debug(
|
||||
"RAG citation search called",
|
||||
query=query,
|
||||
filters=search_filters,
|
||||
rag_client_available=rag_client is not None,
|
||||
)
|
||||
|
||||
# Placeholder citations - in production this would call the RAG system
|
||||
citations = [
|
||||
Citation(
|
||||
doc_id=f"RAG-{query.replace(' ', '-')[:20]}",
|
||||
locator="Retrieved via RAG search",
|
||||
url=f"https://guidance.example.com/search?q={query}",
|
||||
)
|
||||
]
|
||||
|
||||
return citations
|
||||
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error("RAG citation search failed", query=query, error=str(e))
|
||||
return []
|
||||
Reference in New Issue
Block a user