Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
536 lines
17 KiB
Python
536 lines
17 KiB
Python
# FILE: apps/svc-rag-indexer/main.py
|
|
# mypy: disable-error-code=union-attr
|
|
# Vector database indexing with PII protection and de-identification
|
|
|
|
import os
|
|
|
|
# Import shared libraries
|
|
import sys
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
import structlog
|
|
import ulid
|
|
from fastapi import BackgroundTasks, Depends, HTTPException, Request
|
|
from fastapi.responses import JSONResponse
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
|
|
from libs.app_factory import create_app
|
|
from libs.config import BaseAppSettings, create_event_bus, create_qdrant_client
|
|
from libs.events import EventBus, EventPayload, EventTopics
|
|
from libs.observability import get_metrics, get_tracer, setup_observability
|
|
from libs.rag import PIIDetector, QdrantCollectionManager
|
|
from libs.schemas import ErrorResponse
|
|
from libs.security import get_current_user, get_tenant_id
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class RAGIndexerSettings(BaseAppSettings):
|
|
"""Settings for RAG indexer service"""
|
|
|
|
service_name: str = "svc-rag-indexer"
|
|
|
|
# Embedding configuration
|
|
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
|
embedding_dimension: int = 384
|
|
|
|
# Chunking configuration
|
|
chunk_size: int = 512
|
|
chunk_overlap: int = 50
|
|
|
|
# Collection configuration
|
|
collections: dict[str, str] = {
|
|
"documents": "Document chunks with metadata",
|
|
"tax_rules": "Tax rules and regulations",
|
|
"case_law": "Tax case law and precedents",
|
|
"guidance": "HMRC guidance and manuals",
|
|
}
|
|
|
|
# PII protection
|
|
require_pii_free: bool = True
|
|
auto_deidentify: bool = True
|
|
|
|
|
|
# Create app and settings
|
|
app, settings = create_app(
|
|
service_name="svc-rag-indexer",
|
|
title="Tax Agent RAG Indexer Service",
|
|
description="Vector database indexing with PII protection",
|
|
settings_class=RAGIndexerSettings,
|
|
)
|
|
|
|
# Global clients
|
|
qdrant_client = None
|
|
collection_manager: QdrantCollectionManager | None = None
|
|
pii_detector: PIIDetector | None = None
|
|
event_bus: EventBus | None = None
|
|
embedding_model = None
|
|
tracer = get_tracer("svc-rag-indexer")
|
|
metrics = get_metrics()
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event() -> None:
|
|
"""Initialize service dependencies"""
|
|
global qdrant_client, collection_manager, pii_detector, event_bus, embedding_model
|
|
|
|
logger.info("Starting RAG indexer service")
|
|
|
|
# Setup observability
|
|
setup_observability(settings)
|
|
|
|
# Initialize Qdrant client
|
|
qdrant_client = create_qdrant_client(settings)
|
|
collection_manager = QdrantCollectionManager(qdrant_client)
|
|
|
|
# Initialize PII detector
|
|
pii_detector = PIIDetector()
|
|
|
|
# Initialize embedding model
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
embedding_model = SentenceTransformer(settings.embedding_model)
|
|
logger.info("Embedding model loaded", model=settings.embedding_model)
|
|
except ImportError:
|
|
logger.warning("sentence-transformers not available, using mock embeddings")
|
|
embedding_model = None
|
|
|
|
# Initialize event bus
|
|
event_bus = create_event_bus(settings)
|
|
await event_bus.start()
|
|
|
|
# Subscribe to relevant events
|
|
await event_bus.subscribe(EventTopics.DOC_EXTRACTED, _handle_document_extracted) # type: ignore
|
|
await event_bus.subscribe(EventTopics.KG_UPSERTED, _handle_kg_upserted) # type: ignore
|
|
|
|
# Ensure collections exist
|
|
for collection_name in settings.collections:
|
|
await collection_manager.ensure_collection(
|
|
collection_name=collection_name, vector_size=settings.embedding_dimension
|
|
)
|
|
|
|
logger.info("RAG indexer service started successfully")
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event() -> None:
|
|
"""Cleanup service dependencies"""
|
|
global event_bus
|
|
|
|
logger.info("Shutting down RAG indexer service")
|
|
|
|
if event_bus:
|
|
await event_bus.stop()
|
|
|
|
logger.info("RAG indexer service shutdown complete")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check() -> dict[str, Any]:
|
|
"""Health check endpoint"""
|
|
return {
|
|
"status": "healthy",
|
|
"service": settings.service_name,
|
|
"version": settings.service_version,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"collections": list(settings.collections.keys()),
|
|
}
|
|
|
|
|
|
@app.post("/index/{collection_name}")
|
|
async def index_document(
|
|
collection_name: str,
|
|
document: dict[str, Any],
|
|
background_tasks: BackgroundTasks,
|
|
current_user: dict[str, Any] = Depends(get_current_user),
|
|
tenant_id: str = Depends(get_tenant_id),
|
|
):
|
|
"""Index document in vector database"""
|
|
|
|
with tracer.start_as_current_span("index_document") as span:
|
|
span.set_attribute("collection_name", collection_name)
|
|
span.set_attribute("tenant_id", tenant_id)
|
|
|
|
try:
|
|
# Validate collection
|
|
if collection_name not in settings.collections:
|
|
raise HTTPException(
|
|
status_code=400, detail=f"Unknown collection: {collection_name}"
|
|
)
|
|
|
|
# Generate indexing ID
|
|
indexing_id = str(ulid.new())
|
|
span.set_attribute("indexing_id", indexing_id)
|
|
|
|
# Start background indexing
|
|
background_tasks.add_task(
|
|
_index_document_async,
|
|
collection_name,
|
|
document,
|
|
tenant_id,
|
|
indexing_id,
|
|
current_user.get("sub", "system"),
|
|
)
|
|
|
|
logger.info(
|
|
"Document indexing started",
|
|
collection=collection_name,
|
|
indexing_id=indexing_id,
|
|
)
|
|
|
|
return {
|
|
"indexing_id": indexing_id,
|
|
"collection": collection_name,
|
|
"status": "indexing",
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to start indexing", collection=collection_name, error=str(e)
|
|
)
|
|
raise HTTPException(status_code=500, detail="Failed to start indexing")
|
|
|
|
|
|
@app.get("/collections")
|
|
async def list_collections(
|
|
current_user: dict[str, Any] = Depends(get_current_user),
|
|
tenant_id: str = Depends(get_tenant_id),
|
|
):
|
|
"""List available collections"""
|
|
|
|
try:
|
|
collections_info: list[Any] = []
|
|
|
|
for collection_name, description in settings.collections.items():
|
|
# Get collection info from Qdrant
|
|
try:
|
|
collection_info = qdrant_client.get_collection(collection_name)
|
|
point_count = collection_info.points_count
|
|
vector_count = collection_info.vectors_count
|
|
except Exception:
|
|
point_count = 0
|
|
vector_count = 0
|
|
|
|
collections_info.append(
|
|
{
|
|
"name": collection_name,
|
|
"description": description,
|
|
"point_count": point_count,
|
|
"vector_count": vector_count,
|
|
}
|
|
)
|
|
|
|
return {
|
|
"collections": collections_info,
|
|
"total_collections": len(collections_info),
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to list collections", error=str(e))
|
|
raise HTTPException(status_code=500, detail="Failed to list collections")
|
|
|
|
|
|
async def _handle_document_extracted(topic: str, payload: EventPayload) -> None:
|
|
"""Handle document extraction completion events"""
|
|
try:
|
|
data = payload.data
|
|
doc_id = data.get("doc_id")
|
|
tenant_id = data.get("tenant_id")
|
|
extraction_results = data.get("extraction_results")
|
|
|
|
if not doc_id or not tenant_id or not extraction_results:
|
|
logger.warning("Invalid document extraction event", data=data)
|
|
return
|
|
|
|
logger.info("Auto-indexing extracted document", doc_id=doc_id)
|
|
|
|
# Create document for indexing
|
|
document = {
|
|
"doc_id": doc_id,
|
|
"content": _extract_content_from_results(extraction_results),
|
|
"metadata": {
|
|
"doc_id": doc_id,
|
|
"tenant_id": tenant_id,
|
|
"extraction_id": extraction_results.get("extraction_id"),
|
|
"confidence": extraction_results.get("confidence", 0.0),
|
|
"extracted_at": extraction_results.get("extracted_at"),
|
|
"source": "extraction",
|
|
},
|
|
}
|
|
|
|
await _index_document_async(
|
|
collection_name="documents",
|
|
document=document,
|
|
tenant_id=tenant_id,
|
|
indexing_id=str(ulid.new()),
|
|
actor=payload.actor,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to handle document extraction event", error=str(e))
|
|
|
|
|
|
async def _handle_kg_upserted(topic: str, payload: EventPayload) -> None:
|
|
"""Handle knowledge graph upsert events"""
|
|
try:
|
|
data = payload.data
|
|
entities = data.get("entities", [])
|
|
tenant_id = data.get("tenant_id")
|
|
|
|
if not entities or not tenant_id:
|
|
logger.warning("Invalid KG upsert event", data=data)
|
|
return
|
|
|
|
logger.info("Auto-indexing KG entities", count=len(entities))
|
|
|
|
# Index entities as documents
|
|
for entity in entities:
|
|
document = {
|
|
"entity_id": entity.get("id"),
|
|
"content": _extract_content_from_entity(entity),
|
|
"metadata": {
|
|
"entity_type": entity.get("type"),
|
|
"entity_id": entity.get("id"),
|
|
"tenant_id": tenant_id,
|
|
"source": "knowledge_graph",
|
|
},
|
|
}
|
|
|
|
await _index_document_async(
|
|
collection_name="documents",
|
|
document=document,
|
|
tenant_id=tenant_id,
|
|
indexing_id=str(ulid.new()),
|
|
actor=payload.actor,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to handle KG upsert event", error=str(e))
|
|
|
|
|
|
async def _index_document_async(
|
|
collection_name: str,
|
|
document: dict[str, Any],
|
|
tenant_id: str,
|
|
indexing_id: str,
|
|
actor: str,
|
|
):
|
|
"""Index document asynchronously"""
|
|
|
|
with tracer.start_as_current_span("index_document_async") as span:
|
|
span.set_attribute("collection_name", collection_name)
|
|
span.set_attribute("indexing_id", indexing_id)
|
|
span.set_attribute("tenant_id", tenant_id)
|
|
|
|
try:
|
|
content = document.get("content", "")
|
|
metadata = document.get("metadata", {})
|
|
|
|
# Check for PII and de-identify if needed
|
|
if settings.require_pii_free:
|
|
has_pii = pii_detector.has_pii(content)
|
|
|
|
if has_pii:
|
|
if settings.auto_deidentify:
|
|
content, pii_mapping = pii_detector.de_identify_text(content)
|
|
metadata["pii_removed"] = True
|
|
metadata["pii_mapping_hash"] = _hash_pii_mapping(pii_mapping)
|
|
logger.info("PII removed from content", indexing_id=indexing_id)
|
|
else:
|
|
logger.warning(
|
|
"Content contains PII, skipping indexing",
|
|
indexing_id=indexing_id,
|
|
)
|
|
return
|
|
|
|
# Mark as PII-free
|
|
metadata["pii_free"] = True
|
|
metadata["tenant_id"] = tenant_id
|
|
metadata["indexed_at"] = datetime.utcnow().isoformat()
|
|
|
|
# Chunk content
|
|
chunks = _chunk_text(content)
|
|
|
|
# Generate embeddings and index chunks
|
|
indexed_chunks = 0
|
|
for i, chunk in enumerate(chunks):
|
|
try:
|
|
# Generate embedding
|
|
embedding = await _generate_embedding(chunk)
|
|
|
|
# Create point
|
|
point_id = f"{indexing_id}_{i}"
|
|
|
|
from qdrant_client.models import PointStruct
|
|
|
|
point = PointStruct(
|
|
id=point_id,
|
|
vector=embedding,
|
|
payload={
|
|
**metadata,
|
|
"chunk_text": chunk,
|
|
"chunk_index": i,
|
|
"total_chunks": len(chunks),
|
|
},
|
|
)
|
|
|
|
# Index point
|
|
success = await collection_manager.upsert_points(
|
|
collection_name, [point]
|
|
)
|
|
|
|
if success:
|
|
indexed_chunks += 1
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to index chunk", chunk_index=i, error=str(e))
|
|
|
|
# Update metrics
|
|
metrics.counter("documents_indexed_total").labels(
|
|
tenant_id=tenant_id, collection=collection_name
|
|
).inc()
|
|
|
|
metrics.histogram("chunks_per_document").labels(
|
|
collection=collection_name
|
|
).observe(indexed_chunks)
|
|
|
|
# Publish completion event
|
|
event_payload = EventPayload(
|
|
data={
|
|
"indexing_id": indexing_id,
|
|
"collection": collection_name,
|
|
"tenant_id": tenant_id,
|
|
"chunks_indexed": indexed_chunks,
|
|
"total_chunks": len(chunks),
|
|
},
|
|
actor=actor,
|
|
tenant_id=tenant_id,
|
|
)
|
|
|
|
await event_bus.publish(EventTopics.RAG_INDEXED, event_payload)
|
|
|
|
logger.info(
|
|
"Document indexing completed",
|
|
indexing_id=indexing_id,
|
|
chunks=indexed_chunks,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Document indexing failed", indexing_id=indexing_id, error=str(e)
|
|
)
|
|
|
|
# Update error metrics
|
|
metrics.counter("indexing_errors_total").labels(
|
|
tenant_id=tenant_id,
|
|
collection=collection_name,
|
|
error_type=type(e).__name__,
|
|
).inc()
|
|
|
|
|
|
def _extract_content_from_results(extraction_results: dict[str, Any]) -> str:
|
|
"""Extract text content from extraction results"""
|
|
content_parts: list[Any] = []
|
|
|
|
# Add extracted fields
|
|
extracted_fields = extraction_results.get("extracted_fields", {})
|
|
for field_name, field_value in extracted_fields.items():
|
|
content_parts.append(f"{field_name}: {field_value}")
|
|
|
|
return "\n".join(content_parts)
|
|
|
|
|
|
def _extract_content_from_entity(entity: dict[str, Any]) -> str:
|
|
"""Extract text content from KG entity"""
|
|
content_parts: list[Any] = []
|
|
|
|
# Add entity type and ID
|
|
entity_type = entity.get("type", "Unknown")
|
|
entity_id = entity.get("id", "")
|
|
content_parts.append(f"Entity Type: {entity_type}")
|
|
content_parts.append(f"Entity ID: {entity_id}")
|
|
|
|
# Add properties
|
|
properties = entity.get("properties", {})
|
|
for prop_name, prop_value in properties.items():
|
|
if prop_name not in ["tenant_id", "asserted_at", "retracted_at"]:
|
|
content_parts.append(f"{prop_name}: {prop_value}")
|
|
|
|
return "\n".join(content_parts)
|
|
|
|
|
|
def _chunk_text(text: str) -> list[str]:
|
|
"""Chunk text into smaller pieces"""
|
|
if not text:
|
|
return []
|
|
|
|
# Simple chunking by sentences/paragraphs
|
|
chunks: list[Any] = []
|
|
current_chunk = ""
|
|
|
|
sentences = text.split(". ")
|
|
|
|
for sentence in sentences:
|
|
if len(current_chunk) + len(sentence) < settings.chunk_size:
|
|
current_chunk += sentence + ". "
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
current_chunk = sentence + ". "
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
|
|
return chunks
|
|
|
|
|
|
async def _generate_embedding(text: str) -> list[float]:
|
|
"""Generate embedding for text"""
|
|
if embedding_model:
|
|
try:
|
|
embedding = embedding_model.encode(text)
|
|
return embedding.tolist()
|
|
except Exception as e:
|
|
logger.error("Failed to generate embedding", error=str(e))
|
|
|
|
# Fallback: random embedding
|
|
import random
|
|
|
|
return [random.random() for _ in range(settings.embedding_dimension)]
|
|
|
|
|
|
def _hash_pii_mapping(pii_mapping: dict[str, str]) -> str:
|
|
"""Create hash of PII mapping for audit purposes"""
|
|
import hashlib
|
|
import json
|
|
|
|
mapping_json = json.dumps(pii_mapping, sort_keys=True)
|
|
return hashlib.sha256(mapping_json.encode()).hexdigest()
|
|
|
|
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
|
"""Handle HTTP exceptions with RFC7807 format"""
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content=ErrorResponse(
|
|
type=f"https://httpstatuses.com/{exc.status_code}",
|
|
title=exc.detail,
|
|
status=exc.status_code,
|
|
detail=exc.detail,
|
|
instance=str(request.url),
|
|
trace_id="",
|
|
).model_dump(),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8006, reload=True, log_config=None)
|