Files
ai-tax-agent/retrieval/indexer.py
harkon b324ff09ef
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Initial commit
2025-10-11 08:41:36 +01:00

508 lines
18 KiB
Python

# FILE: retrieval/indexer.py
# De-identify -> embed dense/sparse -> upsert to Qdrant with payload
import json
import logging
import re
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
import numpy as np
import spacy
import torch
import yaml
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, SparseVector, VectorParams
from sentence_transformers import SentenceTransformer
from .chunker import DocumentChunker
from .pii_detector import PIIDetector, PIIRedactor
@dataclass
class IndexingResult:
collection_name: str
points_indexed: int
points_updated: int
points_failed: int
processing_time: float
errors: list[str]
class RAGIndexer:
def __init__(self, config_path: str, qdrant_url: str = "http://localhost:6333"):
with open(config_path) as f:
self.config = yaml.safe_load(f)
self.qdrant_client = QdrantClient(url=qdrant_url)
self.chunker = DocumentChunker(config_path)
self.pii_detector = PIIDetector()
self.pii_redactor = PIIRedactor()
# Initialize embedding models
self.dense_model = SentenceTransformer(
self.config.get("embedding_model", "bge-small-en-v1.5")
)
# Initialize sparse model (BM25/SPLADE)
self.sparse_model = self._init_sparse_model()
# Initialize NLP pipeline
self.nlp = spacy.load("en_core_web_sm")
self.logger = logging.getLogger(__name__)
def _init_sparse_model(self):
"""Initialize sparse embedding model (BM25 or SPLADE)"""
sparse_config = self.config.get("sparse_model", {})
model_type = sparse_config.get("type", "bm25")
if model_type == "bm25":
from rank_bm25 import BM25Okapi
return BM25Okapi
elif model_type == "splade":
from transformers import AutoModelForMaskedLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"naver/splade-cocondenser-ensembledistil"
)
model = AutoModelForMaskedLM.from_pretrained(
"naver/splade-cocondenser-ensembledistil"
)
return {"tokenizer": tokenizer, "model": model}
else:
raise ValueError(f"Unsupported sparse model type: {model_type}")
async def index_document(
self, document_path: str, collection_name: str, metadata: dict[str, Any]
) -> IndexingResult:
"""Index a single document into the specified collection"""
start_time = datetime.now()
errors = []
points_indexed = 0
points_updated = 0
points_failed = 0
try:
# Step 1: Chunk the document
chunks = await self.chunker.chunk_document(document_path, metadata)
# Step 2: Process each chunk
points = []
for chunk in chunks:
try:
point = await self._process_chunk(chunk, collection_name, metadata)
if point:
points.append(point)
except Exception as e:
self.logger.error(
f"Failed to process chunk {chunk.get('id', 'unknown')}: {str(e)}"
)
errors.append(f"Chunk processing error: {str(e)}")
points_failed += 1
# Step 3: Upsert to Qdrant
if points:
try:
operation_info = self.qdrant_client.upsert(
collection_name=collection_name, points=points, wait=True
)
points_indexed = len(points)
self.logger.info(
f"Indexed {points_indexed} points to {collection_name}"
)
except Exception as e:
self.logger.error(f"Failed to upsert to Qdrant: {str(e)}")
errors.append(f"Qdrant upsert error: {str(e)}")
points_failed += len(points)
points_indexed = 0
except Exception as e:
self.logger.error(f"Document indexing failed: {str(e)}")
errors.append(f"Document indexing error: {str(e)}")
processing_time = (datetime.now() - start_time).total_seconds()
return IndexingResult(
collection_name=collection_name,
points_indexed=points_indexed,
points_updated=points_updated,
points_failed=points_failed,
processing_time=processing_time,
errors=errors,
)
async def _process_chunk(
self, chunk: dict[str, Any], collection_name: str, base_metadata: dict[str, Any]
) -> PointStruct | None:
"""Process a single chunk: de-identify, embed, create point"""
# Step 1: De-identify PII
content = chunk["content"]
pii_detected = self.pii_detector.detect(content)
if pii_detected:
# Redact PII and create mapping
redacted_content, pii_mapping = self.pii_redactor.redact(
content, pii_detected
)
# Store PII mapping securely (not in vector DB)
await self._store_pii_mapping(chunk["id"], pii_mapping)
# Log PII detection for audit
self.logger.warning(
f"PII detected in chunk {chunk['id']}: {[p['type'] for p in pii_detected]}"
)
else:
redacted_content = content
# Verify no PII remains
if not self._verify_pii_free(redacted_content):
self.logger.error(f"PII verification failed for chunk {chunk['id']}")
return None
# Step 2: Generate embeddings
try:
dense_vector = await self._generate_dense_embedding(redacted_content)
sparse_vector = await self._generate_sparse_embedding(redacted_content)
except Exception as e:
self.logger.error(
f"Embedding generation failed for chunk {chunk['id']}: {str(e)}"
)
return None
# Step 3: Prepare metadata
payload = self._prepare_payload(chunk, base_metadata, redacted_content)
payload["pii_free"] = True # Verified above
# Step 4: Create point
point = PointStruct(
id=chunk["id"],
vector={"dense": dense_vector, "sparse": sparse_vector},
payload=payload,
)
return point
async def _generate_dense_embedding(self, text: str) -> list[float]:
"""Generate dense vector embedding"""
try:
# Use sentence transformer for dense embeddings
embedding = self.dense_model.encode(text, normalize_embeddings=True)
return embedding.tolist()
except Exception as e:
self.logger.error(f"Dense embedding generation failed: {str(e)}")
raise
async def _generate_sparse_embedding(self, text: str) -> SparseVector:
"""Generate sparse vector embedding (BM25 or SPLADE)"""
vector = SparseVector(indices=[], values=[])
try:
sparse_config = self.config.get("sparse_model", {})
model_type = sparse_config.get("type", "bm25")
if model_type == "bm25":
# Simple BM25-style sparse representation
doc = self.nlp(text)
tokens = [
token.lemma_.lower()
for token in doc
if not token.is_stop and not token.is_punct
]
# Create term frequency vector
term_freq = {}
for token in tokens:
term_freq[token] = term_freq.get(token, 0) + 1
# Convert to sparse vector format
vocab_size = sparse_config.get("vocab_size", 30000)
indices = []
values = []
for term, freq in term_freq.items():
# Simple hash-based vocabulary mapping
term_id = hash(term) % vocab_size
indices.append(term_id)
values.append(float(freq))
vector = SparseVector(indices=indices, values=values)
elif model_type == "splade":
# SPLADE sparse embeddings
tokenizer = self.sparse_model["tokenizer"]
model = self.sparse_model["model"]
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=512
)
outputs = model(**inputs)
# Extract sparse representation
logits = outputs.logits.squeeze()
sparse_rep = torch.relu(logits).detach().numpy()
# Convert to sparse format
indices = np.nonzero(sparse_rep)[0].tolist()
values = sparse_rep[indices].tolist()
vector = SparseVector(indices=indices, values=values)
return vector
except Exception as e:
self.logger.error(f"Sparse embedding generation failed: {str(e)}")
# Return empty sparse vector as fallback
return vector
def _prepare_payload(
self, chunk: dict[str, Any], base_metadata: dict[str, Any], content: str
) -> dict[str, Any]:
"""Prepare payload metadata for the chunk"""
# Start with base metadata
payload = base_metadata.copy()
# Add chunk-specific metadata
payload.update(
{
"document_id": chunk.get("document_id"),
"content": content, # De-identified content
"chunk_index": chunk.get("chunk_index", 0),
"total_chunks": chunk.get("total_chunks", 1),
"page_numbers": chunk.get("page_numbers", []),
"section_hierarchy": chunk.get("section_hierarchy", []),
"has_calculations": self._detect_calculations(content),
"has_forms": self._detect_form_references(content),
"confidence_score": chunk.get("confidence_score", 1.0),
"created_at": datetime.now().isoformat(),
"version": self.config.get("version", "1.0"),
}
)
# Extract and add topic tags
topic_tags = self._extract_topic_tags(content)
if topic_tags:
payload["topic_tags"] = topic_tags
# Add content analysis
payload.update(self._analyze_content(content))
return payload
def _detect_calculations(self, text: str) -> bool:
"""Detect if text contains calculations or formulas"""
calculation_patterns = [
r"\d+\s*[+\-*/]\s*\d+",
r"£\d+(?:,\d{3})*(?:\.\d{2})?",
r"\d+(?:\.\d+)?%",
r"total|sum|calculate|compute",
r"rate|threshold|allowance|relief",
]
for pattern in calculation_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def _detect_form_references(self, text: str) -> bool:
"""Detect references to tax forms"""
form_patterns = [
r"SA\d{3}",
r"P\d{2}",
r"CT\d{3}",
r"VAT\d{3}",
r"form\s+\w+",
r"schedule\s+\w+",
]
for pattern in form_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def _extract_topic_tags(self, text: str) -> list[str]:
"""Extract topic tags from content"""
topic_keywords = {
"employment": [
"PAYE",
"payslip",
"P60",
"employment",
"salary",
"wages",
"employer",
],
"self_employment": [
"self-employed",
"business",
"turnover",
"expenses",
"profit",
"loss",
],
"property": ["rental", "property", "landlord", "FHL", "mortgage", "rent"],
"dividends": ["dividend", "shares", "distribution", "corporation tax"],
"capital_gains": ["capital gains", "disposal", "acquisition", "CGT"],
"pensions": ["pension", "retirement", "SIPP", "occupational"],
"savings": ["interest", "savings", "ISA", "bonds"],
"inheritance": ["inheritance", "IHT", "estate", "probate"],
"vat": ["VAT", "value added tax", "registration", "return"],
}
tags = []
text_lower = text.lower()
for topic, keywords in topic_keywords.items():
for keyword in keywords:
if keyword.lower() in text_lower:
tags.append(topic)
break
return list(set(tags)) # Remove duplicates
def _analyze_content(self, text: str) -> dict[str, Any]:
"""Analyze content for additional metadata"""
doc = self.nlp(text)
return {
"word_count": len([token for token in doc if not token.is_space]),
"sentence_count": len(list(doc.sents)),
"entity_count": len(doc.ents),
"complexity_score": self._calculate_complexity(doc),
"language": doc.lang_ if hasattr(doc, "lang_") else "en",
}
def _calculate_complexity(self, doc: dict) -> float:
"""Calculate text complexity score"""
if not doc:
return 0.0
# Simple complexity based on sentence length and vocabulary
avg_sentence_length = sum(len(sent) for sent in doc.sents) / len(
list(doc.sents)
)
unique_words = len(set(token.lemma_.lower() for token in doc if token.is_alpha))
total_words = len([token for token in doc if token.is_alpha])
vocabulary_diversity = unique_words / total_words if total_words > 0 else 0
# Normalize to 0-1 scale
complexity = min(1.0, (avg_sentence_length / 20.0 + vocabulary_diversity) / 2.0)
return complexity
def _verify_pii_free(self, text: str) -> bool:
"""Verify that text contains no PII"""
# Quick verification using patterns
pii_patterns = [
r"\b[A-Z]{2}\d{6}[A-D]\b", # NI number
r"\b\d{10}\b", # UTR
r"\b[A-Z]{2}\d{2}[A-Z]{4}\d{14}\b", # IBAN
r"\b\d{2}-\d{2}-\d{2}\b", # Sort code
r"\b[A-Z]{1,2}\d[A-Z\d]?\s*\d[A-Z]{2}\b", # Postcode
r"\b[\w\.-]+@[\w\.-]+\.\w+\b", # Email
r"\b(?:\+44|0)\d{10,11}\b", # Phone
]
for pattern in pii_patterns:
if re.search(pattern, text):
return False
return True
async def _store_pii_mapping(
self, chunk_id: str, pii_mapping: dict[str, Any]
) -> None:
"""Store PII mapping in secure client data store (not in vector DB)"""
# This would integrate with the secure PostgreSQL client data store
# For now, just log the mapping securely
self.logger.info(
f"PII mapping stored for chunk {chunk_id}: {len(pii_mapping)} items"
)
async def create_collections(self) -> None:
"""Create all Qdrant collections based on configuration"""
collections_config_path = Path(__file__).parent / "qdrant_collections.json"
with open(collections_config_path) as f:
collections_config = json.load(f)
for collection_config in collections_config["collections"]:
collection_name = collection_config["name"]
try:
# Check if collection exists
try:
self.qdrant_client.get_collection(collection_name)
self.logger.info(f"Collection {collection_name} already exists")
continue
except:
pass # Collection doesn't exist, create it
# Create collection
vectors_config = {}
# Dense vector configuration
if "dense" in collection_config:
vectors_config["dense"] = VectorParams(
size=collection_config["dense"]["size"],
distance=Distance.COSINE,
)
# Sparse vector configuration
if collection_config.get("sparse", False):
vectors_config["sparse"] = VectorParams(
size=30000, # Vocabulary size for sparse vectors
distance=Distance.DOT,
on_disk=True,
)
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
**collection_config.get("indexing_config", {}),
)
self.logger.info(f"Created collection: {collection_name}")
except Exception as e:
self.logger.error(
f"Failed to create collection {collection_name}: {str(e)}"
)
raise
async def batch_index(
self, documents: list[dict[str, Any]], collection_name: str
) -> list[IndexingResult]:
"""Index multiple documents in batch"""
results = []
for doc_info in documents:
result = await self.index_document(
doc_info["path"], collection_name, doc_info["metadata"]
)
results.append(result)
return results
def get_collection_stats(self, collection_name: str) -> dict[str, Any]:
"""Get statistics for a collection"""
try:
collection_info = self.qdrant_client.get_collection(collection_name)
return {
"name": collection_name,
"vectors_count": collection_info.vectors_count,
"indexed_vectors_count": collection_info.indexed_vectors_count,
"points_count": collection_info.points_count,
"segments_count": collection_info.segments_count,
"status": collection_info.status,
}
except Exception as e:
self.logger.error(f"Failed to get stats for {collection_name}: {str(e)}")
return {"error": str(e)}