# FILE: retrieval/indexer.py # De-identify -> embed dense/sparse -> upsert to Qdrant with payload import json import logging import re from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Any import numpy as np import spacy import torch import yaml from qdrant_client import QdrantClient from qdrant_client.models import Distance, PointStruct, SparseVector, VectorParams from sentence_transformers import SentenceTransformer from .chunker import DocumentChunker from .pii_detector import PIIDetector, PIIRedactor @dataclass class IndexingResult: collection_name: str points_indexed: int points_updated: int points_failed: int processing_time: float errors: list[str] class RAGIndexer: def __init__(self, config_path: str, qdrant_url: str = "http://localhost:6333"): with open(config_path) as f: self.config = yaml.safe_load(f) self.qdrant_client = QdrantClient(url=qdrant_url) self.chunker = DocumentChunker(config_path) self.pii_detector = PIIDetector() self.pii_redactor = PIIRedactor() # Initialize embedding models self.dense_model = SentenceTransformer( self.config.get("embedding_model", "bge-small-en-v1.5") ) # Initialize sparse model (BM25/SPLADE) self.sparse_model = self._init_sparse_model() # Initialize NLP pipeline self.nlp = spacy.load("en_core_web_sm") self.logger = logging.getLogger(__name__) def _init_sparse_model(self): """Initialize sparse embedding model (BM25 or SPLADE)""" sparse_config = self.config.get("sparse_model", {}) model_type = sparse_config.get("type", "bm25") if model_type == "bm25": from rank_bm25 import BM25Okapi return BM25Okapi elif model_type == "splade": from transformers import AutoModelForMaskedLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( "naver/splade-cocondenser-ensembledistil" ) model = AutoModelForMaskedLM.from_pretrained( "naver/splade-cocondenser-ensembledistil" ) return {"tokenizer": tokenizer, "model": model} else: raise ValueError(f"Unsupported sparse model type: {model_type}") async def index_document( self, document_path: str, collection_name: str, metadata: dict[str, Any] ) -> IndexingResult: """Index a single document into the specified collection""" start_time = datetime.now() errors = [] points_indexed = 0 points_updated = 0 points_failed = 0 try: # Step 1: Chunk the document chunks = await self.chunker.chunk_document(document_path, metadata) # Step 2: Process each chunk points = [] for chunk in chunks: try: point = await self._process_chunk(chunk, collection_name, metadata) if point: points.append(point) except Exception as e: self.logger.error( f"Failed to process chunk {chunk.get('id', 'unknown')}: {str(e)}" ) errors.append(f"Chunk processing error: {str(e)}") points_failed += 1 # Step 3: Upsert to Qdrant if points: try: operation_info = self.qdrant_client.upsert( collection_name=collection_name, points=points, wait=True ) points_indexed = len(points) self.logger.info( f"Indexed {points_indexed} points to {collection_name}" ) except Exception as e: self.logger.error(f"Failed to upsert to Qdrant: {str(e)}") errors.append(f"Qdrant upsert error: {str(e)}") points_failed += len(points) points_indexed = 0 except Exception as e: self.logger.error(f"Document indexing failed: {str(e)}") errors.append(f"Document indexing error: {str(e)}") processing_time = (datetime.now() - start_time).total_seconds() return IndexingResult( collection_name=collection_name, points_indexed=points_indexed, points_updated=points_updated, points_failed=points_failed, processing_time=processing_time, errors=errors, ) async def _process_chunk( self, chunk: dict[str, Any], collection_name: str, base_metadata: dict[str, Any] ) -> PointStruct | None: """Process a single chunk: de-identify, embed, create point""" # Step 1: De-identify PII content = chunk["content"] pii_detected = self.pii_detector.detect(content) if pii_detected: # Redact PII and create mapping redacted_content, pii_mapping = self.pii_redactor.redact( content, pii_detected ) # Store PII mapping securely (not in vector DB) await self._store_pii_mapping(chunk["id"], pii_mapping) # Log PII detection for audit self.logger.warning( f"PII detected in chunk {chunk['id']}: {[p['type'] for p in pii_detected]}" ) else: redacted_content = content # Verify no PII remains if not self._verify_pii_free(redacted_content): self.logger.error(f"PII verification failed for chunk {chunk['id']}") return None # Step 2: Generate embeddings try: dense_vector = await self._generate_dense_embedding(redacted_content) sparse_vector = await self._generate_sparse_embedding(redacted_content) except Exception as e: self.logger.error( f"Embedding generation failed for chunk {chunk['id']}: {str(e)}" ) return None # Step 3: Prepare metadata payload = self._prepare_payload(chunk, base_metadata, redacted_content) payload["pii_free"] = True # Verified above # Step 4: Create point point = PointStruct( id=chunk["id"], vector={"dense": dense_vector, "sparse": sparse_vector}, payload=payload, ) return point async def _generate_dense_embedding(self, text: str) -> list[float]: """Generate dense vector embedding""" try: # Use sentence transformer for dense embeddings embedding = self.dense_model.encode(text, normalize_embeddings=True) return embedding.tolist() except Exception as e: self.logger.error(f"Dense embedding generation failed: {str(e)}") raise async def _generate_sparse_embedding(self, text: str) -> SparseVector: """Generate sparse vector embedding (BM25 or SPLADE)""" vector = SparseVector(indices=[], values=[]) try: sparse_config = self.config.get("sparse_model", {}) model_type = sparse_config.get("type", "bm25") if model_type == "bm25": # Simple BM25-style sparse representation doc = self.nlp(text) tokens = [ token.lemma_.lower() for token in doc if not token.is_stop and not token.is_punct ] # Create term frequency vector term_freq = {} for token in tokens: term_freq[token] = term_freq.get(token, 0) + 1 # Convert to sparse vector format vocab_size = sparse_config.get("vocab_size", 30000) indices = [] values = [] for term, freq in term_freq.items(): # Simple hash-based vocabulary mapping term_id = hash(term) % vocab_size indices.append(term_id) values.append(float(freq)) vector = SparseVector(indices=indices, values=values) elif model_type == "splade": # SPLADE sparse embeddings tokenizer = self.sparse_model["tokenizer"] model = self.sparse_model["model"] inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512 ) outputs = model(**inputs) # Extract sparse representation logits = outputs.logits.squeeze() sparse_rep = torch.relu(logits).detach().numpy() # Convert to sparse format indices = np.nonzero(sparse_rep)[0].tolist() values = sparse_rep[indices].tolist() vector = SparseVector(indices=indices, values=values) return vector except Exception as e: self.logger.error(f"Sparse embedding generation failed: {str(e)}") # Return empty sparse vector as fallback return vector def _prepare_payload( self, chunk: dict[str, Any], base_metadata: dict[str, Any], content: str ) -> dict[str, Any]: """Prepare payload metadata for the chunk""" # Start with base metadata payload = base_metadata.copy() # Add chunk-specific metadata payload.update( { "document_id": chunk.get("document_id"), "content": content, # De-identified content "chunk_index": chunk.get("chunk_index", 0), "total_chunks": chunk.get("total_chunks", 1), "page_numbers": chunk.get("page_numbers", []), "section_hierarchy": chunk.get("section_hierarchy", []), "has_calculations": self._detect_calculations(content), "has_forms": self._detect_form_references(content), "confidence_score": chunk.get("confidence_score", 1.0), "created_at": datetime.now().isoformat(), "version": self.config.get("version", "1.0"), } ) # Extract and add topic tags topic_tags = self._extract_topic_tags(content) if topic_tags: payload["topic_tags"] = topic_tags # Add content analysis payload.update(self._analyze_content(content)) return payload def _detect_calculations(self, text: str) -> bool: """Detect if text contains calculations or formulas""" calculation_patterns = [ r"\d+\s*[+\-*/]\s*\d+", r"£\d+(?:,\d{3})*(?:\.\d{2})?", r"\d+(?:\.\d+)?%", r"total|sum|calculate|compute", r"rate|threshold|allowance|relief", ] for pattern in calculation_patterns: if re.search(pattern, text, re.IGNORECASE): return True return False def _detect_form_references(self, text: str) -> bool: """Detect references to tax forms""" form_patterns = [ r"SA\d{3}", r"P\d{2}", r"CT\d{3}", r"VAT\d{3}", r"form\s+\w+", r"schedule\s+\w+", ] for pattern in form_patterns: if re.search(pattern, text, re.IGNORECASE): return True return False def _extract_topic_tags(self, text: str) -> list[str]: """Extract topic tags from content""" topic_keywords = { "employment": [ "PAYE", "payslip", "P60", "employment", "salary", "wages", "employer", ], "self_employment": [ "self-employed", "business", "turnover", "expenses", "profit", "loss", ], "property": ["rental", "property", "landlord", "FHL", "mortgage", "rent"], "dividends": ["dividend", "shares", "distribution", "corporation tax"], "capital_gains": ["capital gains", "disposal", "acquisition", "CGT"], "pensions": ["pension", "retirement", "SIPP", "occupational"], "savings": ["interest", "savings", "ISA", "bonds"], "inheritance": ["inheritance", "IHT", "estate", "probate"], "vat": ["VAT", "value added tax", "registration", "return"], } tags = [] text_lower = text.lower() for topic, keywords in topic_keywords.items(): for keyword in keywords: if keyword.lower() in text_lower: tags.append(topic) break return list(set(tags)) # Remove duplicates def _analyze_content(self, text: str) -> dict[str, Any]: """Analyze content for additional metadata""" doc = self.nlp(text) return { "word_count": len([token for token in doc if not token.is_space]), "sentence_count": len(list(doc.sents)), "entity_count": len(doc.ents), "complexity_score": self._calculate_complexity(doc), "language": doc.lang_ if hasattr(doc, "lang_") else "en", } def _calculate_complexity(self, doc: dict) -> float: """Calculate text complexity score""" if not doc: return 0.0 # Simple complexity based on sentence length and vocabulary avg_sentence_length = sum(len(sent) for sent in doc.sents) / len( list(doc.sents) ) unique_words = len(set(token.lemma_.lower() for token in doc if token.is_alpha)) total_words = len([token for token in doc if token.is_alpha]) vocabulary_diversity = unique_words / total_words if total_words > 0 else 0 # Normalize to 0-1 scale complexity = min(1.0, (avg_sentence_length / 20.0 + vocabulary_diversity) / 2.0) return complexity def _verify_pii_free(self, text: str) -> bool: """Verify that text contains no PII""" # Quick verification using patterns pii_patterns = [ r"\b[A-Z]{2}\d{6}[A-D]\b", # NI number r"\b\d{10}\b", # UTR r"\b[A-Z]{2}\d{2}[A-Z]{4}\d{14}\b", # IBAN r"\b\d{2}-\d{2}-\d{2}\b", # Sort code r"\b[A-Z]{1,2}\d[A-Z\d]?\s*\d[A-Z]{2}\b", # Postcode r"\b[\w\.-]+@[\w\.-]+\.\w+\b", # Email r"\b(?:\+44|0)\d{10,11}\b", # Phone ] for pattern in pii_patterns: if re.search(pattern, text): return False return True async def _store_pii_mapping( self, chunk_id: str, pii_mapping: dict[str, Any] ) -> None: """Store PII mapping in secure client data store (not in vector DB)""" # This would integrate with the secure PostgreSQL client data store # For now, just log the mapping securely self.logger.info( f"PII mapping stored for chunk {chunk_id}: {len(pii_mapping)} items" ) async def create_collections(self) -> None: """Create all Qdrant collections based on configuration""" collections_config_path = Path(__file__).parent / "qdrant_collections.json" with open(collections_config_path) as f: collections_config = json.load(f) for collection_config in collections_config["collections"]: collection_name = collection_config["name"] try: # Check if collection exists try: self.qdrant_client.get_collection(collection_name) self.logger.info(f"Collection {collection_name} already exists") continue except: pass # Collection doesn't exist, create it # Create collection vectors_config = {} # Dense vector configuration if "dense" in collection_config: vectors_config["dense"] = VectorParams( size=collection_config["dense"]["size"], distance=Distance.COSINE, ) # Sparse vector configuration if collection_config.get("sparse", False): vectors_config["sparse"] = VectorParams( size=30000, # Vocabulary size for sparse vectors distance=Distance.DOT, on_disk=True, ) self.qdrant_client.create_collection( collection_name=collection_name, vectors_config=vectors_config, **collection_config.get("indexing_config", {}), ) self.logger.info(f"Created collection: {collection_name}") except Exception as e: self.logger.error( f"Failed to create collection {collection_name}: {str(e)}" ) raise async def batch_index( self, documents: list[dict[str, Any]], collection_name: str ) -> list[IndexingResult]: """Index multiple documents in batch""" results = [] for doc_info in documents: result = await self.index_document( doc_info["path"], collection_name, doc_info["metadata"] ) results.append(result) return results def get_collection_stats(self, collection_name: str) -> dict[str, Any]: """Get statistics for a collection""" try: collection_info = self.qdrant_client.get_collection(collection_name) return { "name": collection_name, "vectors_count": collection_info.vectors_count, "indexed_vectors_count": collection_info.indexed_vectors_count, "points_count": collection_info.points_count, "segments_count": collection_info.segments_count, "status": collection_info.status, } except Exception as e: self.logger.error(f"Failed to get stats for {collection_name}: {str(e)}") return {"error": str(e)}