Initial commit
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
This commit is contained in:
0
libs/__init__.py
Normal file
0
libs/__init__.py
Normal file
123
libs/app_factory.py
Normal file
123
libs/app_factory.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Factory for creating FastAPI applications with consistent setup."""
|
||||
|
||||
# FILE: libs/app_factory.py
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from libs.config import BaseAppSettings, get_default_settings
|
||||
from libs.observability import setup_observability
|
||||
from libs.schemas import ErrorResponse
|
||||
from libs.security import get_current_user, get_tenant_id
|
||||
from libs.security.middleware import TrustedProxyMiddleware
|
||||
|
||||
|
||||
def create_trusted_proxy_middleware(
|
||||
internal_cidrs: list[str], disable_auth: bool = False
|
||||
) -> TrustedProxyMiddleware:
|
||||
"""Create a TrustedProxyMiddleware instance with the given internal CIDRs."""
|
||||
|
||||
# This is a factory function that will be called by FastAPI's add_middleware
|
||||
# We return a partial function that creates the middleware
|
||||
def middleware_factory(app: Any) -> TrustedProxyMiddleware:
|
||||
return TrustedProxyMiddleware(app, internal_cidrs, disable_auth)
|
||||
|
||||
return middleware_factory # type: ignore
|
||||
|
||||
|
||||
def create_app( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
service_name: str,
|
||||
title: str,
|
||||
description: str,
|
||||
version: str = "1.0.0",
|
||||
settings_class: type[BaseAppSettings] = BaseAppSettings,
|
||||
custom_settings: dict[str, Any] | None = None,
|
||||
) -> tuple[FastAPI, BaseAppSettings]:
|
||||
"""Create a FastAPI application with standard configuration"""
|
||||
|
||||
# Create settings
|
||||
settings_kwargs = {"service_name": service_name}
|
||||
if custom_settings:
|
||||
settings_kwargs.update(custom_settings)
|
||||
|
||||
settings = get_default_settings(**settings_kwargs)
|
||||
if settings_class != BaseAppSettings:
|
||||
# Use custom settings class
|
||||
settings = settings_class(**settings_kwargs) # type: ignore
|
||||
|
||||
# Create lifespan context manager
|
||||
@asynccontextmanager
|
||||
async def lifespan(
|
||||
app: FastAPI,
|
||||
) -> AsyncIterator[None]: # pylint: disable=unused-argument
|
||||
# Startup
|
||||
setup_observability(settings)
|
||||
yield
|
||||
# Shutdown
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title=title, description=description, version=version, lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add middleware
|
||||
app.add_middleware(
|
||||
TrustedProxyMiddleware,
|
||||
internal_cidrs=settings.internal_cidrs,
|
||||
disable_auth=getattr(settings, "disable_auth", False),
|
||||
)
|
||||
|
||||
# Add exception handlers
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(
|
||||
request: Request, exc: HTTPException
|
||||
) -> JSONResponse:
|
||||
"""Handle HTTP exceptions with RFC7807 format"""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ErrorResponse(
|
||||
type=f"https://httpstatuses.com/{exc.status_code}",
|
||||
title=exc.detail,
|
||||
status=exc.status_code,
|
||||
detail=exc.detail,
|
||||
instance=str(request.url),
|
||||
trace_id=getattr(request.state, "trace_id", None),
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
# Add health endpoints
|
||||
@app.get("/healthz")
|
||||
async def health_check() -> dict[str, str]:
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": settings.service_name,
|
||||
"version": version,
|
||||
}
|
||||
|
||||
@app.get("/readyz")
|
||||
async def readiness_check() -> dict[str, str]:
|
||||
"""Readiness check endpoint"""
|
||||
return {"status": "ready", "service": settings.service_name, "version": version}
|
||||
|
||||
@app.get("/livez")
|
||||
async def liveness_check() -> dict[str, str]:
|
||||
"""Liveness check endpoint"""
|
||||
return {"status": "alive", "service": settings.service_name, "version": version}
|
||||
|
||||
return app, settings
|
||||
|
||||
|
||||
# Dependency factories
|
||||
def get_user_dependency() -> Any:
|
||||
"""Get user dependency function"""
|
||||
return get_current_user()
|
||||
|
||||
|
||||
def get_tenant_dependency() -> Any:
|
||||
"""Get tenant dependency function"""
|
||||
return get_tenant_id()
|
||||
12
libs/calibration/__init__.py
Normal file
12
libs/calibration/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Confidence calibration for ML models."""
|
||||
|
||||
from .calibrator import ConfidenceCalibrator
|
||||
from .metrics import DEFAULT_CALIBRATORS, ConfidenceMetrics
|
||||
from .multi_model import MultiModelCalibrator
|
||||
|
||||
__all__ = [
|
||||
"ConfidenceCalibrator",
|
||||
"MultiModelCalibrator",
|
||||
"ConfidenceMetrics",
|
||||
"DEFAULT_CALIBRATORS",
|
||||
]
|
||||
190
libs/calibration/calibrator.py
Normal file
190
libs/calibration/calibrator.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Confidence calibrator using various methods."""
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
from sklearn.isotonic import IsotonicRegression
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ConfidenceCalibrator:
|
||||
"""Calibrate confidence scores using various methods"""
|
||||
|
||||
def __init__(self, method: str = "temperature"):
|
||||
"""
|
||||
Initialize calibrator
|
||||
|
||||
Args:
|
||||
method: Calibration method ('temperature', 'platt', 'isotonic')
|
||||
"""
|
||||
self.method = method
|
||||
self.calibrator = None
|
||||
self.temperature = 1.0
|
||||
self.is_fitted = False
|
||||
|
||||
def fit(self, scores: list[float], labels: list[bool]) -> None:
|
||||
"""
|
||||
Fit calibration model
|
||||
|
||||
Args:
|
||||
scores: Raw confidence scores (0-1)
|
||||
labels: True labels (True/False for correct/incorrect)
|
||||
"""
|
||||
# Validate inputs
|
||||
if len(scores) == 0 or len(labels) == 0:
|
||||
raise ValueError("Scores and labels cannot be empty")
|
||||
|
||||
if len(scores) != len(labels):
|
||||
raise ValueError("Scores and labels must have the same length")
|
||||
|
||||
scores_array = np.array(scores).reshape(-1, 1)
|
||||
labels_array = np.array(labels, dtype=int)
|
||||
|
||||
if self.method == "temperature":
|
||||
self._fit_temperature_scaling(scores_array, labels_array)
|
||||
elif self.method == "platt":
|
||||
self._fit_platt_scaling(scores_array, labels_array)
|
||||
elif self.method == "isotonic":
|
||||
self._fit_isotonic_regression(scores_array, labels_array)
|
||||
else:
|
||||
raise ValueError(f"Unknown calibration method: {self.method}")
|
||||
|
||||
self.is_fitted = True
|
||||
logger.info("Calibrator fitted", method=self.method)
|
||||
|
||||
def _fit_temperature_scaling(self, scores: np.ndarray, labels: np.ndarray) -> None:
|
||||
"""Fit temperature scaling parameter"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from scipy.optimize import minimize_scalar
|
||||
|
||||
def negative_log_likelihood(temperature: float) -> float:
|
||||
# Convert scores to logits
|
||||
epsilon = 1e-7
|
||||
scores_clipped = np.clip(scores.flatten(), epsilon, 1 - epsilon)
|
||||
logits = np.log(scores_clipped / (1 - scores_clipped))
|
||||
|
||||
# Apply temperature scaling
|
||||
calibrated_logits = logits / temperature
|
||||
calibrated_probs = 1 / (1 + np.exp(-calibrated_logits))
|
||||
|
||||
# Calculate negative log likelihood
|
||||
nll = -np.mean(
|
||||
labels * np.log(calibrated_probs + epsilon)
|
||||
+ (1 - labels) * np.log(1 - calibrated_probs + epsilon)
|
||||
)
|
||||
return float(nll)
|
||||
|
||||
# Find optimal temperature
|
||||
result = minimize_scalar( # type: ignore
|
||||
negative_log_likelihood,
|
||||
bounds=(0.1, 10.0),
|
||||
method="bounded", # fmt: skip # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
self.temperature = result.x
|
||||
|
||||
logger.debug("Temperature scaling fitted", temperature=self.temperature)
|
||||
|
||||
def _fit_platt_scaling(self, scores: np.ndarray, labels: np.ndarray) -> None:
|
||||
"""Fit Platt scaling (logistic regression)"""
|
||||
# Convert scores to logits
|
||||
epsilon = 1e-7
|
||||
scores_clipped = np.clip(scores.flatten(), epsilon, 1 - epsilon)
|
||||
logits = np.log(scores_clipped / (1 - scores_clipped)).reshape(-1, 1)
|
||||
|
||||
# Fit logistic regression
|
||||
self.calibrator = LogisticRegression()
|
||||
self.calibrator.fit(logits, labels) # type: ignore
|
||||
|
||||
logger.debug("Platt scaling fitted")
|
||||
|
||||
def _fit_isotonic_regression(self, scores: np.ndarray, labels: np.ndarray) -> None:
|
||||
"""Fit isotonic regression"""
|
||||
self.calibrator = IsotonicRegression(out_of_bounds="clip")
|
||||
self.calibrator.fit(scores.flatten(), labels) # type: ignore
|
||||
|
||||
logger.debug("Isotonic regression fitted")
|
||||
|
||||
def calibrate(self, scores: list[float]) -> list[float]:
|
||||
"""
|
||||
Calibrate confidence scores
|
||||
|
||||
Args:
|
||||
scores: Raw confidence scores
|
||||
|
||||
Returns:
|
||||
Calibrated confidence scores
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
logger.warning("Calibrator not fitted, returning original scores")
|
||||
return scores
|
||||
|
||||
scores_array = np.array(scores)
|
||||
|
||||
if self.method == "temperature":
|
||||
return self._calibrate_temperature(scores_array)
|
||||
if self.method == "platt":
|
||||
return self._calibrate_platt(scores_array)
|
||||
if self.method == "isotonic":
|
||||
return self._calibrate_isotonic(scores_array)
|
||||
return scores
|
||||
|
||||
def _calibrate_temperature(self, scores: np.ndarray) -> list[float]:
|
||||
"""Apply temperature scaling"""
|
||||
epsilon = 1e-7
|
||||
scores_clipped = np.clip(scores, epsilon, 1 - epsilon)
|
||||
|
||||
# Convert to logits
|
||||
logits = np.log(scores_clipped / (1 - scores_clipped))
|
||||
|
||||
# Apply temperature scaling
|
||||
calibrated_logits = logits / self.temperature
|
||||
calibrated_probs = 1 / (1 + np.exp(-calibrated_logits))
|
||||
|
||||
return calibrated_probs.tolist() # type: ignore
|
||||
|
||||
def _calibrate_platt(self, scores: np.ndarray) -> list[float]:
|
||||
"""Apply Platt scaling"""
|
||||
epsilon = 1e-7
|
||||
scores_clipped = np.clip(scores, epsilon, 1 - epsilon)
|
||||
|
||||
# Convert to logits
|
||||
logits = np.log(scores_clipped / (1 - scores_clipped)).reshape(-1, 1)
|
||||
|
||||
# Apply Platt scaling
|
||||
calibrated_probs = self.calibrator.predict_proba(logits)[:, 1] # type: ignore
|
||||
|
||||
return calibrated_probs.tolist() # type: ignore
|
||||
|
||||
def _calibrate_isotonic(self, scores: np.ndarray) -> list[float]:
|
||||
"""Apply isotonic regression"""
|
||||
calibrated_probs = self.calibrator.predict(scores) # type: ignore
|
||||
return calibrated_probs.tolist() # type: ignore
|
||||
|
||||
def save_model(self, filepath: str) -> None:
|
||||
"""Save calibration model"""
|
||||
model_data = {
|
||||
"method": self.method,
|
||||
"temperature": self.temperature,
|
||||
"calibrator": self.calibrator,
|
||||
"is_fitted": self.is_fitted,
|
||||
}
|
||||
|
||||
with open(filepath, "wb") as f:
|
||||
pickle.dump(model_data, f)
|
||||
|
||||
logger.info("Calibration model saved", filepath=filepath)
|
||||
|
||||
def load_model(self, filepath: str) -> None:
|
||||
"""Load calibration model"""
|
||||
with open(filepath, "rb") as f:
|
||||
model_data = pickle.load(f)
|
||||
|
||||
self.method = model_data["method"]
|
||||
self.temperature = model_data["temperature"]
|
||||
self.calibrator = model_data["calibrator"]
|
||||
self.is_fitted = model_data["is_fitted"]
|
||||
|
||||
logger.info("Calibration model loaded", filepath=filepath, method=self.method)
|
||||
144
libs/calibration/metrics.py
Normal file
144
libs/calibration/metrics.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Calibration metrics for evaluating confidence calibration."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ConfidenceMetrics:
|
||||
"""Calculate calibration metrics"""
|
||||
|
||||
@staticmethod
|
||||
def expected_calibration_error(
|
||||
scores: list[float], labels: list[bool], n_bins: int = 10
|
||||
) -> float:
|
||||
"""
|
||||
Calculate Expected Calibration Error (ECE)
|
||||
|
||||
Args:
|
||||
scores: Predicted confidence scores
|
||||
labels: True labels
|
||||
n_bins: Number of bins for calibration
|
||||
|
||||
Returns:
|
||||
ECE value
|
||||
"""
|
||||
scores_array = np.array(scores)
|
||||
labels_array = np.array(labels, dtype=int)
|
||||
|
||||
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
||||
bin_lowers = bin_boundaries[:-1]
|
||||
bin_uppers = bin_boundaries[1:]
|
||||
|
||||
ece = 0
|
||||
|
||||
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers, strict=False):
|
||||
# Find samples in this bin
|
||||
in_bin = (scores_array > bin_lower) & (scores_array <= bin_upper)
|
||||
prop_in_bin = in_bin.mean()
|
||||
|
||||
if prop_in_bin > 0:
|
||||
# Calculate accuracy and confidence in this bin
|
||||
accuracy_in_bin = labels_array[in_bin].mean()
|
||||
avg_confidence_in_bin = scores_array[in_bin].mean()
|
||||
|
||||
# Add to ECE
|
||||
ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
||||
|
||||
return ece
|
||||
|
||||
@staticmethod
|
||||
def maximum_calibration_error(
|
||||
scores: list[float], labels: list[bool], n_bins: int = 10
|
||||
) -> float:
|
||||
"""
|
||||
Calculate Maximum Calibration Error (MCE)
|
||||
|
||||
Args:
|
||||
scores: Predicted confidence scores
|
||||
labels: True labels
|
||||
n_bins: Number of bins for calibration
|
||||
|
||||
Returns:
|
||||
MCE value
|
||||
"""
|
||||
scores_array = np.array(scores)
|
||||
labels_array = np.array(labels, dtype=int)
|
||||
|
||||
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
||||
bin_lowers = bin_boundaries[:-1]
|
||||
bin_uppers = bin_boundaries[1:]
|
||||
|
||||
max_error = 0
|
||||
|
||||
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers, strict=False):
|
||||
# Find samples in this bin
|
||||
in_bin = (scores_array > bin_lower) & (scores_array <= bin_upper)
|
||||
|
||||
if in_bin.sum() > 0:
|
||||
# Calculate accuracy and confidence in this bin
|
||||
accuracy_in_bin = labels_array[in_bin].mean()
|
||||
avg_confidence_in_bin = scores_array[in_bin].mean()
|
||||
|
||||
# Update maximum error
|
||||
error = np.abs(avg_confidence_in_bin - accuracy_in_bin)
|
||||
max_error = max(max_error, error)
|
||||
|
||||
return max_error
|
||||
|
||||
@staticmethod
|
||||
def reliability_diagram_data( # pylint: disable=too-many-locals
|
||||
scores: list[float], labels: list[bool], n_bins: int = 10
|
||||
) -> dict[str, list[float]]:
|
||||
"""
|
||||
Generate data for reliability diagram
|
||||
|
||||
Args:
|
||||
scores: Predicted confidence scores
|
||||
labels: True labels
|
||||
n_bins: Number of bins
|
||||
|
||||
Returns:
|
||||
Dictionary with bin data for plotting
|
||||
"""
|
||||
scores_array = np.array(scores)
|
||||
labels_array = np.array(labels, dtype=int)
|
||||
|
||||
bin_boundaries = np.linspace(0, 1, n_bins + 1)
|
||||
bin_lowers = bin_boundaries[:-1]
|
||||
bin_uppers = bin_boundaries[1:]
|
||||
|
||||
bin_centers = []
|
||||
bin_accuracies = []
|
||||
bin_confidences = []
|
||||
bin_counts = []
|
||||
|
||||
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers, strict=False):
|
||||
# Find samples in this bin
|
||||
in_bin = (scores_array > bin_lower) & (scores_array <= bin_upper)
|
||||
bin_count = in_bin.sum()
|
||||
|
||||
if bin_count > 0:
|
||||
bin_center = (bin_lower + bin_upper) / 2
|
||||
accuracy_in_bin = labels_array[in_bin].mean()
|
||||
avg_confidence_in_bin = scores_array[in_bin].mean()
|
||||
|
||||
bin_centers.append(bin_center)
|
||||
bin_accuracies.append(accuracy_in_bin)
|
||||
bin_confidences.append(avg_confidence_in_bin)
|
||||
bin_counts.append(bin_count)
|
||||
|
||||
return {
|
||||
"bin_centers": bin_centers,
|
||||
"bin_accuracies": bin_accuracies,
|
||||
"bin_confidences": bin_confidences,
|
||||
"bin_counts": bin_counts,
|
||||
}
|
||||
|
||||
|
||||
# Default calibrators for common tasks
|
||||
DEFAULT_CALIBRATORS = {
|
||||
"ocr_confidence": {"method": "temperature"},
|
||||
"extraction_confidence": {"method": "platt"},
|
||||
"rag_confidence": {"method": "isotonic"},
|
||||
"calculation_confidence": {"method": "temperature"},
|
||||
"overall_confidence": {"method": "platt"},
|
||||
}
|
||||
85
libs/calibration/multi_model.py
Normal file
85
libs/calibration/multi_model.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Multi-model calibrator for handling multiple models/tasks."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import structlog
|
||||
|
||||
from .calibrator import ConfidenceCalibrator
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class MultiModelCalibrator:
|
||||
"""Calibrate confidence scores for multiple models/tasks"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calibrators: dict[str, ConfidenceCalibrator] = {}
|
||||
|
||||
def add_calibrator(self, model_name: str, method: str = "temperature") -> None:
|
||||
"""Add calibrator for a specific model"""
|
||||
self.calibrators[model_name] = ConfidenceCalibrator(method)
|
||||
logger.info("Added calibrator", model=model_name, method=method)
|
||||
|
||||
def fit(self, model_name: str, scores: list[float], labels: list[bool]) -> None:
|
||||
"""Fit calibrator for specific model"""
|
||||
if model_name not in self.calibrators:
|
||||
self.add_calibrator(model_name)
|
||||
|
||||
self.calibrators[model_name].fit(scores, labels)
|
||||
|
||||
def calibrate(self, model_name: str, scores: list[float]) -> list[float]:
|
||||
"""Calibrate scores for specific model"""
|
||||
if model_name not in self.calibrators:
|
||||
logger.warning("No calibrator for model", model=model_name)
|
||||
return scores
|
||||
|
||||
return self.calibrators[model_name].calibrate(scores)
|
||||
|
||||
def save_all(self, directory: str) -> None:
|
||||
"""Save all calibrators"""
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
for model_name, calibrator in self.calibrators.items():
|
||||
filepath = os.path.join(directory, f"{model_name}_calibrator.pkl")
|
||||
calibrator.save_model(filepath)
|
||||
|
||||
def load_all(self, directory: str) -> None:
|
||||
"""Load all calibrators from directory"""
|
||||
pattern = os.path.join(directory, "*_calibrator.pkl")
|
||||
for filepath in glob.glob(pattern):
|
||||
filename = os.path.basename(filepath)
|
||||
model_name = filename.replace("_calibrator.pkl", "")
|
||||
|
||||
calibrator = ConfidenceCalibrator()
|
||||
calibrator.load_model(filepath)
|
||||
self.calibrators[model_name] = calibrator
|
||||
|
||||
def save_models(self, directory: str) -> None:
|
||||
"""Save all calibrators (alias for save_all)"""
|
||||
self.save_all(directory)
|
||||
|
||||
def load_models(self, directory: str) -> None:
|
||||
"""Load all calibrators from directory (alias for load_all)"""
|
||||
self.load_all(directory)
|
||||
|
||||
def get_model_names(self) -> list[str]:
|
||||
"""Get list of model names"""
|
||||
return list(self.calibrators.keys())
|
||||
|
||||
def has_model(self, model_name: str) -> bool:
|
||||
"""Check if model exists"""
|
||||
return model_name in self.calibrators
|
||||
|
||||
def is_fitted(self, model_name: str) -> bool:
|
||||
"""Check if model is fitted"""
|
||||
if model_name not in self.calibrators:
|
||||
raise ValueError(f"Model '{model_name}' not found")
|
||||
return self.calibrators[model_name].is_fitted
|
||||
|
||||
def remove_calibrator(self, model_name: str) -> None:
|
||||
"""Remove calibrator for specific model"""
|
||||
if model_name not in self.calibrators:
|
||||
raise ValueError(f"Model '{model_name}' not found")
|
||||
del self.calibrators[model_name]
|
||||
logger.info("Removed calibrator", model=model_name)
|
||||
555
libs/config.py
Normal file
555
libs/config.py
Normal file
@@ -0,0 +1,555 @@
|
||||
# ROLE
|
||||
|
||||
You are a **Senior Platform Engineer + Backend Lead** generating **production code** and **ops assets** for a microservice suite that powers an accounting Knowledge Graph + Vector RAG platform. Authentication/authorization are centralized at the **edge via Traefik + Authentik** (ForwardAuth). **Services are trust-bound** to Traefik and consume user/role claims via forwarded headers/JWT.
|
||||
|
||||
# MISSION
|
||||
|
||||
Produce fully working code for **all application services** (FastAPI + Python 3.12) with:
|
||||
|
||||
- Solid domain models, Pydantic v2 schemas, type hints, strict mypy, ruff lint.
|
||||
- Opentelemetry tracing, Prometheus metrics, structured logging.
|
||||
- Vault-backed secrets, MinIO S3 client, Qdrant client, Neo4j driver, Postgres (SQLAlchemy), Redis.
|
||||
- Eventing (Kafka or SQS/SNS behind an interface).
|
||||
- Deterministic data contracts, end-to-end tests, Dockerfiles, Compose, CI for Gitea.
|
||||
- Traefik labels + Authentik Outpost integration for every exposed route.
|
||||
- Zero PII in vectors (Qdrant), evidence-based lineage in KG, and bitemporal writes.
|
||||
|
||||
# GLOBAL CONSTRAINTS (APPLY TO ALL SERVICES)
|
||||
|
||||
- **Language & Runtime:** Python **3.12**.
|
||||
- **Frameworks:** FastAPI, Pydantic v2, SQLAlchemy 2, httpx, aiokafka or boto3 (pluggable), redis-py, opentelemetry-instrumentation-fastapi, prometheus-fastapi-instrumentator.
|
||||
- **Config:** `pydantic-settings` with `.env` overlay. Provide `Settings` class per service.
|
||||
- **Secrets:** HashiCorp **Vault** (AppRole/JWT). Use Vault Transit to **envelope-encrypt** sensitive fields before persistence (helpers provided in `lib/security.py`).
|
||||
- **Auth:** No OIDC in services. Add `TrustedProxyMiddleware`:
|
||||
|
||||
- Reject if request not from internal network (configurable CIDR).
|
||||
- Require headers set by Traefik+Authentik (`X-Authenticated-User`, `X-Authenticated-Email`, `X-Authenticated-Groups`, `Authorization: Bearer …`).
|
||||
- Parse groups → `roles` list on `request.state`.
|
||||
|
||||
- **Observability:**
|
||||
|
||||
- OpenTelemetry (traceparent propagation), span attrs (service, route, user, tenant).
|
||||
- Prometheus metrics endpoint `/metrics` protected by internal network check.
|
||||
- Structured JSON logs (timestamp, level, svc, trace_id, msg) via `structlog`.
|
||||
|
||||
- **Errors:** Global exception handler → RFC7807 Problem+JSON (`type`, `title`, `status`, `detail`, `instance`, `trace_id`).
|
||||
- **Testing:** `pytest`, `pytest-asyncio`, `hypothesis` (property tests for calculators), `coverage ≥ 90%` per service.
|
||||
- **Static:** `ruff`, `mypy --strict`, `bandit`, `safety`, `licensecheck`.
|
||||
- **Perf:** Each service exposes `/healthz`, `/readyz`, `/livez`; cold start < 500ms; p95 endpoint < 250ms (local).
|
||||
- **Containers:** Distroless or slim images; non-root user; read-only FS; `/tmp` mounted for OCR where needed.
|
||||
- **Docs:** OpenAPI JSON + ReDoc; MkDocs site with service READMEs.
|
||||
|
||||
# SHARED LIBS (GENERATE ONCE, REUSE)
|
||||
|
||||
Create `libs/` used by all services:
|
||||
|
||||
- `libs/config.py` – base `Settings`, env parsing, Vault client factory, MinIO client factory, Qdrant client factory, Neo4j driver factory, Redis factory, Kafka/SQS client factory.
|
||||
- `libs/security.py` – Vault Transit helpers (`encrypt_field`, `decrypt_field`), header parsing, internal-CIDR validator.
|
||||
- `libs/observability.py` – otel init, prometheus instrumentor, logging config.
|
||||
- `libs/events.py` – abstract `EventBus` with `publish(topic, payload: dict)`, `subscribe(topic, handler)`. Two impls: Kafka (`aiokafka`) and SQS/SNS (`boto3`).
|
||||
- `libs/schemas.py` – **canonical Pydantic models** shared across services (Document, Evidence, IncomeItem, etc.) mirroring the ontology schemas. Include JSONSchema exports.
|
||||
- `libs/storage.py` – S3/MinIO helpers (bucket ensure, put/get, presigned).
|
||||
- `libs/neo.py` – Neo4j session helpers, Cypher runner with retry, SHACL validator invoker (pySHACL on exported RDF).
|
||||
- `libs/rag.py` – Qdrant collections CRUD, hybrid search (dense+sparse), rerank wrapper, de-identification utilities (regex + NER; hash placeholders).
|
||||
- `libs/forms.py` – PDF AcroForm fill via `pdfrw` with overlay fallback via `reportlab`.
|
||||
- `libs/calibration.py` – `calibrated_confidence(raw_score, method="temperature_scaling", params=...)`.
|
||||
|
||||
# EVENT TOPICS (STANDARDIZE)
|
||||
|
||||
- `doc.ingested`, `doc.ocr_ready`, `doc.extracted`, `kg.upserted`, `rag.indexed`, `calc.schedule_ready`, `form.filled`, `hmrc.submitted`, `review.requested`, `review.completed`, `firm.sync.completed`
|
||||
|
||||
Each payload MUST include: `event_id (ulid)`, `occurred_at (iso)`, `actor`, `tenant_id`, `trace_id`, `schema_version`, and a `data` object (service-specific).
|
||||
|
||||
# TRUST HEADERS FROM TRAEFIK + AUTHENTIK (USE EXACT KEYS)
|
||||
|
||||
- `X-Authenticated-User` (string)
|
||||
- `X-Authenticated-Email` (string)
|
||||
- `X-Authenticated-Groups` (comma-separated)
|
||||
- `Authorization` (`Bearer <jwt>` from Authentik)
|
||||
Reject any request missing these (except `/healthz|/readyz|/livez|/metrics` from internal CIDR).
|
||||
|
||||
---
|
||||
|
||||
## SERVICES TO IMPLEMENT (CODE FOR EACH)
|
||||
|
||||
### 1) `svc-ingestion`
|
||||
|
||||
**Purpose:** Accept uploads or URLs, checksum, store to MinIO, emit `doc.ingested`.
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `POST /v1/ingest/upload` (multipart file, metadata: `tenant_id`, `kind`, `source`) → `{doc_id, s3_url, checksum}`
|
||||
- `POST /v1/ingest/url` (json: `{url, kind, tenant_id}`) → downloads to MinIO
|
||||
- `GET /v1/docs/{doc_id}` → metadata
|
||||
|
||||
**Logic:**
|
||||
|
||||
- Compute SHA256, dedupe by checksum; MinIO path `tenants/{tenant_id}/raw/{doc_id}.pdf`.
|
||||
- Store metadata in Postgres table `ingest_documents` (alembic migrations).
|
||||
- Publish `doc.ingested` with `{doc_id, bucket, key, pages?, mime}`.
|
||||
|
||||
**Env:** `S3_BUCKET_RAW`, `MINIO_*`, `DB_URL`.
|
||||
|
||||
**Traefik labels:** route `/ingest/*`.
|
||||
|
||||
---
|
||||
|
||||
### 2) `svc-rpa`
|
||||
|
||||
**Purpose:** Scheduled RPA pulls from firm/client portals via Playwright.
|
||||
|
||||
**Tasks:**
|
||||
|
||||
- Playwright login flows (credentials from Vault), 2FA via Authentik OAuth device or OTP secret in Vault.
|
||||
- Download statements/invoices; hand off to `svc-ingestion` via internal POST.
|
||||
- Prefect flows: `pull_portal_X()`, `pull_portal_Y()` with schedules.
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `POST /v1/rpa/run/{connector}` (manual trigger)
|
||||
- `GET /v1/rpa/status/{run_id}`
|
||||
|
||||
**Env:** `VAULT_ADDR`, `VAULT_ROLE_ID`, `VAULT_SECRET_ID`.
|
||||
|
||||
---
|
||||
|
||||
### 3) `svc-ocr`
|
||||
|
||||
**Purpose:** OCR & layout extraction.
|
||||
|
||||
**Pipeline:**
|
||||
|
||||
- Pull object from MinIO, detect rotation/de-skew (`opencv-python`), split pages (`pymupdf`), OCR (`pytesseract`) or bypass if text layer present (`pdfplumber`).
|
||||
- Output per-page text + **bbox** for lines/words.
|
||||
- Write JSON to MinIO `tenants/{tenant_id}/ocr/{doc_id}.json` and emit `doc.ocr_ready`.
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `POST /v1/ocr/{doc_id}` (idempotent trigger)
|
||||
- `GET /v1/ocr/{doc_id}` (fetch OCR JSON)
|
||||
|
||||
**Env:** `TESSERACT_LANGS`, `S3_BUCKET_EVIDENCE`.
|
||||
|
||||
---
|
||||
|
||||
### 4) `svc-extract`
|
||||
|
||||
**Purpose:** Classify docs and extract KV + tables into **schema-constrained JSON** (with bbox/page).
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `POST /v1/extract/{doc_id}` body: `{strategy: "llm|rules|hybrid"}`
|
||||
- `GET /v1/extract/{doc_id}` → structured JSON
|
||||
|
||||
**Implementation:**
|
||||
|
||||
- Use prompt files in `prompts/`: `doc_classify.txt`, `kv_extract.txt`, `table_extract.txt`.
|
||||
- **Validator loop**: run LLM → validate JSONSchema → retry with error messages up to N times.
|
||||
- Return Pydantic models from `libs/schemas.py`.
|
||||
- Emit `doc.extracted`.
|
||||
|
||||
**Env:** `LLM_ENGINE`, `TEMPERATURE`, `MAX_TOKENS`.
|
||||
|
||||
---
|
||||
|
||||
### 5) `svc-normalize-map`
|
||||
|
||||
**Purpose:** Normalize & map extracted data to KG.
|
||||
|
||||
**Logic:**
|
||||
|
||||
- Currency normalization (ECB or static fx table), dates, UK tax year/basis period inference.
|
||||
- Entity resolution (blocking + fuzzy).
|
||||
- Generate nodes/edges (+ `Evidence` with doc_id/page/bbox/text_hash).
|
||||
- Use `libs/neo.py` to write with **bitemporal** fields; run **SHACL** validator; on violation, queue `review.requested`.
|
||||
- Emit `kg.upserted`.
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `POST /v1/map/{doc_id}`
|
||||
- `GET /v1/map/{doc_id}/preview` (diff view, to be used by UI)
|
||||
|
||||
**Env:** `NEO4J_*`.
|
||||
|
||||
---
|
||||
|
||||
### 6) `svc-kg`
|
||||
|
||||
**Purpose:** Graph façade + RDF/SHACL utility.
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `GET /v1/kg/nodes/{label}/{id}`
|
||||
- `POST /v1/kg/cypher` (admin-gated inline query; must check `admin` role)
|
||||
- `POST /v1/kg/export/rdf` (returns RDF for SHACL)
|
||||
- `POST /v1/kg/validate` (run pySHACL against `schemas/shapes.ttl`)
|
||||
- `GET /v1/kg/lineage/{node_id}` (traverse `DERIVED_FROM` → Evidence)
|
||||
|
||||
**Env:** `NEO4J_*`.
|
||||
|
||||
---
|
||||
|
||||
### 7) `svc-rag-indexer`
|
||||
|
||||
**Purpose:** Build Qdrant indices (firm knowledge, legislation, best practices, glossary).
|
||||
|
||||
**Workflow:**
|
||||
|
||||
- Load sources (filesystem, URLs, Firm DMS via `svc-firm-connectors`).
|
||||
- **De-identify PII** (regex + NER), replace with placeholders; store mapping only in Postgres.
|
||||
- Chunk (layout-aware) per `retrieval/chunking.yaml`.
|
||||
- Compute **dense** embeddings (e.g., `bge-small-en-v1.5`) and **sparse** (Qdrant sparse).
|
||||
- Upsert to Qdrant with payload `{jurisdiction, tax_years[], topic_tags[], version, pii_free: true, doc_id/section_id/url}`.
|
||||
- Emit `rag.indexed`.
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `POST /v1/index/run`
|
||||
- `GET /v1/index/status/{run_id}`
|
||||
|
||||
**Env:** `QDRANT_URL`, `RAG_EMBEDDING_MODEL`, `RAG_RERANKER_MODEL`.
|
||||
|
||||
---
|
||||
|
||||
### 8) `svc-rag-retriever`
|
||||
|
||||
**Purpose:** Hybrid search + KG fusion with rerank and calibrated confidence.
|
||||
|
||||
**Endpoint:**
|
||||
|
||||
- `POST /v1/rag/search` `{query, tax_year?, jurisdiction?, k?}` →
|
||||
|
||||
```
|
||||
{
|
||||
"chunks": [...],
|
||||
"citations": [{doc_id|url, section_id?, page?, bbox?}],
|
||||
"kg_hints": [{rule_id, formula_id, node_ids[]}],
|
||||
"calibrated_confidence": 0.0-1.0
|
||||
}
|
||||
```
|
||||
|
||||
**Implementation:**
|
||||
|
||||
- Hybrid score: `alpha * dense + beta * sparse`; rerank top-K via cross-encoder; **KG fusion** (boost chunks citing Rules/Calculations relevant to schedule).
|
||||
- Use `libs/calibration.py` to expose calibrated confidence.
|
||||
|
||||
---
|
||||
|
||||
### 9) `svc-reason`
|
||||
|
||||
**Purpose:** Deterministic calculators + materializers (UK SA).
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `POST /v1/reason/compute_schedule` `{tax_year, taxpayer_id, schedule_id}`
|
||||
- `GET /v1/reason/explain/{schedule_id}` → rationale & lineage paths
|
||||
|
||||
**Implementation:**
|
||||
|
||||
- Pure functions for: employment, self-employment, property (FHL, 20% interest credit), dividends/interest, allowances, NIC (Class 2/4), HICBC, student loans (Plans 1/2/4/5, PGL).
|
||||
- **Deterministic order** as defined; rounding per `FormBox.rounding_rule`.
|
||||
- Use Cypher from `kg/reasoning/schedule_queries.cypher` to materialize box values; attach `DERIVED_FROM` evidence.
|
||||
|
||||
---
|
||||
|
||||
### 10) `svc-forms`
|
||||
|
||||
**Purpose:** Fill PDFs and assemble evidence bundles.
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `POST /v1/forms/fill` `{tax_year, taxpayer_id, form_id}` → returns PDF (binary)
|
||||
- `POST /v1/forms/evidence_pack` `{scope}` → ZIP + manifest + signed hashes (sha256)
|
||||
|
||||
**Implementation:**
|
||||
|
||||
- `pdfrw` for AcroForm; overlay with ReportLab if needed.
|
||||
- Manifest includes `doc_id/page/bbox/text_hash` for every numeric field.
|
||||
|
||||
---
|
||||
|
||||
### 11) `svc-hmrc`
|
||||
|
||||
**Purpose:** HMRC submitter (stub|sandbox|live).
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `POST /v1/hmrc/submit` `{tax_year, taxpayer_id, dry_run}` → `{status, submission_id?, errors[]}`
|
||||
- `GET /v1/hmrc/submissions/{id}`
|
||||
|
||||
**Implementation:**
|
||||
|
||||
- Rate limits, retries/backoff, signed audit log; environment toggle.
|
||||
|
||||
---
|
||||
|
||||
### 12) `svc-firm-connectors`
|
||||
|
||||
**Purpose:** Read-only connectors to Firm Databases (Practice Mgmt, DMS).
|
||||
|
||||
**Endpoints:**
|
||||
|
||||
- `POST /v1/firm/sync` `{since?}` → `{objects_synced, errors[]}`
|
||||
- `GET /v1/firm/objects` (paged)
|
||||
|
||||
**Implementation:**
|
||||
|
||||
- Data contracts in `config/firm_contracts/`; mappers → Secure Client Data Store (Postgres) with lineage columns (`source`, `source_id`, `synced_at`).
|
||||
|
||||
---
|
||||
|
||||
### 13) `ui-review` (outline only)
|
||||
|
||||
- Next.js (SSO handled by Traefik+Authentik), shows extracted fields + evidence snippets; POST overrides to `svc-extract`/`svc-normalize-map`.
|
||||
|
||||
---
|
||||
|
||||
## DATA CONTRACTS (ESSENTIAL EXAMPLES)
|
||||
|
||||
**Event: `doc.ingested`**
|
||||
|
||||
```json
|
||||
{
|
||||
"event_id": "01J...ULID",
|
||||
"occurred_at": "2025-09-13T08:00:00Z",
|
||||
"actor": "svc-ingestion",
|
||||
"tenant_id": "t_123",
|
||||
"trace_id": "abc-123",
|
||||
"schema_version": "1.0",
|
||||
"data": {
|
||||
"doc_id": "d_abc",
|
||||
"bucket": "raw",
|
||||
"key": "tenants/t_123/raw/d_abc.pdf",
|
||||
"checksum": "sha256:...",
|
||||
"kind": "bank_statement",
|
||||
"mime": "application/pdf",
|
||||
"pages": 12
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**RAG search response shape**
|
||||
|
||||
```json
|
||||
{
|
||||
"chunks": [
|
||||
{
|
||||
"id": "c1",
|
||||
"text": "...",
|
||||
"score": 0.78,
|
||||
"payload": {
|
||||
"jurisdiction": "UK",
|
||||
"tax_years": ["2024-25"],
|
||||
"topic_tags": ["FHL"],
|
||||
"pii_free": true
|
||||
}
|
||||
}
|
||||
],
|
||||
"citations": [
|
||||
{ "doc_id": "leg-ITA2007", "section_id": "s272A", "url": "https://..." }
|
||||
],
|
||||
"kg_hints": [
|
||||
{
|
||||
"rule_id": "UK.FHL.Qual",
|
||||
"formula_id": "FHL_Test_v1",
|
||||
"node_ids": ["n123", "n456"]
|
||||
}
|
||||
],
|
||||
"calibrated_confidence": 0.81
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## PERSISTENCE SCHEMAS (POSTGRES; ALEMBIC)
|
||||
|
||||
- `ingest_documents(id pk, tenant_id, doc_id, kind, checksum, bucket, key, mime, pages, created_at)`
|
||||
- `firm_objects(id pk, tenant_id, source, source_id, type, payload jsonb, synced_at)`
|
||||
- Qdrant PII mapping table (if absolutely needed): `pii_links(id pk, placeholder_hash, client_id, created_at)` — **encrypt with Vault Transit**; do NOT store raw values.
|
||||
|
||||
---
|
||||
|
||||
## TRAEFIK + AUTHENTIK (COMPOSE LABELS PER SERVICE)
|
||||
|
||||
For every service container in `infra/compose/docker-compose.local.yml`, add labels:
|
||||
|
||||
```
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.svc-extract.rule=Host(`api.local`) && PathPrefix(`/extract`)"
|
||||
- "traefik.http.routers.svc-extract.entrypoints=websecure"
|
||||
- "traefik.http.routers.svc-extract.tls=true"
|
||||
- "traefik.http.routers.svc-extract.middlewares=authentik-forwardauth,rate-limit"
|
||||
- "traefik.http.services.svc-extract.loadbalancer.server.port=8000"
|
||||
```
|
||||
|
||||
Use the shared dynamic file `traefik-dynamic.yml` with `authentik-forwardauth` and `rate-limit` middlewares.
|
||||
|
||||
---
|
||||
|
||||
## OUTPUT FORMAT (STRICT)
|
||||
|
||||
Implement a **multi-file codebase** as fenced blocks, EXACTLY in this order:
|
||||
|
||||
```txt
|
||||
# FILE: libs/config.py
|
||||
# factories for Vault/MinIO/Qdrant/Neo4j/Redis/EventBus, Settings base
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: libs/security.py
|
||||
# Vault Transit helpers, header parsing, internal CIDR checks, middleware
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: libs/observability.py
|
||||
# otel init, prometheus, structlog
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: libs/events.py
|
||||
# EventBus abstraction with Kafka and SQS/SNS impls
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: libs/schemas.py
|
||||
# Shared Pydantic models mirroring ontology entities
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-ingestion/main.py
|
||||
# FastAPI app, endpoints, MinIO write, Postgres, publish doc.ingested
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-rpa/main.py
|
||||
# Playwright flows, Prefect tasks, triggers
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-ocr/main.py
|
||||
# OCR pipeline, endpoints
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-extract/main.py
|
||||
# Classifier + extractors with validator loop
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-normalize-map/main.py
|
||||
# normalization, entity resolution, KG mapping, SHACL validation call
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-kg/main.py
|
||||
# KG façade, RDF export, SHACL validate, lineage traversal
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-rag-indexer/main.py
|
||||
# chunk/de-id/embed/upsert to Qdrant
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-rag-retriever/main.py
|
||||
# hybrid retrieval + rerank + KG fusion
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-reason/main.py
|
||||
# deterministic calculators, schedule compute/explain
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-forms/main.py
|
||||
# PDF fill + evidence pack
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-hmrc/main.py
|
||||
# submit stub|sandbox|live with audit + retries
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: apps/svc-firm-connectors/main.py
|
||||
# connectors to practice mgmt & DMS, sync to Postgres
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: infra/compose/docker-compose.local.yml
|
||||
# Traefik, Authentik, Vault, MinIO, Qdrant, Neo4j, Postgres, Redis, Prom+Grafana, Loki, Unleash, all services
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: infra/compose/traefik.yml
|
||||
# static Traefik config
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: infra/compose/traefik-dynamic.yml
|
||||
# forwardAuth middleware + routers/services
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: .gitea/workflows/ci.yml
|
||||
# lint->test->build->scan->push->deploy
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: Makefile
|
||||
# bootstrap, run, test, lint, build, deploy, format, seed
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: tests/e2e/test_happy_path.py
|
||||
# end-to-end: ingest -> ocr -> extract -> map -> compute -> fill -> (stub) submit
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: tests/unit/test_calculators.py
|
||||
# boundary tests for UK SA logic (NIC, HICBC, PA taper, FHL)
|
||||
...
|
||||
```
|
||||
|
||||
```txt
|
||||
# FILE: README.md
|
||||
# how to run locally with docker-compose, Authentik setup, Traefik certs
|
||||
...
|
||||
```
|
||||
|
||||
## DEFINITION OF DONE
|
||||
|
||||
- `docker compose up` brings the full stack up; SSO via Authentik; routes secured via Traefik ForwardAuth.
|
||||
- Running `pytest` yields ≥ 90% coverage; `make e2e` passes the ingest→…→submit stub flow.
|
||||
- All services expose `/healthz|/readyz|/livez|/metrics`; OpenAPI at `/docs`.
|
||||
- No PII stored in Qdrant; vectors carry `pii_free=true`.
|
||||
- KG writes are SHACL-validated; violations produce `review.requested` events.
|
||||
- Evidence lineage is present for every numeric box value.
|
||||
- Gitea pipeline passes: lint, test, build, scan, push, deploy.
|
||||
|
||||
# START
|
||||
|
||||
Generate the full codebase and configs in the **exact file blocks and order** specified above.
|
||||
41
libs/config/__init__.py
Normal file
41
libs/config/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Configuration management and client factories."""
|
||||
|
||||
from .factories import (
|
||||
EventBusFactory,
|
||||
MinIOClientFactory,
|
||||
Neo4jDriverFactory,
|
||||
QdrantClientFactory,
|
||||
RedisClientFactory,
|
||||
VaultClientFactory,
|
||||
)
|
||||
from .settings import BaseAppSettings
|
||||
from .utils import (
|
||||
create_event_bus,
|
||||
create_minio_client,
|
||||
create_neo4j_client,
|
||||
create_qdrant_client,
|
||||
create_redis_client,
|
||||
create_vault_client,
|
||||
get_default_settings,
|
||||
get_settings,
|
||||
init_settings,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseAppSettings",
|
||||
"VaultClientFactory",
|
||||
"MinIOClientFactory",
|
||||
"QdrantClientFactory",
|
||||
"Neo4jDriverFactory",
|
||||
"RedisClientFactory",
|
||||
"EventBusFactory",
|
||||
"get_settings",
|
||||
"init_settings",
|
||||
"create_vault_client",
|
||||
"create_minio_client",
|
||||
"create_qdrant_client",
|
||||
"create_neo4j_client",
|
||||
"create_redis_client",
|
||||
"create_event_bus",
|
||||
"get_default_settings",
|
||||
]
|
||||
122
libs/config/factories.py
Normal file
122
libs/config/factories.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Client factories for various services."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import boto3 # type: ignore
|
||||
import hvac
|
||||
import redis.asyncio as redis
|
||||
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer # type: ignore
|
||||
from minio import Minio
|
||||
from neo4j import GraphDatabase
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
from .settings import BaseAppSettings
|
||||
|
||||
|
||||
class VaultClientFactory: # pylint: disable=too-few-public-methods
|
||||
"""Factory for creating Vault clients"""
|
||||
|
||||
@staticmethod
|
||||
def create_client(settings: BaseAppSettings) -> hvac.Client:
|
||||
"""Create authenticated Vault client"""
|
||||
client = hvac.Client(url=settings.vault_addr)
|
||||
|
||||
if settings.vault_token:
|
||||
# Development mode with token
|
||||
client.token = settings.vault_token
|
||||
elif settings.vault_role_id and settings.vault_secret_id:
|
||||
# Production mode with AppRole
|
||||
try:
|
||||
auth_response = client.auth.approle.login(
|
||||
role_id=settings.vault_role_id, secret_id=settings.vault_secret_id
|
||||
)
|
||||
client.token = auth_response["auth"]["client_token"]
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to authenticate with Vault") from e
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either vault_token or vault_role_id/vault_secret_id must be provided"
|
||||
)
|
||||
|
||||
if not client.is_authenticated():
|
||||
raise ValueError("Failed to authenticate with Vault")
|
||||
|
||||
return client
|
||||
|
||||
|
||||
class MinIOClientFactory: # pylint: disable=too-few-public-methods
|
||||
"""Factory for creating MinIO clients"""
|
||||
|
||||
@staticmethod
|
||||
def create_client(settings: BaseAppSettings) -> Minio:
|
||||
"""Create MinIO client"""
|
||||
return Minio(
|
||||
endpoint=settings.minio_endpoint,
|
||||
access_key=settings.minio_access_key,
|
||||
secret_key=settings.minio_secret_key,
|
||||
secure=settings.minio_secure,
|
||||
)
|
||||
|
||||
|
||||
class QdrantClientFactory: # pylint: disable=too-few-public-methods
|
||||
"""Factory for creating Qdrant clients"""
|
||||
|
||||
@staticmethod
|
||||
def create_client(settings: BaseAppSettings) -> QdrantClient:
|
||||
"""Create Qdrant client"""
|
||||
return QdrantClient(url=settings.qdrant_url, api_key=settings.qdrant_api_key)
|
||||
|
||||
|
||||
class Neo4jDriverFactory: # pylint: disable=too-few-public-methods
|
||||
"""Factory for creating Neo4j drivers"""
|
||||
|
||||
@staticmethod
|
||||
def create_driver(settings: BaseAppSettings) -> Any:
|
||||
"""Create Neo4j driver"""
|
||||
return GraphDatabase.driver(
|
||||
settings.neo4j_uri, auth=(settings.neo4j_user, settings.neo4j_password)
|
||||
)
|
||||
|
||||
|
||||
class RedisClientFactory: # pylint: disable=too-few-public-methods
|
||||
"""Factory for creating Redis clients"""
|
||||
|
||||
@staticmethod
|
||||
async def create_client(settings: BaseAppSettings) -> "redis.Redis[str]":
|
||||
"""Create Redis client"""
|
||||
return redis.from_url(
|
||||
settings.redis_url, encoding="utf-8", decode_responses=True
|
||||
)
|
||||
|
||||
|
||||
class EventBusFactory:
|
||||
"""Factory for creating event bus clients"""
|
||||
|
||||
@staticmethod
|
||||
def create_kafka_producer(settings: BaseAppSettings) -> AIOKafkaProducer:
|
||||
"""Create Kafka producer"""
|
||||
return AIOKafkaProducer(
|
||||
bootstrap_servers=settings.kafka_bootstrap_servers,
|
||||
value_serializer=lambda v: v.encode("utf-8") if isinstance(v, str) else v,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_kafka_consumer(
|
||||
settings: BaseAppSettings, topics: list[str]
|
||||
) -> AIOKafkaConsumer:
|
||||
"""Create Kafka consumer"""
|
||||
return AIOKafkaConsumer(
|
||||
*topics,
|
||||
bootstrap_servers=settings.kafka_bootstrap_servers,
|
||||
value_deserializer=lambda m: m.decode("utf-8") if m else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_sqs_client(settings: BaseAppSettings) -> Any:
|
||||
"""Create SQS client"""
|
||||
return boto3.client("sqs", region_name=settings.aws_region)
|
||||
|
||||
@staticmethod
|
||||
def create_sns_client(settings: BaseAppSettings) -> Any:
|
||||
"""Create SNS client"""
|
||||
return boto3.client("sns", region_name=settings.aws_region)
|
||||
113
libs/config/settings.py
Normal file
113
libs/config/settings.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Base settings class for all services."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class BaseAppSettings(BaseSettings):
|
||||
"""Base settings class for all services"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
|
||||
)
|
||||
|
||||
# Service identification
|
||||
service_name: str = Field(default="default-service", description="Service name")
|
||||
service_version: str = Field(default="1.0.0", description="Service version")
|
||||
|
||||
# Network and security
|
||||
host: str = Field(default="0.0.0.0", description="Service host")
|
||||
port: int = Field(default=8000, description="Service port")
|
||||
internal_cidrs: list[str] = Field(
|
||||
default=["10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"],
|
||||
description="Internal network CIDRs",
|
||||
)
|
||||
|
||||
# Development settings
|
||||
dev_mode: bool = Field(
|
||||
default=False,
|
||||
description="Enable development mode (disables auth)",
|
||||
validation_alias="DEV_MODE",
|
||||
)
|
||||
disable_auth: bool = Field(
|
||||
default=False,
|
||||
description="Disable authentication middleware",
|
||||
validation_alias="DISABLE_AUTH",
|
||||
)
|
||||
|
||||
# Vault configuration
|
||||
vault_addr: str = Field(
|
||||
default="http://vault:8200", description="Vault server address"
|
||||
)
|
||||
vault_role_id: str | None = Field(default=None, description="Vault AppRole role ID")
|
||||
vault_secret_id: str | None = Field(
|
||||
default=None, description="Vault AppRole secret ID"
|
||||
)
|
||||
vault_token: str | None = Field(default=None, description="Vault token (dev only)")
|
||||
vault_mount_point: str = Field(
|
||||
default="transit", description="Vault transit mount point"
|
||||
)
|
||||
|
||||
# Database URLs
|
||||
postgres_url: str = Field(
|
||||
default="postgresql://user:pass@postgres:5432/taxagent",
|
||||
description="PostgreSQL connection URL",
|
||||
)
|
||||
neo4j_uri: str = Field(
|
||||
default="bolt://neo4j:7687", description="Neo4j connection URI"
|
||||
)
|
||||
neo4j_user: str = Field(default="neo4j", description="Neo4j username")
|
||||
neo4j_password: str = Field(default="password", description="Neo4j password")
|
||||
redis_url: str = Field(
|
||||
default="redis://redis:6379", description="Redis connection URL"
|
||||
)
|
||||
|
||||
# Object storage
|
||||
minio_endpoint: str = Field(default="minio:9000", description="MinIO endpoint")
|
||||
minio_access_key: str = Field(default="minioadmin", description="MinIO access key")
|
||||
minio_secret_key: str = Field(default="minioadmin", description="MinIO secret key")
|
||||
minio_secure: bool = Field(default=False, description="Use HTTPS for MinIO")
|
||||
|
||||
# Vector database
|
||||
qdrant_url: str = Field(
|
||||
default="http://qdrant:6333", description="Qdrant server URL"
|
||||
)
|
||||
qdrant_api_key: str | None = Field(default=None, description="Qdrant API key")
|
||||
|
||||
# Event bus configuration
|
||||
event_bus_type: str = Field(
|
||||
default="nats", description="Event bus type: nats, kafka, sqs, or memory"
|
||||
)
|
||||
|
||||
# NATS configuration
|
||||
nats_servers: str = Field(
|
||||
default="nats://localhost:4222",
|
||||
description="NATS server URLs (comma-separated)",
|
||||
)
|
||||
nats_stream_name: str = Field(
|
||||
default="TAX_AGENT_EVENTS", description="NATS JetStream stream name"
|
||||
)
|
||||
nats_consumer_group: str = Field(
|
||||
default="tax-agent", description="NATS consumer group name"
|
||||
)
|
||||
|
||||
# Kafka configuration (legacy)
|
||||
kafka_bootstrap_servers: str = Field(
|
||||
default="localhost:9092", description="Kafka bootstrap servers"
|
||||
)
|
||||
|
||||
# AWS configuration
|
||||
aws_region: str = Field(default="us-east-1", description="AWS region for SQS/SNS")
|
||||
|
||||
# Observability
|
||||
otel_service_name: str | None = Field(
|
||||
default=None, description="OpenTelemetry service name"
|
||||
)
|
||||
otel_exporter_endpoint: str | None = Field(
|
||||
default=None, description="OTEL exporter endpoint"
|
||||
)
|
||||
log_level: str = Field(default="INFO", description="Log level")
|
||||
|
||||
# Performance
|
||||
max_workers: int = Field(default=4, description="Maximum worker threads")
|
||||
request_timeout: int = Field(default=30, description="Request timeout in seconds")
|
||||
108
libs/config/utils.py
Normal file
108
libs/config/utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Configuration utility functions and global settings management."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import hvac
|
||||
import redis.asyncio as redis
|
||||
from minio import Minio
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
from libs.events.base import EventBus
|
||||
|
||||
from .factories import (
|
||||
MinIOClientFactory,
|
||||
Neo4jDriverFactory,
|
||||
QdrantClientFactory,
|
||||
RedisClientFactory,
|
||||
VaultClientFactory,
|
||||
)
|
||||
from .settings import BaseAppSettings
|
||||
|
||||
# Global settings instance
|
||||
_settings: BaseAppSettings | None = None
|
||||
|
||||
|
||||
def get_settings() -> BaseAppSettings:
|
||||
"""Get global settings instance"""
|
||||
global _settings # pylint: disable=global-variable-not-assigned
|
||||
if _settings is None:
|
||||
raise RuntimeError("Settings not initialized. Call init_settings() first.")
|
||||
return _settings
|
||||
|
||||
|
||||
def init_settings(
|
||||
settings_class: type[BaseAppSettings] = BaseAppSettings, **kwargs: Any
|
||||
) -> BaseAppSettings:
|
||||
"""Initialize settings with custom class"""
|
||||
global _settings # pylint: disable=global-statement
|
||||
_settings = settings_class(**kwargs)
|
||||
return _settings
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def create_vault_client(settings: BaseAppSettings) -> hvac.Client:
|
||||
"""Create Vault client"""
|
||||
return VaultClientFactory.create_client(settings)
|
||||
|
||||
|
||||
def create_minio_client(settings: BaseAppSettings) -> Minio:
|
||||
"""Create MinIO client"""
|
||||
return MinIOClientFactory.create_client(settings)
|
||||
|
||||
|
||||
def create_qdrant_client(settings: BaseAppSettings) -> QdrantClient:
|
||||
"""Create Qdrant client"""
|
||||
return QdrantClientFactory.create_client(settings)
|
||||
|
||||
|
||||
def create_neo4j_client(settings: BaseAppSettings) -> Any:
|
||||
"""Create Neo4j driver"""
|
||||
return Neo4jDriverFactory.create_driver(settings)
|
||||
|
||||
|
||||
async def create_redis_client(settings: BaseAppSettings) -> "redis.Redis[str]":
|
||||
"""Create Redis client"""
|
||||
return await RedisClientFactory.create_client(settings)
|
||||
|
||||
|
||||
def create_event_bus(settings: BaseAppSettings) -> EventBus:
|
||||
"""Create event bus"""
|
||||
if settings.event_bus_type.lower() == "kafka":
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from ..events import KafkaEventBus
|
||||
|
||||
return KafkaEventBus(settings.kafka_bootstrap_servers)
|
||||
if settings.event_bus_type.lower() == "sqs":
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from ..events import SQSEventBus
|
||||
|
||||
return SQSEventBus(settings.aws_region)
|
||||
if settings.event_bus_type.lower() == "memory":
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from ..events import MemoryEventBus
|
||||
|
||||
return MemoryEventBus()
|
||||
|
||||
# Default to memory bus for unknown types
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from ..events import MemoryEventBus
|
||||
|
||||
return MemoryEventBus()
|
||||
|
||||
|
||||
def get_default_settings(**overrides: Any) -> BaseAppSettings:
|
||||
"""Get default settings with optional overrides"""
|
||||
defaults = {
|
||||
"service_name": "default-service",
|
||||
"vault_addr": "http://vault:8200",
|
||||
"postgres_url": "postgresql://user:pass@postgres:5432/taxagent",
|
||||
"neo4j_uri": "bolt://neo4j:7687",
|
||||
"neo4j_password": "password",
|
||||
"redis_url": "redis://redis:6379",
|
||||
"minio_endpoint": "minio:9000",
|
||||
"minio_access_key": "minioadmin",
|
||||
"minio_secret_key": "minioadmin",
|
||||
"qdrant_url": "http://qdrant:6333",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return BaseAppSettings(**defaults) # type: ignore
|
||||
9
libs/coverage/__init__.py
Normal file
9
libs/coverage/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Coverage evaluation engine for tax document requirements."""
|
||||
|
||||
from .evaluator import CoverageEvaluator
|
||||
from .utils import check_document_coverage
|
||||
|
||||
__all__ = [
|
||||
"CoverageEvaluator",
|
||||
"check_document_coverage",
|
||||
]
|
||||
418
libs/coverage/evaluator.py
Normal file
418
libs/coverage/evaluator.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""Core coverage evaluation engine."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
from ..schemas import (
|
||||
BlockingItem,
|
||||
Citation,
|
||||
CompiledCoveragePolicy,
|
||||
CoverageItem,
|
||||
CoverageReport,
|
||||
FoundEvidence,
|
||||
OverallStatus,
|
||||
Role,
|
||||
ScheduleCoverage,
|
||||
Status,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class CoverageEvaluator:
|
||||
"""Core coverage evaluation engine"""
|
||||
|
||||
def __init__(self, kg_client: Any = None, rag_client: Any = None):
|
||||
self.kg_client = kg_client
|
||||
self.rag_client = rag_client
|
||||
|
||||
async def check_document_coverage(
|
||||
self,
|
||||
taxpayer_id: str,
|
||||
tax_year: str,
|
||||
policy: CompiledCoveragePolicy,
|
||||
) -> CoverageReport:
|
||||
"""Main coverage evaluation workflow"""
|
||||
|
||||
logger.info(
|
||||
"Starting coverage evaluation",
|
||||
taxpayer_id=taxpayer_id,
|
||||
tax_year=tax_year,
|
||||
policy_version=policy.policy.version,
|
||||
)
|
||||
|
||||
# Step A: Infer required schedules
|
||||
required_schedules = await self.infer_required_schedules(
|
||||
taxpayer_id, tax_year, policy
|
||||
)
|
||||
|
||||
# Step B: Evaluate each schedule
|
||||
schedule_coverage = []
|
||||
all_blocking_items = []
|
||||
|
||||
for schedule_id in required_schedules:
|
||||
coverage = await self._evaluate_schedule_coverage(
|
||||
schedule_id, taxpayer_id, tax_year, policy
|
||||
)
|
||||
schedule_coverage.append(coverage)
|
||||
|
||||
# Collect blocking items
|
||||
for evidence in coverage.evidence:
|
||||
if evidence.role == Role.REQUIRED and evidence.status == Status.MISSING:
|
||||
all_blocking_items.append(
|
||||
BlockingItem(schedule_id=schedule_id, evidence_id=evidence.id)
|
||||
)
|
||||
|
||||
# Step C: Determine overall status
|
||||
overall_status = self._determine_overall_status(
|
||||
schedule_coverage, all_blocking_items
|
||||
)
|
||||
|
||||
return CoverageReport(
|
||||
tax_year=tax_year,
|
||||
taxpayer_id=taxpayer_id,
|
||||
schedules_required=required_schedules,
|
||||
overall_status=overall_status,
|
||||
coverage=schedule_coverage,
|
||||
blocking_items=all_blocking_items,
|
||||
policy_version=policy.policy.version,
|
||||
)
|
||||
|
||||
async def infer_required_schedules(
|
||||
self,
|
||||
taxpayer_id: str,
|
||||
tax_year: str,
|
||||
policy: CompiledCoveragePolicy,
|
||||
) -> list[str]:
|
||||
"""Determine which schedules are required for this taxpayer"""
|
||||
|
||||
required = []
|
||||
|
||||
for schedule_id, trigger in policy.policy.triggers.items():
|
||||
is_required = False
|
||||
|
||||
# Check any_of conditions
|
||||
if trigger.any_of:
|
||||
for condition in trigger.any_of:
|
||||
predicate = policy.compiled_predicates.get(condition)
|
||||
if predicate and predicate(taxpayer_id, tax_year):
|
||||
is_required = True
|
||||
break
|
||||
|
||||
# Check all_of conditions
|
||||
if trigger.all_of and not is_required:
|
||||
all_match = True
|
||||
for condition in trigger.all_of:
|
||||
predicate = policy.compiled_predicates.get(condition)
|
||||
if not predicate or not predicate(taxpayer_id, tax_year):
|
||||
all_match = False
|
||||
break
|
||||
if all_match:
|
||||
is_required = True
|
||||
|
||||
if is_required:
|
||||
required.append(schedule_id)
|
||||
logger.debug(
|
||||
"Schedule required",
|
||||
schedule_id=schedule_id,
|
||||
taxpayer_id=taxpayer_id,
|
||||
)
|
||||
|
||||
return required
|
||||
|
||||
async def find_evidence_docs(
|
||||
self,
|
||||
taxpayer_id: str,
|
||||
tax_year: str,
|
||||
evidence_ids: list[str],
|
||||
policy: CompiledCoveragePolicy,
|
||||
) -> dict[str, list[FoundEvidence]]:
|
||||
"""Find evidence documents in the knowledge graph"""
|
||||
|
||||
if not self.kg_client:
|
||||
logger.warning("No KG client available, returning empty evidence")
|
||||
empty_evidence_list: list[FoundEvidence] = []
|
||||
return dict.fromkeys(evidence_ids, empty_evidence_list)
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from ..neo import kg_find_evidence
|
||||
|
||||
evidence_map: dict[str, list[FoundEvidence]] = {}
|
||||
thresholds = policy.policy.defaults.confidence_thresholds
|
||||
|
||||
for evidence_id in evidence_ids:
|
||||
try:
|
||||
found = await kg_find_evidence(
|
||||
self.kg_client,
|
||||
taxpayer_id=taxpayer_id,
|
||||
tax_year=tax_year,
|
||||
kinds=[evidence_id],
|
||||
min_ocr=thresholds.get("ocr", 0.6),
|
||||
date_window=policy.policy.defaults.date_tolerance_days,
|
||||
)
|
||||
evidence_map[evidence_id] = found
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to find evidence",
|
||||
evidence_id=evidence_id,
|
||||
error=str(e),
|
||||
)
|
||||
empty_list: list[FoundEvidence] = []
|
||||
evidence_map[evidence_id] = empty_list
|
||||
|
||||
return evidence_map
|
||||
|
||||
def classify_status(
|
||||
self,
|
||||
found: list[FoundEvidence],
|
||||
policy: CompiledCoveragePolicy,
|
||||
tax_year: str,
|
||||
) -> Status:
|
||||
"""Classify evidence status based on what was found"""
|
||||
|
||||
if not found:
|
||||
return Status.MISSING
|
||||
|
||||
classifier = policy.policy.status_classifier
|
||||
tax_year_start, tax_year_end = self._parse_tax_year_bounds(
|
||||
policy.policy.tax_year_boundary.start,
|
||||
policy.policy.tax_year_boundary.end,
|
||||
)
|
||||
|
||||
# Check for conflicts first
|
||||
if len(found) > 1:
|
||||
# Simple conflict detection: different totals for same period
|
||||
# In production, this would be more sophisticated
|
||||
return Status.CONFLICTING
|
||||
|
||||
evidence = found[0]
|
||||
|
||||
# Check if evidence meets verified criteria
|
||||
if (
|
||||
evidence.ocr_confidence >= classifier.present_verified.min_ocr
|
||||
and evidence.extract_confidence >= classifier.present_verified.min_extract
|
||||
):
|
||||
# Check date validity
|
||||
if evidence.date:
|
||||
# Handle both date-only and datetime strings consistently
|
||||
if "T" not in evidence.date:
|
||||
# Date-only string, add time and timezone (middle of day)
|
||||
evidence_date = datetime.fromisoformat(
|
||||
evidence.date + "T12:00:00+00:00"
|
||||
)
|
||||
else:
|
||||
# Full datetime string, ensure timezone-aware
|
||||
evidence_date = datetime.fromisoformat(
|
||||
evidence.date.replace("Z", "+00:00")
|
||||
)
|
||||
if tax_year_start <= evidence_date <= tax_year_end:
|
||||
return Status.PRESENT_VERIFIED
|
||||
|
||||
# Check if evidence meets unverified criteria
|
||||
if (
|
||||
evidence.ocr_confidence >= classifier.present_unverified.min_ocr
|
||||
and evidence.extract_confidence >= classifier.present_unverified.min_extract
|
||||
):
|
||||
return Status.PRESENT_UNVERIFIED
|
||||
|
||||
# Default to missing if confidence too low
|
||||
return Status.MISSING
|
||||
|
||||
async def build_reason_and_citations(
|
||||
self,
|
||||
schedule_id: str,
|
||||
evidence_item: Any,
|
||||
status: Status,
|
||||
taxpayer_id: str,
|
||||
tax_year: str,
|
||||
policy: CompiledCoveragePolicy,
|
||||
) -> tuple[str, list[Citation]]:
|
||||
"""Build human-readable reason and citations"""
|
||||
|
||||
# Build reason text
|
||||
reason = self._build_reason_text(evidence_item, status, policy)
|
||||
|
||||
# Get citations from KG
|
||||
citations = []
|
||||
if self.kg_client:
|
||||
try:
|
||||
from ..neo import kg_rule_citations
|
||||
|
||||
kg_citations = await kg_rule_citations(
|
||||
self.kg_client, schedule_id, evidence_item.boxes
|
||||
)
|
||||
citations.extend(kg_citations)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get KG citations", error=str(e))
|
||||
|
||||
# Fallback to RAG citations if needed
|
||||
if not citations and self.rag_client:
|
||||
try:
|
||||
from ..rag import rag_search_for_citations
|
||||
|
||||
query = f"{schedule_id} {evidence_item.id} requirements"
|
||||
filters = {
|
||||
"jurisdiction": policy.policy.jurisdiction,
|
||||
"tax_year": tax_year,
|
||||
"pii_free": True,
|
||||
}
|
||||
rag_citations = await rag_search_for_citations(
|
||||
self.rag_client, query, filters
|
||||
)
|
||||
citations.extend(rag_citations)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get RAG citations", error=str(e))
|
||||
|
||||
return reason, citations
|
||||
|
||||
async def _evaluate_schedule_coverage(
|
||||
self,
|
||||
schedule_id: str,
|
||||
taxpayer_id: str,
|
||||
tax_year: str,
|
||||
policy: CompiledCoveragePolicy,
|
||||
) -> ScheduleCoverage:
|
||||
"""Evaluate coverage for a single schedule"""
|
||||
|
||||
schedule_policy = policy.policy.schedules[schedule_id]
|
||||
evidence_items = []
|
||||
|
||||
# Get all evidence IDs for this schedule
|
||||
evidence_ids = [e.id for e in schedule_policy.evidence]
|
||||
|
||||
# Find evidence in KG
|
||||
evidence_map = await self.find_evidence_docs(
|
||||
taxpayer_id, tax_year, evidence_ids, policy
|
||||
)
|
||||
|
||||
# Evaluate each evidence requirement
|
||||
for evidence_req in schedule_policy.evidence:
|
||||
# Check if conditionally required evidence applies
|
||||
if (
|
||||
evidence_req.role == Role.CONDITIONALLY_REQUIRED
|
||||
and evidence_req.condition
|
||||
):
|
||||
predicate = policy.compiled_predicates.get(evidence_req.condition)
|
||||
if not predicate or not predicate(taxpayer_id, tax_year):
|
||||
continue # Skip this evidence as condition not met
|
||||
|
||||
found = evidence_map.get(evidence_req.id, [])
|
||||
status = self.classify_status(found, policy, tax_year)
|
||||
|
||||
reason, citations = await self.build_reason_and_citations(
|
||||
schedule_id, evidence_req, status, taxpayer_id, tax_year, policy
|
||||
)
|
||||
|
||||
evidence_item = CoverageItem(
|
||||
id=evidence_req.id,
|
||||
role=evidence_req.role,
|
||||
status=status,
|
||||
boxes=evidence_req.boxes,
|
||||
found=found,
|
||||
acceptable_alternatives=evidence_req.acceptable_alternatives,
|
||||
reason=reason,
|
||||
citations=citations,
|
||||
)
|
||||
evidence_items.append(evidence_item)
|
||||
|
||||
# Determine schedule status
|
||||
schedule_status = self._determine_schedule_status(evidence_items)
|
||||
|
||||
return ScheduleCoverage(
|
||||
schedule_id=schedule_id,
|
||||
status=schedule_status,
|
||||
evidence=evidence_items,
|
||||
)
|
||||
|
||||
def _determine_overall_status(
|
||||
self,
|
||||
schedule_coverage: list[ScheduleCoverage],
|
||||
blocking_items: list[BlockingItem],
|
||||
) -> OverallStatus:
|
||||
"""Determine overall coverage status"""
|
||||
|
||||
if blocking_items:
|
||||
return OverallStatus.BLOCKING
|
||||
|
||||
# Check if all schedules are OK
|
||||
all_ok = all(s.status == OverallStatus.OK for s in schedule_coverage)
|
||||
if all_ok:
|
||||
return OverallStatus.OK
|
||||
|
||||
return OverallStatus.PARTIAL
|
||||
|
||||
def _determine_schedule_status(
|
||||
self, evidence_items: list[CoverageItem]
|
||||
) -> OverallStatus:
|
||||
"""Determine status for a single schedule"""
|
||||
|
||||
# Check for blocking issues
|
||||
has_missing_required = any(
|
||||
e.role == Role.REQUIRED and e.status == Status.MISSING
|
||||
for e in evidence_items
|
||||
)
|
||||
|
||||
if has_missing_required:
|
||||
return OverallStatus.BLOCKING
|
||||
|
||||
# Check for partial issues
|
||||
has_unverified = any(
|
||||
e.status == Status.PRESENT_UNVERIFIED for e in evidence_items
|
||||
)
|
||||
|
||||
if has_unverified:
|
||||
return OverallStatus.PARTIAL
|
||||
|
||||
return OverallStatus.OK
|
||||
|
||||
def _build_reason_text(
|
||||
self,
|
||||
evidence_item: Any,
|
||||
status: Status,
|
||||
policy: CompiledCoveragePolicy,
|
||||
) -> str:
|
||||
"""Build human-readable reason text"""
|
||||
|
||||
evidence_id = evidence_item.id
|
||||
|
||||
# Get reason from policy if available
|
||||
if evidence_item.reasons and "short" in evidence_item.reasons:
|
||||
base_reason = evidence_item.reasons["short"]
|
||||
else:
|
||||
base_reason = f"{evidence_id} is required for this schedule."
|
||||
|
||||
# Add status-specific details
|
||||
if status == Status.MISSING:
|
||||
return f"No {evidence_id} found. {base_reason}"
|
||||
elif status == Status.PRESENT_UNVERIFIED:
|
||||
return (
|
||||
f"{evidence_id} present but confidence below threshold. {base_reason}"
|
||||
)
|
||||
elif status == Status.CONFLICTING:
|
||||
return f"Conflicting {evidence_id} documents found. {base_reason}"
|
||||
else:
|
||||
return f"{evidence_id} verified. {base_reason}"
|
||||
|
||||
def _parse_tax_year_bounds(
|
||||
self, start_str: str, end_str: str
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""Parse tax year boundary strings to datetime objects"""
|
||||
# Handle both date-only and datetime strings
|
||||
if "T" not in start_str:
|
||||
# Date-only string, add time and timezone
|
||||
start = datetime.fromisoformat(start_str + "T00:00:00+00:00")
|
||||
else:
|
||||
# Full datetime string, ensure timezone-aware
|
||||
start = datetime.fromisoformat(start_str.replace("Z", "+00:00"))
|
||||
|
||||
if "T" not in end_str:
|
||||
# Date-only string, add time and timezone (end of day)
|
||||
end = datetime.fromisoformat(end_str + "T23:59:59+00:00")
|
||||
else:
|
||||
# Full datetime string, ensure timezone-aware
|
||||
end = datetime.fromisoformat(end_str.replace("Z", "+00:00"))
|
||||
|
||||
return start, end
|
||||
18
libs/coverage/utils.py
Normal file
18
libs/coverage/utils.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Utility functions for coverage evaluation."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..schemas import CompiledCoveragePolicy, CoverageReport
|
||||
from .evaluator import CoverageEvaluator
|
||||
|
||||
|
||||
async def check_document_coverage(
|
||||
taxpayer_id: str,
|
||||
tax_year: str,
|
||||
policy: CompiledCoveragePolicy,
|
||||
kg_client: Any = None,
|
||||
rag_client: Any = None,
|
||||
) -> CoverageReport:
|
||||
"""Check document coverage for taxpayer"""
|
||||
evaluator = CoverageEvaluator(kg_client, rag_client)
|
||||
return await evaluator.check_document_coverage(taxpayer_id, tax_year, policy)
|
||||
336
libs/coverage_schema.json
Normal file
336
libs/coverage_schema.json
Normal file
@@ -0,0 +1,336 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"title": "Coverage Policy Schema",
|
||||
"type": "object",
|
||||
"required": [
|
||||
"version",
|
||||
"jurisdiction",
|
||||
"tax_year",
|
||||
"tax_year_boundary",
|
||||
"defaults",
|
||||
"document_kinds",
|
||||
"triggers",
|
||||
"schedules",
|
||||
"status_classifier",
|
||||
"conflict_resolution",
|
||||
"question_templates"
|
||||
],
|
||||
"properties": {
|
||||
"version": {
|
||||
"type": "string",
|
||||
"pattern": "^\\d+\\.\\d+$"
|
||||
},
|
||||
"jurisdiction": {
|
||||
"type": "string",
|
||||
"enum": ["UK", "US", "CA", "AU"]
|
||||
},
|
||||
"tax_year": {
|
||||
"type": "string",
|
||||
"pattern": "^\\d{4}-\\d{2}$"
|
||||
},
|
||||
"tax_year_boundary": {
|
||||
"type": "object",
|
||||
"required": ["start", "end"],
|
||||
"properties": {
|
||||
"start": {
|
||||
"type": "string",
|
||||
"format": "date"
|
||||
},
|
||||
"end": {
|
||||
"type": "string",
|
||||
"format": "date"
|
||||
}
|
||||
}
|
||||
},
|
||||
"defaults": {
|
||||
"type": "object",
|
||||
"required": ["confidence_thresholds"],
|
||||
"properties": {
|
||||
"confidence_thresholds": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ocr": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1
|
||||
},
|
||||
"extract": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
"date_tolerance_days": {
|
||||
"type": "integer",
|
||||
"minimum": 0
|
||||
},
|
||||
"require_lineage_bbox": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"allow_bank_substantiation": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
},
|
||||
"document_kinds": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"minItems": 1,
|
||||
"uniqueItems": true
|
||||
},
|
||||
"guidance_refs": {
|
||||
"type": "object",
|
||||
"patternProperties": {
|
||||
"^[A-Z0-9_]+$": {
|
||||
"type": "object",
|
||||
"required": ["doc_id", "kind"],
|
||||
"properties": {
|
||||
"doc_id": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"triggers": {
|
||||
"type": "object",
|
||||
"patternProperties": {
|
||||
"^SA\\d+[A-Z]*$": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"any_of": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
},
|
||||
"all_of": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
"anyOf": [
|
||||
{"required": ["any_of"]},
|
||||
{"required": ["all_of"]}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"schedules": {
|
||||
"type": "object",
|
||||
"patternProperties": {
|
||||
"^SA\\d+[A-Z]*$": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"guidance_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"evidence": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["id", "role"],
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"role": {
|
||||
"type": "string",
|
||||
"enum": ["REQUIRED", "CONDITIONALLY_REQUIRED", "OPTIONAL"]
|
||||
},
|
||||
"condition": {
|
||||
"type": "string"
|
||||
},
|
||||
"boxes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern": "^SA\\d+[A-Z]*_b\\d+(_\\d+)?$"
|
||||
},
|
||||
"minItems": 0
|
||||
},
|
||||
"acceptable_alternatives": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
},
|
||||
"validity": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"within_tax_year": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"available_by": {
|
||||
"type": "string",
|
||||
"format": "date"
|
||||
}
|
||||
}
|
||||
},
|
||||
"reasons": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"short": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"cross_checks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["name", "logic"],
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"logic": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"selection_rule": {
|
||||
"type": "object"
|
||||
},
|
||||
"notes": {
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"status_classifier": {
|
||||
"type": "object",
|
||||
"required": ["present_verified", "present_unverified", "conflicting", "missing"],
|
||||
"properties": {
|
||||
"present_verified": {
|
||||
"$ref": "#/definitions/statusClassifier"
|
||||
},
|
||||
"present_unverified": {
|
||||
"$ref": "#/definitions/statusClassifier"
|
||||
},
|
||||
"conflicting": {
|
||||
"$ref": "#/definitions/statusClassifier"
|
||||
},
|
||||
"missing": {
|
||||
"$ref": "#/definitions/statusClassifier"
|
||||
}
|
||||
}
|
||||
},
|
||||
"conflict_resolution": {
|
||||
"type": "object",
|
||||
"required": ["precedence"],
|
||||
"properties": {
|
||||
"precedence": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"minItems": 1
|
||||
},
|
||||
"escalation": {
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
},
|
||||
"question_templates": {
|
||||
"type": "object",
|
||||
"required": ["default"],
|
||||
"properties": {
|
||||
"default": {
|
||||
"type": "object",
|
||||
"required": ["text", "why"],
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
"why": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
"reasons": {
|
||||
"type": "object",
|
||||
"patternProperties": {
|
||||
"^[A-Za-z0-9_]+$": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"privacy": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"vector_pii_free": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"redact_patterns": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"definitions": {
|
||||
"statusClassifier": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"min_ocr": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1
|
||||
},
|
||||
"min_extract": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1
|
||||
},
|
||||
"date_in_year": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"date_in_year_or_tolerance": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"conflict_rules": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
},
|
||||
"default": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
282
libs/events/NATS_README.md
Normal file
282
libs/events/NATS_README.md
Normal file
@@ -0,0 +1,282 @@
|
||||
# NATS.io Event Bus with JetStream
|
||||
|
||||
This document describes the NATS.io event bus implementation with JetStream support for the AI Tax Agent project.
|
||||
|
||||
## Overview
|
||||
|
||||
The `NATSEventBus` class provides a robust, scalable event streaming solution using NATS.io with JetStream for persistent messaging. It implements the same `EventBus` interface as other event bus implementations (Kafka, SQS, Memory) for consistency.
|
||||
|
||||
## Features
|
||||
|
||||
- **JetStream Integration**: Uses NATS JetStream for persistent, reliable message delivery
|
||||
- **Automatic Stream Management**: Creates and manages JetStream streams automatically
|
||||
- **Pull-based Consumers**: Uses pull-based consumers for better flow control
|
||||
- **Cluster Support**: Supports NATS cluster configurations for high availability
|
||||
- **Error Handling**: Comprehensive error handling with automatic retries
|
||||
- **Message Acknowledgment**: Explicit message acknowledgment with configurable retry policies
|
||||
- **Durable Consumers**: Creates durable consumers for guaranteed message processing
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```python
|
||||
from libs.events import NATSEventBus
|
||||
|
||||
# Single server
|
||||
bus = NATSEventBus(
|
||||
servers="nats://localhost:4222",
|
||||
stream_name="TAX_AGENT_EVENTS",
|
||||
consumer_group="tax-agent"
|
||||
)
|
||||
|
||||
# Multiple servers (cluster)
|
||||
bus = NATSEventBus(
|
||||
servers=[
|
||||
"nats://nats1.example.com:4222",
|
||||
"nats://nats2.example.com:4222",
|
||||
"nats://nats3.example.com:4222"
|
||||
],
|
||||
stream_name="PRODUCTION_EVENTS",
|
||||
consumer_group="tax-agent-prod"
|
||||
)
|
||||
```
|
||||
|
||||
### Factory Configuration
|
||||
|
||||
```python
|
||||
from libs.events import create_event_bus
|
||||
|
||||
bus = create_event_bus(
|
||||
"nats",
|
||||
servers="nats://localhost:4222",
|
||||
stream_name="TAX_AGENT_EVENTS",
|
||||
consumer_group="tax-agent"
|
||||
)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Publishing Events
|
||||
|
||||
```python
|
||||
from libs.events import EventPayload
|
||||
|
||||
# Create event payload
|
||||
payload = EventPayload(
|
||||
data={"user_id": "123", "action": "login"},
|
||||
actor="user-service",
|
||||
tenant_id="tenant-456",
|
||||
trace_id="trace-789"
|
||||
)
|
||||
|
||||
# Publish event
|
||||
success = await bus.publish("user.login", payload)
|
||||
if success:
|
||||
print("Event published successfully")
|
||||
```
|
||||
|
||||
### Subscribing to Events
|
||||
|
||||
```python
|
||||
async def handle_user_login(topic: str, payload: EventPayload) -> None:
|
||||
print(f"User {payload.data['user_id']} logged in")
|
||||
# Process the event...
|
||||
|
||||
# Subscribe to topic
|
||||
await bus.subscribe("user.login", handle_user_login)
|
||||
```
|
||||
|
||||
### Complete Example
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from libs.events import NATSEventBus, EventPayload
|
||||
|
||||
async def main():
|
||||
bus = NATSEventBus()
|
||||
|
||||
try:
|
||||
# Start the bus
|
||||
await bus.start()
|
||||
|
||||
# Subscribe to events
|
||||
await bus.subscribe("user.created", handle_user_created)
|
||||
|
||||
# Publish an event
|
||||
payload = EventPayload(
|
||||
data={"user_id": "123", "email": "user@example.com"},
|
||||
actor="registration-service",
|
||||
tenant_id="tenant-456"
|
||||
)
|
||||
await bus.publish("user.created", payload)
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(1)
|
||||
|
||||
finally:
|
||||
await bus.stop()
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## JetStream Configuration
|
||||
|
||||
The NATS event bus automatically creates and configures JetStream streams with the following settings:
|
||||
|
||||
- **Retention Policy**: Work Queue (messages are removed after acknowledgment)
|
||||
- **Max Age**: 7 days (messages older than 7 days are automatically deleted)
|
||||
- **Storage**: File-based storage for persistence
|
||||
- **Subject Pattern**: `{stream_name}.*` (e.g., `TAX_AGENT_EVENTS.*`)
|
||||
|
||||
### Consumer Configuration
|
||||
|
||||
- **Durable Consumers**: Each topic subscription creates a durable consumer
|
||||
- **Ack Policy**: Explicit acknowledgment required
|
||||
- **Deliver Policy**: New messages only (doesn't replay old messages)
|
||||
- **Max Deliver**: 3 attempts before message is considered failed
|
||||
- **Ack Wait**: 30 seconds timeout for acknowledgment
|
||||
|
||||
## Error Handling
|
||||
|
||||
The NATS event bus includes comprehensive error handling:
|
||||
|
||||
### Publishing Errors
|
||||
- Network failures are logged and return `False`
|
||||
- Automatic retry logic can be implemented at the application level
|
||||
|
||||
### Consumer Errors
|
||||
- Handler exceptions are caught and logged
|
||||
- Failed messages are negatively acknowledged (NAK) for retry
|
||||
- Messages that fail multiple times are moved to a dead letter queue (if configured)
|
||||
|
||||
### Connection Errors
|
||||
- Automatic reconnection is handled by the NATS client
|
||||
- Consumer tasks are gracefully shut down on connection loss
|
||||
|
||||
## Monitoring and Observability
|
||||
|
||||
The implementation includes structured logging with the following information:
|
||||
|
||||
- Event publishing success/failure
|
||||
- Consumer subscription status
|
||||
- Message processing metrics
|
||||
- Error details and stack traces
|
||||
|
||||
### Log Examples
|
||||
|
||||
```
|
||||
INFO: Event published topic=user.created event_id=01HK... stream_seq=123
|
||||
INFO: Subscribed to topic topic=user.login consumer=tax-agent-user.login
|
||||
ERROR: Handler failed topic=user.created event_id=01HK... error=...
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Throughput
|
||||
- Pull-based consumers allow for controlled message processing
|
||||
- Batch fetching (up to 10 messages per fetch) improves throughput
|
||||
- Async processing enables high concurrency
|
||||
|
||||
### Memory Usage
|
||||
- File-based storage keeps memory usage low
|
||||
- Configurable message retention prevents unbounded growth
|
||||
|
||||
### Network Efficiency
|
||||
- Binary protocol with minimal overhead
|
||||
- Connection pooling and reuse
|
||||
- Efficient subject-based routing
|
||||
|
||||
## Deployment
|
||||
|
||||
### Docker Compose Example
|
||||
|
||||
```yaml
|
||||
services:
|
||||
nats:
|
||||
image: nats:2.10-alpine
|
||||
ports:
|
||||
- "4222:4222"
|
||||
- "8222:8222"
|
||||
command:
|
||||
- "--jetstream"
|
||||
- "--store_dir=/data"
|
||||
- "--http_port=8222"
|
||||
volumes:
|
||||
- nats_data:/data
|
||||
|
||||
volumes:
|
||||
nats_data:
|
||||
```
|
||||
|
||||
### Kubernetes Example
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: StatefulSet
|
||||
metadata:
|
||||
name: nats
|
||||
spec:
|
||||
serviceName: nats
|
||||
replicas: 3
|
||||
selector:
|
||||
matchLabels:
|
||||
app: nats
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: nats
|
||||
spec:
|
||||
containers:
|
||||
- name: nats
|
||||
image: nats:2.10-alpine
|
||||
args:
|
||||
- "--cluster_name=nats-cluster"
|
||||
- "--jetstream"
|
||||
- "--store_dir=/data"
|
||||
ports:
|
||||
- containerPort: 4222
|
||||
- containerPort: 6222
|
||||
- containerPort: 8222
|
||||
volumeMounts:
|
||||
- name: nats-storage
|
||||
mountPath: /data
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
name: nats-storage
|
||||
spec:
|
||||
accessModes: ["ReadWriteOnce"]
|
||||
resources:
|
||||
requests:
|
||||
storage: 10Gi
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
The NATS event bus requires the following Python package:
|
||||
|
||||
```
|
||||
nats-py>=2.6.0
|
||||
```
|
||||
|
||||
This is automatically included in `libs/requirements.txt`.
|
||||
|
||||
## Comparison with Other Event Buses
|
||||
|
||||
| Feature | NATS | Kafka | SQS |
|
||||
|---------|------|-------|-----|
|
||||
| Setup Complexity | Low | Medium | Low |
|
||||
| Throughput | High | Very High | Medium |
|
||||
| Latency | Very Low | Low | Medium |
|
||||
| Persistence | Yes (JetStream) | Yes | Yes |
|
||||
| Ordering | Per Subject | Per Partition | FIFO Queues |
|
||||
| Clustering | Built-in | Built-in | Managed |
|
||||
| Operational Overhead | Low | High | None |
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use meaningful subject names**: Follow a hierarchical naming convention (e.g., `service.entity.action`)
|
||||
2. **Handle failures gracefully**: Implement proper error handling in event handlers
|
||||
3. **Monitor consumer lag**: Track message processing delays
|
||||
4. **Use appropriate retention**: Configure message retention based on business requirements
|
||||
5. **Test failure scenarios**: Verify behavior during network partitions and service failures
|
||||
20
libs/events/__init__.py
Normal file
20
libs/events/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Event-driven architecture with Kafka, SQS, NATS, and Memory support."""
|
||||
|
||||
from .base import EventBus, EventPayload
|
||||
from .factory import create_event_bus
|
||||
from .kafka_bus import KafkaEventBus
|
||||
from .memory_bus import MemoryEventBus
|
||||
from .nats_bus import NATSEventBus
|
||||
from .sqs_bus import SQSEventBus
|
||||
from .topics import EventTopics
|
||||
|
||||
__all__ = [
|
||||
"EventPayload",
|
||||
"EventBus",
|
||||
"KafkaEventBus",
|
||||
"MemoryEventBus",
|
||||
"NATSEventBus",
|
||||
"SQSEventBus",
|
||||
"create_event_bus",
|
||||
"EventTopics",
|
||||
]
|
||||
68
libs/events/base.py
Normal file
68
libs/events/base.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Base event classes and interfaces."""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import ulid
|
||||
|
||||
|
||||
# Each payload MUST include: `event_id (ulid)`, `occurred_at (iso)`, `actor`, `tenant_id`, `trace_id`, `schema_version`, and a `data` object (service-specific).
|
||||
class EventPayload:
|
||||
"""Standard event payload structure"""
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
actor: str,
|
||||
tenant_id: str,
|
||||
trace_id: str | None = None,
|
||||
schema_version: str = "1.0",
|
||||
):
|
||||
self.event_id = str(ulid.new())
|
||||
self.occurred_at = datetime.utcnow().isoformat() + "Z"
|
||||
self.actor = actor
|
||||
self.tenant_id = tenant_id
|
||||
self.trace_id = trace_id
|
||||
self.schema_version = schema_version
|
||||
self.data = data
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for serialization"""
|
||||
return {
|
||||
"event_id": self.event_id,
|
||||
"occurred_at": self.occurred_at,
|
||||
"actor": self.actor,
|
||||
"tenant_id": self.tenant_id,
|
||||
"trace_id": self.trace_id,
|
||||
"schema_version": self.schema_version,
|
||||
"data": self.data,
|
||||
}
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert to JSON string"""
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
|
||||
class EventBus(ABC):
|
||||
"""Abstract event bus interface"""
|
||||
|
||||
@abstractmethod
|
||||
async def publish(self, topic: str, payload: EventPayload) -> bool:
|
||||
"""Publish event to topic"""
|
||||
|
||||
@abstractmethod
|
||||
async def subscribe(
|
||||
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Subscribe to topic with handler"""
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""Start the event bus"""
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop the event bus"""
|
||||
163
libs/events/examples/nats_example.py
Normal file
163
libs/events/examples/nats_example.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Example usage of NATS.io event bus with JetStream."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from libs.events import EventPayload, NATSEventBus, create_event_bus
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def example_handler(topic: str, payload: EventPayload) -> None:
|
||||
"""Example event handler."""
|
||||
logger.info(
|
||||
f"Received event on topic '{topic}': "
|
||||
f"ID={payload.event_id}, "
|
||||
f"Actor={payload.actor}, "
|
||||
f"Data={payload.data}"
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main example function."""
|
||||
# Method 1: Direct instantiation
|
||||
nats_bus = NATSEventBus(
|
||||
servers="nats://localhost:4222", # Can be a list for cluster
|
||||
stream_name="TAX_AGENT_EVENTS",
|
||||
consumer_group="tax-agent",
|
||||
)
|
||||
|
||||
# Method 2: Using factory
|
||||
# nats_bus = create_event_bus(
|
||||
# "nats",
|
||||
# servers="nats://localhost:4222",
|
||||
# stream_name="TAX_AGENT_EVENTS",
|
||||
# consumer_group="tax-agent",
|
||||
# )
|
||||
|
||||
try:
|
||||
# Start the event bus
|
||||
await nats_bus.start()
|
||||
logger.info("NATS event bus started")
|
||||
|
||||
# Subscribe to a topic
|
||||
await nats_bus.subscribe("user.created", example_handler)
|
||||
await nats_bus.subscribe("user.updated", example_handler)
|
||||
logger.info("Subscribed to topics")
|
||||
|
||||
# Publish some events
|
||||
for i in range(5):
|
||||
payload = EventPayload(
|
||||
data={"user_id": f"user-{i}", "name": f"User {i}"},
|
||||
actor="system",
|
||||
tenant_id="tenant-123",
|
||||
trace_id=f"trace-{i}",
|
||||
)
|
||||
|
||||
success = await nats_bus.publish("user.created", payload)
|
||||
if success:
|
||||
logger.info(f"Published event {i}")
|
||||
else:
|
||||
logger.error(f"Failed to publish event {i}")
|
||||
|
||||
# Wait a bit for messages to be processed
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Publish an update event
|
||||
update_payload = EventPayload(
|
||||
data={"user_id": "user-1", "name": "Updated User 1", "email": "user1@example.com"},
|
||||
actor="admin",
|
||||
tenant_id="tenant-123",
|
||||
)
|
||||
|
||||
await nats_bus.publish("user.updated", update_payload)
|
||||
logger.info("Published update event")
|
||||
|
||||
# Wait for processing
|
||||
await asyncio.sleep(2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in example: {e}")
|
||||
finally:
|
||||
# Stop the event bus
|
||||
await nats_bus.stop()
|
||||
logger.info("NATS event bus stopped")
|
||||
|
||||
|
||||
async def cluster_example():
|
||||
"""Example with NATS cluster configuration."""
|
||||
# Connect to a NATS cluster
|
||||
cluster_bus = NATSEventBus(
|
||||
servers=[
|
||||
"nats://nats1.example.com:4222",
|
||||
"nats://nats2.example.com:4222",
|
||||
"nats://nats3.example.com:4222",
|
||||
],
|
||||
stream_name="PRODUCTION_EVENTS",
|
||||
consumer_group="tax-agent-prod",
|
||||
)
|
||||
|
||||
try:
|
||||
await cluster_bus.start()
|
||||
logger.info("Connected to NATS cluster")
|
||||
|
||||
# Subscribe to multiple topics
|
||||
topics = ["document.uploaded", "document.processed", "tax.calculated"]
|
||||
for topic in topics:
|
||||
await cluster_bus.subscribe(topic, example_handler)
|
||||
|
||||
logger.info(f"Subscribed to {len(topics)} topics")
|
||||
|
||||
# Keep running for a while
|
||||
await asyncio.sleep(10)
|
||||
|
||||
finally:
|
||||
await cluster_bus.stop()
|
||||
|
||||
|
||||
async def error_handling_example():
|
||||
"""Example showing error handling."""
|
||||
|
||||
async def failing_handler(topic: str, payload: EventPayload) -> None:
|
||||
"""Handler that sometimes fails."""
|
||||
if payload.data.get("should_fail"):
|
||||
raise ValueError("Simulated handler failure")
|
||||
logger.info(f"Successfully processed event {payload.event_id}")
|
||||
|
||||
bus = NATSEventBus()
|
||||
|
||||
try:
|
||||
await bus.start()
|
||||
await bus.subscribe("test.events", failing_handler)
|
||||
|
||||
# Publish a good event
|
||||
good_payload = EventPayload(
|
||||
data={"message": "This will succeed"},
|
||||
actor="test",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
await bus.publish("test.events", good_payload)
|
||||
|
||||
# Publish a bad event
|
||||
bad_payload = EventPayload(
|
||||
data={"message": "This will fail", "should_fail": True},
|
||||
actor="test",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
await bus.publish("test.events", bad_payload)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
finally:
|
||||
await bus.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the basic example
|
||||
asyncio.run(main())
|
||||
|
||||
# Uncomment to run other examples:
|
||||
# asyncio.run(cluster_example())
|
||||
# asyncio.run(error_handling_example())
|
||||
23
libs/events/factory.py
Normal file
23
libs/events/factory.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Factory function for creating event bus instances."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import EventBus
|
||||
from .kafka_bus import KafkaEventBus
|
||||
from .nats_bus import NATSEventBus
|
||||
from .sqs_bus import SQSEventBus
|
||||
|
||||
|
||||
def create_event_bus(bus_type: str, **kwargs: Any) -> EventBus:
|
||||
"""Factory function to create event bus"""
|
||||
if bus_type.lower() == "kafka":
|
||||
return KafkaEventBus(kwargs.get("bootstrap_servers", "localhost:9092"))
|
||||
if bus_type.lower() == "sqs":
|
||||
return SQSEventBus(kwargs.get("region_name", "us-east-1"))
|
||||
if bus_type.lower() == "nats":
|
||||
return NATSEventBus(
|
||||
servers=kwargs.get("servers", "nats://localhost:4222"),
|
||||
stream_name=kwargs.get("stream_name", "TAX_AGENT_EVENTS"),
|
||||
consumer_group=kwargs.get("consumer_group", "tax-agent"),
|
||||
)
|
||||
raise ValueError(f"Unsupported event bus type: {bus_type}")
|
||||
140
libs/events/kafka_bus.py
Normal file
140
libs/events/kafka_bus.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Kafka implementation of EventBus."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import structlog
|
||||
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer # type: ignore
|
||||
|
||||
from .base import EventBus, EventPayload
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class KafkaEventBus(EventBus):
|
||||
"""Kafka implementation of EventBus"""
|
||||
|
||||
def __init__(self, bootstrap_servers: str):
|
||||
self.bootstrap_servers = bootstrap_servers.split(",")
|
||||
self.producer: AIOKafkaProducer | None = None
|
||||
self.consumers: dict[str, AIOKafkaConsumer] = {}
|
||||
self.handlers: dict[
|
||||
str, list[Callable[[str, EventPayload], Awaitable[None]]]
|
||||
] = {}
|
||||
self.running = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Kafka producer"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.producer = AIOKafkaProducer(
|
||||
bootstrap_servers=",".join(self.bootstrap_servers),
|
||||
value_serializer=lambda v: v.encode("utf-8"),
|
||||
)
|
||||
await self.producer.start()
|
||||
self.running = True
|
||||
logger.info("Kafka event bus started", bootstrap_servers=self.bootstrap_servers)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop Kafka producer and consumers"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
if self.producer:
|
||||
await self.producer.stop()
|
||||
|
||||
for consumer in self.consumers.values():
|
||||
await consumer.stop()
|
||||
|
||||
self.running = False
|
||||
logger.info("Kafka event bus stopped")
|
||||
|
||||
async def publish(self, topic: str, payload: EventPayload) -> bool:
|
||||
"""Publish event to Kafka topic"""
|
||||
if not self.producer:
|
||||
raise RuntimeError("Event bus not started")
|
||||
|
||||
try:
|
||||
await self.producer.send_and_wait(topic, payload.to_json())
|
||||
logger.info(
|
||||
"Event published",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
actor=payload.actor,
|
||||
tenant_id=payload.tenant_id,
|
||||
)
|
||||
return True
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Failed to publish event",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
async def subscribe(
|
||||
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Subscribe to Kafka topic"""
|
||||
if topic not in self.handlers:
|
||||
self.handlers[topic] = []
|
||||
self.handlers[topic].append(handler)
|
||||
|
||||
if topic not in self.consumers:
|
||||
consumer = AIOKafkaConsumer(
|
||||
topic,
|
||||
bootstrap_servers=",".join(self.bootstrap_servers),
|
||||
value_deserializer=lambda m: m.decode("utf-8"),
|
||||
group_id=f"tax-agent-{topic}",
|
||||
auto_offset_reset="latest",
|
||||
)
|
||||
self.consumers[topic] = consumer
|
||||
await consumer.start()
|
||||
|
||||
# Start consumer task
|
||||
asyncio.create_task(self._consume_messages(topic, consumer))
|
||||
|
||||
logger.info("Subscribed to topic", topic=topic)
|
||||
|
||||
async def _consume_messages(self, topic: str, consumer: AIOKafkaConsumer) -> None:
|
||||
"""Consume messages from Kafka topic"""
|
||||
try:
|
||||
async for message in consumer:
|
||||
try:
|
||||
if message.value is not None:
|
||||
payload_dict = json.loads(message.value)
|
||||
else:
|
||||
continue
|
||||
payload = EventPayload(
|
||||
data=payload_dict["data"],
|
||||
actor=payload_dict["actor"],
|
||||
tenant_id=payload_dict["tenant_id"],
|
||||
trace_id=payload_dict.get("trace_id"),
|
||||
schema_version=payload_dict.get("schema_version", "1.0"),
|
||||
)
|
||||
payload.event_id = payload_dict["event_id"]
|
||||
payload.occurred_at = payload_dict["occurred_at"]
|
||||
|
||||
# Call all handlers for this topic
|
||||
for handler in self.handlers.get(topic, []):
|
||||
try:
|
||||
await handler(topic, payload)
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Handler failed",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
handler=handler.__name__,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("Failed to decode message", topic=topic, error=str(e))
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Failed to process message", topic=topic, error=str(e))
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Consumer error", topic=topic, error=str(e))
|
||||
64
libs/events/memory_bus.py
Normal file
64
libs/events/memory_bus.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""In-memory event bus for local development and testing."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from .base import EventBus, EventPayload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryEventBus(EventBus):
|
||||
"""In-memory event bus implementation for local development"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.handlers: dict[
|
||||
str, list[Callable[[str, EventPayload], Awaitable[None]]]
|
||||
] = defaultdict(list)
|
||||
self.running = False
|
||||
|
||||
async def publish(self, topic: str, payload: EventPayload) -> bool:
|
||||
"""Publish event to topic"""
|
||||
try:
|
||||
if not self.running:
|
||||
logger.warning(
|
||||
"Event bus not running, skipping publish to topic: %s", topic
|
||||
)
|
||||
return False
|
||||
|
||||
handlers = self.handlers.get(topic, [])
|
||||
if not handlers:
|
||||
logger.debug("No handlers for topic: %s", topic)
|
||||
return True
|
||||
|
||||
# Execute all handlers concurrently
|
||||
tasks = [handler(topic, payload) for handler in handlers]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logger.debug(
|
||||
"Published event to topic %s with %d handlers", topic, len(handlers)
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to publish event to topic %s: %s", topic, e)
|
||||
return False
|
||||
|
||||
async def subscribe(
|
||||
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Subscribe to topic with handler"""
|
||||
self.handlers[topic].append(handler)
|
||||
logger.debug("Subscribed handler to topic: %s", topic)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the event bus"""
|
||||
self.running = True
|
||||
logger.info("Memory event bus started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the event bus"""
|
||||
self.running = False
|
||||
self.handlers.clear()
|
||||
logger.info("Memory event bus stopped")
|
||||
269
libs/events/nats_bus.py
Normal file
269
libs/events/nats_bus.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""NATS.io with JetStream implementation of EventBus."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import nats # type: ignore
|
||||
import structlog
|
||||
from nats.aio.client import Client as NATS # type: ignore
|
||||
from nats.js import JetStreamContext # type: ignore
|
||||
|
||||
from .base import EventBus, EventPayload
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
"""NATS.io with JetStream implementation of EventBus"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
servers: str | list[str] = "nats://localhost:4222",
|
||||
stream_name: str = "TAX_AGENT_EVENTS",
|
||||
consumer_group: str = "tax-agent",
|
||||
):
|
||||
if isinstance(servers, str):
|
||||
self.servers = [servers]
|
||||
else:
|
||||
self.servers = servers
|
||||
|
||||
self.stream_name = stream_name
|
||||
self.consumer_group = consumer_group
|
||||
self.nc: NATS | None = None
|
||||
self.js: JetStreamContext | None = None
|
||||
self.handlers: dict[
|
||||
str, list[Callable[[str, EventPayload], Awaitable[None]]]
|
||||
] = {}
|
||||
self.subscriptions: dict[str, Any] = {}
|
||||
self.running = False
|
||||
self.consumer_tasks: list[asyncio.Task[None]] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start NATS connection and JetStream context"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Connect to NATS
|
||||
self.nc = await nats.connect(servers=self.servers)
|
||||
|
||||
# Get JetStream context
|
||||
self.js = self.nc.jetstream()
|
||||
|
||||
# Ensure stream exists
|
||||
await self._ensure_stream_exists()
|
||||
|
||||
self.running = True
|
||||
logger.info(
|
||||
"NATS event bus started",
|
||||
servers=self.servers,
|
||||
stream=self.stream_name,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to start NATS event bus", error=str(e))
|
||||
raise
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop NATS connection and consumers"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
# Cancel consumer tasks
|
||||
for task in self.consumer_tasks:
|
||||
task.cancel()
|
||||
|
||||
if self.consumer_tasks:
|
||||
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
|
||||
|
||||
# Unsubscribe from all subscriptions
|
||||
for subscription in self.subscriptions.values():
|
||||
try:
|
||||
await subscription.unsubscribe()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.warning("Error unsubscribing", error=str(e))
|
||||
|
||||
# Close NATS connection
|
||||
if self.nc:
|
||||
await self.nc.close()
|
||||
|
||||
self.running = False
|
||||
logger.info("NATS event bus stopped")
|
||||
|
||||
async def publish(self, topic: str, payload: EventPayload) -> bool:
|
||||
"""Publish event to NATS JetStream"""
|
||||
if not self.js:
|
||||
raise RuntimeError("Event bus not started")
|
||||
|
||||
try:
|
||||
# Create subject name from topic
|
||||
subject = f"{self.stream_name}.{topic}"
|
||||
|
||||
# Publish message with headers
|
||||
headers = {
|
||||
"event_id": payload.event_id,
|
||||
"tenant_id": payload.tenant_id,
|
||||
"actor": payload.actor,
|
||||
"trace_id": payload.trace_id or "",
|
||||
"schema_version": payload.schema_version,
|
||||
}
|
||||
|
||||
ack = await self.js.publish(
|
||||
subject=subject,
|
||||
payload=payload.to_json().encode(),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Event published",
|
||||
topic=topic,
|
||||
subject=subject,
|
||||
event_id=payload.event_id,
|
||||
stream_seq=ack.seq,
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Failed to publish event",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
async def subscribe(
|
||||
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Subscribe to NATS JetStream topic"""
|
||||
if not self.js:
|
||||
raise RuntimeError("Event bus not started")
|
||||
|
||||
if topic not in self.handlers:
|
||||
self.handlers[topic] = []
|
||||
self.handlers[topic].append(handler)
|
||||
|
||||
if topic not in self.subscriptions:
|
||||
try:
|
||||
# Create subject pattern for topic
|
||||
subject = f"{self.stream_name}.{topic}"
|
||||
|
||||
# Create durable consumer
|
||||
consumer_name = f"{self.consumer_group}-{topic}"
|
||||
|
||||
# Subscribe with pull-based consumer
|
||||
subscription = await self.js.pull_subscribe(
|
||||
subject=subject,
|
||||
durable=consumer_name,
|
||||
config=nats.js.api.ConsumerConfig(
|
||||
durable_name=consumer_name,
|
||||
ack_policy=nats.js.api.AckPolicy.EXPLICIT,
|
||||
deliver_policy=nats.js.api.DeliverPolicy.NEW,
|
||||
max_deliver=3,
|
||||
ack_wait=30, # 30 seconds
|
||||
),
|
||||
)
|
||||
|
||||
self.subscriptions[topic] = subscription
|
||||
|
||||
# Start consumer task
|
||||
task = asyncio.create_task(self._consume_messages(topic, subscription))
|
||||
self.consumer_tasks.append(task)
|
||||
|
||||
logger.info(
|
||||
"Subscribed to topic",
|
||||
topic=topic,
|
||||
subject=subject,
|
||||
consumer=consumer_name,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to subscribe to topic", topic=topic, error=str(e))
|
||||
raise
|
||||
|
||||
async def _ensure_stream_exists(self) -> None:
|
||||
"""Ensure JetStream stream exists"""
|
||||
if not self.js:
|
||||
return
|
||||
|
||||
try:
|
||||
# Try to get stream info
|
||||
await self.js.stream_info(self.stream_name)
|
||||
logger.debug("Stream already exists", stream=self.stream_name)
|
||||
|
||||
except nats.js.errors.NotFoundError:
|
||||
# Stream doesn't exist, create it
|
||||
try:
|
||||
await self.js.add_stream(
|
||||
name=self.stream_name,
|
||||
subjects=[f"{self.stream_name}.*"],
|
||||
retention=nats.js.api.RetentionPolicy.WORK_QUEUE,
|
||||
max_age=7 * 24 * 60 * 60, # 7 days in seconds
|
||||
storage=nats.js.api.StorageType.FILE,
|
||||
)
|
||||
logger.info("Created JetStream stream", stream=self.stream_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to create stream", stream=self.stream_name, error=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
async def _consume_messages(self, topic: str, subscription: Any) -> None:
|
||||
"""Consume messages from NATS JetStream subscription"""
|
||||
while self.running:
|
||||
try:
|
||||
# Fetch messages in batches
|
||||
messages = await subscription.fetch(batch=10, timeout=20)
|
||||
|
||||
for message in messages:
|
||||
try:
|
||||
# Parse message payload
|
||||
payload_dict = json.loads(message.data.decode())
|
||||
|
||||
payload = EventPayload(
|
||||
data=payload_dict["data"],
|
||||
actor=payload_dict["actor"],
|
||||
tenant_id=payload_dict["tenant_id"],
|
||||
trace_id=payload_dict.get("trace_id"),
|
||||
schema_version=payload_dict.get("schema_version", "1.0"),
|
||||
)
|
||||
payload.event_id = payload_dict["event_id"]
|
||||
payload.occurred_at = payload_dict["occurred_at"]
|
||||
|
||||
# Call all handlers for this topic
|
||||
for handler in self.handlers.get(topic, []):
|
||||
try:
|
||||
await handler(topic, payload)
|
||||
except (
|
||||
Exception
|
||||
) as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Handler failed",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Acknowledge message
|
||||
await message.ack()
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"Failed to decode message", topic=topic, error=str(e)
|
||||
)
|
||||
await message.nak()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Failed to process message", topic=topic, error=str(e)
|
||||
)
|
||||
await message.nak()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No messages available, continue polling
|
||||
continue
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Consumer error", topic=topic, error=str(e))
|
||||
await asyncio.sleep(5) # Wait before retrying
|
||||
212
libs/events/sqs_bus.py
Normal file
212
libs/events/sqs_bus.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""AWS SQS/SNS implementation of EventBus."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import boto3 # type: ignore
|
||||
import structlog
|
||||
from botocore.exceptions import ClientError # type: ignore
|
||||
|
||||
from .base import EventBus, EventPayload
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SQSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
"""AWS SQS/SNS implementation of EventBus"""
|
||||
|
||||
def __init__(self, region_name: str = "us-east-1"):
|
||||
self.region_name = region_name
|
||||
self.sns_client: Any = None
|
||||
self.sqs_client: Any = None
|
||||
self.topic_arns: dict[str, str] = {}
|
||||
self.queue_urls: dict[str, str] = {}
|
||||
self.handlers: dict[
|
||||
str, list[Callable[[str, EventPayload], Awaitable[None]]]
|
||||
] = {}
|
||||
self.running = False
|
||||
self.consumer_tasks: list[asyncio.Task[None]] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start SQS/SNS clients"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.sns_client = boto3.client("sns", region_name=self.region_name)
|
||||
self.sqs_client = boto3.client("sqs", region_name=self.region_name)
|
||||
self.running = True
|
||||
logger.info("SQS event bus started", region=self.region_name)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop SQS/SNS clients and consumers"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
# Cancel consumer tasks
|
||||
for task in self.consumer_tasks:
|
||||
task.cancel()
|
||||
|
||||
if self.consumer_tasks:
|
||||
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
|
||||
|
||||
self.running = False
|
||||
logger.info("SQS event bus stopped")
|
||||
|
||||
async def publish(self, topic: str, payload: EventPayload) -> bool:
|
||||
"""Publish event to SNS topic"""
|
||||
if not self.sns_client:
|
||||
raise RuntimeError("Event bus not started")
|
||||
|
||||
try:
|
||||
# Ensure topic exists
|
||||
topic_arn = await self._ensure_topic_exists(topic)
|
||||
|
||||
# Publish message
|
||||
response = self.sns_client.publish(
|
||||
TopicArn=topic_arn,
|
||||
Message=payload.to_json(),
|
||||
MessageAttributes={
|
||||
"event_id": {"DataType": "String", "StringValue": payload.event_id},
|
||||
"tenant_id": {
|
||||
"DataType": "String",
|
||||
"StringValue": payload.tenant_id,
|
||||
},
|
||||
"actor": {"DataType": "String", "StringValue": payload.actor},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Event published",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
message_id=response["MessageId"],
|
||||
)
|
||||
return True
|
||||
|
||||
except ClientError as e:
|
||||
logger.error(
|
||||
"Failed to publish event",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
async def subscribe(
|
||||
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Subscribe to SNS topic via SQS queue"""
|
||||
if topic not in self.handlers:
|
||||
self.handlers[topic] = []
|
||||
self.handlers[topic].append(handler)
|
||||
|
||||
if topic not in self.queue_urls:
|
||||
# Create SQS queue for this topic
|
||||
queue_name = f"tax-agent-{topic}"
|
||||
queue_url = await self._ensure_queue_exists(queue_name)
|
||||
self.queue_urls[topic] = queue_url
|
||||
|
||||
# Subscribe queue to SNS topic
|
||||
topic_arn = await self._ensure_topic_exists(topic)
|
||||
await self._subscribe_queue_to_topic(queue_url, topic_arn)
|
||||
|
||||
# Start consumer task
|
||||
task = asyncio.create_task(self._consume_messages(topic, queue_url))
|
||||
self.consumer_tasks.append(task)
|
||||
|
||||
logger.info("Subscribed to topic", topic=topic, queue_name=queue_name)
|
||||
|
||||
async def _ensure_topic_exists(self, topic: str) -> str:
|
||||
"""Ensure SNS topic exists and return ARN"""
|
||||
if topic in self.topic_arns:
|
||||
return self.topic_arns[topic]
|
||||
|
||||
try:
|
||||
response = self.sns_client.create_topic(Name=topic)
|
||||
topic_arn = response["TopicArn"]
|
||||
self.topic_arns[topic] = topic_arn
|
||||
return str(topic_arn)
|
||||
except ClientError as e:
|
||||
logger.error("Failed to create topic", topic=topic, error=str(e))
|
||||
raise
|
||||
|
||||
async def _ensure_queue_exists(self, queue_name: str) -> str:
|
||||
"""Ensure SQS queue exists and return URL"""
|
||||
try:
|
||||
response = self.sqs_client.create_queue(QueueName=queue_name)
|
||||
return str(response["QueueUrl"])
|
||||
except ClientError as e:
|
||||
logger.error("Failed to create queue", queue_name=queue_name, error=str(e))
|
||||
raise
|
||||
|
||||
async def _subscribe_queue_to_topic(self, queue_url: str, topic_arn: str) -> None:
|
||||
"""Subscribe SQS queue to SNS topic"""
|
||||
try:
|
||||
# Get queue attributes
|
||||
queue_attrs = self.sqs_client.get_queue_attributes(
|
||||
QueueUrl=queue_url, AttributeNames=["QueueArn"]
|
||||
)
|
||||
queue_arn = queue_attrs["Attributes"]["QueueArn"]
|
||||
|
||||
# Subscribe queue to topic
|
||||
self.sns_client.subscribe(
|
||||
TopicArn=topic_arn, Protocol="sqs", Endpoint=queue_arn
|
||||
)
|
||||
except ClientError as e:
|
||||
logger.error("Failed to subscribe queue to topic", error=str(e))
|
||||
raise
|
||||
|
||||
async def _consume_messages(self, topic: str, queue_url: str) -> None:
|
||||
"""Consume messages from SQS queue"""
|
||||
# pylint: disable=too-many-nested-blocks
|
||||
while self.running:
|
||||
try:
|
||||
response = self.sqs_client.receive_message(
|
||||
QueueUrl=queue_url, MaxNumberOfMessages=10, WaitTimeSeconds=20
|
||||
)
|
||||
|
||||
messages = response.get("Messages", [])
|
||||
for message in messages:
|
||||
try:
|
||||
# Parse SNS message
|
||||
sns_message = json.loads(message["Body"])
|
||||
payload_dict = json.loads(sns_message["Message"])
|
||||
|
||||
payload = EventPayload(
|
||||
data=payload_dict["data"],
|
||||
actor=payload_dict["actor"],
|
||||
tenant_id=payload_dict["tenant_id"],
|
||||
trace_id=payload_dict.get("trace_id"),
|
||||
schema_version=payload_dict.get("schema_version", "1.0"),
|
||||
)
|
||||
payload.event_id = payload_dict["event_id"]
|
||||
payload.occurred_at = payload_dict["occurred_at"]
|
||||
|
||||
# Call all handlers for this topic
|
||||
for handler in self.handlers.get(topic, []):
|
||||
try:
|
||||
await handler(topic, payload)
|
||||
# pylint: disable=broad-exception-caught
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Handler failed",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Delete message from queue
|
||||
self.sqs_client.delete_message(
|
||||
QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"]
|
||||
)
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Failed to process message", topic=topic, error=str(e)
|
||||
)
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Consumer error", topic=topic, error=str(e))
|
||||
await asyncio.sleep(5) # Wait before retrying
|
||||
17
libs/events/topics.py
Normal file
17
libs/events/topics.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Standard event topic names."""
|
||||
|
||||
|
||||
class EventTopics: # pylint: disable=too-few-public-methods
|
||||
"""Standard event topic names"""
|
||||
|
||||
DOC_INGESTED = "doc.ingested"
|
||||
DOC_OCR_READY = "doc.ocr_ready"
|
||||
DOC_EXTRACTED = "doc.extracted"
|
||||
KG_UPSERTED = "kg.upserted"
|
||||
RAG_INDEXED = "rag.indexed"
|
||||
CALC_SCHEDULE_READY = "calc.schedule_ready"
|
||||
FORM_FILLED = "form.filled"
|
||||
HMRC_SUBMITTED = "hmrc.submitted"
|
||||
REVIEW_REQUESTED = "review.requested"
|
||||
REVIEW_COMPLETED = "review.completed"
|
||||
FIRM_SYNC_COMPLETED = "firm.sync.completed"
|
||||
10
libs/forms/__init__.py
Normal file
10
libs/forms/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""PDF form filling and evidence pack generation."""
|
||||
|
||||
from .evidence_pack import UK_TAX_FORMS, EvidencePackGenerator
|
||||
from .pdf_filler import PDFFormFiller
|
||||
|
||||
__all__ = [
|
||||
"PDFFormFiller",
|
||||
"EvidencePackGenerator",
|
||||
"UK_TAX_FORMS",
|
||||
]
|
||||
185
libs/forms/evidence_pack.py
Normal file
185
libs/forms/evidence_pack.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Evidence pack generation with manifests and signatures."""
|
||||
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class EvidencePackGenerator: # pylint: disable=too-few-public-methods
|
||||
"""Generate evidence packs with manifests and signatures"""
|
||||
|
||||
def __init__(self, storage_client: Any) -> None:
|
||||
self.storage = storage_client
|
||||
|
||||
async def create_evidence_pack( # pylint: disable=too-many-locals
|
||||
self,
|
||||
taxpayer_id: str,
|
||||
tax_year: str,
|
||||
scope: str,
|
||||
evidence_items: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""Create evidence pack with manifest and signatures"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
import hashlib
|
||||
import json
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
# Create ZIP buffer
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
|
||||
manifest: dict[str, Any] = {
|
||||
"taxpayer_id": taxpayer_id,
|
||||
"tax_year": tax_year,
|
||||
"scope": scope,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"evidence_items": [],
|
||||
"signatures": {},
|
||||
}
|
||||
|
||||
# Add evidence files to ZIP
|
||||
for item in evidence_items:
|
||||
doc_id = item["doc_id"]
|
||||
page = item.get("page")
|
||||
bbox = item.get("bbox")
|
||||
text_hash = item.get("text_hash")
|
||||
|
||||
# Get document content
|
||||
doc_content = await self.storage.get_object(
|
||||
bucket_name="raw-documents",
|
||||
object_name=f"tenants/{taxpayer_id}/raw/{doc_id}.pdf",
|
||||
)
|
||||
|
||||
if doc_content:
|
||||
# Add to ZIP
|
||||
zip_filename = f"documents/{doc_id}.pdf"
|
||||
zip_file.writestr(zip_filename, doc_content)
|
||||
|
||||
# Calculate file hash
|
||||
file_hash = hashlib.sha256(doc_content).hexdigest()
|
||||
|
||||
# Add to manifest
|
||||
manifest["evidence_items"].append(
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"filename": zip_filename,
|
||||
"page": page,
|
||||
"bbox": bbox,
|
||||
"text_hash": text_hash,
|
||||
"file_hash": file_hash,
|
||||
"file_size": len(doc_content),
|
||||
}
|
||||
)
|
||||
|
||||
# Sign manifest
|
||||
manifest_json = json.dumps(manifest, indent=2, sort_keys=True)
|
||||
manifest_hash = hashlib.sha256(manifest_json.encode()).hexdigest()
|
||||
|
||||
manifest["signatures"]["manifest_hash"] = manifest_hash
|
||||
manifest["signatures"]["algorithm"] = "SHA-256"
|
||||
|
||||
# Add manifest to ZIP
|
||||
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
|
||||
|
||||
# Get ZIP content
|
||||
zip_content = zip_buffer.getvalue()
|
||||
|
||||
# Store evidence pack
|
||||
pack_filename = f"evidence_pack_{taxpayer_id}_{tax_year}_{scope}.zip"
|
||||
pack_key = f"tenants/{taxpayer_id}/evidence_packs/{pack_filename}"
|
||||
|
||||
success = await self.storage.put_object(
|
||||
bucket_name="evidence-packs",
|
||||
object_name=pack_key,
|
||||
data=io.BytesIO(zip_content),
|
||||
length=len(zip_content),
|
||||
content_type="application/zip",
|
||||
)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"pack_filename": pack_filename,
|
||||
"pack_key": pack_key,
|
||||
"pack_size": len(zip_content),
|
||||
"evidence_count": len(evidence_items),
|
||||
"manifest_hash": manifest_hash,
|
||||
"s3_url": f"s3://evidence-packs/{pack_key}",
|
||||
}
|
||||
raise RuntimeError("Failed to store evidence pack")
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Failed to create evidence pack", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# Form configuration for UK tax forms
|
||||
UK_TAX_FORMS = {
|
||||
"SA100": {
|
||||
"name": "Self Assessment Tax Return",
|
||||
"template_path": "forms/templates/SA100.pdf",
|
||||
"boxes": {
|
||||
"1": {"description": "Your name", "type": "text"},
|
||||
"2": {"description": "Your address", "type": "text"},
|
||||
"3": {"description": "Your UTR", "type": "text"},
|
||||
"4": {"description": "Your NI number", "type": "text"},
|
||||
},
|
||||
},
|
||||
"SA103": {
|
||||
"name": "Self-employment (full)",
|
||||
"template_path": "forms/templates/SA103.pdf",
|
||||
"boxes": {
|
||||
"1": {"description": "Business name", "type": "text"},
|
||||
"2": {"description": "Business description", "type": "text"},
|
||||
"3": {"description": "Accounting period start", "type": "date"},
|
||||
"4": {"description": "Accounting period end", "type": "date"},
|
||||
"20": {"description": "Total turnover", "type": "currency"},
|
||||
"31": {
|
||||
"description": "Total allowable business expenses",
|
||||
"type": "currency",
|
||||
},
|
||||
"32": {"description": "Net profit", "type": "currency"},
|
||||
"33": {"description": "Balancing charges", "type": "currency"},
|
||||
"34": {"description": "Goods/services for own use", "type": "currency"},
|
||||
"35": {"description": "Total taxable profits", "type": "currency"},
|
||||
},
|
||||
},
|
||||
"SA105": {
|
||||
"name": "Property income",
|
||||
"template_path": "forms/templates/SA105.pdf",
|
||||
"boxes": {
|
||||
"20": {"description": "Total rents and other income", "type": "currency"},
|
||||
"29": {
|
||||
"description": "Premiums for the grant of a lease",
|
||||
"type": "currency",
|
||||
},
|
||||
"31": {
|
||||
"description": "Rent, rates, insurance, ground rents etc",
|
||||
"type": "currency",
|
||||
},
|
||||
"32": {"description": "Property management", "type": "currency"},
|
||||
"33": {
|
||||
"description": "Services provided, including wages",
|
||||
"type": "currency",
|
||||
},
|
||||
"34": {
|
||||
"description": "Repairs, maintenance and renewals",
|
||||
"type": "currency",
|
||||
},
|
||||
"35": {
|
||||
"description": "Finance costs, including interest",
|
||||
"type": "currency",
|
||||
},
|
||||
"36": {"description": "Professional fees", "type": "currency"},
|
||||
"37": {"description": "Costs of services provided", "type": "currency"},
|
||||
"38": {
|
||||
"description": "Other allowable property expenses",
|
||||
"type": "currency",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
246
libs/forms/pdf_filler.py
Normal file
246
libs/forms/pdf_filler.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""PDF form filling using pdfrw with reportlab fallback."""
|
||||
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PDFFormFiller:
|
||||
"""PDF form filling using pdfrw with reportlab fallback"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.form_templates: dict[str, Any] = {}
|
||||
|
||||
def load_template(self, form_id: str, template_path: str) -> bool:
|
||||
"""Load PDF form template"""
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from pdfrw import PdfReader # type: ignore
|
||||
|
||||
template = PdfReader(template_path)
|
||||
if template is None:
|
||||
logger.error(
|
||||
"Failed to load PDF template", form_id=form_id, path=template_path
|
||||
)
|
||||
return False
|
||||
|
||||
self.form_templates[form_id] = {"template": template, "path": template_path}
|
||||
|
||||
logger.info("Loaded PDF template", form_id=form_id, path=template_path)
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
logger.error("pdfrw not available for PDF form filling")
|
||||
return False
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Failed to load PDF template", form_id=form_id, error=str(e))
|
||||
return False
|
||||
|
||||
def fill_form(
|
||||
self,
|
||||
form_id: str,
|
||||
field_values: dict[str, str | int | float | bool],
|
||||
output_path: str | None = None,
|
||||
) -> bytes | None:
|
||||
"""Fill PDF form with values"""
|
||||
|
||||
if form_id not in self.form_templates:
|
||||
logger.error("Form template not loaded", form_id=form_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
return self._fill_with_pdfrw(form_id, field_values, output_path)
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.warning(
|
||||
"pdfrw filling failed, trying reportlab overlay", error=str(e)
|
||||
)
|
||||
return self._fill_with_overlay(form_id, field_values, output_path)
|
||||
|
||||
def _fill_with_pdfrw(
|
||||
self,
|
||||
form_id: str,
|
||||
field_values: dict[str, Any],
|
||||
output_path: str | None = None,
|
||||
) -> bytes | None:
|
||||
"""Fill form using pdfrw"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from pdfrw import PdfDict, PdfReader, PdfWriter
|
||||
|
||||
template_info = self.form_templates[form_id]
|
||||
template = PdfReader(template_info["path"])
|
||||
|
||||
# Get form fields
|
||||
if template.Root.AcroForm is None: # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue] # fmt: skip
|
||||
logger.warning("PDF has no AcroForm fields", form_id=form_id)
|
||||
return self._fill_with_overlay(form_id, field_values, output_path)
|
||||
|
||||
# Fill form fields
|
||||
for field in template.Root.AcroForm.Fields: # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue] # fmt: skip
|
||||
field_name = field.T
|
||||
if field_name and field_name[1:-1] in field_values: # Remove parentheses
|
||||
field_value = field_values[field_name[1:-1]]
|
||||
|
||||
# Set field value
|
||||
if isinstance(field_value, bool):
|
||||
# Checkbox field
|
||||
if field_value:
|
||||
field.V = PdfDict.Yes # fmt: skip # pyright: ignore[reportAttributeAccessIssue]
|
||||
field.AS = PdfDict.Yes # fmt: skip # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
field.V = PdfDict.Off # fmt: skip # pyright: ignore[reportAttributeAccessIssue]
|
||||
field.AS = PdfDict.Off # fmt: skip # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
# Text field
|
||||
field.V = str(field_value)
|
||||
|
||||
# Make field read-only
|
||||
field.Ff = 1 # Read-only flag
|
||||
|
||||
# Flatten form (make fields non-editable)
|
||||
if template.Root.AcroForm: # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue] # fmt: skip
|
||||
template.Root.AcroForm.NeedAppearances = True # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue] # fmt: skip
|
||||
|
||||
# Write to output
|
||||
if output_path:
|
||||
writer = PdfWriter(output_path)
|
||||
writer.write(template)
|
||||
with open(output_path, "rb") as f:
|
||||
return f.read()
|
||||
else:
|
||||
# Write to bytes
|
||||
output_buffer = io.BytesIO()
|
||||
writer = PdfWriter(output_buffer)
|
||||
writer.write(template)
|
||||
return output_buffer.getvalue()
|
||||
|
||||
def _fill_with_overlay( # pylint: disable=too-many-locals
|
||||
self,
|
||||
form_id: str,
|
||||
field_values: dict[str, Any],
|
||||
output_path: str | None = None,
|
||||
) -> bytes | None:
|
||||
"""Fill form using reportlab overlay method"""
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from PyPDF2 import PdfReader, PdfWriter
|
||||
from reportlab.lib.pagesizes import A4
|
||||
from reportlab.pdfgen import canvas
|
||||
|
||||
template_info = self.form_templates[form_id]
|
||||
|
||||
# Read original PDF
|
||||
original_pdf = PdfReader(template_info["path"])
|
||||
|
||||
# Create overlay with form data
|
||||
overlay_buffer = io.BytesIO()
|
||||
overlay_canvas = canvas.Canvas(overlay_buffer, pagesize=A4)
|
||||
|
||||
# Get field positions (this would be configured per form)
|
||||
field_positions = self._get_field_positions(form_id)
|
||||
|
||||
# Add text to overlay
|
||||
for field_name, value in field_values.items():
|
||||
if field_name in field_positions:
|
||||
pos = field_positions[field_name]
|
||||
overlay_canvas.drawString(pos["x"], pos["y"], str(value))
|
||||
|
||||
overlay_canvas.save()
|
||||
overlay_buffer.seek(0)
|
||||
|
||||
# Read overlay PDF
|
||||
overlay_pdf = PdfReader(overlay_buffer)
|
||||
|
||||
# Merge original and overlay
|
||||
writer = PdfWriter()
|
||||
for page_num, _ in enumerate(original_pdf.pages):
|
||||
original_page = original_pdf.pages[page_num]
|
||||
|
||||
if page_num < len(overlay_pdf.pages):
|
||||
overlay_page = overlay_pdf.pages[page_num]
|
||||
original_page.merge_page(overlay_page)
|
||||
|
||||
writer.add_page(original_page)
|
||||
|
||||
# Write result
|
||||
if output_path:
|
||||
with open(output_path, "wb") as output_file:
|
||||
writer.write(output_file)
|
||||
with open(output_path, "rb") as f:
|
||||
return f.read()
|
||||
else:
|
||||
output_buffer = io.BytesIO()
|
||||
writer.write(output_buffer)
|
||||
return output_buffer.getvalue()
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
"Required libraries not available for overlay method", error=str(e)
|
||||
)
|
||||
return None
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Overlay filling failed", form_id=form_id, error=str(e))
|
||||
return None
|
||||
|
||||
def _get_field_positions(self, form_id: str) -> dict[str, dict[str, float]]:
|
||||
"""Get field positions for overlay method"""
|
||||
# This would be configured per form type
|
||||
# For now, return sample positions for SA103
|
||||
if form_id == "SA103":
|
||||
return {
|
||||
"box_1": {"x": 100, "y": 750}, # Business name
|
||||
"box_2": {"x": 100, "y": 720}, # Business description
|
||||
"box_20": {"x": 400, "y": 600}, # Total turnover
|
||||
"box_31": {"x": 400, "y": 570}, # Total expenses
|
||||
"box_32": {"x": 400, "y": 540}, # Net profit
|
||||
}
|
||||
return {}
|
||||
|
||||
def get_form_fields(self, form_id: str) -> list[dict[str, Any]]:
|
||||
"""Get list of available form fields"""
|
||||
if form_id not in self.form_templates:
|
||||
return []
|
||||
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from pdfrw import PdfReader
|
||||
|
||||
template_info = self.form_templates[form_id]
|
||||
template = PdfReader(template_info["path"])
|
||||
|
||||
if template.Root.AcroForm is None: # fmt: skip # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue]
|
||||
return []
|
||||
|
||||
fields = []
|
||||
for field in template.Root.AcroForm.Fields: # fmt: skip # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue]
|
||||
field_info = {
|
||||
"name": field.T[1:-1] if field.T else None, # Remove parentheses
|
||||
"type": self._get_field_type(field),
|
||||
"required": bool(field.Ff and int(field.Ff) & 2), # Required flag
|
||||
"readonly": bool(field.Ff and int(field.Ff) & 1), # Read-only flag
|
||||
}
|
||||
|
||||
if field.V:
|
||||
field_info["default_value"] = str(field.V)
|
||||
|
||||
fields.append(field_info)
|
||||
|
||||
return fields
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Failed to get form fields", form_id=form_id, error=str(e))
|
||||
return []
|
||||
|
||||
def _get_field_type(self, field: Any) -> str:
|
||||
"""Determine field type from PDF field"""
|
||||
if hasattr(field, "FT"):
|
||||
field_type = str(field.FT)
|
||||
if "Tx" in field_type:
|
||||
return "text"
|
||||
if "Btn" in field_type:
|
||||
return "checkbox" if field.Ff and int(field.Ff) & 32768 else "button"
|
||||
if "Ch" in field_type:
|
||||
return "choice"
|
||||
return "unknown"
|
||||
140
libs/neo/__init__.py
Normal file
140
libs/neo/__init__.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
|
||||
from .client import Neo4jClient
|
||||
from .queries import TemporalQueries
|
||||
from .validator import SHACLValidator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libs.schemas.coverage.evaluation import Citation, FoundEvidence
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def kg_boxes_exist(client: Neo4jClient, box_ids: list[str]) -> dict[str, bool]:
|
||||
"""Check if form boxes exist in the knowledge graph"""
|
||||
query = """
|
||||
UNWIND $box_ids AS bid
|
||||
OPTIONAL MATCH (fb:FormBox {box_id: bid})
|
||||
RETURN bid, fb IS NOT NULL AS exists
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await client.run_query(query, {"box_ids": box_ids})
|
||||
return {result["bid"]: result["exists"] for result in results}
|
||||
except Exception as e:
|
||||
logger.error("Failed to check box existence", box_ids=box_ids, error=str(e))
|
||||
return dict.fromkeys(box_ids, False)
|
||||
|
||||
|
||||
async def kg_find_evidence(
|
||||
client: Neo4jClient,
|
||||
taxpayer_id: str,
|
||||
tax_year: str,
|
||||
kinds: list[str],
|
||||
min_ocr: float = 0.6,
|
||||
date_window: int = 30,
|
||||
) -> list["FoundEvidence"]:
|
||||
"""Find evidence documents for taxpayer in tax year"""
|
||||
query = """
|
||||
MATCH (p:TaxpayerProfile {taxpayer_id: $tid})-[:OF_TAX_YEAR]->(y:TaxYear {label: $tax_year})
|
||||
MATCH (ev:Evidence)-[:DERIVED_FROM]->(d:Document)
|
||||
WHERE (ev)-[:SUPPORTS]->(p) OR (d)-[:BELONGS_TO]->(p)
|
||||
AND d.kind IN $kinds
|
||||
AND date(d.date) >= date(y.start_date) AND date(d.date) <= date(y.end_date)
|
||||
AND coalesce(ev.ocr_confidence, 0.0) >= $min_ocr
|
||||
RETURN d.doc_id AS doc_id,
|
||||
d.kind AS kind,
|
||||
ev.page AS page,
|
||||
ev.bbox AS bbox,
|
||||
ev.ocr_confidence AS ocr_confidence,
|
||||
ev.extract_confidence AS extract_confidence,
|
||||
d.date AS date
|
||||
ORDER BY ev.ocr_confidence DESC
|
||||
LIMIT 100
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await client.run_query(
|
||||
query,
|
||||
{
|
||||
"tid": taxpayer_id,
|
||||
"tax_year": tax_year,
|
||||
"kinds": kinds,
|
||||
"min_ocr": min_ocr,
|
||||
},
|
||||
)
|
||||
|
||||
# Convert to FoundEvidence format
|
||||
from libs.schemas.coverage.evaluation import FoundEvidence
|
||||
|
||||
evidence_list = []
|
||||
|
||||
for result in results:
|
||||
evidence = FoundEvidence(
|
||||
doc_id=result["doc_id"],
|
||||
kind=result["kind"],
|
||||
pages=[result["page"]] if result["page"] else [],
|
||||
bbox=result["bbox"],
|
||||
ocr_confidence=result["ocr_confidence"] or 0.0,
|
||||
extract_confidence=result["extract_confidence"] or 0.0,
|
||||
date=result["date"],
|
||||
)
|
||||
evidence_list.append(evidence)
|
||||
|
||||
return evidence_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to find evidence",
|
||||
taxpayer_id=taxpayer_id,
|
||||
tax_year=tax_year,
|
||||
kinds=kinds,
|
||||
error=str(e),
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
async def kg_rule_citations(
|
||||
client: Neo4jClient, schedule_id: str, box_ids: list[str]
|
||||
) -> list["Citation"]:
|
||||
"""Get rule citations for schedule and form boxes"""
|
||||
query = """
|
||||
MATCH (fb:FormBox)-[:GOVERNED_BY]->(r:Rule)-[:CITES]->(doc:Document)
|
||||
WHERE fb.box_id IN $box_ids
|
||||
RETURN r.rule_id AS rule_id,
|
||||
doc.doc_id AS doc_id,
|
||||
doc.locator AS locator
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await client.run_query(query, {"box_ids": box_ids})
|
||||
|
||||
# Convert to Citation format
|
||||
from libs.schemas.coverage.evaluation import Citation
|
||||
|
||||
citations = []
|
||||
|
||||
for result in results:
|
||||
citation = Citation(
|
||||
rule_id=result["rule_id"],
|
||||
doc_id=result["doc_id"],
|
||||
locator=result["locator"],
|
||||
)
|
||||
citations.append(citation)
|
||||
|
||||
return citations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get rule citations",
|
||||
schedule_id=schedule_id,
|
||||
box_ids=box_ids,
|
||||
error=str(e),
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
__all__ = ["Neo4jClient", "TemporalQueries", "SHACLValidator"]
|
||||
350
libs/neo/client.py
Normal file
350
libs/neo/client.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""Neo4j session helpers, Cypher runner with retry, SHACL validator invoker."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from neo4j import Transaction
|
||||
from neo4j.exceptions import ServiceUnavailable, TransientError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class Neo4jClient:
|
||||
"""Neo4j client with session management and retry logic"""
|
||||
|
||||
def __init__(self, driver: Any) -> None:
|
||||
self.driver = driver
|
||||
|
||||
async def __aenter__(self) -> "Neo4jClient":
|
||||
"""Async context manager entry"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit"""
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the driver"""
|
||||
await asyncio.get_event_loop().run_in_executor(None, self.driver.close)
|
||||
|
||||
async def run_query(
|
||||
self,
|
||||
query: str,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
database: str = "neo4j",
|
||||
max_retries: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run Cypher query with retry logic"""
|
||||
|
||||
def _run_query() -> list[dict[str, Any]]:
|
||||
with self.driver.session(database=database) as session:
|
||||
result = session.run(query, parameters or {})
|
||||
return [record.data() for record in result]
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await asyncio.get_event_loop().run_in_executor(None, _run_query)
|
||||
|
||||
except (TransientError, ServiceUnavailable) as e:
|
||||
if attempt == max_retries - 1:
|
||||
logger.error(
|
||||
"Query failed after retries",
|
||||
query=query[:100],
|
||||
attempt=attempt + 1,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
wait_time = 2**attempt # Exponential backoff
|
||||
logger.warning(
|
||||
"Query failed, retrying",
|
||||
query=query[:100],
|
||||
attempt=attempt + 1,
|
||||
wait_time=wait_time,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Query failed with non-retryable error",
|
||||
query=query[:100],
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
# This should never be reached due to the raise statements above
|
||||
return []
|
||||
|
||||
async def run_transaction(
|
||||
self, transaction_func: Any, database: str = "neo4j", max_retries: int = 3
|
||||
) -> Any:
|
||||
"""Run transaction with retry logic"""
|
||||
|
||||
def _run_transaction() -> Any:
|
||||
with self.driver.session(database=database) as session:
|
||||
return session.execute_write(transaction_func)
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, _run_transaction
|
||||
)
|
||||
|
||||
except (TransientError, ServiceUnavailable) as e:
|
||||
if attempt == max_retries - 1:
|
||||
logger.error(
|
||||
"Transaction failed after retries",
|
||||
attempt=attempt + 1,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
wait_time = 2**attempt
|
||||
logger.warning(
|
||||
"Transaction failed, retrying",
|
||||
attempt=attempt + 1,
|
||||
wait_time=wait_time,
|
||||
error=str(e),
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Transaction failed with non-retryable error", error=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_node(
|
||||
self, label: str, properties: dict[str, Any], database: str = "neo4j"
|
||||
) -> dict[str, Any]:
|
||||
"""Create a node with temporal properties"""
|
||||
|
||||
# Add temporal properties if not present
|
||||
if "asserted_at" not in properties:
|
||||
properties["asserted_at"] = datetime.utcnow()
|
||||
|
||||
query = f"""
|
||||
CREATE (n:{label} $properties)
|
||||
RETURN n
|
||||
"""
|
||||
|
||||
result = await self.run_query(query, {"properties": properties}, database)
|
||||
node = result[0]["n"] if result else {}
|
||||
# Return node ID if available, otherwise return the full node
|
||||
return node.get("id", node)
|
||||
|
||||
async def update_node(
|
||||
self,
|
||||
label: str,
|
||||
node_id: str,
|
||||
properties: dict[str, Any],
|
||||
database: str = "neo4j",
|
||||
) -> dict[str, Any]:
|
||||
"""Update node with bitemporal versioning"""
|
||||
|
||||
def _update_transaction(tx: Transaction) -> Any:
|
||||
# First, retract the current version
|
||||
retract_query = f"""
|
||||
MATCH (n:{label} {{id: $node_id}})
|
||||
WHERE n.retracted_at IS NULL
|
||||
SET n.retracted_at = datetime()
|
||||
RETURN n
|
||||
"""
|
||||
tx.run(retract_query, {"node_id": node_id}) # fmt: skip # pyright: ignore[reportArgumentType]
|
||||
|
||||
# Create new version
|
||||
new_properties = properties.copy()
|
||||
new_properties["id"] = node_id
|
||||
new_properties["asserted_at"] = datetime.utcnow()
|
||||
|
||||
create_query = f"""
|
||||
CREATE (n:{label} $properties)
|
||||
RETURN n
|
||||
"""
|
||||
result = tx.run(create_query, {"properties": new_properties}) # fmt: skip # pyright: ignore[reportArgumentType]
|
||||
record = result.single()
|
||||
return record["n"] if record else None
|
||||
|
||||
result = await self.run_transaction(_update_transaction, database)
|
||||
return result if isinstance(result, dict) else {}
|
||||
|
||||
async def create_relationship( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
self,
|
||||
from_label: str | None = None,
|
||||
from_id: str | None = None,
|
||||
to_label: str | None = None,
|
||||
to_id: str | None = None,
|
||||
relationship_type: str | None = None,
|
||||
properties: dict[str, Any] | None = None,
|
||||
database: str = "neo4j",
|
||||
# Alternative signature for tests
|
||||
from_node_id: int | None = None,
|
||||
to_node_id: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create relationship between nodes"""
|
||||
|
||||
# Handle alternative signature for tests (using node IDs)
|
||||
if from_node_id is not None and to_node_id is not None:
|
||||
rel_properties = properties or {}
|
||||
if "asserted_at" not in rel_properties:
|
||||
rel_properties["asserted_at"] = datetime.utcnow()
|
||||
|
||||
query = f"""
|
||||
MATCH (from) WHERE id(from) = $from_id
|
||||
MATCH (to) WHERE id(to) = $to_id
|
||||
CREATE (from)-[r:{relationship_type} $properties]->(to)
|
||||
RETURN r
|
||||
"""
|
||||
|
||||
result = await self.run_query(
|
||||
query,
|
||||
{
|
||||
"from_id": from_node_id,
|
||||
"to_id": to_node_id,
|
||||
"properties": rel_properties,
|
||||
},
|
||||
database,
|
||||
)
|
||||
rel = result[0]["r"] if result else {}
|
||||
return rel.get("id", rel)
|
||||
|
||||
# Original signature (using labels and IDs)
|
||||
rel_properties = properties or {}
|
||||
if "asserted_at" not in rel_properties:
|
||||
rel_properties["asserted_at"] = datetime.utcnow()
|
||||
|
||||
query = f"""
|
||||
MATCH (from:{from_label} {{id: $from_id}})
|
||||
MATCH (to:{to_label} {{id: $to_id}})
|
||||
WHERE from.retracted_at IS NULL AND to.retracted_at IS NULL
|
||||
CREATE (from)-[r:{relationship_type} $properties]->(to)
|
||||
RETURN r
|
||||
"""
|
||||
|
||||
result = await self.run_query(
|
||||
query,
|
||||
{"from_id": from_id, "to_id": to_id, "properties": rel_properties},
|
||||
database,
|
||||
)
|
||||
rel = result[0]["r"] if result else {}
|
||||
# Return relationship ID if available, otherwise return the full relationship
|
||||
return rel.get("id", rel)
|
||||
|
||||
async def get_node_lineage(
|
||||
self, node_id: str, max_depth: int = 10, database: str = "neo4j"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get complete lineage for a node"""
|
||||
|
||||
query = """
|
||||
MATCH path = (n {id: $node_id})-[:DERIVED_FROM*1..10]->(evidence:Evidence)
|
||||
WHERE n.retracted_at IS NULL
|
||||
RETURN path, evidence
|
||||
ORDER BY length(path) DESC
|
||||
LIMIT 100
|
||||
"""
|
||||
|
||||
return await self.run_query(
|
||||
query, {"node_id": node_id, "max_depth": max_depth}, database
|
||||
)
|
||||
|
||||
async def export_to_rdf( # pylint: disable=redefined-builtin
|
||||
self,
|
||||
format: str = "turtle",
|
||||
database: str = "neo4j",
|
||||
) -> dict[str, Any]:
|
||||
"""Export graph data to RDF format"""
|
||||
|
||||
query = """
|
||||
CALL n10s.rdf.export.cypher(
|
||||
'MATCH (n) WHERE n.retracted_at IS NULL RETURN n',
|
||||
$format,
|
||||
{}
|
||||
) YIELD triplesCount, format
|
||||
RETURN triplesCount, format
|
||||
"""
|
||||
|
||||
try:
|
||||
result = await self.run_query(query, {"format": format}, database)
|
||||
return result[0] if result else {}
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.warning("RDF export failed, using fallback", error=str(e))
|
||||
fallback_result = await self._export_rdf_fallback(database)
|
||||
return {"rdf_data": fallback_result, "format": format}
|
||||
|
||||
async def _export_rdf_fallback(self, database: str = "neo4j") -> str:
|
||||
"""Fallback RDF export without n10s plugin"""
|
||||
|
||||
# Get all nodes and relationships
|
||||
nodes_query = """
|
||||
MATCH (n) WHERE n.retracted_at IS NULL
|
||||
RETURN labels(n) as labels, properties(n) as props, id(n) as neo_id
|
||||
"""
|
||||
|
||||
rels_query = """
|
||||
MATCH (a)-[r]->(b)
|
||||
WHERE a.retracted_at IS NULL AND b.retracted_at IS NULL
|
||||
RETURN type(r) as type, properties(r) as props,
|
||||
id(a) as from_id, id(b) as to_id
|
||||
"""
|
||||
|
||||
nodes = await self.run_query(nodes_query, database=database)
|
||||
relationships = await self.run_query(rels_query, database=database)
|
||||
|
||||
# Convert to simple Turtle format
|
||||
rdf_lines = ["@prefix tax: <https://tax-kg.example.com/> ."]
|
||||
|
||||
for node in nodes:
|
||||
node_uri = f"tax:node_{node['neo_id']}"
|
||||
for label in node["labels"]:
|
||||
rdf_lines.append(f"{node_uri} a tax:{label} .")
|
||||
|
||||
for prop, value in node["props"].items():
|
||||
if isinstance(value, str):
|
||||
rdf_lines.append(f'{node_uri} tax:{prop} "{value}" .')
|
||||
else:
|
||||
rdf_lines.append(f"{node_uri} tax:{prop} {value} .")
|
||||
|
||||
for rel in relationships:
|
||||
from_uri = f"tax:node_{rel['from_id']}"
|
||||
to_uri = f"tax:node_{rel['to_id']}"
|
||||
rdf_lines.append(f"{from_uri} tax:{rel['type']} {to_uri} .")
|
||||
|
||||
return "\n".join(rdf_lines)
|
||||
|
||||
async def find_nodes(
|
||||
self, label: str, properties: dict[str, Any], database: str = "neo4j"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Find nodes matching label and properties"""
|
||||
where_clause, params = self._build_properties_clause(properties)
|
||||
query = f"MATCH (n:{label}) WHERE {where_clause} RETURN n"
|
||||
|
||||
result = await self.run_query(query, params, database)
|
||||
return [record["n"] for record in result]
|
||||
|
||||
async def execute_query(
|
||||
self,
|
||||
query: str,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
database: str = "neo4j",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Execute a custom Cypher query"""
|
||||
return await self.run_query(query, parameters, database)
|
||||
|
||||
def _build_properties_clause(
|
||||
self, properties: dict[str, Any]
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""Build WHERE clause and parameters for properties"""
|
||||
if not properties:
|
||||
return "true", {}
|
||||
|
||||
clauses = []
|
||||
params = {}
|
||||
for i, (key, value) in enumerate(properties.items()):
|
||||
param_name = f"prop_{i}"
|
||||
clauses.append(f"n.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
return " AND ".join(clauses), params
|
||||
78
libs/neo/queries.py
Normal file
78
libs/neo/queries.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Neo4j Cypher queries for coverage policy system"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TemporalQueries:
|
||||
"""Helper class for temporal queries"""
|
||||
|
||||
@staticmethod
|
||||
def get_current_state_query(
|
||||
label: str, filters: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""Get query for current state of nodes"""
|
||||
where_clause = "n.retracted_at IS NULL"
|
||||
|
||||
if filters:
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, str):
|
||||
filter_conditions.append(f"n.{key} = '{value}'")
|
||||
else:
|
||||
filter_conditions.append(f"n.{key} = {value}")
|
||||
|
||||
if filter_conditions:
|
||||
where_clause += " AND " + " AND ".join(filter_conditions)
|
||||
|
||||
return f"""
|
||||
MATCH (n:{label})
|
||||
WHERE {where_clause}
|
||||
RETURN n
|
||||
ORDER BY n.asserted_at DESC
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_historical_state_query(
|
||||
label: str, as_of_time: datetime, filters: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""Get query for historical state at specific time"""
|
||||
where_clause = f"""
|
||||
n.asserted_at <= datetime('{as_of_time.isoformat()}')
|
||||
AND (n.retracted_at IS NULL OR n.retracted_at > datetime('{as_of_time.isoformat()}'))
|
||||
"""
|
||||
|
||||
if filters:
|
||||
filter_conditions = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, str):
|
||||
filter_conditions.append(f"n.{key} = '{value}'")
|
||||
else:
|
||||
filter_conditions.append(f"n.{key} = {value}")
|
||||
|
||||
if filter_conditions:
|
||||
where_clause += " AND " + " AND ".join(filter_conditions)
|
||||
|
||||
return f"""
|
||||
MATCH (n:{label})
|
||||
WHERE {where_clause}
|
||||
RETURN n
|
||||
ORDER BY n.asserted_at DESC
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_audit_trail_query(node_id: str) -> str:
|
||||
"""Get complete audit trail for a node"""
|
||||
return f"""
|
||||
MATCH (n {{id: '{node_id}'}})
|
||||
RETURN n.asserted_at as asserted_at,
|
||||
n.retracted_at as retracted_at,
|
||||
n.source as source,
|
||||
n.extractor_version as extractor_version,
|
||||
properties(n) as properties
|
||||
ORDER BY n.asserted_at ASC
|
||||
"""
|
||||
70
libs/neo/validator.py
Normal file
70
libs/neo/validator.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""SHACL validation using pySHACL"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# pyright: ignore[reportAttributeAccessIssue]
|
||||
class SHACLValidator: # pylint: disable=too-few-public-methods
|
||||
"""SHACL validation using pySHACL"""
|
||||
|
||||
def __init__(self, shapes_file: str) -> None:
|
||||
self.shapes_file = shapes_file
|
||||
|
||||
async def validate_graph(self, rdf_data: str) -> dict[str, Any]:
|
||||
"""Validate RDF data against SHACL shapes"""
|
||||
|
||||
def _validate() -> dict[str, Any]:
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from pyshacl import validate
|
||||
from rdflib import Graph
|
||||
|
||||
# Load data graph
|
||||
data_graph = Graph()
|
||||
data_graph.parse(data=rdf_data, format="turtle")
|
||||
|
||||
# Load shapes graph
|
||||
shapes_graph = Graph()
|
||||
shapes_graph.parse(self.shapes_file, format="turtle")
|
||||
|
||||
# Run validation
|
||||
conforms, results_graph, results_text = validate(
|
||||
data_graph=data_graph,
|
||||
shacl_graph=shapes_graph,
|
||||
inference="rdfs",
|
||||
abort_on_first=False,
|
||||
allow_infos=True,
|
||||
allow_warnings=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"conforms": conforms,
|
||||
"results_text": results_text,
|
||||
"violations_count": len(
|
||||
list(
|
||||
results_graph.subjects() # pyright: ignore[reportAttributeAccessIssue]
|
||||
) # fmt: skip # pyright: ignore[reportAttributeAccessIssue]
|
||||
),
|
||||
}
|
||||
|
||||
except ImportError:
|
||||
logger.warning("pySHACL not available, skipping validation")
|
||||
return {
|
||||
"conforms": True,
|
||||
"results_text": "SHACL validation skipped (pySHACL not installed)",
|
||||
"violations_count": 0,
|
||||
}
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("SHACL validation failed", error=str(e))
|
||||
return {
|
||||
"conforms": False,
|
||||
"results_text": f"Validation error: {str(e)}",
|
||||
"violations_count": -1,
|
||||
}
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(None, _validate)
|
||||
18
libs/observability/__init__.py
Normal file
18
libs/observability/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Observability setup with OpenTelemetry, Prometheus, and structured logging."""
|
||||
|
||||
from .logging import configure_logging
|
||||
from .opentelemetry_setup import init_opentelemetry
|
||||
from .prometheus import BusinessMetrics, get_business_metrics, init_prometheus_metrics
|
||||
from .setup import setup_observability
|
||||
from .utils import get_metrics, get_tracer
|
||||
|
||||
__all__ = [
|
||||
"configure_logging",
|
||||
"init_opentelemetry",
|
||||
"init_prometheus_metrics",
|
||||
"BusinessMetrics",
|
||||
"get_business_metrics",
|
||||
"setup_observability",
|
||||
"get_tracer",
|
||||
"get_metrics",
|
||||
]
|
||||
75
libs/observability/logging.py
Normal file
75
libs/observability/logging.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Structured logging configuration with OpenTelemetry integration."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from opentelemetry import trace
|
||||
|
||||
|
||||
def configure_logging(service_name: str, log_level: str = "INFO") -> None:
|
||||
"""Configure structured logging with structlog"""
|
||||
|
||||
def add_service_name( # pylint: disable=unused-argument
|
||||
logger: Any,
|
||||
method_name: str,
|
||||
event_dict: dict[str, Any], # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
event_dict["service"] = service_name
|
||||
return event_dict
|
||||
|
||||
def add_trace_id( # pylint: disable=unused-argument
|
||||
logger: Any,
|
||||
method_name: str,
|
||||
event_dict: dict[str, Any], # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
"""Add trace ID to log entries"""
|
||||
span = trace.get_current_span()
|
||||
if span and span.get_span_context().is_valid:
|
||||
event_dict["trace_id"] = format(span.get_span_context().trace_id, "032x")
|
||||
event_dict["span_id"] = format(span.get_span_context().span_id, "016x")
|
||||
return event_dict
|
||||
|
||||
def add_timestamp( # pylint: disable=unused-argument
|
||||
logger: Any,
|
||||
method_name: str,
|
||||
event_dict: dict[str, Any], # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
event_dict["timestamp"] = time.time()
|
||||
return event_dict
|
||||
|
||||
# Configure structlog
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.stdlib.filter_by_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.add_log_level,
|
||||
add_service_name, # type: ignore
|
||||
add_trace_id, # type: ignore
|
||||
add_timestamp, # type: ignore
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
context_class=dict,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
# Configure standard library logging
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(message)s",
|
||||
stream=sys.stdout,
|
||||
level=getattr(logging, log_level.upper()),
|
||||
)
|
||||
|
||||
# Reduce noise from some libraries
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
99
libs/observability/opentelemetry_setup.py
Normal file
99
libs/observability/opentelemetry_setup.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""OpenTelemetry tracing and metrics initialization."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
|
||||
from opentelemetry.instrumentation.psycopg2 import Psycopg2Instrumentor
|
||||
from opentelemetry.instrumentation.redis import RedisInstrumentor
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import (
|
||||
MetricExporter,
|
||||
PeriodicExportingMetricReader,
|
||||
)
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
|
||||
|
||||
|
||||
def init_opentelemetry(
|
||||
service_name: str,
|
||||
service_version: str = "1.0.0",
|
||||
otlp_endpoint: str | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Initialize OpenTelemetry tracing and metrics"""
|
||||
|
||||
# Create resource
|
||||
resource = Resource.create(
|
||||
{
|
||||
"service.name": service_name,
|
||||
"service.version": service_version,
|
||||
"service.instance.id": os.getenv("HOSTNAME", "unknown"),
|
||||
}
|
||||
)
|
||||
|
||||
# Configure tracing
|
||||
span_exporter: SpanExporter
|
||||
if otlp_endpoint:
|
||||
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint)
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
else:
|
||||
# Use console exporter for development
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
|
||||
|
||||
span_exporter = ConsoleSpanExporter()
|
||||
except ImportError:
|
||||
# Fallback to logging exporter
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
|
||||
|
||||
span_exporter = ConsoleSpanExporter()
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
|
||||
tracer_provider = TracerProvider(resource=resource)
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
trace.set_tracer_provider(tracer_provider)
|
||||
|
||||
# Configure metrics
|
||||
metric_exporter: MetricExporter
|
||||
if otlp_endpoint:
|
||||
metric_exporter = OTLPMetricExporter(endpoint=otlp_endpoint)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
metric_exporter, export_interval_millis=30000
|
||||
)
|
||||
else:
|
||||
# Use console exporter for development
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from opentelemetry.sdk.metrics.export import ConsoleMetricExporter
|
||||
|
||||
metric_exporter = ConsoleMetricExporter()
|
||||
except ImportError:
|
||||
# Fallback to logging exporter
|
||||
from opentelemetry.sdk.metrics.export import ConsoleMetricExporter
|
||||
|
||||
metric_exporter = ConsoleMetricExporter()
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
metric_exporter, export_interval_millis=30000
|
||||
)
|
||||
|
||||
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(meter_provider)
|
||||
|
||||
# Auto-instrument common libraries
|
||||
try:
|
||||
FastAPIInstrumentor().instrument()
|
||||
HTTPXClientInstrumentor().instrument()
|
||||
Psycopg2Instrumentor().instrument()
|
||||
RedisInstrumentor().instrument()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# Ignore instrumentation errors in tests
|
||||
pass
|
||||
|
||||
return trace.get_tracer(service_name), metrics.get_meter(service_name)
|
||||
235
libs/observability/prometheus.py
Normal file
235
libs/observability/prometheus.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Prometheus metrics setup and business metrics."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Info
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
|
||||
def init_prometheus_metrics( # pylint: disable=unused-argument
|
||||
app: Any, service_name: str
|
||||
) -> Any:
|
||||
"""Initialize Prometheus metrics for FastAPI app"""
|
||||
|
||||
# Create instrumentator
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=False,
|
||||
should_ignore_untemplated=True,
|
||||
should_respect_env_var=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
excluded_handlers=["/metrics", "/healthz", "/readyz", "/livez"],
|
||||
env_var_name="ENABLE_METRICS",
|
||||
inprogress_name="http_requests_inprogress",
|
||||
inprogress_labels=True,
|
||||
)
|
||||
|
||||
# Add custom metrics
|
||||
instrumentator.add(
|
||||
lambda info: info.modified_duration < 0.1, # type: ignore
|
||||
lambda info: Counter(
|
||||
"http_requests_fast_total",
|
||||
"Number of fast HTTP requests (< 100ms)",
|
||||
["method", "endpoint"],
|
||||
)
|
||||
.labels(method=info.method, endpoint=info.modified_handler)
|
||||
.inc(),
|
||||
)
|
||||
|
||||
instrumentator.add(
|
||||
lambda info: info.modified_duration > 1.0, # type: ignore
|
||||
lambda info: Counter(
|
||||
"http_requests_slow_total",
|
||||
"Number of slow HTTP requests (> 1s)",
|
||||
["method", "endpoint"],
|
||||
)
|
||||
.labels(method=info.method, endpoint=info.modified_handler)
|
||||
.inc(),
|
||||
)
|
||||
|
||||
# Instrument the app
|
||||
instrumentator.instrument(app)
|
||||
instrumentator.expose(app, endpoint="/metrics")
|
||||
|
||||
return instrumentator
|
||||
|
||||
|
||||
# Global registry for business metrics to avoid duplicates
|
||||
_business_metrics_registry: dict[str, Any] = {}
|
||||
|
||||
|
||||
# Custom metrics for business logic
|
||||
class BusinessMetrics: # pylint: disable=too-many-instance-attributes
|
||||
"""Custom business metrics for the application"""
|
||||
|
||||
def __init__(self, service_name: str):
|
||||
self.service_name = service_name
|
||||
# Sanitize service name for Prometheus metrics (replace hyphens with underscores)
|
||||
self.sanitized_name = service_name.replace("-", "_")
|
||||
|
||||
# Create a custom registry for this service to avoid conflicts
|
||||
self.registry = CollectorRegistry()
|
||||
|
||||
# Document processing metrics
|
||||
self.documents_processed = Counter(
|
||||
"documents_processed_total",
|
||||
"Total number of documents processed",
|
||||
["service", "document_type", "status"],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
# Add active connections metric for tests
|
||||
self.active_connections = Gauge(
|
||||
"active_connections",
|
||||
"Number of active connections",
|
||||
["service"],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
# Dynamic counters for forms service
|
||||
self._dynamic_counters: dict[str, Any] = {}
|
||||
|
||||
self.document_processing_duration = Histogram(
|
||||
f"document_processing_duration_seconds_{self.sanitized_name}",
|
||||
"Time spent processing documents",
|
||||
["service", "document_type"],
|
||||
buckets=[0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, 120.0, 300.0],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
# Field extraction metrics
|
||||
self.field_extractions = Counter(
|
||||
f"field_extractions_total_{self.sanitized_name}",
|
||||
"Total number of field extractions",
|
||||
["service", "field_type", "status"],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
self.extraction_confidence = Histogram(
|
||||
f"extraction_confidence_score_{self.sanitized_name}",
|
||||
"Confidence scores for extractions",
|
||||
["service", "extraction_type"],
|
||||
buckets=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
# Tax calculation metrics
|
||||
self.tax_calculations = Counter(
|
||||
f"tax_calculations_total_{self.sanitized_name}",
|
||||
"Total number of tax calculations",
|
||||
["service", "calculation_type", "status"],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
self.calculation_confidence = Histogram(
|
||||
f"calculation_confidence_score_{self.sanitized_name}",
|
||||
"Confidence scores for tax calculations",
|
||||
["service", "calculation_type"],
|
||||
buckets=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
# RAG metrics
|
||||
self.rag_searches = Counter(
|
||||
f"rag_searches_total_{self.sanitized_name}",
|
||||
"Total number of RAG searches",
|
||||
["service", "collection", "status"],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
self.rag_search_duration = Histogram(
|
||||
f"rag_search_duration_seconds_{self.sanitized_name}",
|
||||
"Time spent on RAG searches",
|
||||
["service", "collection"],
|
||||
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
self.rag_relevance_score = Histogram(
|
||||
f"rag_relevance_score_{self.sanitized_name}",
|
||||
"RAG search relevance scores",
|
||||
["service", "collection"],
|
||||
buckets=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
# Knowledge graph metrics
|
||||
self.kg_operations = Counter(
|
||||
f"kg_operations_total_{self.sanitized_name}",
|
||||
"Total number of KG operations",
|
||||
["service", "operation", "status"],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
self.kg_query_duration = Histogram(
|
||||
f"kg_query_duration_seconds_{self.sanitized_name}",
|
||||
"Time spent on KG queries",
|
||||
["service", "query_type"],
|
||||
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
# HMRC submission metrics
|
||||
self.hmrc_submissions = Counter(
|
||||
f"hmrc_submissions_total_{self.sanitized_name}",
|
||||
"Total number of HMRC submissions",
|
||||
["service", "submission_type", "status"],
|
||||
registry=self.registry,
|
||||
)
|
||||
|
||||
# Service health metrics
|
||||
self.service_info = Info(
|
||||
f"service_info_{self.sanitized_name}",
|
||||
"Service information",
|
||||
registry=self.registry,
|
||||
)
|
||||
try:
|
||||
self.service_info.info({"service": service_name, "version": "1.0.0"})
|
||||
except (AttributeError, ValueError):
|
||||
# Handle prometheus_client version compatibility or registry conflicts
|
||||
pass
|
||||
|
||||
def counter(self, name: str, labelnames: list[str] | None = None) -> Any:
|
||||
"""Get or create a counter metric with dynamic labels"""
|
||||
# Use provided labelnames or default ones
|
||||
if labelnames is None:
|
||||
labelnames = ["tenant_id", "form_id", "scope", "error_type"]
|
||||
|
||||
# Create a unique key based on name and labelnames
|
||||
label_key = f"{name}_{','.join(sorted(labelnames))}"
|
||||
|
||||
if label_key not in self._dynamic_counters:
|
||||
self._dynamic_counters[label_key] = Counter(
|
||||
name,
|
||||
f"Dynamic counter: {name}",
|
||||
labelnames=labelnames,
|
||||
registry=self.registry,
|
||||
)
|
||||
return self._dynamic_counters[label_key]
|
||||
|
||||
def histogram(self, name: str, labelnames: list[str] | None = None) -> Any:
|
||||
"""Get or create a histogram metric with dynamic labels"""
|
||||
# Use provided labelnames or default ones
|
||||
if labelnames is None:
|
||||
labelnames = ["tenant_id", "kind"]
|
||||
|
||||
# Create a unique key based on name and labelnames
|
||||
label_key = f"{name}_{','.join(sorted(labelnames))}"
|
||||
histogram_key = f"_histogram_{label_key}"
|
||||
|
||||
if not hasattr(self, histogram_key):
|
||||
histogram = Histogram(
|
||||
name,
|
||||
f"Dynamic histogram: {name}",
|
||||
labelnames=labelnames,
|
||||
registry=self.registry,
|
||||
)
|
||||
setattr(self, histogram_key, histogram)
|
||||
return getattr(self, histogram_key)
|
||||
|
||||
|
||||
def get_business_metrics(service_name: str) -> BusinessMetrics:
|
||||
"""Get business metrics instance for service"""
|
||||
# Use singleton pattern to avoid registry conflicts
|
||||
if service_name not in _business_metrics_registry:
|
||||
_business_metrics_registry[service_name] = BusinessMetrics(service_name)
|
||||
return _business_metrics_registry[service_name] # type: ignore
|
||||
64
libs/observability/setup.py
Normal file
64
libs/observability/setup.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Complete observability setup orchestration."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .logging import configure_logging
|
||||
from .opentelemetry_setup import init_opentelemetry
|
||||
from .prometheus import get_business_metrics, init_prometheus_metrics
|
||||
|
||||
|
||||
def setup_observability(
|
||||
settings_or_app: Any,
|
||||
service_name: str | None = None,
|
||||
service_version: str = "1.0.0",
|
||||
log_level: str = "INFO",
|
||||
otlp_endpoint: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Setup complete observability stack for a service"""
|
||||
|
||||
# Handle both settings object and individual parameters
|
||||
if hasattr(settings_or_app, "service_name"):
|
||||
# Called with settings object
|
||||
settings = settings_or_app
|
||||
service_name = settings.service_name
|
||||
service_version = getattr(settings, "service_version", "1.0.0")
|
||||
log_level = getattr(settings, "log_level", "INFO")
|
||||
otlp_endpoint = getattr(settings, "otel_exporter_endpoint", None)
|
||||
app = None
|
||||
else:
|
||||
# Called with app object
|
||||
app = settings_or_app
|
||||
if not service_name:
|
||||
raise ValueError("service_name is required when passing app object")
|
||||
|
||||
# Configure logging
|
||||
configure_logging(service_name or "unknown", log_level)
|
||||
|
||||
# Initialize OpenTelemetry
|
||||
tracer, meter = init_opentelemetry(
|
||||
service_name or "unknown", service_version, otlp_endpoint
|
||||
)
|
||||
|
||||
# Get business metrics
|
||||
business_metrics = get_business_metrics(service_name or "unknown")
|
||||
|
||||
# If app is provided, set up Prometheus and add to app state
|
||||
if app:
|
||||
# Initialize Prometheus metrics
|
||||
instrumentator = init_prometheus_metrics(app, service_name or "unknown")
|
||||
|
||||
# Add to app state
|
||||
app.state.tracer = tracer
|
||||
app.state.meter = meter
|
||||
app.state.metrics = business_metrics
|
||||
app.state.instrumentator = instrumentator
|
||||
|
||||
return {
|
||||
"tracer": tracer,
|
||||
"meter": meter,
|
||||
"metrics": business_metrics,
|
||||
"instrumentator": instrumentator,
|
||||
}
|
||||
|
||||
# Just return the observability components
|
||||
return {"tracer": tracer, "meter": meter, "metrics": business_metrics}
|
||||
17
libs/observability/utils.py
Normal file
17
libs/observability/utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Utility functions for observability components."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from .prometheus import BusinessMetrics, get_business_metrics
|
||||
|
||||
|
||||
def get_tracer(service_name: str = "default") -> Any:
|
||||
"""Get OpenTelemetry tracer"""
|
||||
return trace.get_tracer(service_name)
|
||||
|
||||
|
||||
def get_metrics(service_name: str = "default") -> BusinessMetrics:
|
||||
"""Get business metrics instance"""
|
||||
return get_business_metrics(service_name)
|
||||
21
libs/policy/__init__.py
Normal file
21
libs/policy/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Coverage policy loading and management with overlays and hot reload."""
|
||||
|
||||
from .loader import PolicyLoader
|
||||
from .utils import (
|
||||
apply_feature_flags,
|
||||
compile_predicates,
|
||||
get_policy_loader,
|
||||
load_policy,
|
||||
merge_overlays,
|
||||
validate_policy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PolicyLoader",
|
||||
"get_policy_loader",
|
||||
"load_policy",
|
||||
"merge_overlays",
|
||||
"apply_feature_flags",
|
||||
"compile_predicates",
|
||||
"validate_policy",
|
||||
]
|
||||
386
libs/policy/loader.py
Normal file
386
libs/policy/loader.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""Policy loading and management with overlays and hot reload."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
from jsonschema import ValidationError, validate
|
||||
|
||||
from ..schemas import (
|
||||
CompiledCoveragePolicy,
|
||||
CoveragePolicy,
|
||||
PolicyError,
|
||||
ValidationResult,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PolicyLoader:
|
||||
"""Loads and manages coverage policies with overlays and hot reload"""
|
||||
|
||||
def __init__(self, config_dir: str = "config"):
|
||||
self.config_dir = Path(config_dir)
|
||||
self.schema_path = Path(__file__).parent.parent / "coverage_schema.json"
|
||||
self._schema_cache: dict[str, Any] | None = None
|
||||
|
||||
def load_policy(
|
||||
self,
|
||||
baseline_path: str | None = None,
|
||||
jurisdiction: str = "UK",
|
||||
tax_year: str = "2024-25",
|
||||
tenant_id: str | None = None,
|
||||
) -> CoveragePolicy:
|
||||
"""Load policy with overlays applied"""
|
||||
|
||||
# Default baseline path
|
||||
if baseline_path is None:
|
||||
baseline_path = str(self.config_dir / "coverage.yaml")
|
||||
|
||||
# Load baseline policy
|
||||
baseline = self._load_yaml_file(baseline_path)
|
||||
|
||||
# Collect overlay files
|
||||
overlay_files = []
|
||||
|
||||
# Jurisdiction-specific overlay
|
||||
jurisdiction_file = self.config_dir / f"coverage.{jurisdiction}.{tax_year}.yaml"
|
||||
if jurisdiction_file.exists():
|
||||
overlay_files.append(str(jurisdiction_file))
|
||||
|
||||
# Tenant-specific overlay
|
||||
if tenant_id:
|
||||
tenant_file = self.config_dir / "overrides" / f"{tenant_id}.yaml"
|
||||
if tenant_file.exists():
|
||||
overlay_files.append(str(tenant_file))
|
||||
|
||||
# Load overlays
|
||||
overlays = [self._load_yaml_file(path) for path in overlay_files]
|
||||
|
||||
# Merge all policies
|
||||
merged = self.merge_overlays(baseline, *overlays)
|
||||
|
||||
# Apply feature flags if available
|
||||
merged = self.apply_feature_flags(merged)
|
||||
|
||||
# Validate against schema
|
||||
self._validate_policy(merged)
|
||||
|
||||
# Convert to Pydantic model
|
||||
try:
|
||||
policy = CoveragePolicy(**merged)
|
||||
logger.info(
|
||||
"Policy loaded successfully",
|
||||
jurisdiction=jurisdiction,
|
||||
tax_year=tax_year,
|
||||
tenant_id=tenant_id,
|
||||
overlays=len(overlays),
|
||||
)
|
||||
return policy
|
||||
except Exception as e:
|
||||
raise PolicyError(f"Failed to parse policy: {str(e)}") from e
|
||||
|
||||
def merge_overlays(
|
||||
self, base: dict[str, Any], *overlays: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Merge base policy with overlays using deep merge"""
|
||||
result = base.copy()
|
||||
|
||||
for overlay in overlays:
|
||||
result = self._deep_merge(result, overlay)
|
||||
|
||||
return result
|
||||
|
||||
def apply_feature_flags(self, policy: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Apply feature flags to policy (placeholder for Unleash integration)"""
|
||||
# TODO: Integrate with Unleash feature flags
|
||||
# For now, just return the policy unchanged
|
||||
logger.debug("Feature flags not implemented, returning policy unchanged")
|
||||
return policy
|
||||
|
||||
def compile_predicates(self, policy: CoveragePolicy) -> CompiledCoveragePolicy:
|
||||
"""Compile condition strings into callable predicates"""
|
||||
compiled_predicates: dict[str, Callable[[str, str], bool]] = {}
|
||||
|
||||
# Compile trigger conditions
|
||||
for schedule_id, trigger in policy.triggers.items():
|
||||
for condition in trigger.any_of + trigger.all_of:
|
||||
if condition not in compiled_predicates:
|
||||
compiled_predicates[condition] = self._compile_condition(condition)
|
||||
|
||||
# Compile evidence conditions
|
||||
for schedule in policy.schedules.values():
|
||||
for evidence in schedule.evidence:
|
||||
if evidence.condition and evidence.condition not in compiled_predicates:
|
||||
compiled_predicates[evidence.condition] = self._compile_condition(
|
||||
evidence.condition
|
||||
)
|
||||
|
||||
# Calculate hash of source files
|
||||
source_files = [str(self.config_dir / "coverage.yaml")]
|
||||
policy_hash = self._calculate_hash(source_files)
|
||||
|
||||
return CompiledCoveragePolicy(
|
||||
policy=policy,
|
||||
compiled_predicates=compiled_predicates,
|
||||
compiled_at=datetime.utcnow(),
|
||||
hash=policy_hash,
|
||||
source_files=source_files,
|
||||
)
|
||||
|
||||
def validate_policy(self, policy_dict: dict[str, Any]) -> ValidationResult:
|
||||
"""Validate policy against schema and business rules"""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
try:
|
||||
# JSON Schema validation
|
||||
self._validate_policy(policy_dict)
|
||||
|
||||
# Business rule validation
|
||||
business_errors, business_warnings = self._validate_business_rules(
|
||||
policy_dict
|
||||
)
|
||||
errors.extend(business_errors)
|
||||
warnings.extend(business_warnings)
|
||||
|
||||
except ValidationError as e:
|
||||
errors.append(f"Schema validation failed: {e.message}")
|
||||
except Exception as e:
|
||||
errors.append(f"Validation error: {str(e)}")
|
||||
|
||||
return ValidationResult(ok=len(errors) == 0, errors=errors, warnings=warnings)
|
||||
|
||||
def _load_yaml_file(self, path: str) -> dict[str, Any]:
|
||||
"""Load YAML file with error handling"""
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
except FileNotFoundError:
|
||||
raise PolicyError(f"Policy file not found: {path}")
|
||||
except yaml.YAMLError as e:
|
||||
raise PolicyError(f"Invalid YAML in {path}: {str(e)}")
|
||||
|
||||
def _deep_merge(
|
||||
self, base: dict[str, Any], overlay: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Deep merge two dictionaries"""
|
||||
result = base.copy()
|
||||
|
||||
for key, value in overlay.items():
|
||||
if (
|
||||
key in result
|
||||
and isinstance(result[key], dict)
|
||||
and isinstance(value, dict)
|
||||
):
|
||||
result[key] = self._deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
def _validate_policy(self, policy_dict: dict[str, Any]) -> None:
|
||||
"""Validate policy against JSON schema"""
|
||||
if self._schema_cache is None:
|
||||
with open(self.schema_path, encoding="utf-8") as f:
|
||||
self._schema_cache = json.load(f)
|
||||
|
||||
validate(instance=policy_dict, schema=self._schema_cache) # fmt: skip # pyright: ignore[reportArgumentType]
|
||||
|
||||
def _validate_business_rules(
|
||||
self, policy_dict: dict[str, Any]
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Validate business rules beyond schema"""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# Check that all evidence IDs are in document_kinds
|
||||
document_kinds = set(policy_dict.get("document_kinds", []))
|
||||
|
||||
for schedule_id, schedule in policy_dict.get("schedules", {}).items():
|
||||
for evidence in schedule.get("evidence", []):
|
||||
evidence_id = evidence.get("id")
|
||||
if evidence_id not in document_kinds:
|
||||
# Check if it's in acceptable_alternatives of any evidence
|
||||
found_in_alternatives = False
|
||||
for other_schedule in policy_dict.get("schedules", {}).values():
|
||||
for other_evidence in other_schedule.get("evidence", []):
|
||||
if evidence_id in other_evidence.get(
|
||||
"acceptable_alternatives", []
|
||||
):
|
||||
found_in_alternatives = True
|
||||
break
|
||||
if found_in_alternatives:
|
||||
break
|
||||
|
||||
if not found_in_alternatives:
|
||||
errors.append(
|
||||
f"Evidence ID '{evidence_id}' in schedule '{schedule_id}' "
|
||||
f"not found in document_kinds or acceptable_alternatives"
|
||||
)
|
||||
|
||||
# Check acceptable alternatives
|
||||
for alt in evidence.get("acceptable_alternatives", []):
|
||||
if alt not in document_kinds:
|
||||
warnings.append(
|
||||
f"Alternative '{alt}' for evidence '{evidence_id}' "
|
||||
f"not found in document_kinds"
|
||||
)
|
||||
|
||||
# Check that all schedules referenced in triggers exist
|
||||
triggers = policy_dict.get("triggers", {})
|
||||
schedules = policy_dict.get("schedules", {})
|
||||
|
||||
for schedule_id in triggers:
|
||||
if schedule_id not in schedules:
|
||||
errors.append(
|
||||
f"Trigger for '{schedule_id}' but no schedule definition found"
|
||||
)
|
||||
|
||||
return errors, warnings
|
||||
|
||||
def _compile_condition(self, condition: str) -> Callable[[str, str], bool]:
|
||||
"""Compile a condition string into a callable predicate"""
|
||||
|
||||
# Simple condition parser for the DSL
|
||||
condition = condition.strip()
|
||||
|
||||
# Handle exists() conditions
|
||||
exists_match = re.match(r"exists\((\w+)\[([^\]]+)\]\)", condition)
|
||||
if exists_match:
|
||||
entity_type = exists_match.group(1)
|
||||
filters = exists_match.group(2)
|
||||
return self._create_exists_predicate(entity_type, filters)
|
||||
|
||||
# Handle simple property conditions
|
||||
if condition in [
|
||||
"property_joint_ownership",
|
||||
"candidate_FHL",
|
||||
"claims_FTCR",
|
||||
"claims_remittance_basis",
|
||||
"received_estate_income",
|
||||
]:
|
||||
return self._create_property_predicate(condition)
|
||||
|
||||
# Handle computed conditions
|
||||
if condition in [
|
||||
"turnover_lt_vat_threshold",
|
||||
"turnover_ge_vat_threshold",
|
||||
]:
|
||||
return self._create_computed_predicate(condition)
|
||||
|
||||
# Handle taxpayer flags
|
||||
if condition.startswith("taxpayer_flag:"):
|
||||
flag_name = condition.split(":", 1)[1].strip()
|
||||
return self._create_flag_predicate(flag_name)
|
||||
|
||||
# Handle filing mode
|
||||
if condition.startswith("filing_mode:"):
|
||||
mode = condition.split(":", 1)[1].strip()
|
||||
return self._create_filing_mode_predicate(mode)
|
||||
|
||||
# Default: always false for unknown conditions
|
||||
logger.warning("Unknown condition, defaulting to False", condition=condition)
|
||||
return lambda taxpayer_id, tax_year: False
|
||||
|
||||
def _create_exists_predicate(
|
||||
self, entity_type: str, filters: str
|
||||
) -> Callable[[str, str], bool]:
|
||||
"""Create predicate for exists() conditions"""
|
||||
|
||||
def predicate(taxpayer_id: str, tax_year: str) -> bool:
|
||||
# This would query the KG for the entity with filters
|
||||
# For now, return False as placeholder
|
||||
logger.debug(
|
||||
"Exists predicate called",
|
||||
entity_type=entity_type,
|
||||
filters=filters,
|
||||
taxpayer_id=taxpayer_id,
|
||||
tax_year=tax_year,
|
||||
)
|
||||
return False
|
||||
|
||||
return predicate
|
||||
|
||||
def _create_property_predicate(
|
||||
self, property_name: str
|
||||
) -> Callable[[str, str], bool]:
|
||||
"""Create predicate for property conditions"""
|
||||
|
||||
def predicate(taxpayer_id: str, tax_year: str) -> bool:
|
||||
# This would query the KG for the property
|
||||
logger.debug(
|
||||
"Property predicate called",
|
||||
property_name=property_name,
|
||||
taxpayer_id=taxpayer_id,
|
||||
tax_year=tax_year,
|
||||
)
|
||||
return False
|
||||
|
||||
return predicate
|
||||
|
||||
def _create_computed_predicate(
|
||||
self, computation: str
|
||||
) -> Callable[[str, str], bool]:
|
||||
"""Create predicate for computed conditions"""
|
||||
|
||||
def predicate(taxpayer_id: str, tax_year: str) -> bool:
|
||||
# This would perform the computation
|
||||
logger.debug(
|
||||
"Computed predicate called",
|
||||
computation=computation,
|
||||
taxpayer_id=taxpayer_id,
|
||||
tax_year=tax_year,
|
||||
)
|
||||
return False
|
||||
|
||||
return predicate
|
||||
|
||||
def _create_flag_predicate(self, flag_name: str) -> Callable[[str, str], bool]:
|
||||
"""Create predicate for taxpayer flags"""
|
||||
|
||||
def predicate(taxpayer_id: str, tax_year: str) -> bool:
|
||||
# This would check taxpayer flags
|
||||
logger.debug(
|
||||
"Flag predicate called",
|
||||
flag_name=flag_name,
|
||||
taxpayer_id=taxpayer_id,
|
||||
tax_year=tax_year,
|
||||
)
|
||||
return False
|
||||
|
||||
return predicate
|
||||
|
||||
def _create_filing_mode_predicate(self, mode: str) -> Callable[[str, str], bool]:
|
||||
"""Create predicate for filing mode"""
|
||||
|
||||
def predicate(taxpayer_id: str, tax_year: str) -> bool:
|
||||
# This would check filing mode preference
|
||||
logger.debug(
|
||||
"Filing mode predicate called",
|
||||
mode=mode,
|
||||
taxpayer_id=taxpayer_id,
|
||||
tax_year=tax_year,
|
||||
)
|
||||
return False
|
||||
|
||||
return predicate
|
||||
|
||||
def _calculate_hash(self, file_paths: list[str]) -> str:
|
||||
"""Calculate hash of policy files"""
|
||||
hasher = hashlib.sha256()
|
||||
|
||||
for path in sorted(file_paths):
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
hasher.update(f.read())
|
||||
except FileNotFoundError:
|
||||
logger.warning("File not found for hashing", path=path)
|
||||
|
||||
return hasher.hexdigest()
|
||||
50
libs/policy/utils.py
Normal file
50
libs/policy/utils.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Utility functions for policy management."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..schemas import CompiledCoveragePolicy, CoveragePolicy, ValidationResult
|
||||
from .loader import PolicyLoader
|
||||
|
||||
# Global policy loader instance
|
||||
_policy_loader: PolicyLoader | None = None
|
||||
|
||||
|
||||
def get_policy_loader(config_dir: str = "config") -> PolicyLoader:
|
||||
"""Get global policy loader instance"""
|
||||
global _policy_loader
|
||||
if _policy_loader is None:
|
||||
_policy_loader = PolicyLoader(config_dir)
|
||||
return _policy_loader
|
||||
|
||||
|
||||
# Convenience functions
|
||||
def load_policy(
|
||||
baseline_path: str | None = None,
|
||||
jurisdiction: str = "UK",
|
||||
tax_year: str = "2024-25",
|
||||
tenant_id: str | None = None,
|
||||
) -> CoveragePolicy:
|
||||
"""Load coverage policy with overlays"""
|
||||
return get_policy_loader().load_policy(
|
||||
baseline_path, jurisdiction, tax_year, tenant_id
|
||||
)
|
||||
|
||||
|
||||
def merge_overlays(base: dict[str, Any], *overlays: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Merge base policy with overlays"""
|
||||
return get_policy_loader().merge_overlays(base, *overlays)
|
||||
|
||||
|
||||
def apply_feature_flags(policy: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Apply feature flags to policy"""
|
||||
return get_policy_loader().apply_feature_flags(policy)
|
||||
|
||||
|
||||
def compile_predicates(policy: CoveragePolicy) -> CompiledCoveragePolicy:
|
||||
"""Compile policy predicates"""
|
||||
return get_policy_loader().compile_predicates(policy)
|
||||
|
||||
|
||||
def validate_policy(policy_dict: dict[str, Any]) -> ValidationResult:
|
||||
"""Validate policy"""
|
||||
return get_policy_loader().validate_policy(policy_dict)
|
||||
13
libs/rag/__init__.py
Normal file
13
libs/rag/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Qdrant collections CRUD, hybrid search, rerank wrapper, de-identification utilities."""
|
||||
|
||||
from .collection_manager import QdrantCollectionManager
|
||||
from .pii_detector import PIIDetector
|
||||
from .retriever import RAGRetriever
|
||||
from .utils import rag_search_for_citations
|
||||
|
||||
__all__ = [
|
||||
"PIIDetector",
|
||||
"QdrantCollectionManager",
|
||||
"RAGRetriever",
|
||||
"rag_search_for_citations",
|
||||
]
|
||||
233
libs/rag/collection_manager.py
Normal file
233
libs/rag/collection_manager.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Manage Qdrant collections for RAG."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
Filter,
|
||||
PointStruct,
|
||||
SparseVector,
|
||||
VectorParams,
|
||||
)
|
||||
|
||||
from .pii_detector import PIIDetector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class QdrantCollectionManager:
|
||||
"""Manage Qdrant collections for RAG"""
|
||||
|
||||
def __init__(self, client: QdrantClient):
|
||||
self.client = client
|
||||
self.pii_detector = PIIDetector()
|
||||
|
||||
async def ensure_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
vector_size: int = 384,
|
||||
distance: Distance = Distance.COSINE,
|
||||
sparse_vector_config: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Ensure collection exists with proper configuration"""
|
||||
try:
|
||||
# Check if collection exists
|
||||
collections = self.client.get_collections().collections
|
||||
if any(c.name == collection_name for c in collections):
|
||||
logger.debug("Collection already exists", collection=collection_name)
|
||||
return True
|
||||
|
||||
# Create collection with dense vectors
|
||||
vector_config = VectorParams(size=vector_size, distance=distance)
|
||||
|
||||
# Add sparse vector configuration if provided
|
||||
sparse_vectors_config = None
|
||||
if sparse_vector_config:
|
||||
sparse_vectors_config = {"sparse": sparse_vector_config}
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vector_config,
|
||||
sparse_vectors_config=sparse_vectors_config, # type: ignore
|
||||
)
|
||||
|
||||
logger.info("Created collection", collection=collection_name)
|
||||
return True
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Failed to create collection", collection=collection_name, error=str(e)
|
||||
)
|
||||
return False
|
||||
|
||||
async def upsert_points(
|
||||
self, collection_name: str, points: list[PointStruct]
|
||||
) -> bool:
|
||||
"""Upsert points to collection"""
|
||||
try:
|
||||
# Validate all points are PII-free
|
||||
for point in points:
|
||||
if point.payload and not point.payload.get("pii_free", False):
|
||||
logger.warning("Point not marked as PII-free", point_id=point.id)
|
||||
return False
|
||||
|
||||
self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
logger.info(
|
||||
"Upserted points", collection=collection_name, count=len(points)
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Failed to upsert points", collection=collection_name, error=str(e)
|
||||
)
|
||||
return False
|
||||
|
||||
async def search_dense( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
self,
|
||||
collection_name: str,
|
||||
query_vector: list[float],
|
||||
limit: int = 10,
|
||||
filter_conditions: Filter | None = None,
|
||||
score_threshold: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search using dense vectors"""
|
||||
try:
|
||||
search_result = self.client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
query_filter=filter_conditions,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
return [
|
||||
{"id": hit.id, "score": hit.score, "payload": hit.payload}
|
||||
for hit in search_result
|
||||
]
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Dense search failed", collection=collection_name, error=str(e)
|
||||
)
|
||||
return []
|
||||
|
||||
async def search_sparse(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_vector: SparseVector,
|
||||
limit: int = 10,
|
||||
filter_conditions: Filter | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search using sparse vectors"""
|
||||
try:
|
||||
search_result = self.client.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector, # type: ignore
|
||||
query_filter=filter_conditions,
|
||||
limit=limit,
|
||||
using="sparse",
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
)
|
||||
|
||||
return [
|
||||
{"id": hit.id, "score": hit.score, "payload": hit.payload}
|
||||
for hit in search_result
|
||||
]
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Sparse search failed", collection=collection_name, error=str(e)
|
||||
)
|
||||
return []
|
||||
|
||||
async def hybrid_search( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
self,
|
||||
collection_name: str,
|
||||
dense_vector: list[float],
|
||||
sparse_vector: SparseVector,
|
||||
limit: int = 10,
|
||||
alpha: float = 0.5,
|
||||
filter_conditions: Filter | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Perform hybrid search combining dense and sparse results"""
|
||||
|
||||
# Get dense results
|
||||
dense_results = await self.search_dense(
|
||||
collection_name=collection_name,
|
||||
query_vector=dense_vector,
|
||||
limit=limit * 2, # Get more results for fusion
|
||||
filter_conditions=filter_conditions,
|
||||
)
|
||||
|
||||
# Get sparse results
|
||||
sparse_results = await self.search_sparse(
|
||||
collection_name=collection_name,
|
||||
query_vector=sparse_vector,
|
||||
limit=limit * 2,
|
||||
filter_conditions=filter_conditions,
|
||||
)
|
||||
|
||||
# Combine and re-rank results
|
||||
return self._fuse_results(dense_results, sparse_results, alpha, limit)
|
||||
|
||||
def _fuse_results( # pylint: disable=too-many-locals
|
||||
self,
|
||||
dense_results: list[dict[str, Any]],
|
||||
sparse_results: list[dict[str, Any]],
|
||||
alpha: float,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fuse dense and sparse search results"""
|
||||
|
||||
# Create score maps
|
||||
dense_scores = {result["id"]: result["score"] for result in dense_results}
|
||||
sparse_scores = {result["id"]: result["score"] for result in sparse_results}
|
||||
|
||||
# Get all unique IDs
|
||||
all_ids = set(dense_scores.keys()) | set(sparse_scores.keys())
|
||||
|
||||
# Calculate hybrid scores
|
||||
hybrid_results = []
|
||||
for doc_id in all_ids:
|
||||
dense_score = dense_scores.get(doc_id, 0.0)
|
||||
sparse_score = sparse_scores.get(doc_id, 0.0)
|
||||
|
||||
# Normalize scores (simple min-max normalization)
|
||||
if dense_results:
|
||||
max_dense = max(dense_scores.values())
|
||||
dense_score = dense_score / max_dense if max_dense > 0 else 0
|
||||
|
||||
if sparse_results:
|
||||
max_sparse = max(sparse_scores.values())
|
||||
sparse_score = sparse_score / max_sparse if max_sparse > 0 else 0
|
||||
|
||||
# Combine scores
|
||||
hybrid_score = alpha * dense_score + (1 - alpha) * sparse_score
|
||||
|
||||
# Get payload from either result
|
||||
payload = None
|
||||
for result in dense_results + sparse_results:
|
||||
if result["id"] == doc_id:
|
||||
payload = result["payload"]
|
||||
break
|
||||
|
||||
hybrid_results.append(
|
||||
{
|
||||
"id": doc_id,
|
||||
"score": hybrid_score,
|
||||
"dense_score": dense_score,
|
||||
"sparse_score": sparse_score,
|
||||
"payload": payload,
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by hybrid score and return top results
|
||||
hybrid_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return hybrid_results[:limit]
|
||||
507
libs/rag/indexer.py
Normal file
507
libs/rag/indexer.py
Normal file
@@ -0,0 +1,507 @@
|
||||
# FILE: retrieval/indexer.py
|
||||
# De-identify -> embed dense/sparse -> upsert to Qdrant with payload
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import spacy
|
||||
import torch
|
||||
import yaml
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Distance, PointStruct, SparseVector, VectorParams
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from .chunker import DocumentChunker
|
||||
from .pii_detector import PIIDetector, PIIRedactor
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexingResult:
|
||||
collection_name: str
|
||||
points_indexed: int
|
||||
points_updated: int
|
||||
points_failed: int
|
||||
processing_time: float
|
||||
errors: list[str]
|
||||
|
||||
|
||||
class RAGIndexer:
|
||||
def __init__(self, config_path: str, qdrant_url: str = "http://localhost:6333"):
|
||||
with open(config_path) as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
|
||||
self.qdrant_client = QdrantClient(url=qdrant_url)
|
||||
self.chunker = DocumentChunker(config_path)
|
||||
self.pii_detector = PIIDetector()
|
||||
self.pii_redactor = PIIRedactor()
|
||||
|
||||
# Initialize embedding models
|
||||
self.dense_model = SentenceTransformer(
|
||||
self.config.get("embedding_model", "bge-small-en-v1.5")
|
||||
)
|
||||
|
||||
# Initialize sparse model (BM25/SPLADE)
|
||||
self.sparse_model = self._init_sparse_model()
|
||||
|
||||
# Initialize NLP pipeline
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _init_sparse_model(self):
|
||||
"""Initialize sparse embedding model (BM25 or SPLADE)"""
|
||||
sparse_config = self.config.get("sparse_model", {})
|
||||
model_type = sparse_config.get("type", "bm25")
|
||||
|
||||
if model_type == "bm25":
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
return BM25Okapi
|
||||
elif model_type == "splade":
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"naver/splade-cocondenser-ensembledistil"
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
"naver/splade-cocondenser-ensembledistil"
|
||||
)
|
||||
return {"tokenizer": tokenizer, "model": model}
|
||||
else:
|
||||
raise ValueError(f"Unsupported sparse model type: {model_type}")
|
||||
|
||||
async def index_document(
|
||||
self, document_path: str, collection_name: str, metadata: dict[str, Any]
|
||||
) -> IndexingResult:
|
||||
"""Index a single document into the specified collection"""
|
||||
start_time = datetime.now()
|
||||
errors = []
|
||||
points_indexed = 0
|
||||
points_updated = 0
|
||||
points_failed = 0
|
||||
|
||||
try:
|
||||
# Step 1: Chunk the document
|
||||
chunks = await self.chunker.chunk_document(document_path, metadata)
|
||||
|
||||
# Step 2: Process each chunk
|
||||
points = []
|
||||
for chunk in chunks:
|
||||
try:
|
||||
point = await self._process_chunk(chunk, collection_name, metadata)
|
||||
if point:
|
||||
points.append(point)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Failed to process chunk {chunk.get('id', 'unknown')}: {str(e)}"
|
||||
)
|
||||
errors.append(f"Chunk processing error: {str(e)}")
|
||||
points_failed += 1
|
||||
|
||||
# Step 3: Upsert to Qdrant
|
||||
if points:
|
||||
try:
|
||||
operation_info = self.qdrant_client.upsert(
|
||||
collection_name=collection_name, points=points, wait=True
|
||||
)
|
||||
points_indexed = len(points)
|
||||
self.logger.info(
|
||||
f"Indexed {points_indexed} points to {collection_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to upsert to Qdrant: {str(e)}")
|
||||
errors.append(f"Qdrant upsert error: {str(e)}")
|
||||
points_failed += len(points)
|
||||
points_indexed = 0
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Document indexing failed: {str(e)}")
|
||||
errors.append(f"Document indexing error: {str(e)}")
|
||||
|
||||
processing_time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
return IndexingResult(
|
||||
collection_name=collection_name,
|
||||
points_indexed=points_indexed,
|
||||
points_updated=points_updated,
|
||||
points_failed=points_failed,
|
||||
processing_time=processing_time,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def _process_chunk(
|
||||
self, chunk: dict[str, Any], collection_name: str, base_metadata: dict[str, Any]
|
||||
) -> PointStruct | None:
|
||||
"""Process a single chunk: de-identify, embed, create point"""
|
||||
|
||||
# Step 1: De-identify PII
|
||||
content = chunk["content"]
|
||||
pii_detected = self.pii_detector.detect(content)
|
||||
|
||||
if pii_detected:
|
||||
# Redact PII and create mapping
|
||||
redacted_content, pii_mapping = self.pii_redactor.redact(
|
||||
content, pii_detected
|
||||
)
|
||||
|
||||
# Store PII mapping securely (not in vector DB)
|
||||
await self._store_pii_mapping(chunk["id"], pii_mapping)
|
||||
|
||||
# Log PII detection for audit
|
||||
self.logger.warning(
|
||||
f"PII detected in chunk {chunk['id']}: {[p['type'] for p in pii_detected]}"
|
||||
)
|
||||
else:
|
||||
redacted_content = content
|
||||
|
||||
# Verify no PII remains
|
||||
if not self._verify_pii_free(redacted_content):
|
||||
self.logger.error(f"PII verification failed for chunk {chunk['id']}")
|
||||
return None
|
||||
|
||||
# Step 2: Generate embeddings
|
||||
try:
|
||||
dense_vector = await self._generate_dense_embedding(redacted_content)
|
||||
sparse_vector = await self._generate_sparse_embedding(redacted_content)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Embedding generation failed for chunk {chunk['id']}: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Step 3: Prepare metadata
|
||||
payload = self._prepare_payload(chunk, base_metadata, redacted_content)
|
||||
payload["pii_free"] = True # Verified above
|
||||
|
||||
# Step 4: Create point
|
||||
point = PointStruct(
|
||||
id=chunk["id"],
|
||||
vector={"dense": dense_vector, "sparse": sparse_vector},
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
return point
|
||||
|
||||
async def _generate_dense_embedding(self, text: str) -> list[float]:
|
||||
"""Generate dense vector embedding"""
|
||||
try:
|
||||
# Use sentence transformer for dense embeddings
|
||||
embedding = self.dense_model.encode(text, normalize_embeddings=True)
|
||||
return embedding.tolist()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Dense embedding generation failed: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _generate_sparse_embedding(self, text: str) -> SparseVector:
|
||||
"""Generate sparse vector embedding (BM25 or SPLADE)"""
|
||||
vector = SparseVector(indices=[], values=[])
|
||||
|
||||
try:
|
||||
sparse_config = self.config.get("sparse_model", {})
|
||||
model_type = sparse_config.get("type", "bm25")
|
||||
|
||||
if model_type == "bm25":
|
||||
# Simple BM25-style sparse representation
|
||||
doc = self.nlp(text)
|
||||
tokens = [
|
||||
token.lemma_.lower()
|
||||
for token in doc
|
||||
if not token.is_stop and not token.is_punct
|
||||
]
|
||||
|
||||
# Create term frequency vector
|
||||
term_freq = {}
|
||||
for token in tokens:
|
||||
term_freq[token] = term_freq.get(token, 0) + 1
|
||||
|
||||
# Convert to sparse vector format
|
||||
vocab_size = sparse_config.get("vocab_size", 30000)
|
||||
indices = []
|
||||
values = []
|
||||
|
||||
for term, freq in term_freq.items():
|
||||
# Simple hash-based vocabulary mapping
|
||||
term_id = hash(term) % vocab_size
|
||||
indices.append(term_id)
|
||||
values.append(float(freq))
|
||||
|
||||
vector = SparseVector(indices=indices, values=values)
|
||||
|
||||
elif model_type == "splade":
|
||||
# SPLADE sparse embeddings
|
||||
tokenizer = self.sparse_model["tokenizer"]
|
||||
model = self.sparse_model["model"]
|
||||
|
||||
inputs = tokenizer(
|
||||
text, return_tensors="pt", truncation=True, max_length=512
|
||||
)
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Extract sparse representation
|
||||
logits = outputs.logits.squeeze()
|
||||
sparse_rep = torch.relu(logits).detach().numpy()
|
||||
|
||||
# Convert to sparse format
|
||||
indices = np.nonzero(sparse_rep)[0].tolist()
|
||||
values = sparse_rep[indices].tolist()
|
||||
|
||||
vector = SparseVector(indices=indices, values=values)
|
||||
|
||||
return vector
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Sparse embedding generation failed: {str(e)}")
|
||||
# Return empty sparse vector as fallback
|
||||
return vector
|
||||
|
||||
def _prepare_payload(
|
||||
self, chunk: dict[str, Any], base_metadata: dict[str, Any], content: str
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare payload metadata for the chunk"""
|
||||
|
||||
# Start with base metadata
|
||||
payload = base_metadata.copy()
|
||||
|
||||
# Add chunk-specific metadata
|
||||
payload.update(
|
||||
{
|
||||
"document_id": chunk.get("document_id"),
|
||||
"content": content, # De-identified content
|
||||
"chunk_index": chunk.get("chunk_index", 0),
|
||||
"total_chunks": chunk.get("total_chunks", 1),
|
||||
"page_numbers": chunk.get("page_numbers", []),
|
||||
"section_hierarchy": chunk.get("section_hierarchy", []),
|
||||
"has_calculations": self._detect_calculations(content),
|
||||
"has_forms": self._detect_form_references(content),
|
||||
"confidence_score": chunk.get("confidence_score", 1.0),
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"version": self.config.get("version", "1.0"),
|
||||
}
|
||||
)
|
||||
|
||||
# Extract and add topic tags
|
||||
topic_tags = self._extract_topic_tags(content)
|
||||
if topic_tags:
|
||||
payload["topic_tags"] = topic_tags
|
||||
|
||||
# Add content analysis
|
||||
payload.update(self._analyze_content(content))
|
||||
|
||||
return payload
|
||||
|
||||
def _detect_calculations(self, text: str) -> bool:
|
||||
"""Detect if text contains calculations or formulas"""
|
||||
calculation_patterns = [
|
||||
r"\d+\s*[+\-*/]\s*\d+",
|
||||
r"£\d+(?:,\d{3})*(?:\.\d{2})?",
|
||||
r"\d+(?:\.\d+)?%",
|
||||
r"total|sum|calculate|compute",
|
||||
r"rate|threshold|allowance|relief",
|
||||
]
|
||||
|
||||
for pattern in calculation_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _detect_form_references(self, text: str) -> bool:
|
||||
"""Detect references to tax forms"""
|
||||
form_patterns = [
|
||||
r"SA\d{3}",
|
||||
r"P\d{2}",
|
||||
r"CT\d{3}",
|
||||
r"VAT\d{3}",
|
||||
r"form\s+\w+",
|
||||
r"schedule\s+\w+",
|
||||
]
|
||||
|
||||
for pattern in form_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_topic_tags(self, text: str) -> list[str]:
|
||||
"""Extract topic tags from content"""
|
||||
topic_keywords = {
|
||||
"employment": [
|
||||
"PAYE",
|
||||
"payslip",
|
||||
"P60",
|
||||
"employment",
|
||||
"salary",
|
||||
"wages",
|
||||
"employer",
|
||||
],
|
||||
"self_employment": [
|
||||
"self-employed",
|
||||
"business",
|
||||
"turnover",
|
||||
"expenses",
|
||||
"profit",
|
||||
"loss",
|
||||
],
|
||||
"property": ["rental", "property", "landlord", "FHL", "mortgage", "rent"],
|
||||
"dividends": ["dividend", "shares", "distribution", "corporation tax"],
|
||||
"capital_gains": ["capital gains", "disposal", "acquisition", "CGT"],
|
||||
"pensions": ["pension", "retirement", "SIPP", "occupational"],
|
||||
"savings": ["interest", "savings", "ISA", "bonds"],
|
||||
"inheritance": ["inheritance", "IHT", "estate", "probate"],
|
||||
"vat": ["VAT", "value added tax", "registration", "return"],
|
||||
}
|
||||
|
||||
tags = []
|
||||
text_lower = text.lower()
|
||||
|
||||
for topic, keywords in topic_keywords.items():
|
||||
for keyword in keywords:
|
||||
if keyword.lower() in text_lower:
|
||||
tags.append(topic)
|
||||
break
|
||||
|
||||
return list(set(tags)) # Remove duplicates
|
||||
|
||||
def _analyze_content(self, text: str) -> dict[str, Any]:
|
||||
"""Analyze content for additional metadata"""
|
||||
doc = self.nlp(text)
|
||||
|
||||
return {
|
||||
"word_count": len([token for token in doc if not token.is_space]),
|
||||
"sentence_count": len(list(doc.sents)),
|
||||
"entity_count": len(doc.ents),
|
||||
"complexity_score": self._calculate_complexity(doc),
|
||||
"language": doc.lang_ if hasattr(doc, "lang_") else "en",
|
||||
}
|
||||
|
||||
def _calculate_complexity(self, doc: dict) -> float:
|
||||
"""Calculate text complexity score"""
|
||||
if not doc:
|
||||
return 0.0
|
||||
|
||||
# Simple complexity based on sentence length and vocabulary
|
||||
avg_sentence_length = sum(len(sent) for sent in doc.sents) / len(
|
||||
list(doc.sents)
|
||||
)
|
||||
unique_words = len(set(token.lemma_.lower() for token in doc if token.is_alpha))
|
||||
total_words = len([token for token in doc if token.is_alpha])
|
||||
|
||||
vocabulary_diversity = unique_words / total_words if total_words > 0 else 0
|
||||
|
||||
# Normalize to 0-1 scale
|
||||
complexity = min(1.0, (avg_sentence_length / 20.0 + vocabulary_diversity) / 2.0)
|
||||
return complexity
|
||||
|
||||
def _verify_pii_free(self, text: str) -> bool:
|
||||
"""Verify that text contains no PII"""
|
||||
# Quick verification using patterns
|
||||
pii_patterns = [
|
||||
r"\b[A-Z]{2}\d{6}[A-D]\b", # NI number
|
||||
r"\b\d{10}\b", # UTR
|
||||
r"\b[A-Z]{2}\d{2}[A-Z]{4}\d{14}\b", # IBAN
|
||||
r"\b\d{2}-\d{2}-\d{2}\b", # Sort code
|
||||
r"\b[A-Z]{1,2}\d[A-Z\d]?\s*\d[A-Z]{2}\b", # Postcode
|
||||
r"\b[\w\.-]+@[\w\.-]+\.\w+\b", # Email
|
||||
r"\b(?:\+44|0)\d{10,11}\b", # Phone
|
||||
]
|
||||
|
||||
for pattern in pii_patterns:
|
||||
if re.search(pattern, text):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _store_pii_mapping(
|
||||
self, chunk_id: str, pii_mapping: dict[str, Any]
|
||||
) -> None:
|
||||
"""Store PII mapping in secure client data store (not in vector DB)"""
|
||||
# This would integrate with the secure PostgreSQL client data store
|
||||
# For now, just log the mapping securely
|
||||
self.logger.info(
|
||||
f"PII mapping stored for chunk {chunk_id}: {len(pii_mapping)} items"
|
||||
)
|
||||
|
||||
async def create_collections(self) -> None:
|
||||
"""Create all Qdrant collections based on configuration"""
|
||||
collections_config_path = Path(__file__).parent / "qdrant_collections.json"
|
||||
|
||||
with open(collections_config_path) as f:
|
||||
collections_config = json.load(f)
|
||||
|
||||
for collection_config in collections_config["collections"]:
|
||||
collection_name = collection_config["name"]
|
||||
|
||||
try:
|
||||
# Check if collection exists
|
||||
try:
|
||||
self.qdrant_client.get_collection(collection_name)
|
||||
self.logger.info(f"Collection {collection_name} already exists")
|
||||
continue
|
||||
except:
|
||||
pass # Collection doesn't exist, create it
|
||||
|
||||
# Create collection
|
||||
vectors_config = {}
|
||||
|
||||
# Dense vector configuration
|
||||
if "dense" in collection_config:
|
||||
vectors_config["dense"] = VectorParams(
|
||||
size=collection_config["dense"]["size"],
|
||||
distance=Distance.COSINE,
|
||||
)
|
||||
|
||||
# Sparse vector configuration
|
||||
if collection_config.get("sparse", False):
|
||||
vectors_config["sparse"] = VectorParams(
|
||||
size=30000, # Vocabulary size for sparse vectors
|
||||
distance=Distance.DOT,
|
||||
on_disk=True,
|
||||
)
|
||||
|
||||
self.qdrant_client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=vectors_config,
|
||||
**collection_config.get("indexing_config", {}),
|
||||
)
|
||||
|
||||
self.logger.info(f"Created collection: {collection_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Failed to create collection {collection_name}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def batch_index(
|
||||
self, documents: list[dict[str, Any]], collection_name: str
|
||||
) -> list[IndexingResult]:
|
||||
"""Index multiple documents in batch"""
|
||||
results = []
|
||||
|
||||
for doc_info in documents:
|
||||
result = await self.index_document(
|
||||
doc_info["path"], collection_name, doc_info["metadata"]
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def get_collection_stats(self, collection_name: str) -> dict[str, Any]:
|
||||
"""Get statistics for a collection"""
|
||||
try:
|
||||
collection_info = self.qdrant_client.get_collection(collection_name)
|
||||
return {
|
||||
"name": collection_name,
|
||||
"vectors_count": collection_info.vectors_count,
|
||||
"indexed_vectors_count": collection_info.indexed_vectors_count,
|
||||
"points_count": collection_info.points_count,
|
||||
"segments_count": collection_info.segments_count,
|
||||
"status": collection_info.status,
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get stats for {collection_name}: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
77
libs/rag/pii_detector.py
Normal file
77
libs/rag/pii_detector.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""PII detection and de-identification utilities."""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PIIDetector:
|
||||
"""PII detection and de-identification utilities"""
|
||||
|
||||
# Regex patterns for common PII
|
||||
PII_PATTERNS = {
|
||||
"uk_ni_number": r"\b[A-CEGHJ-PR-TW-Z]{2}\d{6}[A-D]\b",
|
||||
"uk_utr": r"\b\d{10}\b",
|
||||
"uk_postcode": r"\b[A-Z]{1,2}\d[A-Z0-9]?\s*\d[A-Z]{2}\b",
|
||||
"uk_sort_code": r"\b\d{2}-\d{2}-\d{2}\b",
|
||||
"uk_account_number": r"\b\d{8}\b",
|
||||
"email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
||||
"phone": r"\b(?:\+44|0)\d{10,11}\b",
|
||||
"iban": r"\bGB\d{2}[A-Z]{4}\d{14}\b",
|
||||
"amount": r"£\d{1,3}(?:,\d{3})*(?:\.\d{2})?",
|
||||
"date": r"\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.compiled_patterns = {
|
||||
name: re.compile(pattern, re.IGNORECASE)
|
||||
for name, pattern in self.PII_PATTERNS.items()
|
||||
}
|
||||
|
||||
def detect_pii(self, text: str) -> list[dict[str, Any]]:
|
||||
"""Detect PII in text and return matches with positions"""
|
||||
matches = []
|
||||
|
||||
for pii_type, pattern in self.compiled_patterns.items():
|
||||
for match in pattern.finditer(text):
|
||||
matches.append(
|
||||
{
|
||||
"type": pii_type,
|
||||
"value": match.group(),
|
||||
"start": match.start(),
|
||||
"end": match.end(),
|
||||
"placeholder": self._generate_placeholder(
|
||||
pii_type, match.group()
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return sorted(matches, key=lambda x: x["start"])
|
||||
|
||||
def de_identify_text(self, text: str) -> tuple[str, dict[str, str]]:
|
||||
"""De-identify text by replacing PII with placeholders"""
|
||||
pii_matches = self.detect_pii(text)
|
||||
pii_mapping = {}
|
||||
|
||||
# Replace PII from end to start to maintain positions
|
||||
de_identified = text
|
||||
for match in reversed(pii_matches):
|
||||
placeholder = match["placeholder"]
|
||||
pii_mapping[placeholder] = match["value"]
|
||||
de_identified = (
|
||||
de_identified[: match["start"]]
|
||||
+ placeholder
|
||||
+ de_identified[match["end"] :]
|
||||
)
|
||||
|
||||
return de_identified, pii_mapping
|
||||
|
||||
def _generate_placeholder(self, pii_type: str, value: str) -> str:
|
||||
"""Generate consistent placeholder for PII value"""
|
||||
# Create hash of the value for consistent placeholders
|
||||
value_hash = hashlib.md5(value.encode()).hexdigest()[:8]
|
||||
return f"[{pii_type.upper()}_{value_hash}]"
|
||||
|
||||
def has_pii(self, text: str) -> bool:
|
||||
"""Check if text contains any PII"""
|
||||
return len(self.detect_pii(text)) > 0
|
||||
235
libs/rag/retriever.py
Normal file
235
libs/rag/retriever.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""High-level RAG retrieval with reranking and KG fusion."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
SparseVector,
|
||||
)
|
||||
|
||||
from .collection_manager import QdrantCollectionManager
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RAGRetriever: # pylint: disable=too-few-public-methods
|
||||
"""High-level RAG retrieval with reranking and KG fusion"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_client: QdrantClient,
|
||||
neo4j_client: Any = None,
|
||||
reranker_model: str | None = None,
|
||||
) -> None:
|
||||
self.collection_manager = QdrantCollectionManager(qdrant_client)
|
||||
self.neo4j_client = neo4j_client
|
||||
self.reranker_model = reranker_model
|
||||
|
||||
async def search( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
|
||||
self,
|
||||
query: str,
|
||||
collections: list[str],
|
||||
dense_vector: list[float],
|
||||
sparse_vector: SparseVector,
|
||||
k: int = 10,
|
||||
alpha: float = 0.5,
|
||||
beta: float = 0.3, # pylint: disable=unused-argument
|
||||
gamma: float = 0.2, # pylint: disable=unused-argument
|
||||
tax_year: str | None = None,
|
||||
jurisdiction: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Perform comprehensive RAG search with KG fusion"""
|
||||
|
||||
# Build filter conditions
|
||||
filter_conditions = self._build_filter(tax_year, jurisdiction)
|
||||
|
||||
# Search each collection
|
||||
all_chunks = []
|
||||
for collection in collections:
|
||||
chunks = await self.collection_manager.hybrid_search(
|
||||
collection_name=collection,
|
||||
dense_vector=dense_vector,
|
||||
sparse_vector=sparse_vector,
|
||||
limit=k,
|
||||
alpha=alpha,
|
||||
filter_conditions=filter_conditions,
|
||||
)
|
||||
|
||||
# Add collection info to chunks
|
||||
for chunk in chunks:
|
||||
chunk["collection"] = collection
|
||||
|
||||
all_chunks.extend(chunks)
|
||||
|
||||
# Re-rank if reranker is available
|
||||
if self.reranker_model and len(all_chunks) > k:
|
||||
all_chunks = await self._rerank_chunks(query, all_chunks, k)
|
||||
|
||||
# Sort by score and take top k
|
||||
all_chunks.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_chunks = all_chunks[:k]
|
||||
|
||||
# Get KG hints if Neo4j client is available
|
||||
kg_hints = []
|
||||
if self.neo4j_client:
|
||||
kg_hints = await self._get_kg_hints(query, top_chunks)
|
||||
|
||||
# Extract citations
|
||||
citations = self._extract_citations(top_chunks)
|
||||
|
||||
# Calculate calibrated confidence
|
||||
calibrated_confidence = self._calculate_confidence(top_chunks)
|
||||
|
||||
return {
|
||||
"chunks": top_chunks,
|
||||
"citations": citations,
|
||||
"kg_hints": kg_hints,
|
||||
"calibrated_confidence": calibrated_confidence,
|
||||
}
|
||||
|
||||
def _build_filter(
|
||||
self, tax_year: str | None = None, jurisdiction: str | None = None
|
||||
) -> Filter | None:
|
||||
"""Build Qdrant filter conditions"""
|
||||
conditions = []
|
||||
|
||||
if jurisdiction:
|
||||
conditions.append(
|
||||
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
|
||||
)
|
||||
|
||||
if tax_year:
|
||||
conditions.append(
|
||||
FieldCondition(key="tax_years", match=MatchValue(value=tax_year))
|
||||
)
|
||||
|
||||
# Always require PII-free content
|
||||
conditions.append(FieldCondition(key="pii_free", match=MatchValue(value=True)))
|
||||
|
||||
if conditions:
|
||||
return Filter(must=conditions) # type: ignore
|
||||
return None
|
||||
|
||||
async def _rerank_chunks( # pylint: disable=unused-argument
|
||||
self, query: str, chunks: list[dict[str, Any]], k: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Rerank chunks using cross-encoder model"""
|
||||
try:
|
||||
# This would integrate with a reranking service
|
||||
# For now, return original chunks
|
||||
logger.debug("Reranking not implemented, returning original order")
|
||||
return chunks
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.warning("Reranking failed, using original order", error=str(e))
|
||||
return chunks
|
||||
|
||||
async def _get_kg_hints( # pylint: disable=unused-argument
|
||||
self, query: str, chunks: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get knowledge graph hints related to the query"""
|
||||
try:
|
||||
# Extract potential rule/formula references from chunks
|
||||
hints = []
|
||||
|
||||
for chunk in chunks:
|
||||
payload = chunk.get("payload", {})
|
||||
topic_tags = payload.get("topic_tags", [])
|
||||
|
||||
# Look for tax rules related to the topics
|
||||
if topic_tags and self.neo4j_client:
|
||||
kg_query = """
|
||||
MATCH (r:Rule)-[:APPLIES_TO]->(topic)
|
||||
WHERE topic.name IN $topics
|
||||
AND r.retracted_at IS NULL
|
||||
RETURN r.rule_id as rule_id,
|
||||
r.formula as formula_id,
|
||||
collect(id(topic)) as node_ids
|
||||
LIMIT 5
|
||||
"""
|
||||
|
||||
kg_results = await self.neo4j_client.run_query(
|
||||
kg_query, {"topics": topic_tags}
|
||||
)
|
||||
|
||||
for result in kg_results:
|
||||
hints.append(
|
||||
{
|
||||
"rule_id": result["rule_id"],
|
||||
"formula_id": result["formula_id"],
|
||||
"node_ids": result["node_ids"],
|
||||
}
|
||||
)
|
||||
|
||||
return hints[:5] # Limit to top 5 hints
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.warning("Failed to get KG hints", error=str(e))
|
||||
return []
|
||||
|
||||
def _extract_citations(self, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Extract citation information from chunks"""
|
||||
citations = []
|
||||
seen_docs = set()
|
||||
|
||||
for chunk in chunks:
|
||||
payload = chunk.get("payload", {})
|
||||
|
||||
# Extract document reference
|
||||
doc_id = payload.get("doc_id")
|
||||
url = payload.get("url")
|
||||
section_id = payload.get("section_id")
|
||||
page = payload.get("page")
|
||||
bbox = payload.get("bbox")
|
||||
|
||||
# Create citation key to avoid duplicates
|
||||
citation_key = doc_id or url
|
||||
if citation_key and citation_key not in seen_docs:
|
||||
citation = {}
|
||||
|
||||
if doc_id:
|
||||
citation["doc_id"] = doc_id
|
||||
if url:
|
||||
citation["url"] = url
|
||||
if section_id:
|
||||
citation["section_id"] = section_id
|
||||
if page:
|
||||
citation["page"] = page
|
||||
if bbox:
|
||||
citation["bbox"] = bbox
|
||||
|
||||
citations.append(citation)
|
||||
seen_docs.add(citation_key)
|
||||
|
||||
return citations
|
||||
|
||||
def _calculate_confidence(self, chunks: list[dict[str, Any]]) -> float:
|
||||
"""Calculate calibrated confidence score"""
|
||||
if not chunks:
|
||||
return 0.0
|
||||
|
||||
# Simple confidence calculation based on top scores
|
||||
top_scores = [chunk["score"] for chunk in chunks[:3]]
|
||||
|
||||
if not top_scores:
|
||||
return 0.0
|
||||
|
||||
# Average of top 3 scores with diminishing returns
|
||||
weights = [0.5, 0.3, 0.2]
|
||||
weighted_score = sum(
|
||||
score * weight
|
||||
for score, weight in zip(
|
||||
top_scores, weights[: len(top_scores)], strict=False
|
||||
)
|
||||
)
|
||||
|
||||
# Apply calibration (simple temperature scaling)
|
||||
# In production, this would use learned calibration parameters
|
||||
temperature = 1.2
|
||||
calibrated = weighted_score / temperature
|
||||
|
||||
return min(max(calibrated, 0.0), 1.0) # type: ignore
|
||||
44
libs/rag/utils.py
Normal file
44
libs/rag/utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Coverage-specific RAG utility functions."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
from libs.schemas.coverage.evaluation import Citation
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def rag_search_for_citations(
|
||||
rag_client: Any, query: str, filters: dict[str, Any] | None = None
|
||||
) -> list["Citation"]:
|
||||
"""Search for citations using RAG with PII-free filtering"""
|
||||
|
||||
try:
|
||||
# Ensure PII-free filter is always applied
|
||||
search_filters = filters or {}
|
||||
search_filters["pii_free"] = True
|
||||
|
||||
# This would integrate with the actual RAG retrieval system
|
||||
# For now, return a placeholder implementation
|
||||
logger.debug(
|
||||
"RAG citation search called",
|
||||
query=query,
|
||||
filters=search_filters,
|
||||
rag_client_available=rag_client is not None,
|
||||
)
|
||||
|
||||
# Placeholder citations - in production this would call the RAG system
|
||||
citations = [
|
||||
Citation(
|
||||
doc_id=f"RAG-{query.replace(' ', '-')[:20]}",
|
||||
locator="Retrieved via RAG search",
|
||||
url=f"https://guidance.example.com/search?q={query}",
|
||||
)
|
||||
]
|
||||
|
||||
return citations
|
||||
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error("RAG citation search failed", query=query, error=str(e))
|
||||
return []
|
||||
38
libs/requirements-base.txt
Normal file
38
libs/requirements-base.txt
Normal file
@@ -0,0 +1,38 @@
|
||||
# Core framework dependencies (Required by all services)
|
||||
fastapi>=0.118.0
|
||||
uvicorn[standard]>=0.37.0
|
||||
pydantic>=2.11.9
|
||||
pydantic-settings>=2.11.0
|
||||
|
||||
# Database drivers (lightweight)
|
||||
sqlalchemy>=2.0.43
|
||||
asyncpg>=0.30.0
|
||||
psycopg2-binary>=2.9.10
|
||||
neo4j>=6.0.2
|
||||
redis[hiredis]>=6.4.0
|
||||
|
||||
# Object storage and vector database
|
||||
minio>=7.2.18
|
||||
qdrant-client>=1.15.1
|
||||
|
||||
# Event streaming (NATS only - removed Kafka)
|
||||
nats-py>=2.11.0
|
||||
|
||||
# Security and secrets management
|
||||
hvac>=2.3.0
|
||||
cryptography>=46.0.2
|
||||
|
||||
# Observability and monitoring (minimal)
|
||||
prometheus-client>=0.23.1
|
||||
prometheus-fastapi-instrumentator>=7.1.0
|
||||
structlog>=25.4.0
|
||||
|
||||
# HTTP client
|
||||
httpx>=0.28.1
|
||||
|
||||
# Utilities
|
||||
ulid-py>=1.1.0
|
||||
python-multipart>=0.0.20
|
||||
python-dateutil>=2.9.0
|
||||
python-dotenv>=1.1.1
|
||||
orjson>=3.11.3
|
||||
30
libs/requirements-dev.txt
Normal file
30
libs/requirements-dev.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
# Development dependencies (NOT included in Docker images)
|
||||
|
||||
# Type checking
|
||||
mypy>=1.7.0
|
||||
types-redis>=4.6.0
|
||||
types-requests>=2.31.0
|
||||
|
||||
# Testing utilities
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.21.0
|
||||
pytest-minio-mock>=0.4
|
||||
pytest-cov>=4.1.0
|
||||
hypothesis>=6.88.0
|
||||
|
||||
# Code quality
|
||||
ruff>=0.1.0
|
||||
black>=23.11.0
|
||||
isort>=5.12.0
|
||||
bandit>=1.7.0
|
||||
safety>=2.3.0
|
||||
|
||||
# OpenTelemetry instrumentation (development only)
|
||||
opentelemetry-api>=1.21.0
|
||||
opentelemetry-sdk>=1.21.0
|
||||
opentelemetry-exporter-otlp-proto-grpc>=1.21.0
|
||||
opentelemetry-instrumentation-fastapi>=0.42b0
|
||||
opentelemetry-instrumentation-httpx>=0.42b0
|
||||
opentelemetry-instrumentation-psycopg2>=0.42b0
|
||||
opentelemetry-instrumentation-redis>=0.42b0
|
||||
|
||||
20
libs/requirements-ml.txt
Normal file
20
libs/requirements-ml.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
# ML and AI libraries (ONLY for services that need them)
|
||||
# WARNING: These are HEAVY dependencies - only include in services that absolutely need them
|
||||
|
||||
# Sentence transformers (includes PyTorch - ~2GB)
|
||||
sentence-transformers>=5.1.1
|
||||
|
||||
# Transformers library (includes PyTorch - ~1GB)
|
||||
transformers>=4.57.0
|
||||
|
||||
# Traditional ML (lighter than deep learning)
|
||||
scikit-learn>=1.7.2
|
||||
numpy>=2.3.3
|
||||
|
||||
# NLP libraries
|
||||
spacy>=3.8.7
|
||||
nltk>=3.9.2
|
||||
|
||||
# Text processing
|
||||
fuzzywuzzy>=0.18.0
|
||||
python-Levenshtein>=0.27.1
|
||||
5
libs/requirements-pdf.txt
Normal file
5
libs/requirements-pdf.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
# PDF processing libraries (only for services that need them)
|
||||
pdfrw>=0.4
|
||||
reportlab>=4.4.4
|
||||
PyPDF2>=3.0.1
|
||||
pdfplumber>=0.11.7
|
||||
3
libs/requirements-rdf.txt
Normal file
3
libs/requirements-rdf.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
# RDF and semantic web libraries (only for KG service)
|
||||
pyshacl>=0.30.1
|
||||
rdflib>=7.2.1
|
||||
10
libs/requirements.txt
Normal file
10
libs/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
# DEPRECATED: This file is kept for backward compatibility
|
||||
# Use the split requirements files instead:
|
||||
# - requirements-base.txt: Core dependencies (use in all services)
|
||||
# - requirements-ml.txt: ML/AI dependencies (use only in ML services)
|
||||
# - requirements-pdf.txt: PDF processing (use only in services that process PDFs)
|
||||
# - requirements-rdf.txt: RDF/semantic web (use only in KG service)
|
||||
# - requirements-dev.txt: Development dependencies (NOT in Docker images)
|
||||
|
||||
# For backward compatibility, include base requirements
|
||||
-r requirements-base.txt
|
||||
175
libs/schemas/__init__.py
Normal file
175
libs/schemas/__init__.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Shared Pydantic models mirroring ontology entities."""
|
||||
|
||||
# Import all enums
|
||||
# Import coverage models
|
||||
from .coverage.core import (
|
||||
CompiledCoveragePolicy,
|
||||
ConflictRules,
|
||||
CoveragePolicy,
|
||||
CrossCheck,
|
||||
Defaults,
|
||||
EvidenceItem,
|
||||
GuidanceRef,
|
||||
Privacy,
|
||||
QuestionTemplates,
|
||||
SchedulePolicy,
|
||||
StatusClassifier,
|
||||
StatusClassifierConfig,
|
||||
TaxYearBoundary,
|
||||
Trigger,
|
||||
Validity,
|
||||
)
|
||||
from .coverage.evaluation import (
|
||||
BlockingItem,
|
||||
Citation,
|
||||
ClarifyContext,
|
||||
ClarifyResponse,
|
||||
CoverageGap,
|
||||
CoverageItem,
|
||||
CoverageReport,
|
||||
FoundEvidence,
|
||||
ScheduleCoverage,
|
||||
UploadOption,
|
||||
)
|
||||
from .coverage.utils import CoverageAudit, PolicyError, PolicyVersion, ValidationResult
|
||||
|
||||
# Import all entities
|
||||
from .entities import (
|
||||
Account,
|
||||
BaseEntity,
|
||||
Calculation,
|
||||
Document,
|
||||
Evidence,
|
||||
ExpenseItem,
|
||||
FormBox,
|
||||
IncomeItem,
|
||||
Party,
|
||||
Payment,
|
||||
PropertyAsset,
|
||||
Rule,
|
||||
TaxpayerProfile,
|
||||
)
|
||||
from .enums import (
|
||||
DocumentKind,
|
||||
ExpenseType,
|
||||
HealthStatus,
|
||||
IncomeType,
|
||||
OverallStatus,
|
||||
PartySubtype,
|
||||
PropertyUsage,
|
||||
Role,
|
||||
Status,
|
||||
TaxpayerType,
|
||||
)
|
||||
|
||||
# Import error models
|
||||
from .errors import ErrorResponse, ValidationError, ValidationErrorResponse
|
||||
|
||||
# Import health models
|
||||
from .health import HealthCheck, ServiceHealth
|
||||
|
||||
# Import request models
|
||||
from .requests import (
|
||||
DocumentUploadRequest,
|
||||
ExtractionRequest,
|
||||
FirmSyncRequest,
|
||||
HMRCSubmissionRequest,
|
||||
RAGSearchRequest,
|
||||
ScheduleComputeRequest,
|
||||
)
|
||||
|
||||
# Import response models
|
||||
from .responses import (
|
||||
DocumentUploadResponse,
|
||||
ExtractionResponse,
|
||||
FirmSyncResponse,
|
||||
HMRCSubmissionResponse,
|
||||
RAGSearchResponse,
|
||||
ScheduleComputeResponse,
|
||||
)
|
||||
|
||||
# Import utility functions
|
||||
from .utils import get_entity_schemas
|
||||
|
||||
__all__ = [
|
||||
# Enums
|
||||
"DocumentKind",
|
||||
"ExpenseType",
|
||||
"HealthStatus",
|
||||
"IncomeType",
|
||||
"OverallStatus",
|
||||
"PartySubtype",
|
||||
"PropertyUsage",
|
||||
"Role",
|
||||
"Status",
|
||||
"TaxpayerType",
|
||||
# Entities
|
||||
"Account",
|
||||
"BaseEntity",
|
||||
"Calculation",
|
||||
"Document",
|
||||
"Evidence",
|
||||
"ExpenseItem",
|
||||
"FormBox",
|
||||
"IncomeItem",
|
||||
"Party",
|
||||
"Payment",
|
||||
"PropertyAsset",
|
||||
"Rule",
|
||||
"TaxpayerProfile",
|
||||
# Errors
|
||||
"ErrorResponse",
|
||||
"ValidationError",
|
||||
"ValidationErrorResponse",
|
||||
# Health
|
||||
"HealthCheck",
|
||||
"ServiceHealth",
|
||||
# Requests
|
||||
"DocumentUploadRequest",
|
||||
"ExtractionRequest",
|
||||
"FirmSyncRequest",
|
||||
"HMRCSubmissionRequest",
|
||||
"RAGSearchRequest",
|
||||
"ScheduleComputeRequest",
|
||||
# Responses
|
||||
"DocumentUploadResponse",
|
||||
"ExtractionResponse",
|
||||
"FirmSyncResponse",
|
||||
"HMRCSubmissionResponse",
|
||||
"RAGSearchResponse",
|
||||
"ScheduleComputeResponse",
|
||||
# Utils
|
||||
"get_entity_schemas",
|
||||
# Coverage core models
|
||||
"Validity",
|
||||
"StatusClassifier",
|
||||
"StatusClassifierConfig",
|
||||
"EvidenceItem",
|
||||
"CrossCheck",
|
||||
"SchedulePolicy",
|
||||
"Trigger",
|
||||
"GuidanceRef",
|
||||
"QuestionTemplates",
|
||||
"ConflictRules",
|
||||
"TaxYearBoundary",
|
||||
"Defaults",
|
||||
"Privacy",
|
||||
"CoveragePolicy",
|
||||
"CompiledCoveragePolicy",
|
||||
# Coverage evaluation models
|
||||
"FoundEvidence",
|
||||
"Citation",
|
||||
"CoverageItem",
|
||||
"ScheduleCoverage",
|
||||
"BlockingItem",
|
||||
"CoverageReport",
|
||||
"CoverageGap",
|
||||
"ClarifyContext",
|
||||
"UploadOption",
|
||||
"ClarifyResponse",
|
||||
# Coverage utility models
|
||||
"PolicyError",
|
||||
"ValidationResult",
|
||||
"PolicyVersion",
|
||||
"CoverageAudit",
|
||||
]
|
||||
0
libs/schemas/coverage/__init__.py
Normal file
0
libs/schemas/coverage/__init__.py
Normal file
146
libs/schemas/coverage/core.py
Normal file
146
libs/schemas/coverage/core.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Core coverage policy models."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..enums import Role
|
||||
|
||||
|
||||
class Validity(BaseModel):
|
||||
"""Validity constraints for evidence"""
|
||||
|
||||
within_tax_year: bool = False
|
||||
available_by: str | None = None
|
||||
date_tolerance_days: int = 30
|
||||
|
||||
|
||||
class StatusClassifier(BaseModel):
|
||||
"""Rules for classifying evidence status"""
|
||||
|
||||
min_ocr: float = 0.82
|
||||
min_extract: float = 0.85
|
||||
date_in_year: bool = True
|
||||
date_in_year_or_tolerance: bool = True
|
||||
conflict_rules: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StatusClassifierConfig(BaseModel):
|
||||
"""Complete status classifier configuration"""
|
||||
|
||||
present_verified: StatusClassifier
|
||||
present_unverified: StatusClassifier
|
||||
conflicting: StatusClassifier
|
||||
missing: StatusClassifier = Field(default_factory=lambda: StatusClassifier())
|
||||
|
||||
|
||||
class EvidenceItem(BaseModel):
|
||||
"""Evidence requirement definition"""
|
||||
|
||||
id: str
|
||||
role: Role
|
||||
condition: str | None = None
|
||||
boxes: list[str] = Field(default_factory=list)
|
||||
acceptable_alternatives: list[str] = Field(default_factory=list)
|
||||
validity: Validity = Field(default_factory=Validity)
|
||||
reasons: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CrossCheck(BaseModel):
|
||||
"""Cross-validation rule"""
|
||||
|
||||
name: str
|
||||
logic: str
|
||||
|
||||
|
||||
class SchedulePolicy(BaseModel):
|
||||
"""Policy for a specific tax schedule"""
|
||||
|
||||
guidance_hint: str | None = None
|
||||
evidence: list[EvidenceItem] = Field(default_factory=list)
|
||||
cross_checks: list[CrossCheck] = Field(default_factory=list)
|
||||
selection_rule: dict[str, str] = Field(default_factory=dict)
|
||||
notes: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Trigger(BaseModel):
|
||||
"""Schedule trigger condition"""
|
||||
|
||||
any_of: list[str] = Field(default_factory=list)
|
||||
all_of: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class GuidanceRef(BaseModel):
|
||||
"""Reference to guidance document"""
|
||||
|
||||
doc_id: str
|
||||
kind: str
|
||||
|
||||
|
||||
class QuestionTemplates(BaseModel):
|
||||
"""Templates for generating clarifying questions"""
|
||||
|
||||
default: dict[str, str] = Field(default_factory=dict)
|
||||
reasons: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ConflictRules(BaseModel):
|
||||
"""Rules for handling conflicting evidence"""
|
||||
|
||||
precedence: list[str] = Field(default_factory=list)
|
||||
escalation: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TaxYearBoundary(BaseModel):
|
||||
"""Tax year date boundaries"""
|
||||
|
||||
start: str
|
||||
end: str
|
||||
|
||||
|
||||
class Defaults(BaseModel):
|
||||
"""Default configuration values"""
|
||||
|
||||
confidence_thresholds: dict[str, float] = Field(default_factory=dict)
|
||||
date_tolerance_days: int = 30
|
||||
require_lineage_bbox: bool = True
|
||||
allow_bank_substantiation: bool = True
|
||||
|
||||
|
||||
class Privacy(BaseModel):
|
||||
"""Privacy and PII handling configuration"""
|
||||
|
||||
vector_pii_free: bool = True
|
||||
redact_patterns: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CoveragePolicy(BaseModel):
|
||||
"""Complete coverage policy definition"""
|
||||
|
||||
version: str
|
||||
jurisdiction: str
|
||||
tax_year: str
|
||||
tax_year_boundary: TaxYearBoundary
|
||||
defaults: Defaults
|
||||
document_kinds: list[str] = Field(default_factory=list)
|
||||
guidance_refs: dict[str, GuidanceRef] = Field(default_factory=dict)
|
||||
triggers: dict[str, Trigger] = Field(default_factory=dict)
|
||||
schedules: dict[str, SchedulePolicy] = Field(default_factory=dict)
|
||||
status_classifier: StatusClassifierConfig
|
||||
conflict_resolution: ConflictRules
|
||||
question_templates: QuestionTemplates
|
||||
privacy: Privacy
|
||||
|
||||
|
||||
class CompiledCoveragePolicy(BaseModel):
|
||||
"""Coverage policy with compiled predicates"""
|
||||
|
||||
policy: CoveragePolicy
|
||||
compiled_predicates: dict[str, Callable[[str, str], bool]] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
compiled_at: datetime
|
||||
hash: str
|
||||
source_files: list[str] = Field(default_factory=list)
|
||||
112
libs/schemas/coverage/evaluation.py
Normal file
112
libs/schemas/coverage/evaluation.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Coverage evaluation models."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..enums import OverallStatus, Role, Status
|
||||
|
||||
|
||||
class FoundEvidence(BaseModel):
|
||||
"""Evidence found in the knowledge graph"""
|
||||
|
||||
doc_id: str
|
||||
kind: str
|
||||
confidence: float = 0.0
|
||||
pages: list[int] = Field(default_factory=list)
|
||||
bbox: dict[str, float] | None = None
|
||||
ocr_confidence: float = 0.0
|
||||
extract_confidence: float = 0.0
|
||||
date: str | None = None
|
||||
|
||||
|
||||
class Citation(BaseModel):
|
||||
"""Citation reference"""
|
||||
|
||||
rule_id: str | None = None
|
||||
doc_id: str | None = None
|
||||
url: str | None = None
|
||||
locator: str | None = None
|
||||
section_id: str | None = None
|
||||
page: int | None = None
|
||||
bbox: dict[str, float] | None = None
|
||||
|
||||
|
||||
class CoverageItem(BaseModel):
|
||||
"""Coverage evaluation for a single evidence item"""
|
||||
|
||||
id: str
|
||||
role: Role
|
||||
status: Status
|
||||
boxes: list[str] = Field(default_factory=list)
|
||||
found: list[FoundEvidence] = Field(default_factory=list)
|
||||
acceptable_alternatives: list[str] = Field(default_factory=list)
|
||||
reason: str = ""
|
||||
citations: list[Citation] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScheduleCoverage(BaseModel):
|
||||
"""Coverage evaluation for a schedule"""
|
||||
|
||||
schedule_id: str
|
||||
status: OverallStatus
|
||||
evidence: list[CoverageItem] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BlockingItem(BaseModel):
|
||||
"""Item that blocks completion"""
|
||||
|
||||
schedule_id: str
|
||||
evidence_id: str
|
||||
|
||||
|
||||
class CoverageReport(BaseModel):
|
||||
"""Complete coverage evaluation report"""
|
||||
|
||||
tax_year: str
|
||||
taxpayer_id: str
|
||||
schedules_required: list[str] = Field(default_factory=list)
|
||||
overall_status: OverallStatus
|
||||
coverage: list[ScheduleCoverage] = Field(default_factory=list)
|
||||
blocking_items: list[BlockingItem] = Field(default_factory=list)
|
||||
evaluated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
policy_version: str = ""
|
||||
|
||||
|
||||
class CoverageGap(BaseModel):
|
||||
"""Gap in coverage requiring clarification"""
|
||||
|
||||
schedule_id: str
|
||||
evidence_id: str
|
||||
role: Role
|
||||
reason: str
|
||||
boxes: list[str] = Field(default_factory=list)
|
||||
citations: list[Citation] = Field(default_factory=list)
|
||||
acceptable_alternatives: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ClarifyContext(BaseModel):
|
||||
"""Context for clarifying question"""
|
||||
|
||||
tax_year: str
|
||||
taxpayer_id: str
|
||||
jurisdiction: str
|
||||
|
||||
|
||||
class UploadOption(BaseModel):
|
||||
"""Upload option for user"""
|
||||
|
||||
label: str
|
||||
accepted_formats: list[str] = Field(default_factory=list)
|
||||
upload_endpoint: str
|
||||
|
||||
|
||||
class ClarifyResponse(BaseModel):
|
||||
"""Response to clarifying question request"""
|
||||
|
||||
question_text: str
|
||||
why_it_is_needed: str
|
||||
citations: list[Citation] = Field(default_factory=list)
|
||||
options_to_provide: list[UploadOption] = Field(default_factory=list)
|
||||
blocking: bool = False
|
||||
boxes_affected: list[str] = Field(default_factory=list)
|
||||
48
libs/schemas/coverage/utils.py
Normal file
48
libs/schemas/coverage/utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Utility models for coverage system."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..enums import OverallStatus
|
||||
|
||||
|
||||
class PolicyError(Exception):
|
||||
"""Policy loading or validation error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ValidationResult(BaseModel):
|
||||
"""Policy validation result"""
|
||||
|
||||
ok: bool
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PolicyVersion(BaseModel):
|
||||
"""Policy version record"""
|
||||
|
||||
id: int | None = None
|
||||
version: str
|
||||
jurisdiction: str
|
||||
tax_year: str
|
||||
tenant_id: str | None = None
|
||||
source_files: list[str] = Field(default_factory=list)
|
||||
compiled_at: datetime
|
||||
hash: str
|
||||
|
||||
|
||||
class CoverageAudit(BaseModel):
|
||||
"""Coverage audit record"""
|
||||
|
||||
id: int | None = None
|
||||
taxpayer_id: str
|
||||
tax_year: str
|
||||
policy_version: str
|
||||
overall_status: OverallStatus
|
||||
blocking_items: list[dict[str, Any]] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
trace_id: str | None = None
|
||||
230
libs/schemas/entities.py
Normal file
230
libs/schemas/entities.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""Core business entities with temporal modeling."""
|
||||
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from .enums import (
|
||||
DocumentKind,
|
||||
ExpenseType,
|
||||
IncomeType,
|
||||
PartySubtype,
|
||||
PropertyUsage,
|
||||
TaxpayerType,
|
||||
)
|
||||
|
||||
|
||||
class BaseEntity(BaseModel):
|
||||
"""Base entity with temporal fields"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
str_strip_whitespace=True, validate_assignment=True, use_enum_values=True
|
||||
)
|
||||
|
||||
# Temporal fields (bitemporal modeling)
|
||||
valid_from: datetime = Field(
|
||||
..., description="When the fact became valid in reality"
|
||||
)
|
||||
valid_to: datetime | None = Field(
|
||||
None, description="When the fact ceased to be valid"
|
||||
)
|
||||
asserted_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, description="When recorded in system"
|
||||
)
|
||||
retracted_at: datetime | None = Field(
|
||||
None, description="When retracted from system"
|
||||
)
|
||||
source: str = Field(..., description="Source of the information")
|
||||
extractor_version: str = Field(..., description="Version of extraction system")
|
||||
|
||||
|
||||
class TaxpayerProfile(BaseEntity):
|
||||
"""Taxpayer profile entity"""
|
||||
|
||||
taxpayer_id: str = Field(..., description="Unique taxpayer identifier")
|
||||
type: TaxpayerType = Field(..., description="Type of taxpayer")
|
||||
utr: str | None = Field(
|
||||
None, pattern=r"^\d{10}$", description="Unique Taxpayer Reference"
|
||||
)
|
||||
ni_number: str | None = Field(
|
||||
None,
|
||||
pattern=r"^[A-CEGHJ-PR-TW-Z]{2}\d{6}[A-D]$",
|
||||
description="National Insurance Number",
|
||||
)
|
||||
residence: str | None = Field(None, description="Tax residence")
|
||||
|
||||
|
||||
class Document(BaseEntity):
|
||||
"""Document entity"""
|
||||
|
||||
doc_id: str = Field(
|
||||
..., pattern=r"^doc_[a-f0-9]{16}$", description="Document identifier"
|
||||
)
|
||||
kind: DocumentKind = Field(..., description="Type of document")
|
||||
source: str = Field(..., description="Source of document")
|
||||
mime: str = Field(..., description="MIME type")
|
||||
checksum: str = Field(
|
||||
..., pattern=r"^[a-f0-9]{64}$", description="SHA-256 checksum"
|
||||
)
|
||||
file_size: int | None = Field(None, ge=0, description="File size in bytes")
|
||||
pages: int | None = Field(None, ge=1, description="Number of pages")
|
||||
date_range: dict[str, date] | None = Field(None, description="Document date range")
|
||||
|
||||
|
||||
class Evidence(BaseEntity):
|
||||
"""Evidence entity linking to document snippets"""
|
||||
|
||||
snippet_id: str = Field(..., description="Evidence snippet identifier")
|
||||
doc_ref: str = Field(..., description="Reference to source document")
|
||||
page: int = Field(..., ge=1, description="Page number")
|
||||
bbox: list[float] | None = Field(
|
||||
None, description="Bounding box coordinates [x1, y1, x2, y2]"
|
||||
)
|
||||
text_hash: str = Field(
|
||||
..., pattern=r"^[a-f0-9]{64}$", description="SHA-256 hash of extracted text"
|
||||
)
|
||||
ocr_confidence: float | None = Field(
|
||||
None, ge=0.0, le=1.0, description="OCR confidence score"
|
||||
)
|
||||
|
||||
|
||||
class IncomeItem(BaseEntity):
|
||||
"""Income item entity"""
|
||||
|
||||
income_id: str = Field(..., description="Income item identifier")
|
||||
type: IncomeType = Field(..., description="Type of income")
|
||||
gross: Decimal = Field(..., ge=0, description="Gross amount")
|
||||
net: Decimal | None = Field(None, ge=0, description="Net amount")
|
||||
tax_withheld: Decimal | None = Field(None, ge=0, description="Tax withheld")
|
||||
currency: str = Field(..., pattern=r"^[A-Z]{3}$", description="Currency code")
|
||||
period_start: date | None = Field(None, description="Income period start")
|
||||
period_end: date | None = Field(None, description="Income period end")
|
||||
description: str | None = Field(None, description="Income description")
|
||||
|
||||
|
||||
class ExpenseItem(BaseEntity):
|
||||
"""Expense item entity"""
|
||||
|
||||
expense_id: str = Field(..., description="Expense item identifier")
|
||||
type: ExpenseType = Field(..., description="Type of expense")
|
||||
amount: Decimal = Field(..., ge=0, description="Expense amount")
|
||||
currency: str = Field(..., pattern=r"^[A-Z]{3}$", description="Currency code")
|
||||
description: str | None = Field(None, description="Expense description")
|
||||
category: str | None = Field(None, description="Expense category")
|
||||
allowable: bool | None = Field(None, description="Whether expense is allowable")
|
||||
capitalizable_flag: bool | None = Field(
|
||||
None, description="Whether expense should be capitalized"
|
||||
)
|
||||
vat_amount: Decimal | None = Field(None, ge=0, description="VAT amount")
|
||||
net_amount: Decimal | None = Field(
|
||||
None, ge=0, description="Net amount excluding VAT"
|
||||
)
|
||||
|
||||
|
||||
class Party(BaseEntity):
|
||||
"""Party entity (person or organization)"""
|
||||
|
||||
party_id: str = Field(..., description="Party identifier")
|
||||
name: str = Field(..., min_length=1, description="Party name")
|
||||
subtype: PartySubtype | None = Field(None, description="Party subtype")
|
||||
address: str | None = Field(None, description="Party address")
|
||||
vat_number: str | None = Field(
|
||||
None, pattern=r"^GB\d{9}$|^GB\d{12}$", description="UK VAT number"
|
||||
)
|
||||
utr: str | None = Field(
|
||||
None, pattern=r"^\d{10}$", description="Unique Taxpayer Reference"
|
||||
)
|
||||
reg_no: str | None = Field(None, description="Registration number")
|
||||
paye_reference: str | None = Field(None, description="PAYE reference")
|
||||
|
||||
|
||||
class Account(BaseEntity):
|
||||
"""Bank account entity"""
|
||||
|
||||
account_id: str = Field(..., description="Account identifier")
|
||||
iban: str | None = Field(
|
||||
None, pattern=r"^GB\d{2}[A-Z]{4}\d{14}$", description="UK IBAN"
|
||||
)
|
||||
sort_code: str | None = Field(
|
||||
None, pattern=r"^\d{2}-\d{2}-\d{2}$", description="Sort code"
|
||||
)
|
||||
account_no: str | None = Field(
|
||||
None, pattern=r"^\d{8}$", description="Account number"
|
||||
)
|
||||
institution: str | None = Field(None, description="Financial institution")
|
||||
account_type: str | None = Field(None, description="Account type")
|
||||
currency: str = Field(default="GBP", description="Account currency")
|
||||
|
||||
|
||||
class PropertyAsset(BaseEntity):
|
||||
"""Property asset entity"""
|
||||
|
||||
property_id: str = Field(..., description="Property identifier")
|
||||
address: str = Field(..., min_length=10, description="Property address")
|
||||
postcode: str | None = Field(
|
||||
None, pattern=r"^[A-Z]{1,2}\d[A-Z0-9]?\s*\d[A-Z]{2}$", description="UK postcode"
|
||||
)
|
||||
tenure: str | None = Field(None, description="Property tenure")
|
||||
ownership_share: float | None = Field(
|
||||
None, ge=0.0, le=1.0, description="Ownership share"
|
||||
)
|
||||
usage: PropertyUsage | None = Field(None, description="Property usage type")
|
||||
|
||||
|
||||
class Payment(BaseEntity):
|
||||
"""Payment transaction entity"""
|
||||
|
||||
payment_id: str = Field(..., description="Payment identifier")
|
||||
payment_date: date = Field(..., description="Payment date")
|
||||
amount: Decimal = Field(
|
||||
..., description="Payment amount (positive for credit, negative for debit)"
|
||||
)
|
||||
currency: str = Field(..., pattern=r"^[A-Z]{3}$", description="Currency code")
|
||||
direction: str = Field(..., description="Payment direction (credit/debit)")
|
||||
description: str | None = Field(None, description="Payment description")
|
||||
reference: str | None = Field(None, description="Payment reference")
|
||||
balance_after: Decimal | None = Field(
|
||||
None, description="Account balance after payment"
|
||||
)
|
||||
|
||||
|
||||
class Calculation(BaseEntity):
|
||||
"""Tax calculation entity"""
|
||||
|
||||
calculation_id: str = Field(..., description="Calculation identifier")
|
||||
schedule: str = Field(..., description="Tax schedule (SA100, SA103, etc.)")
|
||||
tax_year: str = Field(
|
||||
..., pattern=r"^\d{4}-\d{2}$", description="Tax year (e.g., 2023-24)"
|
||||
)
|
||||
total_income: Decimal | None = Field(None, ge=0, description="Total income")
|
||||
total_expenses: Decimal | None = Field(None, ge=0, description="Total expenses")
|
||||
net_profit: Decimal | None = Field(None, description="Net profit/loss")
|
||||
calculated_at: datetime = Field(
|
||||
default_factory=datetime.utcnow, description="Calculation timestamp"
|
||||
)
|
||||
|
||||
|
||||
class FormBox(BaseEntity):
|
||||
"""Form box entity"""
|
||||
|
||||
form: str = Field(..., description="Form identifier (SA100, SA103, etc.)")
|
||||
box: str = Field(..., description="Box identifier")
|
||||
value: Decimal | str | bool = Field(..., description="Box value")
|
||||
description: str | None = Field(None, description="Box description")
|
||||
confidence: float | None = Field(
|
||||
None, ge=0.0, le=1.0, description="Confidence score"
|
||||
)
|
||||
|
||||
|
||||
class Rule(BaseEntity):
|
||||
"""Tax rule entity"""
|
||||
|
||||
rule_id: str = Field(..., description="Rule identifier")
|
||||
name: str = Field(..., description="Rule name")
|
||||
description: str | None = Field(None, description="Rule description")
|
||||
jurisdiction: str = Field(default="UK", description="Tax jurisdiction")
|
||||
tax_years: list[str] = Field(..., description="Applicable tax years")
|
||||
formula: str | None = Field(None, description="Rule formula")
|
||||
conditions: dict[str, Any] | None = Field(None, description="Rule conditions")
|
||||
102
libs/schemas/enums.py
Normal file
102
libs/schemas/enums.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Enumeration types for the tax system."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaxpayerType(str, Enum):
|
||||
"""Taxpayer types"""
|
||||
|
||||
INDIVIDUAL = "Individual"
|
||||
PARTNERSHIP = "Partnership"
|
||||
COMPANY = "Company"
|
||||
|
||||
|
||||
class DocumentKind(str, Enum):
|
||||
"""Document types"""
|
||||
|
||||
BANK_STATEMENT = "bank_statement"
|
||||
INVOICE = "invoice"
|
||||
RECEIPT = "receipt"
|
||||
P_AND_L = "p_and_l"
|
||||
BALANCE_SHEET = "balance_sheet"
|
||||
PAYSLIP = "payslip"
|
||||
DIVIDEND_VOUCHER = "dividend_voucher"
|
||||
PROPERTY_STATEMENT = "property_statement"
|
||||
PRIOR_RETURN = "prior_return"
|
||||
LETTER = "letter"
|
||||
CERTIFICATE = "certificate"
|
||||
|
||||
|
||||
class IncomeType(str, Enum):
|
||||
"""Income types"""
|
||||
|
||||
EMPLOYMENT = "employment"
|
||||
SELF_EMPLOYMENT = "self_employment"
|
||||
PROPERTY = "property"
|
||||
DIVIDEND = "dividend"
|
||||
INTEREST = "interest"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class ExpenseType(str, Enum):
|
||||
"""Expense types"""
|
||||
|
||||
BUSINESS = "business"
|
||||
PROPERTY = "property"
|
||||
CAPITAL = "capital"
|
||||
PERSONAL = "personal"
|
||||
|
||||
|
||||
class PartySubtype(str, Enum):
|
||||
"""Party subtypes"""
|
||||
|
||||
EMPLOYER = "Employer"
|
||||
PAYER = "Payer"
|
||||
BANK = "Bank"
|
||||
LANDLORD = "Landlord"
|
||||
TENANT = "Tenant"
|
||||
SUPPLIER = "Supplier"
|
||||
CLIENT = "Client"
|
||||
|
||||
|
||||
class PropertyUsage(str, Enum):
|
||||
"""Property usage types"""
|
||||
|
||||
RESIDENTIAL = "residential"
|
||||
FURNISHED_HOLIDAY_LETTING = "furnished_holiday_letting"
|
||||
COMMERCIAL = "commercial"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
"""Health status values"""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
UNHEALTHY = "unhealthy"
|
||||
DEGRADED = "degraded"
|
||||
|
||||
|
||||
# Coverage evaluation enums
|
||||
class Role(str, Enum):
|
||||
"""Evidence role in coverage evaluation"""
|
||||
|
||||
REQUIRED = "REQUIRED"
|
||||
CONDITIONALLY_REQUIRED = "CONDITIONALLY_REQUIRED"
|
||||
OPTIONAL = "OPTIONAL"
|
||||
|
||||
|
||||
class Status(str, Enum):
|
||||
"""Evidence status classification"""
|
||||
|
||||
PRESENT_VERIFIED = "present_verified"
|
||||
PRESENT_UNVERIFIED = "present_unverified"
|
||||
MISSING = "missing"
|
||||
CONFLICTING = "conflicting"
|
||||
|
||||
|
||||
class OverallStatus(str, Enum):
|
||||
"""Overall coverage status"""
|
||||
|
||||
OK = "ok"
|
||||
PARTIAL = "partial"
|
||||
BLOCKING = "blocking"
|
||||
30
libs/schemas/errors.py
Normal file
30
libs/schemas/errors.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Error response models."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""RFC7807 Problem+JSON error response"""
|
||||
|
||||
type: str = Field(..., description="Error type URI")
|
||||
title: str = Field(..., description="Error title")
|
||||
status: int = Field(..., description="HTTP status code")
|
||||
detail: str = Field(..., description="Error detail")
|
||||
instance: str = Field(..., description="Error instance URI")
|
||||
trace_id: str | None = Field(None, description="Trace identifier")
|
||||
|
||||
|
||||
class ValidationError(BaseModel):
|
||||
"""Validation error details"""
|
||||
|
||||
field: str = Field(..., description="Field name")
|
||||
message: str = Field(..., description="Error message")
|
||||
value: Any = Field(..., description="Invalid value")
|
||||
|
||||
|
||||
class ValidationErrorResponse(ErrorResponse):
|
||||
"""Validation error response with field details"""
|
||||
|
||||
errors: list[ValidationError] = Field(..., description="Validation errors")
|
||||
32
libs/schemas/health.py
Normal file
32
libs/schemas/health.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Health check models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .enums import HealthStatus
|
||||
|
||||
|
||||
class HealthCheck(BaseModel):
|
||||
"""Health check response"""
|
||||
|
||||
status: HealthStatus = Field(..., description="Overall health status")
|
||||
timestamp: datetime = Field(
|
||||
default_factory=datetime.utcnow, description="Check timestamp"
|
||||
)
|
||||
version: str = Field(..., description="Service version")
|
||||
checks: dict[str, dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Individual checks"
|
||||
)
|
||||
|
||||
|
||||
class ServiceHealth(BaseModel):
|
||||
"""Individual service health status"""
|
||||
|
||||
name: str = Field(..., description="Service name")
|
||||
status: HealthStatus = Field(..., description="Service health status")
|
||||
response_time_ms: float | None = Field(
|
||||
None, description="Response time in milliseconds"
|
||||
)
|
||||
error: str | None = Field(None, description="Error message if unhealthy")
|
||||
65
libs/schemas/requests.py
Normal file
65
libs/schemas/requests.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""API request models."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from .enums import DocumentKind
|
||||
|
||||
|
||||
class DocumentUploadRequest(BaseModel):
|
||||
"""Request model for document upload"""
|
||||
|
||||
tenant_id: str = Field(..., description="Tenant identifier")
|
||||
kind: DocumentKind = Field(..., description="Document type")
|
||||
source: str = Field(..., description="Document source")
|
||||
|
||||
|
||||
class ExtractionRequest(BaseModel):
|
||||
"""Request model for document extraction"""
|
||||
|
||||
strategy: str = Field(default="hybrid", description="Extraction strategy")
|
||||
|
||||
|
||||
class RAGSearchRequest(BaseModel):
|
||||
"""Request model for RAG search"""
|
||||
|
||||
query: str = Field(..., min_length=1, description="Search query")
|
||||
tax_year: str | None = Field(None, description="Tax year filter")
|
||||
jurisdiction: str | None = Field(None, description="Jurisdiction filter")
|
||||
k: int = Field(default=10, ge=1, le=100, description="Number of results")
|
||||
|
||||
|
||||
class ScheduleComputeRequest(BaseModel):
|
||||
"""Request model for schedule computation"""
|
||||
|
||||
tax_year: str = Field(..., pattern=r"^\d{4}-\d{2}$", description="Tax year")
|
||||
taxpayer_id: str = Field(..., description="Taxpayer identifier")
|
||||
schedule_id: str = Field(..., description="Schedule identifier")
|
||||
|
||||
|
||||
class HMRCSubmissionRequest(BaseModel):
|
||||
"""Request model for HMRC submission"""
|
||||
|
||||
tax_year: str = Field(..., pattern=r"^\d{4}-\d{2}$", description="Tax year")
|
||||
taxpayer_id: str = Field(..., description="Taxpayer identifier")
|
||||
dry_run: bool = Field(default=True, description="Dry run flag")
|
||||
|
||||
|
||||
class FirmSyncRequest(BaseModel):
|
||||
"""Request to sync firm data"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
firm_id: str = Field(..., description="Firm identifier")
|
||||
system: str = Field(..., description="Practice management system to sync with")
|
||||
sync_type: str = Field(
|
||||
default="full", description="Type of sync: full, incremental"
|
||||
)
|
||||
force_refresh: bool = Field(
|
||||
default=False, description="Force refresh of cached data"
|
||||
)
|
||||
connection_config: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Configuration for connecting to the practice management system",
|
||||
)
|
||||
69
libs/schemas/responses.py
Normal file
69
libs/schemas/responses.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""API response models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class DocumentUploadResponse(BaseModel):
|
||||
"""Response model for document upload"""
|
||||
|
||||
doc_id: str = Field(..., description="Document identifier")
|
||||
s3_url: str = Field(..., description="S3 URL")
|
||||
checksum: str = Field(..., description="Document checksum")
|
||||
|
||||
|
||||
class ExtractionResponse(BaseModel):
|
||||
"""Response model for document extraction"""
|
||||
|
||||
extraction_id: str = Field(..., description="Extraction identifier")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
|
||||
extracted_fields: dict[str, Any] = Field(..., description="Extracted fields")
|
||||
provenance: list[dict[str, Any]] = Field(..., description="Provenance information")
|
||||
|
||||
|
||||
class RAGSearchResponse(BaseModel):
|
||||
"""Response model for RAG search"""
|
||||
|
||||
chunks: list[dict[str, Any]] = Field(..., description="Retrieved chunks")
|
||||
citations: list[dict[str, Any]] = Field(..., description="Source citations")
|
||||
kg_hints: list[dict[str, Any]] = Field(..., description="Knowledge graph hints")
|
||||
calibrated_confidence: float = Field(
|
||||
..., ge=0.0, le=1.0, description="Calibrated confidence"
|
||||
)
|
||||
|
||||
|
||||
class ScheduleComputeResponse(BaseModel):
|
||||
"""Response model for schedule computation"""
|
||||
|
||||
calculation_id: str = Field(..., description="Calculation identifier")
|
||||
schedule: str = Field(..., description="Schedule identifier")
|
||||
form_boxes: dict[str, dict[str, Any]] = Field(
|
||||
..., description="Computed form boxes"
|
||||
)
|
||||
evidence_trail: list[dict[str, Any]] = Field(..., description="Evidence trail")
|
||||
|
||||
|
||||
class HMRCSubmissionResponse(BaseModel):
|
||||
"""Response model for HMRC submission"""
|
||||
|
||||
submission_id: str = Field(..., description="Submission identifier")
|
||||
status: str = Field(..., description="Submission status")
|
||||
hmrc_reference: str | None = Field(None, description="HMRC reference")
|
||||
submission_timestamp: datetime = Field(..., description="Submission timestamp")
|
||||
validation_results: dict[str, Any] = Field(..., description="Validation results")
|
||||
|
||||
|
||||
class FirmSyncResponse(BaseModel):
|
||||
"""Response from firm sync operation"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
firm_id: str = Field(..., description="Firm identifier")
|
||||
status: str = Field(..., description="Sync status: success, error, partial")
|
||||
message: str = Field(..., description="Status message")
|
||||
synced_entities: int = Field(default=0, description="Number of entities synced")
|
||||
errors: list[str] = Field(
|
||||
default_factory=list, description="List of errors encountered"
|
||||
)
|
||||
69
libs/schemas/utils.py
Normal file
69
libs/schemas/utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Utility functions for schema export."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .entities import (
|
||||
Account,
|
||||
Calculation,
|
||||
Document,
|
||||
Evidence,
|
||||
ExpenseItem,
|
||||
FormBox,
|
||||
IncomeItem,
|
||||
Party,
|
||||
Payment,
|
||||
PropertyAsset,
|
||||
Rule,
|
||||
TaxpayerProfile,
|
||||
)
|
||||
from .requests import (
|
||||
DocumentUploadRequest,
|
||||
ExtractionRequest,
|
||||
FirmSyncRequest,
|
||||
HMRCSubmissionRequest,
|
||||
RAGSearchRequest,
|
||||
ScheduleComputeRequest,
|
||||
)
|
||||
from .responses import (
|
||||
DocumentUploadResponse,
|
||||
ExtractionResponse,
|
||||
FirmSyncResponse,
|
||||
HMRCSubmissionResponse,
|
||||
RAGSearchResponse,
|
||||
ScheduleComputeResponse,
|
||||
)
|
||||
|
||||
|
||||
def get_entity_schemas() -> dict[str, dict[str, Any]]:
|
||||
"""Export JSON schemas for all models"""
|
||||
schemas = {}
|
||||
|
||||
# Core entities
|
||||
schemas["TaxpayerProfile"] = TaxpayerProfile.model_json_schema()
|
||||
schemas["Document"] = Document.model_json_schema()
|
||||
schemas["Evidence"] = Evidence.model_json_schema()
|
||||
schemas["IncomeItem"] = IncomeItem.model_json_schema()
|
||||
schemas["ExpenseItem"] = ExpenseItem.model_json_schema()
|
||||
schemas["Party"] = Party.model_json_schema()
|
||||
schemas["Account"] = Account.model_json_schema()
|
||||
schemas["PropertyAsset"] = PropertyAsset.model_json_schema()
|
||||
schemas["Payment"] = Payment.model_json_schema()
|
||||
schemas["Calculation"] = Calculation.model_json_schema()
|
||||
schemas["FormBox"] = FormBox.model_json_schema()
|
||||
schemas["Rule"] = Rule.model_json_schema()
|
||||
|
||||
# Request/Response models
|
||||
schemas["DocumentUploadRequest"] = DocumentUploadRequest.model_json_schema()
|
||||
schemas["DocumentUploadResponse"] = DocumentUploadResponse.model_json_schema()
|
||||
schemas["ExtractionRequest"] = ExtractionRequest.model_json_schema()
|
||||
schemas["ExtractionResponse"] = ExtractionResponse.model_json_schema()
|
||||
schemas["RAGSearchRequest"] = RAGSearchRequest.model_json_schema()
|
||||
schemas["RAGSearchResponse"] = RAGSearchResponse.model_json_schema()
|
||||
schemas["ScheduleComputeRequest"] = ScheduleComputeRequest.model_json_schema()
|
||||
schemas["ScheduleComputeResponse"] = ScheduleComputeResponse.model_json_schema()
|
||||
schemas["HMRCSubmissionRequest"] = HMRCSubmissionRequest.model_json_schema()
|
||||
schemas["HMRCSubmissionResponse"] = HMRCSubmissionResponse.model_json_schema()
|
||||
schemas["FirmSyncRequest"] = FirmSyncRequest.model_json_schema()
|
||||
schemas["FirmSyncResponse"] = FirmSyncResponse.model_json_schema()
|
||||
|
||||
return schemas
|
||||
26
libs/security/__init__.py
Normal file
26
libs/security/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Security utilities for authentication, authorization, and encryption."""
|
||||
|
||||
from .auth import AuthenticationHeaders
|
||||
from .dependencies import (
|
||||
get_current_tenant,
|
||||
get_current_user,
|
||||
get_tenant_id,
|
||||
require_admin_role,
|
||||
require_reviewer_role,
|
||||
)
|
||||
from .middleware import TrustedProxyMiddleware, create_trusted_proxy_middleware
|
||||
from .utils import is_internal_request
|
||||
from .vault import VaultTransitHelper
|
||||
|
||||
__all__ = [
|
||||
"VaultTransitHelper",
|
||||
"AuthenticationHeaders",
|
||||
"TrustedProxyMiddleware",
|
||||
"is_internal_request",
|
||||
"require_admin_role",
|
||||
"require_reviewer_role",
|
||||
"get_current_tenant",
|
||||
"get_current_user",
|
||||
"get_tenant_id",
|
||||
"create_trusted_proxy_middleware",
|
||||
]
|
||||
61
libs/security/auth.py
Normal file
61
libs/security/auth.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Authentication headers parsing and validation."""
|
||||
|
||||
import structlog
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class AuthenticationHeaders:
|
||||
"""Parse and validate authentication headers from Traefik + Authentik"""
|
||||
|
||||
def __init__(self, request: Request):
|
||||
self.request = request
|
||||
self.headers = request.headers
|
||||
|
||||
@property
|
||||
def authenticated_user(self) -> str | None:
|
||||
"""Get authenticated user from headers"""
|
||||
return self.headers.get("X-Authenticated-User")
|
||||
|
||||
@property
|
||||
def authenticated_email(self) -> str | None:
|
||||
"""Get authenticated email from headers"""
|
||||
return self.headers.get("X-Authenticated-Email")
|
||||
|
||||
@property
|
||||
def authenticated_groups(self) -> list[str]:
|
||||
"""Get authenticated groups from headers"""
|
||||
groups_header = self.headers.get("X-Authenticated-Groups", "")
|
||||
return [g.strip() for g in groups_header.split(",") if g.strip()]
|
||||
|
||||
@property
|
||||
def authorization_token(self) -> str | None:
|
||||
"""Get JWT token from Authorization header"""
|
||||
auth_header = self.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header[7:]
|
||||
return None
|
||||
|
||||
def has_role(self, role: str) -> bool:
|
||||
"""Check if user has specific role"""
|
||||
return role in self.authenticated_groups
|
||||
|
||||
def has_any_role(self, roles: list[str]) -> bool:
|
||||
"""Check if user has any of the specified roles"""
|
||||
return any(role in self.authenticated_groups for role in roles)
|
||||
|
||||
def require_role(self, role: str) -> None:
|
||||
"""Require specific role or raise HTTPException"""
|
||||
if not self.has_role(role):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=f"Role '{role}' required"
|
||||
)
|
||||
|
||||
def require_any_role(self, roles: list[str]) -> None:
|
||||
"""Require any of the specified roles or raise HTTPException"""
|
||||
if not self.has_any_role(roles):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"One of roles {roles} required",
|
||||
)
|
||||
79
libs/security/dependencies.py
Normal file
79
libs/security/dependencies.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""FastAPI dependency functions for authentication and authorization."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
|
||||
def require_admin_role(request: Request) -> None:
|
||||
"""Dependency to require admin role"""
|
||||
auth = getattr(request.state, "auth", None)
|
||||
if not auth:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required"
|
||||
)
|
||||
auth.require_role("admin")
|
||||
|
||||
|
||||
def require_reviewer_role(request: Request) -> None:
|
||||
"""Dependency to require reviewer role"""
|
||||
auth = getattr(request.state, "auth", None)
|
||||
if not auth:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required"
|
||||
)
|
||||
auth.require_any_role(["admin", "reviewer"])
|
||||
|
||||
|
||||
def get_current_tenant(request: Request) -> str | None:
|
||||
"""Extract tenant ID from user context or headers"""
|
||||
# This could be extracted from JWT claims, user groups, or custom headers
|
||||
# For now, we'll use a simple mapping from user to tenant
|
||||
user = getattr(request.state, "user", None)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Simple tenant extraction - in production this would be more sophisticated
|
||||
# Could be from JWT claims, database lookup, or group membership
|
||||
roles = getattr(request.state, "roles", [])
|
||||
for role in roles:
|
||||
if role.startswith("tenant:"):
|
||||
return str(role.split(":", 1)[1])
|
||||
|
||||
# Default tenant for development
|
||||
return "default"
|
||||
|
||||
|
||||
# Dependency functions for FastAPI
|
||||
def get_current_user() -> Callable[[Request], dict[str, Any]]:
|
||||
"""FastAPI dependency to get current user"""
|
||||
|
||||
def _get_current_user(request: Request) -> dict[str, Any]:
|
||||
user = getattr(request.state, "user", None)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
)
|
||||
return {
|
||||
"sub": user,
|
||||
"email": getattr(request.state, "email", ""),
|
||||
"roles": getattr(request.state, "roles", []),
|
||||
}
|
||||
|
||||
return _get_current_user
|
||||
|
||||
|
||||
def get_tenant_id() -> Callable[[Request], str]:
|
||||
"""FastAPI dependency to get tenant ID"""
|
||||
|
||||
def _get_tenant_id(request: Request) -> str:
|
||||
tenant_id = get_current_tenant(request)
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="Tenant ID required"
|
||||
)
|
||||
return tenant_id
|
||||
|
||||
return _get_tenant_id
|
||||
134
libs/security/middleware.py
Normal file
134
libs/security/middleware.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Trusted proxy middleware for authentication validation."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from fastapi import HTTPException, Request, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from .auth import AuthenticationHeaders
|
||||
from .utils import is_internal_request
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrustedProxyMiddleware(
|
||||
BaseHTTPMiddleware
|
||||
): # pylint: disable=too-few-public-methods
|
||||
"""Middleware to validate requests from trusted proxy (Traefik)"""
|
||||
|
||||
def __init__(self, app: Any, internal_cidrs: list[str], disable_auth: bool = False):
|
||||
super().__init__(app)
|
||||
self.internal_cidrs = internal_cidrs
|
||||
self.disable_auth = disable_auth
|
||||
self.public_endpoints = {
|
||||
"/healthz",
|
||||
"/readyz",
|
||||
"/livez",
|
||||
"/metrics",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable[..., Any]) -> Any:
|
||||
"""Process request through middleware"""
|
||||
# Get client IP (considering proxy headers)
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# Check if authentication is disabled (development mode)
|
||||
if self.disable_auth:
|
||||
# Set development state
|
||||
request.state.user = "dev-user"
|
||||
request.state.email = "dev@example.com"
|
||||
request.state.roles = ["developers"]
|
||||
request.state.auth_token = "dev-token"
|
||||
logger.info(
|
||||
"Development mode: authentication disabled", path=request.url.path
|
||||
)
|
||||
return await call_next(request)
|
||||
|
||||
# Check if this is a public endpoint
|
||||
if request.url.path in self.public_endpoints:
|
||||
# For metrics endpoint, still require internal network
|
||||
if request.url.path == "/metrics":
|
||||
if not is_internal_request(client_ip, self.internal_cidrs):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Metrics endpoint only accessible from internal network",
|
||||
)
|
||||
# Set minimal state for public endpoints
|
||||
request.state.user = None
|
||||
request.state.email = None
|
||||
request.state.roles = []
|
||||
return await call_next(request)
|
||||
|
||||
# For protected endpoints, validate authentication headers
|
||||
auth_headers = AuthenticationHeaders(request)
|
||||
|
||||
# Require authentication headers
|
||||
if not auth_headers.authenticated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing X-Authenticated-User header",
|
||||
)
|
||||
|
||||
if not auth_headers.authenticated_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing X-Authenticated-Email header",
|
||||
)
|
||||
|
||||
if not auth_headers.authorization_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing Authorization header",
|
||||
)
|
||||
|
||||
# Set request state
|
||||
request.state.user = auth_headers.authenticated_user
|
||||
request.state.email = auth_headers.authenticated_email
|
||||
request.state.roles = auth_headers.authenticated_groups
|
||||
request.state.auth_token = auth_headers.authorization_token
|
||||
|
||||
# Add authentication helper to request
|
||||
request.state.auth = auth_headers
|
||||
|
||||
logger.info(
|
||||
"Authenticated request",
|
||||
user=auth_headers.authenticated_user,
|
||||
email=auth_headers.authenticated_email,
|
||||
roles=auth_headers.authenticated_groups,
|
||||
path=request.url.path,
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Get client IP considering proxy headers"""
|
||||
# Check X-Forwarded-For header first (set by Traefik)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# Take the first IP in the chain
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Check X-Real-IP header
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct client IP
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def create_trusted_proxy_middleware(
|
||||
internal_cidrs: list[str],
|
||||
) -> Callable[[Any], TrustedProxyMiddleware]:
|
||||
"""Factory function to create TrustedProxyMiddleware"""
|
||||
|
||||
def middleware_factory( # pylint: disable=unused-argument
|
||||
app: Any,
|
||||
) -> TrustedProxyMiddleware:
|
||||
return TrustedProxyMiddleware(app, internal_cidrs)
|
||||
|
||||
return middleware_factory
|
||||
20
libs/security/utils.py
Normal file
20
libs/security/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Security utility functions."""
|
||||
|
||||
import ipaddress
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def is_internal_request(client_ip: str, internal_cidrs: list[str]) -> bool:
|
||||
"""Check if request comes from internal network"""
|
||||
try:
|
||||
client_addr = ipaddress.ip_address(client_ip)
|
||||
for cidr in internal_cidrs:
|
||||
if client_addr in ipaddress.ip_network(cidr):
|
||||
return True
|
||||
return False
|
||||
except ValueError:
|
||||
logger.warning("Invalid client IP address", client_ip=client_ip)
|
||||
return False
|
||||
58
libs/security/vault.py
Normal file
58
libs/security/vault.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Vault Transit encryption/decryption helpers."""
|
||||
|
||||
import base64
|
||||
|
||||
import hvac
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class VaultTransitHelper:
|
||||
"""Helper for Vault Transit encryption/decryption"""
|
||||
|
||||
def __init__(self, vault_client: hvac.Client, mount_point: str = "transit"):
|
||||
self.vault_client = vault_client
|
||||
self.mount_point = mount_point
|
||||
|
||||
def encrypt_field(self, key_name: str, plaintext: str) -> str:
|
||||
"""Encrypt a field using Vault Transit"""
|
||||
try:
|
||||
# Ensure key exists
|
||||
self._ensure_key_exists(key_name)
|
||||
|
||||
# Encrypt the data
|
||||
response = self.vault_client.secrets.transit.encrypt_data(
|
||||
mount_point=self.mount_point,
|
||||
name=key_name,
|
||||
plaintext=base64.b64encode(plaintext.encode()).decode(),
|
||||
)
|
||||
return str(response["data"]["ciphertext"])
|
||||
except Exception as e:
|
||||
logger.error("Failed to encrypt field", key_name=key_name, error=str(e))
|
||||
raise
|
||||
|
||||
def decrypt_field(self, key_name: str, ciphertext: str) -> str:
|
||||
"""Decrypt a field using Vault Transit"""
|
||||
try:
|
||||
response = self.vault_client.secrets.transit.decrypt_data(
|
||||
mount_point=self.mount_point, name=key_name, ciphertext=ciphertext
|
||||
)
|
||||
return base64.b64decode(response["data"]["plaintext"]).decode()
|
||||
except Exception as e:
|
||||
logger.error("Failed to decrypt field", key_name=key_name, error=str(e))
|
||||
raise
|
||||
|
||||
def _ensure_key_exists(self, key_name: str) -> None:
|
||||
"""Ensure encryption key exists in Vault"""
|
||||
try:
|
||||
self.vault_client.secrets.transit.read_key(
|
||||
mount_point=self.mount_point, name=key_name
|
||||
)
|
||||
# pylint: disable-next=broad-exception-caught
|
||||
except Exception: # hvac.exceptions.InvalidPath
|
||||
# Key doesn't exist, create it
|
||||
self.vault_client.secrets.transit.create_key(
|
||||
mount_point=self.mount_point, name=key_name, key_type="aes256-gcm96"
|
||||
)
|
||||
logger.info("Created new encryption key", key_name=key_name)
|
||||
9
libs/storage/__init__.py
Normal file
9
libs/storage/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Storage client and document management for MinIO/S3."""
|
||||
|
||||
from .client import StorageClient
|
||||
from .document import DocumentStorage
|
||||
|
||||
__all__ = [
|
||||
"StorageClient",
|
||||
"DocumentStorage",
|
||||
]
|
||||
231
libs/storage/client.py
Normal file
231
libs/storage/client.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""MinIO/S3 storage client wrapper."""
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Any, BinaryIO
|
||||
|
||||
import structlog
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class StorageClient:
|
||||
"""MinIO/S3 storage client wrapper"""
|
||||
|
||||
def __init__(self, minio_client: Minio):
|
||||
self.client = minio_client
|
||||
|
||||
async def ensure_bucket(self, bucket_name: str, region: str = "us-east-1") -> bool:
|
||||
"""Ensure bucket exists, create if not"""
|
||||
try:
|
||||
# Check if bucket exists
|
||||
if self.client.bucket_exists(bucket_name):
|
||||
logger.debug("Bucket already exists", bucket=bucket_name)
|
||||
return True
|
||||
|
||||
# Create bucket
|
||||
self.client.make_bucket(bucket_name, location=region)
|
||||
logger.info("Created bucket", bucket=bucket_name, region=region)
|
||||
return True
|
||||
|
||||
except S3Error as e:
|
||||
logger.error("Failed to ensure bucket", bucket=bucket_name, error=str(e))
|
||||
return False
|
||||
|
||||
async def put_object( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
self,
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
data: BinaryIO,
|
||||
length: int,
|
||||
content_type: str = "application/octet-stream",
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> bool:
|
||||
"""Upload object to bucket"""
|
||||
try:
|
||||
# Ensure bucket exists
|
||||
await self.ensure_bucket(bucket_name)
|
||||
|
||||
# Upload object
|
||||
result = self.client.put_object(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
data=data,
|
||||
length=length,
|
||||
content_type=content_type,
|
||||
metadata=metadata or {}, # fmt: skip # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Object uploaded",
|
||||
bucket=bucket_name,
|
||||
object=object_name,
|
||||
etag=result.etag,
|
||||
size=length,
|
||||
)
|
||||
return True
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(
|
||||
"Failed to upload object",
|
||||
bucket=bucket_name,
|
||||
object=object_name,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
async def get_object(self, bucket_name: str, object_name: str) -> bytes | None:
|
||||
"""Download object from bucket"""
|
||||
try:
|
||||
response = self.client.get_object(bucket_name, object_name)
|
||||
data = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
|
||||
logger.debug(
|
||||
"Object downloaded",
|
||||
bucket=bucket_name,
|
||||
object=object_name,
|
||||
size=len(data),
|
||||
)
|
||||
return data # type: ignore
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(
|
||||
"Failed to download object",
|
||||
bucket=bucket_name,
|
||||
object=object_name,
|
||||
error=str(e),
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_object_stream(self, bucket_name: str, object_name: str) -> Any:
|
||||
"""Get object as stream"""
|
||||
try:
|
||||
response = self.client.get_object(bucket_name, object_name)
|
||||
return response
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(
|
||||
"Failed to get object stream",
|
||||
bucket=bucket_name,
|
||||
object=object_name,
|
||||
error=str(e),
|
||||
)
|
||||
return None
|
||||
|
||||
async def object_exists(self, bucket_name: str, object_name: str) -> bool:
|
||||
"""Check if object exists"""
|
||||
try:
|
||||
self.client.stat_object(bucket_name, object_name)
|
||||
return True
|
||||
except S3Error:
|
||||
return False
|
||||
|
||||
async def delete_object(self, bucket_name: str, object_name: str) -> bool:
|
||||
"""Delete object from bucket"""
|
||||
try:
|
||||
self.client.remove_object(bucket_name, object_name)
|
||||
logger.info("Object deleted", bucket=bucket_name, object=object_name)
|
||||
return True
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(
|
||||
"Failed to delete object",
|
||||
bucket=bucket_name,
|
||||
object=object_name,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
async def list_objects(
|
||||
self, bucket_name: str, prefix: str | None = None, recursive: bool = True
|
||||
) -> list[str]:
|
||||
"""List objects in bucket"""
|
||||
try:
|
||||
objects = self.client.list_objects(
|
||||
bucket_name, prefix=prefix, recursive=recursive
|
||||
)
|
||||
return [obj.object_name for obj in objects if obj.object_name is not None]
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(
|
||||
"Failed to list objects",
|
||||
bucket=bucket_name,
|
||||
prefix=prefix,
|
||||
error=str(e),
|
||||
)
|
||||
return []
|
||||
|
||||
async def get_presigned_url(
|
||||
self,
|
||||
bucket_name: str,
|
||||
object_name: str,
|
||||
expires: timedelta = timedelta(hours=1),
|
||||
method: str = "GET",
|
||||
) -> str | None:
|
||||
"""Generate presigned URL for object access"""
|
||||
try:
|
||||
url = self.client.get_presigned_url(
|
||||
method=method,
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_name,
|
||||
expires=expires,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Generated presigned URL",
|
||||
bucket=bucket_name,
|
||||
object=object_name,
|
||||
method=method,
|
||||
expires=expires,
|
||||
)
|
||||
return str(url)
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(
|
||||
"Failed to generate presigned URL",
|
||||
bucket=bucket_name,
|
||||
object=object_name,
|
||||
error=str(e),
|
||||
)
|
||||
return None
|
||||
|
||||
async def copy_object(
|
||||
self, source_bucket: str, source_object: str, dest_bucket: str, dest_object: str
|
||||
) -> bool:
|
||||
"""Copy object between buckets/locations"""
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from minio.commonconfig import CopySource
|
||||
|
||||
# Ensure destination bucket exists
|
||||
await self.ensure_bucket(dest_bucket)
|
||||
|
||||
# Copy object
|
||||
self.client.copy_object(
|
||||
bucket_name=dest_bucket,
|
||||
object_name=dest_object,
|
||||
source=CopySource(source_bucket, source_object),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Object copied",
|
||||
source_bucket=source_bucket,
|
||||
source_object=source_object,
|
||||
dest_bucket=dest_bucket,
|
||||
dest_object=dest_object,
|
||||
)
|
||||
return True
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(
|
||||
"Failed to copy object",
|
||||
source_bucket=source_bucket,
|
||||
source_object=source_object,
|
||||
dest_bucket=dest_bucket,
|
||||
dest_object=dest_object,
|
||||
error=str(e),
|
||||
)
|
||||
return False
|
||||
145
libs/storage/document.py
Normal file
145
libs/storage/document.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""High-level document storage operations."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
from .client import StorageClient
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DocumentStorage:
|
||||
"""High-level document storage operations"""
|
||||
|
||||
def __init__(self, storage_client: StorageClient):
|
||||
self.storage = storage_client
|
||||
|
||||
async def store_document( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
self,
|
||||
tenant_id: str,
|
||||
doc_id: str,
|
||||
content: bytes,
|
||||
content_type: str = "application/pdf",
|
||||
metadata: dict[str, str] | None = None,
|
||||
bucket_name: str = "raw-documents",
|
||||
) -> dict[str, Any]:
|
||||
"""Store document with metadata"""
|
||||
|
||||
# Calculate checksum
|
||||
checksum = hashlib.sha256(content).hexdigest()
|
||||
|
||||
# Prepare metadata
|
||||
doc_metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"doc_id": doc_id,
|
||||
"checksum": checksum,
|
||||
"size": str(len(content)),
|
||||
**(metadata or {}),
|
||||
}
|
||||
|
||||
# Determine bucket and key
|
||||
object_key = f"tenants/{tenant_id}/raw/{doc_id}.pdf"
|
||||
|
||||
# Upload to storage
|
||||
success = await self.storage.put_object(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_key,
|
||||
data=BytesIO(content),
|
||||
length=len(content),
|
||||
content_type=content_type,
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"bucket": bucket_name,
|
||||
"key": object_key,
|
||||
"checksum": checksum,
|
||||
"size": len(content),
|
||||
"s3_url": f"s3://{bucket_name}/{object_key}",
|
||||
}
|
||||
|
||||
raise RuntimeError("Failed to store document")
|
||||
|
||||
async def store_ocr_result(
|
||||
self, tenant_id: str, doc_id: str, ocr_data: dict[str, Any]
|
||||
) -> str:
|
||||
"""Store OCR results as JSON"""
|
||||
bucket_name = "evidence"
|
||||
object_key = f"tenants/{tenant_id}/ocr/{doc_id}.json"
|
||||
|
||||
# Convert to JSON bytes
|
||||
json_data = json.dumps(ocr_data, indent=2).encode("utf-8")
|
||||
|
||||
# Upload to storage
|
||||
success = await self.storage.put_object(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_key,
|
||||
data=BytesIO(json_data),
|
||||
length=len(json_data),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
if success:
|
||||
return f"s3://{bucket_name}/{object_key}"
|
||||
|
||||
raise RuntimeError("Failed to store OCR result")
|
||||
|
||||
async def store_extraction_result(
|
||||
self, tenant_id: str, doc_id: str, extraction_data: dict[str, Any]
|
||||
) -> str:
|
||||
"""Store extraction results as JSON"""
|
||||
bucket_name = "evidence"
|
||||
object_key = f"tenants/{tenant_id}/extractions/{doc_id}.json"
|
||||
|
||||
# Convert to JSON bytes
|
||||
json_data = json.dumps(extraction_data, indent=2).encode("utf-8")
|
||||
|
||||
# Upload to storage
|
||||
success = await self.storage.put_object(
|
||||
bucket_name=bucket_name,
|
||||
object_name=object_key,
|
||||
data=BytesIO(json_data),
|
||||
length=len(json_data),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
if success:
|
||||
return f"s3://{bucket_name}/{object_key}"
|
||||
|
||||
raise RuntimeError("Failed to store extraction result")
|
||||
|
||||
async def get_document(self, tenant_id: str, doc_id: str) -> bytes | None:
|
||||
"""Retrieve document content"""
|
||||
bucket_name = "raw-documents"
|
||||
object_key = f"tenants/{tenant_id}/raw/{doc_id}.pdf"
|
||||
|
||||
return await self.storage.get_object(bucket_name, object_key)
|
||||
|
||||
async def get_ocr_result(
|
||||
self, tenant_id: str, doc_id: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Retrieve OCR results"""
|
||||
bucket_name = "evidence"
|
||||
object_key = f"tenants/{tenant_id}/ocr/{doc_id}.json"
|
||||
|
||||
data = await self.storage.get_object(bucket_name, object_key)
|
||||
if data:
|
||||
return json.loads(data.decode("utf-8")) # type: ignore
|
||||
return None
|
||||
|
||||
async def get_extraction_result(
|
||||
self, tenant_id: str, doc_id: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Retrieve extraction results"""
|
||||
bucket_name = "evidence"
|
||||
object_key = f"tenants/{tenant_id}/extractions/{doc_id}.json"
|
||||
|
||||
data = await self.storage.get_object(bucket_name, object_key)
|
||||
if data:
|
||||
return json.loads(data.decode("utf-8")) # type: ignore
|
||||
return None
|
||||
Reference in New Issue
Block a user