Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
477 lines
16 KiB
Python
477 lines
16 KiB
Python
# FILE: apps/svc-rag-retriever/main.py
|
|
# mypy: disable-error-code=union-attr
|
|
# Hybrid search with KG fusion, reranking, and calibrated confidence
|
|
|
|
import os
|
|
|
|
# Import shared libraries
|
|
import sys
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
import structlog
|
|
from fastapi import Depends, HTTPException, Query, Request
|
|
from fastapi.responses import JSONResponse
|
|
from qdrant_client.models import SparseVector
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
|
|
from libs.app_factory import create_app
|
|
from libs.calibration import ConfidenceCalibrator
|
|
from libs.config import (
|
|
BaseAppSettings,
|
|
create_event_bus,
|
|
create_neo4j_client,
|
|
create_qdrant_client,
|
|
)
|
|
from libs.events import EventBus
|
|
from libs.neo import Neo4jClient
|
|
from libs.observability import get_metrics, get_tracer, setup_observability
|
|
from libs.rag import RAGRetriever
|
|
from libs.schemas import ErrorResponse, RAGSearchRequest, RAGSearchResponse
|
|
from libs.security import get_current_user, get_tenant_id
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class RAGRetrieverSettings(BaseAppSettings):
|
|
"""Settings for RAG retriever service"""
|
|
|
|
service_name: str = "svc-rag-retriever"
|
|
|
|
# Embedding configuration
|
|
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
|
embedding_dimension: int = 384
|
|
|
|
# Search configuration
|
|
default_k: int = 10
|
|
max_k: int = 100
|
|
alpha: float = 0.5 # Dense/sparse balance
|
|
beta: float = 0.3 # Vector/KG balance
|
|
gamma: float = 0.2 # Reranking weight
|
|
|
|
# Collections to search
|
|
search_collections: list[str] = ["documents", "tax_rules", "guidance"]
|
|
|
|
# Reranking
|
|
reranker_model: str | None = None
|
|
rerank_top_k: int = 50
|
|
|
|
|
|
# Create app and settings
|
|
app, settings = create_app(
|
|
service_name="svc-rag-retriever",
|
|
title="Tax Agent RAG Retriever Service",
|
|
description="Hybrid search with KG fusion and reranking",
|
|
settings_class=RAGRetrieverSettings,
|
|
)
|
|
|
|
# Global clients
|
|
qdrant_client = None
|
|
neo4j_client: Neo4jClient | None = None
|
|
rag_retriever: RAGRetriever | None = None
|
|
event_bus: EventBus | None = None
|
|
embedding_model = None
|
|
confidence_calibrator: ConfidenceCalibrator | None = None
|
|
tracer = get_tracer("svc-rag-retriever")
|
|
metrics = get_metrics()
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event() -> None:
|
|
"""Initialize service dependencies"""
|
|
global qdrant_client, neo4j_client, rag_retriever, event_bus, embedding_model, confidence_calibrator
|
|
|
|
logger.info("Starting RAG retriever service")
|
|
|
|
# Setup observability
|
|
setup_observability(settings)
|
|
|
|
# Initialize Qdrant client
|
|
qdrant_client = create_qdrant_client(settings)
|
|
|
|
# Initialize Neo4j client
|
|
neo4j_driver = create_neo4j_client(settings)
|
|
neo4j_client = Neo4jClient(neo4j_driver)
|
|
|
|
# Initialize RAG retriever
|
|
rag_retriever = RAGRetriever(
|
|
qdrant_client=qdrant_client,
|
|
neo4j_client=neo4j_client,
|
|
reranker_model=settings.reranker_model,
|
|
)
|
|
|
|
# Initialize embedding model
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
embedding_model = SentenceTransformer(settings.embedding_model)
|
|
logger.info("Embedding model loaded", model=settings.embedding_model)
|
|
except ImportError:
|
|
logger.warning("sentence-transformers not available, using mock embeddings")
|
|
embedding_model = None
|
|
|
|
# Initialize confidence calibrator
|
|
confidence_calibrator = ConfidenceCalibrator(method="isotonic")
|
|
|
|
# Initialize event bus
|
|
event_bus = create_event_bus(settings)
|
|
await event_bus.start() # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
|
|
|
|
logger.info("RAG retriever service started successfully")
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event() -> None:
|
|
"""Cleanup service dependencies"""
|
|
global neo4j_client, event_bus
|
|
|
|
logger.info("Shutting down RAG retriever service")
|
|
|
|
if neo4j_client:
|
|
await neo4j_client.close()
|
|
|
|
if event_bus:
|
|
await event_bus.stop()
|
|
|
|
logger.info("RAG retriever service shutdown complete")
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check() -> dict[str, Any]:
|
|
"""Health check endpoint"""
|
|
return {
|
|
"status": "healthy",
|
|
"service": settings.service_name,
|
|
"version": settings.service_version,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"search_collections": settings.search_collections,
|
|
}
|
|
|
|
|
|
@app.post("/search", response_model=RAGSearchResponse)
|
|
async def search(
|
|
request_data: RAGSearchRequest,
|
|
current_user: dict[str, Any] = Depends(get_current_user),
|
|
tenant_id: str = Depends(get_tenant_id),
|
|
) -> RAGSearchResponse:
|
|
"""Perform hybrid RAG search"""
|
|
|
|
with tracer.start_as_current_span("rag_search") as span:
|
|
span.set_attribute("query", request_data.query[:100])
|
|
span.set_attribute("tenant_id", tenant_id)
|
|
span.set_attribute("k", request_data.k)
|
|
|
|
try:
|
|
# Generate embeddings for query
|
|
dense_vector = await _generate_embedding(request_data.query)
|
|
sparse_vector = await _generate_sparse_vector(request_data.query)
|
|
|
|
# Perform search
|
|
search_results = await rag_retriever.search( # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
|
|
query=request_data.query,
|
|
collections=settings.search_collections,
|
|
dense_vector=dense_vector,
|
|
sparse_vector=sparse_vector,
|
|
k=request_data.k,
|
|
alpha=settings.alpha,
|
|
beta=settings.beta,
|
|
gamma=settings.gamma,
|
|
tax_year=request_data.tax_year,
|
|
jurisdiction=request_data.jurisdiction,
|
|
)
|
|
|
|
# Update metrics
|
|
metrics.counter("searches_total").labels(tenant_id=tenant_id).inc()
|
|
|
|
metrics.histogram("search_results_count").labels(
|
|
tenant_id=tenant_id
|
|
).observe(len(search_results["chunks"]))
|
|
|
|
metrics.histogram("search_confidence").labels(tenant_id=tenant_id).observe(
|
|
search_results["calibrated_confidence"]
|
|
)
|
|
|
|
logger.info(
|
|
"RAG search completed",
|
|
query=request_data.query[:50],
|
|
results=len(search_results["chunks"]),
|
|
confidence=search_results["calibrated_confidence"],
|
|
)
|
|
|
|
return RAGSearchResponse(
|
|
chunks=search_results["chunks"],
|
|
citations=search_results["citations"],
|
|
kg_hints=search_results["kg_hints"],
|
|
calibrated_confidence=search_results["calibrated_confidence"],
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"RAG search failed", query=request_data.query[:50], error=str(e)
|
|
)
|
|
|
|
# Update error metrics
|
|
metrics.counter("search_errors_total").labels(
|
|
tenant_id=tenant_id, error_type=type(e).__name__
|
|
).inc()
|
|
|
|
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
|
|
|
|
|
@app.get("/similar/{doc_id}")
|
|
async def find_similar_documents(
|
|
doc_id: str,
|
|
k: int = Query(default=10, le=settings.max_k),
|
|
current_user: dict[str, Any] = Depends(get_current_user),
|
|
tenant_id: str = Depends(get_tenant_id),
|
|
) -> dict[str, Any]:
|
|
"""Find documents similar to given document"""
|
|
|
|
with tracer.start_as_current_span("find_similar") as span:
|
|
span.set_attribute("doc_id", doc_id)
|
|
span.set_attribute("tenant_id", tenant_id)
|
|
span.set_attribute("k", k)
|
|
|
|
try:
|
|
# Get document content from vector database
|
|
# This would search for the document by doc_id in metadata
|
|
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
|
|
filter_conditions = Filter(
|
|
must=[
|
|
FieldCondition(key="doc_id", match=MatchValue(value=doc_id)),
|
|
FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id)),
|
|
]
|
|
)
|
|
|
|
# Search for the document
|
|
doc_results = await rag_retriever.collection_manager.search_dense( # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
|
|
collection_name="documents",
|
|
query_vector=[0.0] * settings.embedding_dimension, # Dummy vector
|
|
limit=1,
|
|
filter_conditions=filter_conditions,
|
|
)
|
|
|
|
if not doc_results:
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
|
|
|
# Get the document's vector and use it for similarity search
|
|
doc_vector = doc_results[0]["payload"].get("vector")
|
|
if not doc_vector:
|
|
raise HTTPException(status_code=400, detail="Document has no vector")
|
|
|
|
# Find similar documents
|
|
similar_results = await rag_retriever.collection_manager.search_dense( # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
|
|
collection_name="documents",
|
|
query_vector=doc_vector,
|
|
limit=k + 1, # +1 to exclude the original document
|
|
filter_conditions=Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="tenant_id", match=MatchValue(value=tenant_id)
|
|
)
|
|
],
|
|
must_not=[
|
|
FieldCondition(key="doc_id", match=MatchValue(value=doc_id))
|
|
],
|
|
),
|
|
)
|
|
|
|
return {
|
|
"doc_id": doc_id,
|
|
"similar_documents": similar_results[:k],
|
|
"count": len(similar_results[:k]),
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Similar document search failed", doc_id=doc_id, error=str(e))
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Similar search failed: {str(e)}"
|
|
)
|
|
|
|
|
|
@app.post("/explain")
|
|
async def explain_search(
|
|
query: str,
|
|
search_results: list[dict[str, Any]],
|
|
current_user: dict[str, Any] = Depends(get_current_user),
|
|
tenant_id: str = Depends(get_tenant_id),
|
|
) -> dict[str, Any]:
|
|
"""Explain search results and ranking"""
|
|
|
|
with tracer.start_as_current_span("explain_search") as span:
|
|
span.set_attribute("query", query[:100])
|
|
span.set_attribute("tenant_id", tenant_id)
|
|
span.set_attribute("results_count", len(search_results))
|
|
|
|
try:
|
|
explanations = []
|
|
|
|
for i, result in enumerate(search_results):
|
|
explanation = {
|
|
"rank": i + 1,
|
|
"chunk_id": result.get("id"),
|
|
"score": result.get("score", 0.0),
|
|
"dense_score": result.get("dense_score", 0.0),
|
|
"sparse_score": result.get("sparse_score", 0.0),
|
|
"collection": result.get("collection"),
|
|
"explanation": _generate_explanation(query, result),
|
|
}
|
|
explanations.append(explanation)
|
|
|
|
return {
|
|
"query": query,
|
|
"explanations": explanations,
|
|
"ranking_factors": {
|
|
"alpha": settings.alpha,
|
|
"beta": settings.beta,
|
|
"gamma": settings.gamma,
|
|
},
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Search explanation failed", error=str(e))
|
|
raise HTTPException(status_code=500, detail=f"Explanation failed: {str(e)}")
|
|
|
|
|
|
async def _generate_embedding(text: str) -> list[float]:
|
|
"""Generate dense embedding for text"""
|
|
if embedding_model:
|
|
try:
|
|
embedding = embedding_model.encode(text)
|
|
return embedding.tolist()
|
|
except Exception as e:
|
|
logger.error("Failed to generate embedding", error=str(e))
|
|
|
|
# Fallback: random embedding
|
|
import random
|
|
|
|
return [random.random() for _ in range(settings.embedding_dimension)]
|
|
|
|
|
|
async def _generate_sparse_vector(text: str) -> SparseVector:
|
|
"""Generate sparse vector for text (BM25-style)"""
|
|
try:
|
|
# This would use a proper sparse encoder like SPLADE
|
|
# For now, create a simple sparse representation
|
|
from qdrant_client.models import SparseVector
|
|
|
|
# Simple word-based sparse vector
|
|
words = text.lower().split()
|
|
word_counts: dict[str, int] = {}
|
|
for word in words:
|
|
word_counts[word] = word_counts.get(word, 0) + 1
|
|
|
|
# Convert to sparse vector format
|
|
indices = []
|
|
values = []
|
|
|
|
for _i, (word, count) in enumerate(word_counts.items()):
|
|
# Use hash of word as index
|
|
word_hash = hash(word) % 10000 # Limit vocabulary size
|
|
indices.append(word_hash)
|
|
values.append(float(count))
|
|
|
|
return SparseVector(indices=indices, values=values)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to generate sparse vector", error=str(e))
|
|
# Return empty sparse vector
|
|
from qdrant_client.models import SparseVector
|
|
|
|
return SparseVector(indices=[], values=[])
|
|
|
|
|
|
def _generate_explanation(query: str, result: dict[str, Any]) -> str:
|
|
"""Generate human-readable explanation for search result"""
|
|
|
|
explanations = []
|
|
|
|
# Score explanation
|
|
score = result.get("score", 0.0)
|
|
dense_score = result.get("dense_score", 0.0)
|
|
sparse_score = result.get("sparse_score", 0.0)
|
|
|
|
explanations.append(f"Overall score: {score:.3f}")
|
|
|
|
if dense_score > 0:
|
|
explanations.append(f"Semantic similarity: {dense_score:.3f}")
|
|
|
|
if sparse_score > 0:
|
|
explanations.append(f"Keyword match: {sparse_score:.3f}")
|
|
|
|
# Collection explanation
|
|
collection = result.get("collection")
|
|
if collection:
|
|
explanations.append(f"Source: {collection}")
|
|
|
|
# Metadata explanation
|
|
payload = result.get("payload", {})
|
|
doc_id = payload.get("doc_id")
|
|
if doc_id:
|
|
explanations.append(f"Document: {doc_id}")
|
|
|
|
confidence = payload.get("confidence")
|
|
if confidence:
|
|
explanations.append(f"Extraction confidence: {confidence:.3f}")
|
|
|
|
return "; ".join(explanations)
|
|
|
|
|
|
@app.get("/stats")
|
|
async def get_search_stats(
|
|
current_user: dict[str, Any] = Depends(get_current_user),
|
|
tenant_id: str = Depends(get_tenant_id),
|
|
) -> dict[str, Any]:
|
|
"""Get search statistics"""
|
|
|
|
try:
|
|
# This would aggregate metrics from Prometheus
|
|
# For now, return mock stats
|
|
stats = {
|
|
"total_searches": 1000,
|
|
"avg_results_per_search": 8.5,
|
|
"avg_confidence": 0.75,
|
|
"collections": {
|
|
"documents": {"searches": 800, "avg_confidence": 0.78},
|
|
"tax_rules": {"searches": 150, "avg_confidence": 0.85},
|
|
"guidance": {"searches": 50, "avg_confidence": 0.70},
|
|
},
|
|
"top_queries": [
|
|
{"query": "capital gains tax", "count": 45},
|
|
{"query": "business expenses", "count": 38},
|
|
{"query": "property income", "count": 32},
|
|
],
|
|
}
|
|
|
|
return stats
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get search stats", error=str(e))
|
|
raise HTTPException(status_code=500, detail="Failed to get stats")
|
|
|
|
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
|
"""Handle HTTP exceptions with RFC7807 format"""
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content=ErrorResponse(
|
|
type=f"https://httpstatuses.com/{exc.status_code}",
|
|
title=exc.detail,
|
|
status=exc.status_code,
|
|
detail=exc.detail,
|
|
instance=str(request.url),
|
|
trace_id=getattr(request.state, "trace_id", None),
|
|
).dict(),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8007, reload=True, log_config=None)
|