# FILE: apps/svc-rag-retriever/main.py # mypy: disable-error-code=union-attr # Hybrid search with KG fusion, reranking, and calibrated confidence import os # Import shared libraries import sys from datetime import datetime from typing import Any import structlog from fastapi import Depends, HTTPException, Query, Request from fastapi.responses import JSONResponse from qdrant_client.models import SparseVector sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from libs.app_factory import create_app from libs.calibration import ConfidenceCalibrator from libs.config import ( BaseAppSettings, create_event_bus, create_neo4j_client, create_qdrant_client, ) from libs.events import EventBus from libs.neo import Neo4jClient from libs.observability import get_metrics, get_tracer, setup_observability from libs.rag import RAGRetriever from libs.schemas import ErrorResponse, RAGSearchRequest, RAGSearchResponse from libs.security import get_current_user, get_tenant_id logger = structlog.get_logger() class RAGRetrieverSettings(BaseAppSettings): """Settings for RAG retriever service""" service_name: str = "svc-rag-retriever" # Embedding configuration embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_dimension: int = 384 # Search configuration default_k: int = 10 max_k: int = 100 alpha: float = 0.5 # Dense/sparse balance beta: float = 0.3 # Vector/KG balance gamma: float = 0.2 # Reranking weight # Collections to search search_collections: list[str] = ["documents", "tax_rules", "guidance"] # Reranking reranker_model: str | None = None rerank_top_k: int = 50 # Create app and settings app, settings = create_app( service_name="svc-rag-retriever", title="Tax Agent RAG Retriever Service", description="Hybrid search with KG fusion and reranking", settings_class=RAGRetrieverSettings, ) # Global clients qdrant_client = None neo4j_client: Neo4jClient | None = None rag_retriever: RAGRetriever | None = None event_bus: EventBus | None = None embedding_model = None confidence_calibrator: ConfidenceCalibrator | None = None tracer = get_tracer("svc-rag-retriever") metrics = get_metrics() @app.on_event("startup") async def startup_event() -> None: """Initialize service dependencies""" global qdrant_client, neo4j_client, rag_retriever, event_bus, embedding_model, confidence_calibrator logger.info("Starting RAG retriever service") # Setup observability setup_observability(settings) # Initialize Qdrant client qdrant_client = create_qdrant_client(settings) # Initialize Neo4j client neo4j_driver = create_neo4j_client(settings) neo4j_client = Neo4jClient(neo4j_driver) # Initialize RAG retriever rag_retriever = RAGRetriever( qdrant_client=qdrant_client, neo4j_client=neo4j_client, reranker_model=settings.reranker_model, ) # Initialize embedding model try: from sentence_transformers import SentenceTransformer embedding_model = SentenceTransformer(settings.embedding_model) logger.info("Embedding model loaded", model=settings.embedding_model) except ImportError: logger.warning("sentence-transformers not available, using mock embeddings") embedding_model = None # Initialize confidence calibrator confidence_calibrator = ConfidenceCalibrator(method="isotonic") # Initialize event bus event_bus = create_event_bus(settings) await event_bus.start() # fmt: skip # pyright: ignore[reportOptionalMemberAccess] logger.info("RAG retriever service started successfully") @app.on_event("shutdown") async def shutdown_event() -> None: """Cleanup service dependencies""" global neo4j_client, event_bus logger.info("Shutting down RAG retriever service") if neo4j_client: await neo4j_client.close() if event_bus: await event_bus.stop() logger.info("RAG retriever service shutdown complete") @app.get("/health") async def health_check() -> dict[str, Any]: """Health check endpoint""" return { "status": "healthy", "service": settings.service_name, "version": settings.service_version, "timestamp": datetime.utcnow().isoformat(), "search_collections": settings.search_collections, } @app.post("/search", response_model=RAGSearchResponse) async def search( request_data: RAGSearchRequest, current_user: dict[str, Any] = Depends(get_current_user), tenant_id: str = Depends(get_tenant_id), ) -> RAGSearchResponse: """Perform hybrid RAG search""" with tracer.start_as_current_span("rag_search") as span: span.set_attribute("query", request_data.query[:100]) span.set_attribute("tenant_id", tenant_id) span.set_attribute("k", request_data.k) try: # Generate embeddings for query dense_vector = await _generate_embedding(request_data.query) sparse_vector = await _generate_sparse_vector(request_data.query) # Perform search search_results = await rag_retriever.search( # fmt: skip # pyright: ignore[reportOptionalMemberAccess] query=request_data.query, collections=settings.search_collections, dense_vector=dense_vector, sparse_vector=sparse_vector, k=request_data.k, alpha=settings.alpha, beta=settings.beta, gamma=settings.gamma, tax_year=request_data.tax_year, jurisdiction=request_data.jurisdiction, ) # Update metrics metrics.counter("searches_total").labels(tenant_id=tenant_id).inc() metrics.histogram("search_results_count").labels( tenant_id=tenant_id ).observe(len(search_results["chunks"])) metrics.histogram("search_confidence").labels(tenant_id=tenant_id).observe( search_results["calibrated_confidence"] ) logger.info( "RAG search completed", query=request_data.query[:50], results=len(search_results["chunks"]), confidence=search_results["calibrated_confidence"], ) return RAGSearchResponse( chunks=search_results["chunks"], citations=search_results["citations"], kg_hints=search_results["kg_hints"], calibrated_confidence=search_results["calibrated_confidence"], ) except Exception as e: logger.error( "RAG search failed", query=request_data.query[:50], error=str(e) ) # Update error metrics metrics.counter("search_errors_total").labels( tenant_id=tenant_id, error_type=type(e).__name__ ).inc() raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") @app.get("/similar/{doc_id}") async def find_similar_documents( doc_id: str, k: int = Query(default=10, le=settings.max_k), current_user: dict[str, Any] = Depends(get_current_user), tenant_id: str = Depends(get_tenant_id), ) -> dict[str, Any]: """Find documents similar to given document""" with tracer.start_as_current_span("find_similar") as span: span.set_attribute("doc_id", doc_id) span.set_attribute("tenant_id", tenant_id) span.set_attribute("k", k) try: # Get document content from vector database # This would search for the document by doc_id in metadata from qdrant_client.models import FieldCondition, Filter, MatchValue filter_conditions = Filter( must=[ FieldCondition(key="doc_id", match=MatchValue(value=doc_id)), FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id)), ] ) # Search for the document doc_results = await rag_retriever.collection_manager.search_dense( # fmt: skip # pyright: ignore[reportOptionalMemberAccess] collection_name="documents", query_vector=[0.0] * settings.embedding_dimension, # Dummy vector limit=1, filter_conditions=filter_conditions, ) if not doc_results: raise HTTPException(status_code=404, detail="Document not found") # Get the document's vector and use it for similarity search doc_vector = doc_results[0]["payload"].get("vector") if not doc_vector: raise HTTPException(status_code=400, detail="Document has no vector") # Find similar documents similar_results = await rag_retriever.collection_manager.search_dense( # fmt: skip # pyright: ignore[reportOptionalMemberAccess] collection_name="documents", query_vector=doc_vector, limit=k + 1, # +1 to exclude the original document filter_conditions=Filter( must=[ FieldCondition( key="tenant_id", match=MatchValue(value=tenant_id) ) ], must_not=[ FieldCondition(key="doc_id", match=MatchValue(value=doc_id)) ], ), ) return { "doc_id": doc_id, "similar_documents": similar_results[:k], "count": len(similar_results[:k]), } except HTTPException: raise except Exception as e: logger.error("Similar document search failed", doc_id=doc_id, error=str(e)) raise HTTPException( status_code=500, detail=f"Similar search failed: {str(e)}" ) @app.post("/explain") async def explain_search( query: str, search_results: list[dict[str, Any]], current_user: dict[str, Any] = Depends(get_current_user), tenant_id: str = Depends(get_tenant_id), ) -> dict[str, Any]: """Explain search results and ranking""" with tracer.start_as_current_span("explain_search") as span: span.set_attribute("query", query[:100]) span.set_attribute("tenant_id", tenant_id) span.set_attribute("results_count", len(search_results)) try: explanations = [] for i, result in enumerate(search_results): explanation = { "rank": i + 1, "chunk_id": result.get("id"), "score": result.get("score", 0.0), "dense_score": result.get("dense_score", 0.0), "sparse_score": result.get("sparse_score", 0.0), "collection": result.get("collection"), "explanation": _generate_explanation(query, result), } explanations.append(explanation) return { "query": query, "explanations": explanations, "ranking_factors": { "alpha": settings.alpha, "beta": settings.beta, "gamma": settings.gamma, }, } except Exception as e: logger.error("Search explanation failed", error=str(e)) raise HTTPException(status_code=500, detail=f"Explanation failed: {str(e)}") async def _generate_embedding(text: str) -> list[float]: """Generate dense embedding for text""" if embedding_model: try: embedding = embedding_model.encode(text) return embedding.tolist() except Exception as e: logger.error("Failed to generate embedding", error=str(e)) # Fallback: random embedding import random return [random.random() for _ in range(settings.embedding_dimension)] async def _generate_sparse_vector(text: str) -> SparseVector: """Generate sparse vector for text (BM25-style)""" try: # This would use a proper sparse encoder like SPLADE # For now, create a simple sparse representation from qdrant_client.models import SparseVector # Simple word-based sparse vector words = text.lower().split() word_counts: dict[str, int] = {} for word in words: word_counts[word] = word_counts.get(word, 0) + 1 # Convert to sparse vector format indices = [] values = [] for _i, (word, count) in enumerate(word_counts.items()): # Use hash of word as index word_hash = hash(word) % 10000 # Limit vocabulary size indices.append(word_hash) values.append(float(count)) return SparseVector(indices=indices, values=values) except Exception as e: logger.error("Failed to generate sparse vector", error=str(e)) # Return empty sparse vector from qdrant_client.models import SparseVector return SparseVector(indices=[], values=[]) def _generate_explanation(query: str, result: dict[str, Any]) -> str: """Generate human-readable explanation for search result""" explanations = [] # Score explanation score = result.get("score", 0.0) dense_score = result.get("dense_score", 0.0) sparse_score = result.get("sparse_score", 0.0) explanations.append(f"Overall score: {score:.3f}") if dense_score > 0: explanations.append(f"Semantic similarity: {dense_score:.3f}") if sparse_score > 0: explanations.append(f"Keyword match: {sparse_score:.3f}") # Collection explanation collection = result.get("collection") if collection: explanations.append(f"Source: {collection}") # Metadata explanation payload = result.get("payload", {}) doc_id = payload.get("doc_id") if doc_id: explanations.append(f"Document: {doc_id}") confidence = payload.get("confidence") if confidence: explanations.append(f"Extraction confidence: {confidence:.3f}") return "; ".join(explanations) @app.get("/stats") async def get_search_stats( current_user: dict[str, Any] = Depends(get_current_user), tenant_id: str = Depends(get_tenant_id), ) -> dict[str, Any]: """Get search statistics""" try: # This would aggregate metrics from Prometheus # For now, return mock stats stats = { "total_searches": 1000, "avg_results_per_search": 8.5, "avg_confidence": 0.75, "collections": { "documents": {"searches": 800, "avg_confidence": 0.78}, "tax_rules": {"searches": 150, "avg_confidence": 0.85}, "guidance": {"searches": 50, "avg_confidence": 0.70}, }, "top_queries": [ {"query": "capital gains tax", "count": 45}, {"query": "business expenses", "count": 38}, {"query": "property income", "count": 32}, ], } return stats except Exception as e: logger.error("Failed to get search stats", error=str(e)) raise HTTPException(status_code=500, detail="Failed to get stats") @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: """Handle HTTP exceptions with RFC7807 format""" return JSONResponse( status_code=exc.status_code, content=ErrorResponse( type=f"https://httpstatuses.com/{exc.status_code}", title=exc.detail, status=exc.status_code, detail=exc.detail, instance=str(request.url), trace_id=getattr(request.state, "trace_id", None), ).dict(), ) if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8007, reload=True, log_config=None)