Initial commit
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
This commit is contained in:
476
apps/svc_rag_retriever/main.py
Normal file
476
apps/svc_rag_retriever/main.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# FILE: apps/svc-rag-retriever/main.py
|
||||
# mypy: disable-error-code=union-attr
|
||||
# Hybrid search with KG fusion, reranking, and calibrated confidence
|
||||
|
||||
import os
|
||||
|
||||
# Import shared libraries
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from fastapi import Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from qdrant_client.models import SparseVector
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from libs.app_factory import create_app
|
||||
from libs.calibration import ConfidenceCalibrator
|
||||
from libs.config import (
|
||||
BaseAppSettings,
|
||||
create_event_bus,
|
||||
create_neo4j_client,
|
||||
create_qdrant_client,
|
||||
)
|
||||
from libs.events import EventBus
|
||||
from libs.neo import Neo4jClient
|
||||
from libs.observability import get_metrics, get_tracer, setup_observability
|
||||
from libs.rag import RAGRetriever
|
||||
from libs.schemas import ErrorResponse, RAGSearchRequest, RAGSearchResponse
|
||||
from libs.security import get_current_user, get_tenant_id
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RAGRetrieverSettings(BaseAppSettings):
|
||||
"""Settings for RAG retriever service"""
|
||||
|
||||
service_name: str = "svc-rag-retriever"
|
||||
|
||||
# Embedding configuration
|
||||
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
embedding_dimension: int = 384
|
||||
|
||||
# Search configuration
|
||||
default_k: int = 10
|
||||
max_k: int = 100
|
||||
alpha: float = 0.5 # Dense/sparse balance
|
||||
beta: float = 0.3 # Vector/KG balance
|
||||
gamma: float = 0.2 # Reranking weight
|
||||
|
||||
# Collections to search
|
||||
search_collections: list[str] = ["documents", "tax_rules", "guidance"]
|
||||
|
||||
# Reranking
|
||||
reranker_model: str | None = None
|
||||
rerank_top_k: int = 50
|
||||
|
||||
|
||||
# Create app and settings
|
||||
app, settings = create_app(
|
||||
service_name="svc-rag-retriever",
|
||||
title="Tax Agent RAG Retriever Service",
|
||||
description="Hybrid search with KG fusion and reranking",
|
||||
settings_class=RAGRetrieverSettings,
|
||||
)
|
||||
|
||||
# Global clients
|
||||
qdrant_client = None
|
||||
neo4j_client: Neo4jClient | None = None
|
||||
rag_retriever: RAGRetriever | None = None
|
||||
event_bus: EventBus | None = None
|
||||
embedding_model = None
|
||||
confidence_calibrator: ConfidenceCalibrator | None = None
|
||||
tracer = get_tracer("svc-rag-retriever")
|
||||
metrics = get_metrics()
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
"""Initialize service dependencies"""
|
||||
global qdrant_client, neo4j_client, rag_retriever, event_bus, embedding_model, confidence_calibrator
|
||||
|
||||
logger.info("Starting RAG retriever service")
|
||||
|
||||
# Setup observability
|
||||
setup_observability(settings)
|
||||
|
||||
# Initialize Qdrant client
|
||||
qdrant_client = create_qdrant_client(settings)
|
||||
|
||||
# Initialize Neo4j client
|
||||
neo4j_driver = create_neo4j_client(settings)
|
||||
neo4j_client = Neo4jClient(neo4j_driver)
|
||||
|
||||
# Initialize RAG retriever
|
||||
rag_retriever = RAGRetriever(
|
||||
qdrant_client=qdrant_client,
|
||||
neo4j_client=neo4j_client,
|
||||
reranker_model=settings.reranker_model,
|
||||
)
|
||||
|
||||
# Initialize embedding model
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
embedding_model = SentenceTransformer(settings.embedding_model)
|
||||
logger.info("Embedding model loaded", model=settings.embedding_model)
|
||||
except ImportError:
|
||||
logger.warning("sentence-transformers not available, using mock embeddings")
|
||||
embedding_model = None
|
||||
|
||||
# Initialize confidence calibrator
|
||||
confidence_calibrator = ConfidenceCalibrator(method="isotonic")
|
||||
|
||||
# Initialize event bus
|
||||
event_bus = create_event_bus(settings)
|
||||
await event_bus.start() # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
|
||||
|
||||
logger.info("RAG retriever service started successfully")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
"""Cleanup service dependencies"""
|
||||
global neo4j_client, event_bus
|
||||
|
||||
logger.info("Shutting down RAG retriever service")
|
||||
|
||||
if neo4j_client:
|
||||
await neo4j_client.close()
|
||||
|
||||
if event_bus:
|
||||
await event_bus.stop()
|
||||
|
||||
logger.info("RAG retriever service shutdown complete")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, Any]:
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": settings.service_name,
|
||||
"version": settings.service_version,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"search_collections": settings.search_collections,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/search", response_model=RAGSearchResponse)
|
||||
async def search(
|
||||
request_data: RAGSearchRequest,
|
||||
current_user: dict[str, Any] = Depends(get_current_user),
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> RAGSearchResponse:
|
||||
"""Perform hybrid RAG search"""
|
||||
|
||||
with tracer.start_as_current_span("rag_search") as span:
|
||||
span.set_attribute("query", request_data.query[:100])
|
||||
span.set_attribute("tenant_id", tenant_id)
|
||||
span.set_attribute("k", request_data.k)
|
||||
|
||||
try:
|
||||
# Generate embeddings for query
|
||||
dense_vector = await _generate_embedding(request_data.query)
|
||||
sparse_vector = await _generate_sparse_vector(request_data.query)
|
||||
|
||||
# Perform search
|
||||
search_results = await rag_retriever.search( # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
|
||||
query=request_data.query,
|
||||
collections=settings.search_collections,
|
||||
dense_vector=dense_vector,
|
||||
sparse_vector=sparse_vector,
|
||||
k=request_data.k,
|
||||
alpha=settings.alpha,
|
||||
beta=settings.beta,
|
||||
gamma=settings.gamma,
|
||||
tax_year=request_data.tax_year,
|
||||
jurisdiction=request_data.jurisdiction,
|
||||
)
|
||||
|
||||
# Update metrics
|
||||
metrics.counter("searches_total").labels(tenant_id=tenant_id).inc()
|
||||
|
||||
metrics.histogram("search_results_count").labels(
|
||||
tenant_id=tenant_id
|
||||
).observe(len(search_results["chunks"]))
|
||||
|
||||
metrics.histogram("search_confidence").labels(tenant_id=tenant_id).observe(
|
||||
search_results["calibrated_confidence"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"RAG search completed",
|
||||
query=request_data.query[:50],
|
||||
results=len(search_results["chunks"]),
|
||||
confidence=search_results["calibrated_confidence"],
|
||||
)
|
||||
|
||||
return RAGSearchResponse(
|
||||
chunks=search_results["chunks"],
|
||||
citations=search_results["citations"],
|
||||
kg_hints=search_results["kg_hints"],
|
||||
calibrated_confidence=search_results["calibrated_confidence"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"RAG search failed", query=request_data.query[:50], error=str(e)
|
||||
)
|
||||
|
||||
# Update error metrics
|
||||
metrics.counter("search_errors_total").labels(
|
||||
tenant_id=tenant_id, error_type=type(e).__name__
|
||||
).inc()
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/similar/{doc_id}")
|
||||
async def find_similar_documents(
|
||||
doc_id: str,
|
||||
k: int = Query(default=10, le=settings.max_k),
|
||||
current_user: dict[str, Any] = Depends(get_current_user),
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""Find documents similar to given document"""
|
||||
|
||||
with tracer.start_as_current_span("find_similar") as span:
|
||||
span.set_attribute("doc_id", doc_id)
|
||||
span.set_attribute("tenant_id", tenant_id)
|
||||
span.set_attribute("k", k)
|
||||
|
||||
try:
|
||||
# Get document content from vector database
|
||||
# This would search for the document by doc_id in metadata
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||||
|
||||
filter_conditions = Filter(
|
||||
must=[
|
||||
FieldCondition(key="doc_id", match=MatchValue(value=doc_id)),
|
||||
FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id)),
|
||||
]
|
||||
)
|
||||
|
||||
# Search for the document
|
||||
doc_results = await rag_retriever.collection_manager.search_dense( # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
|
||||
collection_name="documents",
|
||||
query_vector=[0.0] * settings.embedding_dimension, # Dummy vector
|
||||
limit=1,
|
||||
filter_conditions=filter_conditions,
|
||||
)
|
||||
|
||||
if not doc_results:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Get the document's vector and use it for similarity search
|
||||
doc_vector = doc_results[0]["payload"].get("vector")
|
||||
if not doc_vector:
|
||||
raise HTTPException(status_code=400, detail="Document has no vector")
|
||||
|
||||
# Find similar documents
|
||||
similar_results = await rag_retriever.collection_manager.search_dense( # fmt: skip # pyright: ignore[reportOptionalMemberAccess]
|
||||
collection_name="documents",
|
||||
query_vector=doc_vector,
|
||||
limit=k + 1, # +1 to exclude the original document
|
||||
filter_conditions=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="tenant_id", match=MatchValue(value=tenant_id)
|
||||
)
|
||||
],
|
||||
must_not=[
|
||||
FieldCondition(key="doc_id", match=MatchValue(value=doc_id))
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"similar_documents": similar_results[:k],
|
||||
"count": len(similar_results[:k]),
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Similar document search failed", doc_id=doc_id, error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Similar search failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/explain")
|
||||
async def explain_search(
|
||||
query: str,
|
||||
search_results: list[dict[str, Any]],
|
||||
current_user: dict[str, Any] = Depends(get_current_user),
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""Explain search results and ranking"""
|
||||
|
||||
with tracer.start_as_current_span("explain_search") as span:
|
||||
span.set_attribute("query", query[:100])
|
||||
span.set_attribute("tenant_id", tenant_id)
|
||||
span.set_attribute("results_count", len(search_results))
|
||||
|
||||
try:
|
||||
explanations = []
|
||||
|
||||
for i, result in enumerate(search_results):
|
||||
explanation = {
|
||||
"rank": i + 1,
|
||||
"chunk_id": result.get("id"),
|
||||
"score": result.get("score", 0.0),
|
||||
"dense_score": result.get("dense_score", 0.0),
|
||||
"sparse_score": result.get("sparse_score", 0.0),
|
||||
"collection": result.get("collection"),
|
||||
"explanation": _generate_explanation(query, result),
|
||||
}
|
||||
explanations.append(explanation)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"explanations": explanations,
|
||||
"ranking_factors": {
|
||||
"alpha": settings.alpha,
|
||||
"beta": settings.beta,
|
||||
"gamma": settings.gamma,
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Search explanation failed", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Explanation failed: {str(e)}")
|
||||
|
||||
|
||||
async def _generate_embedding(text: str) -> list[float]:
|
||||
"""Generate dense embedding for text"""
|
||||
if embedding_model:
|
||||
try:
|
||||
embedding = embedding_model.encode(text)
|
||||
return embedding.tolist()
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate embedding", error=str(e))
|
||||
|
||||
# Fallback: random embedding
|
||||
import random
|
||||
|
||||
return [random.random() for _ in range(settings.embedding_dimension)]
|
||||
|
||||
|
||||
async def _generate_sparse_vector(text: str) -> SparseVector:
|
||||
"""Generate sparse vector for text (BM25-style)"""
|
||||
try:
|
||||
# This would use a proper sparse encoder like SPLADE
|
||||
# For now, create a simple sparse representation
|
||||
from qdrant_client.models import SparseVector
|
||||
|
||||
# Simple word-based sparse vector
|
||||
words = text.lower().split()
|
||||
word_counts: dict[str, int] = {}
|
||||
for word in words:
|
||||
word_counts[word] = word_counts.get(word, 0) + 1
|
||||
|
||||
# Convert to sparse vector format
|
||||
indices = []
|
||||
values = []
|
||||
|
||||
for _i, (word, count) in enumerate(word_counts.items()):
|
||||
# Use hash of word as index
|
||||
word_hash = hash(word) % 10000 # Limit vocabulary size
|
||||
indices.append(word_hash)
|
||||
values.append(float(count))
|
||||
|
||||
return SparseVector(indices=indices, values=values)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate sparse vector", error=str(e))
|
||||
# Return empty sparse vector
|
||||
from qdrant_client.models import SparseVector
|
||||
|
||||
return SparseVector(indices=[], values=[])
|
||||
|
||||
|
||||
def _generate_explanation(query: str, result: dict[str, Any]) -> str:
|
||||
"""Generate human-readable explanation for search result"""
|
||||
|
||||
explanations = []
|
||||
|
||||
# Score explanation
|
||||
score = result.get("score", 0.0)
|
||||
dense_score = result.get("dense_score", 0.0)
|
||||
sparse_score = result.get("sparse_score", 0.0)
|
||||
|
||||
explanations.append(f"Overall score: {score:.3f}")
|
||||
|
||||
if dense_score > 0:
|
||||
explanations.append(f"Semantic similarity: {dense_score:.3f}")
|
||||
|
||||
if sparse_score > 0:
|
||||
explanations.append(f"Keyword match: {sparse_score:.3f}")
|
||||
|
||||
# Collection explanation
|
||||
collection = result.get("collection")
|
||||
if collection:
|
||||
explanations.append(f"Source: {collection}")
|
||||
|
||||
# Metadata explanation
|
||||
payload = result.get("payload", {})
|
||||
doc_id = payload.get("doc_id")
|
||||
if doc_id:
|
||||
explanations.append(f"Document: {doc_id}")
|
||||
|
||||
confidence = payload.get("confidence")
|
||||
if confidence:
|
||||
explanations.append(f"Extraction confidence: {confidence:.3f}")
|
||||
|
||||
return "; ".join(explanations)
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def get_search_stats(
|
||||
current_user: dict[str, Any] = Depends(get_current_user),
|
||||
tenant_id: str = Depends(get_tenant_id),
|
||||
) -> dict[str, Any]:
|
||||
"""Get search statistics"""
|
||||
|
||||
try:
|
||||
# This would aggregate metrics from Prometheus
|
||||
# For now, return mock stats
|
||||
stats = {
|
||||
"total_searches": 1000,
|
||||
"avg_results_per_search": 8.5,
|
||||
"avg_confidence": 0.75,
|
||||
"collections": {
|
||||
"documents": {"searches": 800, "avg_confidence": 0.78},
|
||||
"tax_rules": {"searches": 150, "avg_confidence": 0.85},
|
||||
"guidance": {"searches": 50, "avg_confidence": 0.70},
|
||||
},
|
||||
"top_queries": [
|
||||
{"query": "capital gains tax", "count": 45},
|
||||
{"query": "business expenses", "count": 38},
|
||||
{"query": "property income", "count": 32},
|
||||
],
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get search stats", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Failed to get stats")
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
"""Handle HTTP exceptions with RFC7807 format"""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ErrorResponse(
|
||||
type=f"https://httpstatuses.com/{exc.status_code}",
|
||||
title=exc.detail,
|
||||
status=exc.status_code,
|
||||
detail=exc.detail,
|
||||
instance=str(request.url),
|
||||
trace_id=getattr(request.state, "trace_id", None),
|
||||
).dict(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8007, reload=True, log_config=None)
|
||||
Reference in New Issue
Block a user