Files
ai-tax-agent/apps/svc_rag_retriever/main.py
harkon b324ff09ef
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Initial commit
2025-10-11 08:41:36 +01:00

477 lines
16 KiB
Python

# FILE: apps/svc-rag-retriever/main.py
# mypy: disable-error-code=union-attr
# Hybrid search with KG fusion, reranking, and calibrated confidence
import os
# Import shared libraries
import sys
from datetime import datetime
from typing import Any
import structlog
from fastapi import Depends, HTTPException, Query, Request
from fastapi.responses import JSONResponse
from qdrant_client.models import SparseVector
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from libs.app_factory import create_app
from libs.calibration import ConfidenceCalibrator
from libs.config import (
BaseAppSettings,
create_event_bus,
create_neo4j_client,
create_qdrant_client,
)
from libs.events import EventBus
from libs.neo import Neo4jClient
from libs.observability import get_metrics, get_tracer, setup_observability
from libs.rag import RAGRetriever
from libs.schemas import ErrorResponse, RAGSearchRequest, RAGSearchResponse
from libs.security import get_current_user, get_tenant_id
logger = structlog.get_logger()
class RAGRetrieverSettings(BaseAppSettings):
"""Settings for RAG retriever service"""
service_name: str = "svc-rag-retriever"
# Embedding configuration
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
embedding_dimension: int = 384
# Search configuration
default_k: int = 10
max_k: int = 100
alpha: float = 0.5 # Dense/sparse balance
beta: float = 0.3 # Vector/KG balance
gamma: float = 0.2 # Reranking weight
# Collections to search
search_collections: list[str] = ["documents", "tax_rules", "guidance"]
# Reranking
reranker_model: str | None = None
rerank_top_k: int = 50
# Create app and settings
app, settings = create_app(
service_name="svc-rag-retriever",
title="Tax Agent RAG Retriever Service",
description="Hybrid search with KG fusion and reranking",
settings_class=RAGRetrieverSettings,
)
# Global clients
qdrant_client = None
neo4j_client: Neo4jClient | None = None
rag_retriever: RAGRetriever | None = None
event_bus: EventBus | None = None
embedding_model = None
confidence_calibrator: ConfidenceCalibrator | None = None
tracer = get_tracer("svc-rag-retriever")
metrics = get_metrics()
@app.on_event("startup")
async def startup_event() -> None:
"""Initialize service dependencies"""
global qdrant_client, neo4j_client, rag_retriever, event_bus, embedding_model, confidence_calibrator
logger.info("Starting RAG retriever service")
# Setup observability
setup_observability(settings)
# Initialize Qdrant client
qdrant_client = create_qdrant_client(settings)
# Initialize Neo4j client
neo4j_driver = create_neo4j_client(settings)
neo4j_client = Neo4jClient(neo4j_driver)
# Initialize RAG retriever
rag_retriever = RAGRetriever(
qdrant_client=qdrant_client,
neo4j_client=neo4j_client,
reranker_model=settings.reranker_model,
)
# Initialize embedding model
try:
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer(settings.embedding_model)
logger.info("Embedding model loaded", model=settings.embedding_model)
except ImportError:
logger.warning("sentence-transformers not available, using mock embeddings")
embedding_model = None
# Initialize confidence calibrator
confidence_calibrator = ConfidenceCalibrator(method="isotonic")
# Initialize event bus
event_bus = create_event_bus(settings)
await event_bus.start() # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
logger.info("RAG retriever service started successfully")
@app.on_event("shutdown")
async def shutdown_event() -> None:
"""Cleanup service dependencies"""
global neo4j_client, event_bus
logger.info("Shutting down RAG retriever service")
if neo4j_client:
await neo4j_client.close()
if event_bus:
await event_bus.stop()
logger.info("RAG retriever service shutdown complete")
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Health check endpoint"""
return {
"status": "healthy",
"service": settings.service_name,
"version": settings.service_version,
"timestamp": datetime.utcnow().isoformat(),
"search_collections": settings.search_collections,
}
@app.post("/search", response_model=RAGSearchResponse)
async def search(
request_data: RAGSearchRequest,
current_user: dict[str, Any] = Depends(get_current_user),
tenant_id: str = Depends(get_tenant_id),
) -> RAGSearchResponse:
"""Perform hybrid RAG search"""
with tracer.start_as_current_span("rag_search") as span:
span.set_attribute("query", request_data.query[:100])
span.set_attribute("tenant_id", tenant_id)
span.set_attribute("k", request_data.k)
try:
# Generate embeddings for query
dense_vector = await _generate_embedding(request_data.query)
sparse_vector = await _generate_sparse_vector(request_data.query)
# Perform search
search_results = await rag_retriever.search( # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
query=request_data.query,
collections=settings.search_collections,
dense_vector=dense_vector,
sparse_vector=sparse_vector,
k=request_data.k,
alpha=settings.alpha,
beta=settings.beta,
gamma=settings.gamma,
tax_year=request_data.tax_year,
jurisdiction=request_data.jurisdiction,
)
# Update metrics
metrics.counter("searches_total").labels(tenant_id=tenant_id).inc()
metrics.histogram("search_results_count").labels(
tenant_id=tenant_id
).observe(len(search_results["chunks"]))
metrics.histogram("search_confidence").labels(tenant_id=tenant_id).observe(
search_results["calibrated_confidence"]
)
logger.info(
"RAG search completed",
query=request_data.query[:50],
results=len(search_results["chunks"]),
confidence=search_results["calibrated_confidence"],
)
return RAGSearchResponse(
chunks=search_results["chunks"],
citations=search_results["citations"],
kg_hints=search_results["kg_hints"],
calibrated_confidence=search_results["calibrated_confidence"],
)
except Exception as e:
logger.error(
"RAG search failed", query=request_data.query[:50], error=str(e)
)
# Update error metrics
metrics.counter("search_errors_total").labels(
tenant_id=tenant_id, error_type=type(e).__name__
).inc()
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
@app.get("/similar/{doc_id}")
async def find_similar_documents(
doc_id: str,
k: int = Query(default=10, le=settings.max_k),
current_user: dict[str, Any] = Depends(get_current_user),
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""Find documents similar to given document"""
with tracer.start_as_current_span("find_similar") as span:
span.set_attribute("doc_id", doc_id)
span.set_attribute("tenant_id", tenant_id)
span.set_attribute("k", k)
try:
# Get document content from vector database
# This would search for the document by doc_id in metadata
from qdrant_client.models import FieldCondition, Filter, MatchValue
filter_conditions = Filter(
must=[
FieldCondition(key="doc_id", match=MatchValue(value=doc_id)),
FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id)),
]
)
# Search for the document
doc_results = await rag_retriever.collection_manager.search_dense( # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
collection_name="documents",
query_vector=[0.0] * settings.embedding_dimension, # Dummy vector
limit=1,
filter_conditions=filter_conditions,
)
if not doc_results:
raise HTTPException(status_code=404, detail="Document not found")
# Get the document's vector and use it for similarity search
doc_vector = doc_results[0]["payload"].get("vector")
if not doc_vector:
raise HTTPException(status_code=400, detail="Document has no vector")
# Find similar documents
similar_results = await rag_retriever.collection_manager.search_dense( # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
collection_name="documents",
query_vector=doc_vector,
limit=k + 1, # +1 to exclude the original document
filter_conditions=Filter(
must=[
FieldCondition(
key="tenant_id", match=MatchValue(value=tenant_id)
)
],
must_not=[
FieldCondition(key="doc_id", match=MatchValue(value=doc_id))
],
),
)
return {
"doc_id": doc_id,
"similar_documents": similar_results[:k],
"count": len(similar_results[:k]),
}
except HTTPException:
raise
except Exception as e:
logger.error("Similar document search failed", doc_id=doc_id, error=str(e))
raise HTTPException(
status_code=500, detail=f"Similar search failed: {str(e)}"
)
@app.post("/explain")
async def explain_search(
query: str,
search_results: list[dict[str, Any]],
current_user: dict[str, Any] = Depends(get_current_user),
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""Explain search results and ranking"""
with tracer.start_as_current_span("explain_search") as span:
span.set_attribute("query", query[:100])
span.set_attribute("tenant_id", tenant_id)
span.set_attribute("results_count", len(search_results))
try:
explanations = []
for i, result in enumerate(search_results):
explanation = {
"rank": i + 1,
"chunk_id": result.get("id"),
"score": result.get("score", 0.0),
"dense_score": result.get("dense_score", 0.0),
"sparse_score": result.get("sparse_score", 0.0),
"collection": result.get("collection"),
"explanation": _generate_explanation(query, result),
}
explanations.append(explanation)
return {
"query": query,
"explanations": explanations,
"ranking_factors": {
"alpha": settings.alpha,
"beta": settings.beta,
"gamma": settings.gamma,
},
}
except Exception as e:
logger.error("Search explanation failed", error=str(e))
raise HTTPException(status_code=500, detail=f"Explanation failed: {str(e)}")
async def _generate_embedding(text: str) -> list[float]:
"""Generate dense embedding for text"""
if embedding_model:
try:
embedding = embedding_model.encode(text)
return embedding.tolist()
except Exception as e:
logger.error("Failed to generate embedding", error=str(e))
# Fallback: random embedding
import random
return [random.random() for _ in range(settings.embedding_dimension)]
async def _generate_sparse_vector(text: str) -> SparseVector:
"""Generate sparse vector for text (BM25-style)"""
try:
# This would use a proper sparse encoder like SPLADE
# For now, create a simple sparse representation
from qdrant_client.models import SparseVector
# Simple word-based sparse vector
words = text.lower().split()
word_counts: dict[str, int] = {}
for word in words:
word_counts[word] = word_counts.get(word, 0) + 1
# Convert to sparse vector format
indices = []
values = []
for _i, (word, count) in enumerate(word_counts.items()):
# Use hash of word as index
word_hash = hash(word) % 10000 # Limit vocabulary size
indices.append(word_hash)
values.append(float(count))
return SparseVector(indices=indices, values=values)
except Exception as e:
logger.error("Failed to generate sparse vector", error=str(e))
# Return empty sparse vector
from qdrant_client.models import SparseVector
return SparseVector(indices=[], values=[])
def _generate_explanation(query: str, result: dict[str, Any]) -> str:
"""Generate human-readable explanation for search result"""
explanations = []
# Score explanation
score = result.get("score", 0.0)
dense_score = result.get("dense_score", 0.0)
sparse_score = result.get("sparse_score", 0.0)
explanations.append(f"Overall score: {score:.3f}")
if dense_score > 0:
explanations.append(f"Semantic similarity: {dense_score:.3f}")
if sparse_score > 0:
explanations.append(f"Keyword match: {sparse_score:.3f}")
# Collection explanation
collection = result.get("collection")
if collection:
explanations.append(f"Source: {collection}")
# Metadata explanation
payload = result.get("payload", {})
doc_id = payload.get("doc_id")
if doc_id:
explanations.append(f"Document: {doc_id}")
confidence = payload.get("confidence")
if confidence:
explanations.append(f"Extraction confidence: {confidence:.3f}")
return "; ".join(explanations)
@app.get("/stats")
async def get_search_stats(
current_user: dict[str, Any] = Depends(get_current_user),
tenant_id: str = Depends(get_tenant_id),
) -> dict[str, Any]:
"""Get search statistics"""
try:
# This would aggregate metrics from Prometheus
# For now, return mock stats
stats = {
"total_searches": 1000,
"avg_results_per_search": 8.5,
"avg_confidence": 0.75,
"collections": {
"documents": {"searches": 800, "avg_confidence": 0.78},
"tax_rules": {"searches": 150, "avg_confidence": 0.85},
"guidance": {"searches": 50, "avg_confidence": 0.70},
},
"top_queries": [
{"query": "capital gains tax", "count": 45},
{"query": "business expenses", "count": 38},
{"query": "property income", "count": 32},
],
}
return stats
except Exception as e:
logger.error("Failed to get search stats", error=str(e))
raise HTTPException(status_code=500, detail="Failed to get stats")
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""Handle HTTP exceptions with RFC7807 format"""
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
type=f"https://httpstatuses.com/{exc.status_code}",
title=exc.detail,
status=exc.status_code,
detail=exc.detail,
instance=str(request.url),
trace_id=getattr(request.state, "trace_id", None),
).dict(),
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8007, reload=True, log_config=None)