# FILE: apps/svc-rag-indexer/main.py # mypy: disable-error-code=union-attr # Vector database indexing with PII protection and de-identification import os # Import shared libraries import sys from datetime import datetime from typing import Any import structlog import ulid from fastapi import BackgroundTasks, Depends, HTTPException, Request from fastapi.responses import JSONResponse sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from libs.app_factory import create_app from libs.config import BaseAppSettings, create_event_bus, create_qdrant_client from libs.events import EventBus, EventPayload, EventTopics from libs.observability import get_metrics, get_tracer, setup_observability from libs.rag import PIIDetector, QdrantCollectionManager from libs.schemas import ErrorResponse from libs.security import get_current_user, get_tenant_id logger = structlog.get_logger() class RAGIndexerSettings(BaseAppSettings): """Settings for RAG indexer service""" service_name: str = "svc-rag-indexer" # Embedding configuration embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_dimension: int = 384 # Chunking configuration chunk_size: int = 512 chunk_overlap: int = 50 # Collection configuration collections: dict[str, str] = { "documents": "Document chunks with metadata", "tax_rules": "Tax rules and regulations", "case_law": "Tax case law and precedents", "guidance": "HMRC guidance and manuals", } # PII protection require_pii_free: bool = True auto_deidentify: bool = True # Create app and settings app, settings = create_app( service_name="svc-rag-indexer", title="Tax Agent RAG Indexer Service", description="Vector database indexing with PII protection", settings_class=RAGIndexerSettings, ) # Global clients qdrant_client = None collection_manager: QdrantCollectionManager | None = None pii_detector: PIIDetector | None = None event_bus: EventBus | None = None embedding_model = None tracer = get_tracer("svc-rag-indexer") metrics = get_metrics() @app.on_event("startup") async def startup_event() -> None: """Initialize service dependencies""" global qdrant_client, collection_manager, pii_detector, event_bus, embedding_model logger.info("Starting RAG indexer service") # Setup observability setup_observability(settings) # Initialize Qdrant client qdrant_client = create_qdrant_client(settings) collection_manager = QdrantCollectionManager(qdrant_client) # Initialize PII detector pii_detector = PIIDetector() # Initialize embedding model try: from sentence_transformers import SentenceTransformer embedding_model = SentenceTransformer(settings.embedding_model) logger.info("Embedding model loaded", model=settings.embedding_model) except ImportError: logger.warning("sentence-transformers not available, using mock embeddings") embedding_model = None # Initialize event bus event_bus = create_event_bus(settings) await event_bus.start() # Subscribe to relevant events await event_bus.subscribe(EventTopics.DOC_EXTRACTED, _handle_document_extracted) # type: ignore await event_bus.subscribe(EventTopics.KG_UPSERTED, _handle_kg_upserted) # type: ignore # Ensure collections exist for collection_name in settings.collections: await collection_manager.ensure_collection( collection_name=collection_name, vector_size=settings.embedding_dimension ) logger.info("RAG indexer service started successfully") @app.on_event("shutdown") async def shutdown_event() -> None: """Cleanup service dependencies""" global event_bus logger.info("Shutting down RAG indexer service") if event_bus: await event_bus.stop() logger.info("RAG indexer service shutdown complete") @app.get("/health") async def health_check() -> dict[str, Any]: """Health check endpoint""" return { "status": "healthy", "service": settings.service_name, "version": settings.service_version, "timestamp": datetime.utcnow().isoformat(), "collections": list(settings.collections.keys()), } @app.post("/index/{collection_name}") async def index_document( collection_name: str, document: dict[str, Any], background_tasks: BackgroundTasks, current_user: dict[str, Any] = Depends(get_current_user), tenant_id: str = Depends(get_tenant_id), ): """Index document in vector database""" with tracer.start_as_current_span("index_document") as span: span.set_attribute("collection_name", collection_name) span.set_attribute("tenant_id", tenant_id) try: # Validate collection if collection_name not in settings.collections: raise HTTPException( status_code=400, detail=f"Unknown collection: {collection_name}" ) # Generate indexing ID indexing_id = str(ulid.new()) span.set_attribute("indexing_id", indexing_id) # Start background indexing background_tasks.add_task( _index_document_async, collection_name, document, tenant_id, indexing_id, current_user.get("sub", "system"), ) logger.info( "Document indexing started", collection=collection_name, indexing_id=indexing_id, ) return { "indexing_id": indexing_id, "collection": collection_name, "status": "indexing", } except HTTPException: raise except Exception as e: logger.error( "Failed to start indexing", collection=collection_name, error=str(e) ) raise HTTPException(status_code=500, detail="Failed to start indexing") @app.get("/collections") async def list_collections( current_user: dict[str, Any] = Depends(get_current_user), tenant_id: str = Depends(get_tenant_id), ): """List available collections""" try: collections_info: list[Any] = [] for collection_name, description in settings.collections.items(): # Get collection info from Qdrant try: collection_info = qdrant_client.get_collection(collection_name) point_count = collection_info.points_count vector_count = collection_info.vectors_count except Exception: point_count = 0 vector_count = 0 collections_info.append( { "name": collection_name, "description": description, "point_count": point_count, "vector_count": vector_count, } ) return { "collections": collections_info, "total_collections": len(collections_info), } except Exception as e: logger.error("Failed to list collections", error=str(e)) raise HTTPException(status_code=500, detail="Failed to list collections") async def _handle_document_extracted(topic: str, payload: EventPayload) -> None: """Handle document extraction completion events""" try: data = payload.data doc_id = data.get("doc_id") tenant_id = data.get("tenant_id") extraction_results = data.get("extraction_results") if not doc_id or not tenant_id or not extraction_results: logger.warning("Invalid document extraction event", data=data) return logger.info("Auto-indexing extracted document", doc_id=doc_id) # Create document for indexing document = { "doc_id": doc_id, "content": _extract_content_from_results(extraction_results), "metadata": { "doc_id": doc_id, "tenant_id": tenant_id, "extraction_id": extraction_results.get("extraction_id"), "confidence": extraction_results.get("confidence", 0.0), "extracted_at": extraction_results.get("extracted_at"), "source": "extraction", }, } await _index_document_async( collection_name="documents", document=document, tenant_id=tenant_id, indexing_id=str(ulid.new()), actor=payload.actor, ) except Exception as e: logger.error("Failed to handle document extraction event", error=str(e)) async def _handle_kg_upserted(topic: str, payload: EventPayload) -> None: """Handle knowledge graph upsert events""" try: data = payload.data entities = data.get("entities", []) tenant_id = data.get("tenant_id") if not entities or not tenant_id: logger.warning("Invalid KG upsert event", data=data) return logger.info("Auto-indexing KG entities", count=len(entities)) # Index entities as documents for entity in entities: document = { "entity_id": entity.get("id"), "content": _extract_content_from_entity(entity), "metadata": { "entity_type": entity.get("type"), "entity_id": entity.get("id"), "tenant_id": tenant_id, "source": "knowledge_graph", }, } await _index_document_async( collection_name="documents", document=document, tenant_id=tenant_id, indexing_id=str(ulid.new()), actor=payload.actor, ) except Exception as e: logger.error("Failed to handle KG upsert event", error=str(e)) async def _index_document_async( collection_name: str, document: dict[str, Any], tenant_id: str, indexing_id: str, actor: str, ): """Index document asynchronously""" with tracer.start_as_current_span("index_document_async") as span: span.set_attribute("collection_name", collection_name) span.set_attribute("indexing_id", indexing_id) span.set_attribute("tenant_id", tenant_id) try: content = document.get("content", "") metadata = document.get("metadata", {}) # Check for PII and de-identify if needed if settings.require_pii_free: has_pii = pii_detector.has_pii(content) if has_pii: if settings.auto_deidentify: content, pii_mapping = pii_detector.de_identify_text(content) metadata["pii_removed"] = True metadata["pii_mapping_hash"] = _hash_pii_mapping(pii_mapping) logger.info("PII removed from content", indexing_id=indexing_id) else: logger.warning( "Content contains PII, skipping indexing", indexing_id=indexing_id, ) return # Mark as PII-free metadata["pii_free"] = True metadata["tenant_id"] = tenant_id metadata["indexed_at"] = datetime.utcnow().isoformat() # Chunk content chunks = _chunk_text(content) # Generate embeddings and index chunks indexed_chunks = 0 for i, chunk in enumerate(chunks): try: # Generate embedding embedding = await _generate_embedding(chunk) # Create point point_id = f"{indexing_id}_{i}" from qdrant_client.models import PointStruct point = PointStruct( id=point_id, vector=embedding, payload={ **metadata, "chunk_text": chunk, "chunk_index": i, "total_chunks": len(chunks), }, ) # Index point success = await collection_manager.upsert_points( collection_name, [point] ) if success: indexed_chunks += 1 except Exception as e: logger.error("Failed to index chunk", chunk_index=i, error=str(e)) # Update metrics metrics.counter("documents_indexed_total").labels( tenant_id=tenant_id, collection=collection_name ).inc() metrics.histogram("chunks_per_document").labels( collection=collection_name ).observe(indexed_chunks) # Publish completion event event_payload = EventPayload( data={ "indexing_id": indexing_id, "collection": collection_name, "tenant_id": tenant_id, "chunks_indexed": indexed_chunks, "total_chunks": len(chunks), }, actor=actor, tenant_id=tenant_id, ) await event_bus.publish(EventTopics.RAG_INDEXED, event_payload) logger.info( "Document indexing completed", indexing_id=indexing_id, chunks=indexed_chunks, ) except Exception as e: logger.error( "Document indexing failed", indexing_id=indexing_id, error=str(e) ) # Update error metrics metrics.counter("indexing_errors_total").labels( tenant_id=tenant_id, collection=collection_name, error_type=type(e).__name__, ).inc() def _extract_content_from_results(extraction_results: dict[str, Any]) -> str: """Extract text content from extraction results""" content_parts: list[Any] = [] # Add extracted fields extracted_fields = extraction_results.get("extracted_fields", {}) for field_name, field_value in extracted_fields.items(): content_parts.append(f"{field_name}: {field_value}") return "\n".join(content_parts) def _extract_content_from_entity(entity: dict[str, Any]) -> str: """Extract text content from KG entity""" content_parts: list[Any] = [] # Add entity type and ID entity_type = entity.get("type", "Unknown") entity_id = entity.get("id", "") content_parts.append(f"Entity Type: {entity_type}") content_parts.append(f"Entity ID: {entity_id}") # Add properties properties = entity.get("properties", {}) for prop_name, prop_value in properties.items(): if prop_name not in ["tenant_id", "asserted_at", "retracted_at"]: content_parts.append(f"{prop_name}: {prop_value}") return "\n".join(content_parts) def _chunk_text(text: str) -> list[str]: """Chunk text into smaller pieces""" if not text: return [] # Simple chunking by sentences/paragraphs chunks: list[Any] = [] current_chunk = "" sentences = text.split(". ") for sentence in sentences: if len(current_chunk) + len(sentence) < settings.chunk_size: current_chunk += sentence + ". " else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + ". " if current_chunk: chunks.append(current_chunk.strip()) return chunks async def _generate_embedding(text: str) -> list[float]: """Generate embedding for text""" if embedding_model: try: embedding = embedding_model.encode(text) return embedding.tolist() except Exception as e: logger.error("Failed to generate embedding", error=str(e)) # Fallback: random embedding import random return [random.random() for _ in range(settings.embedding_dimension)] def _hash_pii_mapping(pii_mapping: dict[str, str]) -> str: """Create hash of PII mapping for audit purposes""" import hashlib import json mapping_json = json.dumps(pii_mapping, sort_keys=True) return hashlib.sha256(mapping_json.encode()).hexdigest() @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: """Handle HTTP exceptions with RFC7807 format""" return JSONResponse( status_code=exc.status_code, content=ErrorResponse( type=f"https://httpstatuses.com/{exc.status_code}", title=exc.detail, status=exc.status_code, detail=exc.detail, instance=str(request.url), trace_id="", ).model_dump(), ) if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8006, reload=True, log_config=None)