Files
ai-tax-agent/apps/svc_rag_indexer/main.py
harkon b324ff09ef
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Initial commit
2025-10-11 08:41:36 +01:00

536 lines
17 KiB
Python

# FILE: apps/svc-rag-indexer/main.py
# mypy: disable-error-code=union-attr
# Vector database indexing with PII protection and de-identification
import os
# Import shared libraries
import sys
from datetime import datetime
from typing import Any
import structlog
import ulid
from fastapi import BackgroundTasks, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from libs.app_factory import create_app
from libs.config import BaseAppSettings, create_event_bus, create_qdrant_client
from libs.events import EventBus, EventPayload, EventTopics
from libs.observability import get_metrics, get_tracer, setup_observability
from libs.rag import PIIDetector, QdrantCollectionManager
from libs.schemas import ErrorResponse
from libs.security import get_current_user, get_tenant_id
logger = structlog.get_logger()
class RAGIndexerSettings(BaseAppSettings):
"""Settings for RAG indexer service"""
service_name: str = "svc-rag-indexer"
# Embedding configuration
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
embedding_dimension: int = 384
# Chunking configuration
chunk_size: int = 512
chunk_overlap: int = 50
# Collection configuration
collections: dict[str, str] = {
"documents": "Document chunks with metadata",
"tax_rules": "Tax rules and regulations",
"case_law": "Tax case law and precedents",
"guidance": "HMRC guidance and manuals",
}
# PII protection
require_pii_free: bool = True
auto_deidentify: bool = True
# Create app and settings
app, settings = create_app(
service_name="svc-rag-indexer",
title="Tax Agent RAG Indexer Service",
description="Vector database indexing with PII protection",
settings_class=RAGIndexerSettings,
)
# Global clients
qdrant_client = None
collection_manager: QdrantCollectionManager | None = None
pii_detector: PIIDetector | None = None
event_bus: EventBus | None = None
embedding_model = None
tracer = get_tracer("svc-rag-indexer")
metrics = get_metrics()
@app.on_event("startup")
async def startup_event() -> None:
"""Initialize service dependencies"""
global qdrant_client, collection_manager, pii_detector, event_bus, embedding_model
logger.info("Starting RAG indexer service")
# Setup observability
setup_observability(settings)
# Initialize Qdrant client
qdrant_client = create_qdrant_client(settings)
collection_manager = QdrantCollectionManager(qdrant_client)
# Initialize PII detector
pii_detector = PIIDetector()
# Initialize embedding model
try:
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer(settings.embedding_model)
logger.info("Embedding model loaded", model=settings.embedding_model)
except ImportError:
logger.warning("sentence-transformers not available, using mock embeddings")
embedding_model = None
# Initialize event bus
event_bus = create_event_bus(settings)
await event_bus.start()
# Subscribe to relevant events
await event_bus.subscribe(EventTopics.DOC_EXTRACTED, _handle_document_extracted) # type: ignore
await event_bus.subscribe(EventTopics.KG_UPSERTED, _handle_kg_upserted) # type: ignore
# Ensure collections exist
for collection_name in settings.collections:
await collection_manager.ensure_collection(
collection_name=collection_name, vector_size=settings.embedding_dimension
)
logger.info("RAG indexer service started successfully")
@app.on_event("shutdown")
async def shutdown_event() -> None:
"""Cleanup service dependencies"""
global event_bus
logger.info("Shutting down RAG indexer service")
if event_bus:
await event_bus.stop()
logger.info("RAG indexer service shutdown complete")
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Health check endpoint"""
return {
"status": "healthy",
"service": settings.service_name,
"version": settings.service_version,
"timestamp": datetime.utcnow().isoformat(),
"collections": list(settings.collections.keys()),
}
@app.post("/index/{collection_name}")
async def index_document(
collection_name: str,
document: dict[str, Any],
background_tasks: BackgroundTasks,
current_user: dict[str, Any] = Depends(get_current_user),
tenant_id: str = Depends(get_tenant_id),
):
"""Index document in vector database"""
with tracer.start_as_current_span("index_document") as span:
span.set_attribute("collection_name", collection_name)
span.set_attribute("tenant_id", tenant_id)
try:
# Validate collection
if collection_name not in settings.collections:
raise HTTPException(
status_code=400, detail=f"Unknown collection: {collection_name}"
)
# Generate indexing ID
indexing_id = str(ulid.new())
span.set_attribute("indexing_id", indexing_id)
# Start background indexing
background_tasks.add_task(
_index_document_async,
collection_name,
document,
tenant_id,
indexing_id,
current_user.get("sub", "system"),
)
logger.info(
"Document indexing started",
collection=collection_name,
indexing_id=indexing_id,
)
return {
"indexing_id": indexing_id,
"collection": collection_name,
"status": "indexing",
}
except HTTPException:
raise
except Exception as e:
logger.error(
"Failed to start indexing", collection=collection_name, error=str(e)
)
raise HTTPException(status_code=500, detail="Failed to start indexing")
@app.get("/collections")
async def list_collections(
current_user: dict[str, Any] = Depends(get_current_user),
tenant_id: str = Depends(get_tenant_id),
):
"""List available collections"""
try:
collections_info: list[Any] = []
for collection_name, description in settings.collections.items():
# Get collection info from Qdrant
try:
collection_info = qdrant_client.get_collection(collection_name)
point_count = collection_info.points_count
vector_count = collection_info.vectors_count
except Exception:
point_count = 0
vector_count = 0
collections_info.append(
{
"name": collection_name,
"description": description,
"point_count": point_count,
"vector_count": vector_count,
}
)
return {
"collections": collections_info,
"total_collections": len(collections_info),
}
except Exception as e:
logger.error("Failed to list collections", error=str(e))
raise HTTPException(status_code=500, detail="Failed to list collections")
async def _handle_document_extracted(topic: str, payload: EventPayload) -> None:
"""Handle document extraction completion events"""
try:
data = payload.data
doc_id = data.get("doc_id")
tenant_id = data.get("tenant_id")
extraction_results = data.get("extraction_results")
if not doc_id or not tenant_id or not extraction_results:
logger.warning("Invalid document extraction event", data=data)
return
logger.info("Auto-indexing extracted document", doc_id=doc_id)
# Create document for indexing
document = {
"doc_id": doc_id,
"content": _extract_content_from_results(extraction_results),
"metadata": {
"doc_id": doc_id,
"tenant_id": tenant_id,
"extraction_id": extraction_results.get("extraction_id"),
"confidence": extraction_results.get("confidence", 0.0),
"extracted_at": extraction_results.get("extracted_at"),
"source": "extraction",
},
}
await _index_document_async(
collection_name="documents",
document=document,
tenant_id=tenant_id,
indexing_id=str(ulid.new()),
actor=payload.actor,
)
except Exception as e:
logger.error("Failed to handle document extraction event", error=str(e))
async def _handle_kg_upserted(topic: str, payload: EventPayload) -> None:
"""Handle knowledge graph upsert events"""
try:
data = payload.data
entities = data.get("entities", [])
tenant_id = data.get("tenant_id")
if not entities or not tenant_id:
logger.warning("Invalid KG upsert event", data=data)
return
logger.info("Auto-indexing KG entities", count=len(entities))
# Index entities as documents
for entity in entities:
document = {
"entity_id": entity.get("id"),
"content": _extract_content_from_entity(entity),
"metadata": {
"entity_type": entity.get("type"),
"entity_id": entity.get("id"),
"tenant_id": tenant_id,
"source": "knowledge_graph",
},
}
await _index_document_async(
collection_name="documents",
document=document,
tenant_id=tenant_id,
indexing_id=str(ulid.new()),
actor=payload.actor,
)
except Exception as e:
logger.error("Failed to handle KG upsert event", error=str(e))
async def _index_document_async(
collection_name: str,
document: dict[str, Any],
tenant_id: str,
indexing_id: str,
actor: str,
):
"""Index document asynchronously"""
with tracer.start_as_current_span("index_document_async") as span:
span.set_attribute("collection_name", collection_name)
span.set_attribute("indexing_id", indexing_id)
span.set_attribute("tenant_id", tenant_id)
try:
content = document.get("content", "")
metadata = document.get("metadata", {})
# Check for PII and de-identify if needed
if settings.require_pii_free:
has_pii = pii_detector.has_pii(content)
if has_pii:
if settings.auto_deidentify:
content, pii_mapping = pii_detector.de_identify_text(content)
metadata["pii_removed"] = True
metadata["pii_mapping_hash"] = _hash_pii_mapping(pii_mapping)
logger.info("PII removed from content", indexing_id=indexing_id)
else:
logger.warning(
"Content contains PII, skipping indexing",
indexing_id=indexing_id,
)
return
# Mark as PII-free
metadata["pii_free"] = True
metadata["tenant_id"] = tenant_id
metadata["indexed_at"] = datetime.utcnow().isoformat()
# Chunk content
chunks = _chunk_text(content)
# Generate embeddings and index chunks
indexed_chunks = 0
for i, chunk in enumerate(chunks):
try:
# Generate embedding
embedding = await _generate_embedding(chunk)
# Create point
point_id = f"{indexing_id}_{i}"
from qdrant_client.models import PointStruct
point = PointStruct(
id=point_id,
vector=embedding,
payload={
**metadata,
"chunk_text": chunk,
"chunk_index": i,
"total_chunks": len(chunks),
},
)
# Index point
success = await collection_manager.upsert_points(
collection_name, [point]
)
if success:
indexed_chunks += 1
except Exception as e:
logger.error("Failed to index chunk", chunk_index=i, error=str(e))
# Update metrics
metrics.counter("documents_indexed_total").labels(
tenant_id=tenant_id, collection=collection_name
).inc()
metrics.histogram("chunks_per_document").labels(
collection=collection_name
).observe(indexed_chunks)
# Publish completion event
event_payload = EventPayload(
data={
"indexing_id": indexing_id,
"collection": collection_name,
"tenant_id": tenant_id,
"chunks_indexed": indexed_chunks,
"total_chunks": len(chunks),
},
actor=actor,
tenant_id=tenant_id,
)
await event_bus.publish(EventTopics.RAG_INDEXED, event_payload)
logger.info(
"Document indexing completed",
indexing_id=indexing_id,
chunks=indexed_chunks,
)
except Exception as e:
logger.error(
"Document indexing failed", indexing_id=indexing_id, error=str(e)
)
# Update error metrics
metrics.counter("indexing_errors_total").labels(
tenant_id=tenant_id,
collection=collection_name,
error_type=type(e).__name__,
).inc()
def _extract_content_from_results(extraction_results: dict[str, Any]) -> str:
"""Extract text content from extraction results"""
content_parts: list[Any] = []
# Add extracted fields
extracted_fields = extraction_results.get("extracted_fields", {})
for field_name, field_value in extracted_fields.items():
content_parts.append(f"{field_name}: {field_value}")
return "\n".join(content_parts)
def _extract_content_from_entity(entity: dict[str, Any]) -> str:
"""Extract text content from KG entity"""
content_parts: list[Any] = []
# Add entity type and ID
entity_type = entity.get("type", "Unknown")
entity_id = entity.get("id", "")
content_parts.append(f"Entity Type: {entity_type}")
content_parts.append(f"Entity ID: {entity_id}")
# Add properties
properties = entity.get("properties", {})
for prop_name, prop_value in properties.items():
if prop_name not in ["tenant_id", "asserted_at", "retracted_at"]:
content_parts.append(f"{prop_name}: {prop_value}")
return "\n".join(content_parts)
def _chunk_text(text: str) -> list[str]:
"""Chunk text into smaller pieces"""
if not text:
return []
# Simple chunking by sentences/paragraphs
chunks: list[Any] = []
current_chunk = ""
sentences = text.split(". ")
for sentence in sentences:
if len(current_chunk) + len(sentence) < settings.chunk_size:
current_chunk += sentence + ". "
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + ". "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
async def _generate_embedding(text: str) -> list[float]:
"""Generate embedding for text"""
if embedding_model:
try:
embedding = embedding_model.encode(text)
return embedding.tolist()
except Exception as e:
logger.error("Failed to generate embedding", error=str(e))
# Fallback: random embedding
import random
return [random.random() for _ in range(settings.embedding_dimension)]
def _hash_pii_mapping(pii_mapping: dict[str, str]) -> str:
"""Create hash of PII mapping for audit purposes"""
import hashlib
import json
mapping_json = json.dumps(pii_mapping, sort_keys=True)
return hashlib.sha256(mapping_json.encode()).hexdigest()
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""Handle HTTP exceptions with RFC7807 format"""
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
type=f"https://httpstatuses.com/{exc.status_code}",
title=exc.detail,
status=exc.status_code,
detail=exc.detail,
instance=str(request.url),
trace_id="",
).model_dump(),
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8006, reload=True, log_config=None)