Initial commit
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled

This commit is contained in:
harkon
2025-10-11 08:41:36 +01:00
commit b324ff09ef
276 changed files with 55220 additions and 0 deletions

0
libs/__init__.py Normal file
View File

123
libs/app_factory.py Normal file
View File

@@ -0,0 +1,123 @@
"""Factory for creating FastAPI applications with consistent setup."""
# FILE: libs/app_factory.py
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from libs.config import BaseAppSettings, get_default_settings
from libs.observability import setup_observability
from libs.schemas import ErrorResponse
from libs.security import get_current_user, get_tenant_id
from libs.security.middleware import TrustedProxyMiddleware
def create_trusted_proxy_middleware(
internal_cidrs: list[str], disable_auth: bool = False
) -> TrustedProxyMiddleware:
"""Create a TrustedProxyMiddleware instance with the given internal CIDRs."""
# This is a factory function that will be called by FastAPI's add_middleware
# We return a partial function that creates the middleware
def middleware_factory(app: Any) -> TrustedProxyMiddleware:
return TrustedProxyMiddleware(app, internal_cidrs, disable_auth)
return middleware_factory # type: ignore
def create_app( # pylint: disable=too-many-arguments,too-many-positional-arguments
service_name: str,
title: str,
description: str,
version: str = "1.0.0",
settings_class: type[BaseAppSettings] = BaseAppSettings,
custom_settings: dict[str, Any] | None = None,
) -> tuple[FastAPI, BaseAppSettings]:
"""Create a FastAPI application with standard configuration"""
# Create settings
settings_kwargs = {"service_name": service_name}
if custom_settings:
settings_kwargs.update(custom_settings)
settings = get_default_settings(**settings_kwargs)
if settings_class != BaseAppSettings:
# Use custom settings class
settings = settings_class(**settings_kwargs) # type: ignore
# Create lifespan context manager
@asynccontextmanager
async def lifespan(
app: FastAPI,
) -> AsyncIterator[None]: # pylint: disable=unused-argument
# Startup
setup_observability(settings)
yield
# Shutdown
# Create FastAPI app
app = FastAPI(
title=title, description=description, version=version, lifespan=lifespan
)
# Add middleware
app.add_middleware(
TrustedProxyMiddleware,
internal_cidrs=settings.internal_cidrs,
disable_auth=getattr(settings, "disable_auth", False),
)
# Add exception handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(
request: Request, exc: HTTPException
) -> JSONResponse:
"""Handle HTTP exceptions with RFC7807 format"""
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
type=f"https://httpstatuses.com/{exc.status_code}",
title=exc.detail,
status=exc.status_code,
detail=exc.detail,
instance=str(request.url),
trace_id=getattr(request.state, "trace_id", None),
).model_dump(),
)
# Add health endpoints
@app.get("/healthz")
async def health_check() -> dict[str, str]:
"""Health check endpoint"""
return {
"status": "healthy",
"service": settings.service_name,
"version": version,
}
@app.get("/readyz")
async def readiness_check() -> dict[str, str]:
"""Readiness check endpoint"""
return {"status": "ready", "service": settings.service_name, "version": version}
@app.get("/livez")
async def liveness_check() -> dict[str, str]:
"""Liveness check endpoint"""
return {"status": "alive", "service": settings.service_name, "version": version}
return app, settings
# Dependency factories
def get_user_dependency() -> Any:
"""Get user dependency function"""
return get_current_user()
def get_tenant_dependency() -> Any:
"""Get tenant dependency function"""
return get_tenant_id()

View File

@@ -0,0 +1,12 @@
"""Confidence calibration for ML models."""
from .calibrator import ConfidenceCalibrator
from .metrics import DEFAULT_CALIBRATORS, ConfidenceMetrics
from .multi_model import MultiModelCalibrator
__all__ = [
"ConfidenceCalibrator",
"MultiModelCalibrator",
"ConfidenceMetrics",
"DEFAULT_CALIBRATORS",
]

View File

@@ -0,0 +1,190 @@
"""Confidence calibrator using various methods."""
import pickle
import numpy as np
import structlog
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression
logger = structlog.get_logger()
class ConfidenceCalibrator:
"""Calibrate confidence scores using various methods"""
def __init__(self, method: str = "temperature"):
"""
Initialize calibrator
Args:
method: Calibration method ('temperature', 'platt', 'isotonic')
"""
self.method = method
self.calibrator = None
self.temperature = 1.0
self.is_fitted = False
def fit(self, scores: list[float], labels: list[bool]) -> None:
"""
Fit calibration model
Args:
scores: Raw confidence scores (0-1)
labels: True labels (True/False for correct/incorrect)
"""
# Validate inputs
if len(scores) == 0 or len(labels) == 0:
raise ValueError("Scores and labels cannot be empty")
if len(scores) != len(labels):
raise ValueError("Scores and labels must have the same length")
scores_array = np.array(scores).reshape(-1, 1)
labels_array = np.array(labels, dtype=int)
if self.method == "temperature":
self._fit_temperature_scaling(scores_array, labels_array)
elif self.method == "platt":
self._fit_platt_scaling(scores_array, labels_array)
elif self.method == "isotonic":
self._fit_isotonic_regression(scores_array, labels_array)
else:
raise ValueError(f"Unknown calibration method: {self.method}")
self.is_fitted = True
logger.info("Calibrator fitted", method=self.method)
def _fit_temperature_scaling(self, scores: np.ndarray, labels: np.ndarray) -> None:
"""Fit temperature scaling parameter"""
# pylint: disable=import-outside-toplevel
from scipy.optimize import minimize_scalar
def negative_log_likelihood(temperature: float) -> float:
# Convert scores to logits
epsilon = 1e-7
scores_clipped = np.clip(scores.flatten(), epsilon, 1 - epsilon)
logits = np.log(scores_clipped / (1 - scores_clipped))
# Apply temperature scaling
calibrated_logits = logits / temperature
calibrated_probs = 1 / (1 + np.exp(-calibrated_logits))
# Calculate negative log likelihood
nll = -np.mean(
labels * np.log(calibrated_probs + epsilon)
+ (1 - labels) * np.log(1 - calibrated_probs + epsilon)
)
return float(nll)
# Find optimal temperature
result = minimize_scalar( # type: ignore
negative_log_likelihood,
bounds=(0.1, 10.0),
method="bounded", # fmt: skip # pyright: ignore[reportArgumentType]
)
self.temperature = result.x
logger.debug("Temperature scaling fitted", temperature=self.temperature)
def _fit_platt_scaling(self, scores: np.ndarray, labels: np.ndarray) -> None:
"""Fit Platt scaling (logistic regression)"""
# Convert scores to logits
epsilon = 1e-7
scores_clipped = np.clip(scores.flatten(), epsilon, 1 - epsilon)
logits = np.log(scores_clipped / (1 - scores_clipped)).reshape(-1, 1)
# Fit logistic regression
self.calibrator = LogisticRegression()
self.calibrator.fit(logits, labels) # type: ignore
logger.debug("Platt scaling fitted")
def _fit_isotonic_regression(self, scores: np.ndarray, labels: np.ndarray) -> None:
"""Fit isotonic regression"""
self.calibrator = IsotonicRegression(out_of_bounds="clip")
self.calibrator.fit(scores.flatten(), labels) # type: ignore
logger.debug("Isotonic regression fitted")
def calibrate(self, scores: list[float]) -> list[float]:
"""
Calibrate confidence scores
Args:
scores: Raw confidence scores
Returns:
Calibrated confidence scores
"""
if not self.is_fitted:
logger.warning("Calibrator not fitted, returning original scores")
return scores
scores_array = np.array(scores)
if self.method == "temperature":
return self._calibrate_temperature(scores_array)
if self.method == "platt":
return self._calibrate_platt(scores_array)
if self.method == "isotonic":
return self._calibrate_isotonic(scores_array)
return scores
def _calibrate_temperature(self, scores: np.ndarray) -> list[float]:
"""Apply temperature scaling"""
epsilon = 1e-7
scores_clipped = np.clip(scores, epsilon, 1 - epsilon)
# Convert to logits
logits = np.log(scores_clipped / (1 - scores_clipped))
# Apply temperature scaling
calibrated_logits = logits / self.temperature
calibrated_probs = 1 / (1 + np.exp(-calibrated_logits))
return calibrated_probs.tolist() # type: ignore
def _calibrate_platt(self, scores: np.ndarray) -> list[float]:
"""Apply Platt scaling"""
epsilon = 1e-7
scores_clipped = np.clip(scores, epsilon, 1 - epsilon)
# Convert to logits
logits = np.log(scores_clipped / (1 - scores_clipped)).reshape(-1, 1)
# Apply Platt scaling
calibrated_probs = self.calibrator.predict_proba(logits)[:, 1] # type: ignore
return calibrated_probs.tolist() # type: ignore
def _calibrate_isotonic(self, scores: np.ndarray) -> list[float]:
"""Apply isotonic regression"""
calibrated_probs = self.calibrator.predict(scores) # type: ignore
return calibrated_probs.tolist() # type: ignore
def save_model(self, filepath: str) -> None:
"""Save calibration model"""
model_data = {
"method": self.method,
"temperature": self.temperature,
"calibrator": self.calibrator,
"is_fitted": self.is_fitted,
}
with open(filepath, "wb") as f:
pickle.dump(model_data, f)
logger.info("Calibration model saved", filepath=filepath)
def load_model(self, filepath: str) -> None:
"""Load calibration model"""
with open(filepath, "rb") as f:
model_data = pickle.load(f)
self.method = model_data["method"]
self.temperature = model_data["temperature"]
self.calibrator = model_data["calibrator"]
self.is_fitted = model_data["is_fitted"]
logger.info("Calibration model loaded", filepath=filepath, method=self.method)

144
libs/calibration/metrics.py Normal file
View File

@@ -0,0 +1,144 @@
"""Calibration metrics for evaluating confidence calibration."""
import numpy as np
class ConfidenceMetrics:
"""Calculate calibration metrics"""
@staticmethod
def expected_calibration_error(
scores: list[float], labels: list[bool], n_bins: int = 10
) -> float:
"""
Calculate Expected Calibration Error (ECE)
Args:
scores: Predicted confidence scores
labels: True labels
n_bins: Number of bins for calibration
Returns:
ECE value
"""
scores_array = np.array(scores)
labels_array = np.array(labels, dtype=int)
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
ece = 0
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers, strict=False):
# Find samples in this bin
in_bin = (scores_array > bin_lower) & (scores_array <= bin_upper)
prop_in_bin = in_bin.mean()
if prop_in_bin > 0:
# Calculate accuracy and confidence in this bin
accuracy_in_bin = labels_array[in_bin].mean()
avg_confidence_in_bin = scores_array[in_bin].mean()
# Add to ECE
ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
return ece
@staticmethod
def maximum_calibration_error(
scores: list[float], labels: list[bool], n_bins: int = 10
) -> float:
"""
Calculate Maximum Calibration Error (MCE)
Args:
scores: Predicted confidence scores
labels: True labels
n_bins: Number of bins for calibration
Returns:
MCE value
"""
scores_array = np.array(scores)
labels_array = np.array(labels, dtype=int)
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
max_error = 0
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers, strict=False):
# Find samples in this bin
in_bin = (scores_array > bin_lower) & (scores_array <= bin_upper)
if in_bin.sum() > 0:
# Calculate accuracy and confidence in this bin
accuracy_in_bin = labels_array[in_bin].mean()
avg_confidence_in_bin = scores_array[in_bin].mean()
# Update maximum error
error = np.abs(avg_confidence_in_bin - accuracy_in_bin)
max_error = max(max_error, error)
return max_error
@staticmethod
def reliability_diagram_data( # pylint: disable=too-many-locals
scores: list[float], labels: list[bool], n_bins: int = 10
) -> dict[str, list[float]]:
"""
Generate data for reliability diagram
Args:
scores: Predicted confidence scores
labels: True labels
n_bins: Number of bins
Returns:
Dictionary with bin data for plotting
"""
scores_array = np.array(scores)
labels_array = np.array(labels, dtype=int)
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
bin_centers = []
bin_accuracies = []
bin_confidences = []
bin_counts = []
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers, strict=False):
# Find samples in this bin
in_bin = (scores_array > bin_lower) & (scores_array <= bin_upper)
bin_count = in_bin.sum()
if bin_count > 0:
bin_center = (bin_lower + bin_upper) / 2
accuracy_in_bin = labels_array[in_bin].mean()
avg_confidence_in_bin = scores_array[in_bin].mean()
bin_centers.append(bin_center)
bin_accuracies.append(accuracy_in_bin)
bin_confidences.append(avg_confidence_in_bin)
bin_counts.append(bin_count)
return {
"bin_centers": bin_centers,
"bin_accuracies": bin_accuracies,
"bin_confidences": bin_confidences,
"bin_counts": bin_counts,
}
# Default calibrators for common tasks
DEFAULT_CALIBRATORS = {
"ocr_confidence": {"method": "temperature"},
"extraction_confidence": {"method": "platt"},
"rag_confidence": {"method": "isotonic"},
"calculation_confidence": {"method": "temperature"},
"overall_confidence": {"method": "platt"},
}

View File

@@ -0,0 +1,85 @@
"""Multi-model calibrator for handling multiple models/tasks."""
import glob
import os
import structlog
from .calibrator import ConfidenceCalibrator
logger = structlog.get_logger()
class MultiModelCalibrator:
"""Calibrate confidence scores for multiple models/tasks"""
def __init__(self) -> None:
self.calibrators: dict[str, ConfidenceCalibrator] = {}
def add_calibrator(self, model_name: str, method: str = "temperature") -> None:
"""Add calibrator for a specific model"""
self.calibrators[model_name] = ConfidenceCalibrator(method)
logger.info("Added calibrator", model=model_name, method=method)
def fit(self, model_name: str, scores: list[float], labels: list[bool]) -> None:
"""Fit calibrator for specific model"""
if model_name not in self.calibrators:
self.add_calibrator(model_name)
self.calibrators[model_name].fit(scores, labels)
def calibrate(self, model_name: str, scores: list[float]) -> list[float]:
"""Calibrate scores for specific model"""
if model_name not in self.calibrators:
logger.warning("No calibrator for model", model=model_name)
return scores
return self.calibrators[model_name].calibrate(scores)
def save_all(self, directory: str) -> None:
"""Save all calibrators"""
os.makedirs(directory, exist_ok=True)
for model_name, calibrator in self.calibrators.items():
filepath = os.path.join(directory, f"{model_name}_calibrator.pkl")
calibrator.save_model(filepath)
def load_all(self, directory: str) -> None:
"""Load all calibrators from directory"""
pattern = os.path.join(directory, "*_calibrator.pkl")
for filepath in glob.glob(pattern):
filename = os.path.basename(filepath)
model_name = filename.replace("_calibrator.pkl", "")
calibrator = ConfidenceCalibrator()
calibrator.load_model(filepath)
self.calibrators[model_name] = calibrator
def save_models(self, directory: str) -> None:
"""Save all calibrators (alias for save_all)"""
self.save_all(directory)
def load_models(self, directory: str) -> None:
"""Load all calibrators from directory (alias for load_all)"""
self.load_all(directory)
def get_model_names(self) -> list[str]:
"""Get list of model names"""
return list(self.calibrators.keys())
def has_model(self, model_name: str) -> bool:
"""Check if model exists"""
return model_name in self.calibrators
def is_fitted(self, model_name: str) -> bool:
"""Check if model is fitted"""
if model_name not in self.calibrators:
raise ValueError(f"Model '{model_name}' not found")
return self.calibrators[model_name].is_fitted
def remove_calibrator(self, model_name: str) -> None:
"""Remove calibrator for specific model"""
if model_name not in self.calibrators:
raise ValueError(f"Model '{model_name}' not found")
del self.calibrators[model_name]
logger.info("Removed calibrator", model=model_name)

555
libs/config.py Normal file
View File

@@ -0,0 +1,555 @@
# ROLE
You are a **Senior Platform Engineer + Backend Lead** generating **production code** and **ops assets** for a microservice suite that powers an accounting Knowledge Graph + Vector RAG platform. Authentication/authorization are centralized at the **edge via Traefik + Authentik** (ForwardAuth). **Services are trust-bound** to Traefik and consume user/role claims via forwarded headers/JWT.
# MISSION
Produce fully working code for **all application services** (FastAPI + Python 3.12) with:
- Solid domain models, Pydantic v2 schemas, type hints, strict mypy, ruff lint.
- Opentelemetry tracing, Prometheus metrics, structured logging.
- Vault-backed secrets, MinIO S3 client, Qdrant client, Neo4j driver, Postgres (SQLAlchemy), Redis.
- Eventing (Kafka or SQS/SNS behind an interface).
- Deterministic data contracts, end-to-end tests, Dockerfiles, Compose, CI for Gitea.
- Traefik labels + Authentik Outpost integration for every exposed route.
- Zero PII in vectors (Qdrant), evidence-based lineage in KG, and bitemporal writes.
# GLOBAL CONSTRAINTS (APPLY TO ALL SERVICES)
- **Language & Runtime:** Python **3.12**.
- **Frameworks:** FastAPI, Pydantic v2, SQLAlchemy 2, httpx, aiokafka or boto3 (pluggable), redis-py, opentelemetry-instrumentation-fastapi, prometheus-fastapi-instrumentator.
- **Config:** `pydantic-settings` with `.env` overlay. Provide `Settings` class per service.
- **Secrets:** HashiCorp **Vault** (AppRole/JWT). Use Vault Transit to **envelope-encrypt** sensitive fields before persistence (helpers provided in `lib/security.py`).
- **Auth:** No OIDC in services. Add `TrustedProxyMiddleware`:
- Reject if request not from internal network (configurable CIDR).
- Require headers set by Traefik+Authentik (`X-Authenticated-User`, `X-Authenticated-Email`, `X-Authenticated-Groups`, `Authorization: Bearer `).
- Parse groups `roles` list on `request.state`.
- **Observability:**
- OpenTelemetry (traceparent propagation), span attrs (service, route, user, tenant).
- Prometheus metrics endpoint `/metrics` protected by internal network check.
- Structured JSON logs (timestamp, level, svc, trace_id, msg) via `structlog`.
- **Errors:** Global exception handler RFC7807 Problem+JSON (`type`, `title`, `status`, `detail`, `instance`, `trace_id`).
- **Testing:** `pytest`, `pytest-asyncio`, `hypothesis` (property tests for calculators), `coverage 90%` per service.
- **Static:** `ruff`, `mypy --strict`, `bandit`, `safety`, `licensecheck`.
- **Perf:** Each service exposes `/healthz`, `/readyz`, `/livez`; cold start < 500ms; p95 endpoint < 250ms (local).
- **Containers:** Distroless or slim images; non-root user; read-only FS; `/tmp` mounted for OCR where needed.
- **Docs:** OpenAPI JSON + ReDoc; MkDocs site with service READMEs.
# SHARED LIBS (GENERATE ONCE, REUSE)
Create `libs/` used by all services:
- `libs/config.py` base `Settings`, env parsing, Vault client factory, MinIO client factory, Qdrant client factory, Neo4j driver factory, Redis factory, Kafka/SQS client factory.
- `libs/security.py` Vault Transit helpers (`encrypt_field`, `decrypt_field`), header parsing, internal-CIDR validator.
- `libs/observability.py` otel init, prometheus instrumentor, logging config.
- `libs/events.py` abstract `EventBus` with `publish(topic, payload: dict)`, `subscribe(topic, handler)`. Two impls: Kafka (`aiokafka`) and SQS/SNS (`boto3`).
- `libs/schemas.py` **canonical Pydantic models** shared across services (Document, Evidence, IncomeItem, etc.) mirroring the ontology schemas. Include JSONSchema exports.
- `libs/storage.py` S3/MinIO helpers (bucket ensure, put/get, presigned).
- `libs/neo.py` Neo4j session helpers, Cypher runner with retry, SHACL validator invoker (pySHACL on exported RDF).
- `libs/rag.py` Qdrant collections CRUD, hybrid search (dense+sparse), rerank wrapper, de-identification utilities (regex + NER; hash placeholders).
- `libs/forms.py` PDF AcroForm fill via `pdfrw` with overlay fallback via `reportlab`.
- `libs/calibration.py` `calibrated_confidence(raw_score, method="temperature_scaling", params=...)`.
# EVENT TOPICS (STANDARDIZE)
- `doc.ingested`, `doc.ocr_ready`, `doc.extracted`, `kg.upserted`, `rag.indexed`, `calc.schedule_ready`, `form.filled`, `hmrc.submitted`, `review.requested`, `review.completed`, `firm.sync.completed`
Each payload MUST include: `event_id (ulid)`, `occurred_at (iso)`, `actor`, `tenant_id`, `trace_id`, `schema_version`, and a `data` object (service-specific).
# TRUST HEADERS FROM TRAEFIK + AUTHENTIK (USE EXACT KEYS)
- `X-Authenticated-User` (string)
- `X-Authenticated-Email` (string)
- `X-Authenticated-Groups` (comma-separated)
- `Authorization` (`Bearer <jwt>` from Authentik)
Reject any request missing these (except `/healthz|/readyz|/livez|/metrics` from internal CIDR).
---
## SERVICES TO IMPLEMENT (CODE FOR EACH)
### 1) `svc-ingestion`
**Purpose:** Accept uploads or URLs, checksum, store to MinIO, emit `doc.ingested`.
**Endpoints:**
- `POST /v1/ingest/upload` (multipart file, metadata: `tenant_id`, `kind`, `source`) `{doc_id, s3_url, checksum}`
- `POST /v1/ingest/url` (json: `{url, kind, tenant_id}`) downloads to MinIO
- `GET /v1/docs/{doc_id}` metadata
**Logic:**
- Compute SHA256, dedupe by checksum; MinIO path `tenants/{tenant_id}/raw/{doc_id}.pdf`.
- Store metadata in Postgres table `ingest_documents` (alembic migrations).
- Publish `doc.ingested` with `{doc_id, bucket, key, pages?, mime}`.
**Env:** `S3_BUCKET_RAW`, `MINIO_*`, `DB_URL`.
**Traefik labels:** route `/ingest/*`.
---
### 2) `svc-rpa`
**Purpose:** Scheduled RPA pulls from firm/client portals via Playwright.
**Tasks:**
- Playwright login flows (credentials from Vault), 2FA via Authentik OAuth device or OTP secret in Vault.
- Download statements/invoices; hand off to `svc-ingestion` via internal POST.
- Prefect flows: `pull_portal_X()`, `pull_portal_Y()` with schedules.
**Endpoints:**
- `POST /v1/rpa/run/{connector}` (manual trigger)
- `GET /v1/rpa/status/{run_id}`
**Env:** `VAULT_ADDR`, `VAULT_ROLE_ID`, `VAULT_SECRET_ID`.
---
### 3) `svc-ocr`
**Purpose:** OCR & layout extraction.
**Pipeline:**
- Pull object from MinIO, detect rotation/de-skew (`opencv-python`), split pages (`pymupdf`), OCR (`pytesseract`) or bypass if text layer present (`pdfplumber`).
- Output per-page text + **bbox** for lines/words.
- Write JSON to MinIO `tenants/{tenant_id}/ocr/{doc_id}.json` and emit `doc.ocr_ready`.
**Endpoints:**
- `POST /v1/ocr/{doc_id}` (idempotent trigger)
- `GET /v1/ocr/{doc_id}` (fetch OCR JSON)
**Env:** `TESSERACT_LANGS`, `S3_BUCKET_EVIDENCE`.
---
### 4) `svc-extract`
**Purpose:** Classify docs and extract KV + tables into **schema-constrained JSON** (with bbox/page).
**Endpoints:**
- `POST /v1/extract/{doc_id}` body: `{strategy: "llm|rules|hybrid"}`
- `GET /v1/extract/{doc_id}` structured JSON
**Implementation:**
- Use prompt files in `prompts/`: `doc_classify.txt`, `kv_extract.txt`, `table_extract.txt`.
- **Validator loop**: run LLM validate JSONSchema retry with error messages up to N times.
- Return Pydantic models from `libs/schemas.py`.
- Emit `doc.extracted`.
**Env:** `LLM_ENGINE`, `TEMPERATURE`, `MAX_TOKENS`.
---
### 5) `svc-normalize-map`
**Purpose:** Normalize & map extracted data to KG.
**Logic:**
- Currency normalization (ECB or static fx table), dates, UK tax year/basis period inference.
- Entity resolution (blocking + fuzzy).
- Generate nodes/edges (+ `Evidence` with doc_id/page/bbox/text_hash).
- Use `libs/neo.py` to write with **bitemporal** fields; run **SHACL** validator; on violation, queue `review.requested`.
- Emit `kg.upserted`.
**Endpoints:**
- `POST /v1/map/{doc_id}`
- `GET /v1/map/{doc_id}/preview` (diff view, to be used by UI)
**Env:** `NEO4J_*`.
---
### 6) `svc-kg`
**Purpose:** Graph façade + RDF/SHACL utility.
**Endpoints:**
- `GET /v1/kg/nodes/{label}/{id}`
- `POST /v1/kg/cypher` (admin-gated inline query; must check `admin` role)
- `POST /v1/kg/export/rdf` (returns RDF for SHACL)
- `POST /v1/kg/validate` (run pySHACL against `schemas/shapes.ttl`)
- `GET /v1/kg/lineage/{node_id}` (traverse `DERIVED_FROM` Evidence)
**Env:** `NEO4J_*`.
---
### 7) `svc-rag-indexer`
**Purpose:** Build Qdrant indices (firm knowledge, legislation, best practices, glossary).
**Workflow:**
- Load sources (filesystem, URLs, Firm DMS via `svc-firm-connectors`).
- **De-identify PII** (regex + NER), replace with placeholders; store mapping only in Postgres.
- Chunk (layout-aware) per `retrieval/chunking.yaml`.
- Compute **dense** embeddings (e.g., `bge-small-en-v1.5`) and **sparse** (Qdrant sparse).
- Upsert to Qdrant with payload `{jurisdiction, tax_years[], topic_tags[], version, pii_free: true, doc_id/section_id/url}`.
- Emit `rag.indexed`.
**Endpoints:**
- `POST /v1/index/run`
- `GET /v1/index/status/{run_id}`
**Env:** `QDRANT_URL`, `RAG_EMBEDDING_MODEL`, `RAG_RERANKER_MODEL`.
---
### 8) `svc-rag-retriever`
**Purpose:** Hybrid search + KG fusion with rerank and calibrated confidence.
**Endpoint:**
- `POST /v1/rag/search` `{query, tax_year?, jurisdiction?, k?}`
```
{
"chunks": [...],
"citations": [{doc_id|url, section_id?, page?, bbox?}],
"kg_hints": [{rule_id, formula_id, node_ids[]}],
"calibrated_confidence": 0.0-1.0
}
```
**Implementation:**
- Hybrid score: `alpha * dense + beta * sparse`; rerank top-K via cross-encoder; **KG fusion** (boost chunks citing Rules/Calculations relevant to schedule).
- Use `libs/calibration.py` to expose calibrated confidence.
---
### 9) `svc-reason`
**Purpose:** Deterministic calculators + materializers (UK SA).
**Endpoints:**
- `POST /v1/reason/compute_schedule` `{tax_year, taxpayer_id, schedule_id}`
- `GET /v1/reason/explain/{schedule_id}` rationale & lineage paths
**Implementation:**
- Pure functions for: employment, self-employment, property (FHL, 20% interest credit), dividends/interest, allowances, NIC (Class 2/4), HICBC, student loans (Plans 1/2/4/5, PGL).
- **Deterministic order** as defined; rounding per `FormBox.rounding_rule`.
- Use Cypher from `kg/reasoning/schedule_queries.cypher` to materialize box values; attach `DERIVED_FROM` evidence.
---
### 10) `svc-forms`
**Purpose:** Fill PDFs and assemble evidence bundles.
**Endpoints:**
- `POST /v1/forms/fill` `{tax_year, taxpayer_id, form_id}` returns PDF (binary)
- `POST /v1/forms/evidence_pack` `{scope}` ZIP + manifest + signed hashes (sha256)
**Implementation:**
- `pdfrw` for AcroForm; overlay with ReportLab if needed.
- Manifest includes `doc_id/page/bbox/text_hash` for every numeric field.
---
### 11) `svc-hmrc`
**Purpose:** HMRC submitter (stub|sandbox|live).
**Endpoints:**
- `POST /v1/hmrc/submit` `{tax_year, taxpayer_id, dry_run}` `{status, submission_id?, errors[]}`
- `GET /v1/hmrc/submissions/{id}`
**Implementation:**
- Rate limits, retries/backoff, signed audit log; environment toggle.
---
### 12) `svc-firm-connectors`
**Purpose:** Read-only connectors to Firm Databases (Practice Mgmt, DMS).
**Endpoints:**
- `POST /v1/firm/sync` `{since?}` `{objects_synced, errors[]}`
- `GET /v1/firm/objects` (paged)
**Implementation:**
- Data contracts in `config/firm_contracts/`; mappers Secure Client Data Store (Postgres) with lineage columns (`source`, `source_id`, `synced_at`).
---
### 13) `ui-review` (outline only)
- Next.js (SSO handled by Traefik+Authentik), shows extracted fields + evidence snippets; POST overrides to `svc-extract`/`svc-normalize-map`.
---
## DATA CONTRACTS (ESSENTIAL EXAMPLES)
**Event: `doc.ingested`**
```json
{
"event_id": "01J...ULID",
"occurred_at": "2025-09-13T08:00:00Z",
"actor": "svc-ingestion",
"tenant_id": "t_123",
"trace_id": "abc-123",
"schema_version": "1.0",
"data": {
"doc_id": "d_abc",
"bucket": "raw",
"key": "tenants/t_123/raw/d_abc.pdf",
"checksum": "sha256:...",
"kind": "bank_statement",
"mime": "application/pdf",
"pages": 12
}
}
```
**RAG search response shape**
```json
{
"chunks": [
{
"id": "c1",
"text": "...",
"score": 0.78,
"payload": {
"jurisdiction": "UK",
"tax_years": ["2024-25"],
"topic_tags": ["FHL"],
"pii_free": true
}
}
],
"citations": [
{ "doc_id": "leg-ITA2007", "section_id": "s272A", "url": "https://..." }
],
"kg_hints": [
{
"rule_id": "UK.FHL.Qual",
"formula_id": "FHL_Test_v1",
"node_ids": ["n123", "n456"]
}
],
"calibrated_confidence": 0.81
}
```
---
## PERSISTENCE SCHEMAS (POSTGRES; ALEMBIC)
- `ingest_documents(id pk, tenant_id, doc_id, kind, checksum, bucket, key, mime, pages, created_at)`
- `firm_objects(id pk, tenant_id, source, source_id, type, payload jsonb, synced_at)`
- Qdrant PII mapping table (if absolutely needed): `pii_links(id pk, placeholder_hash, client_id, created_at)` **encrypt with Vault Transit**; do NOT store raw values.
---
## TRAEFIK + AUTHENTIK (COMPOSE LABELS PER SERVICE)
For every service container in `infra/compose/docker-compose.local.yml`, add labels:
```
- "traefik.enable=true"
- "traefik.http.routers.svc-extract.rule=Host(`api.local`) && PathPrefix(`/extract`)"
- "traefik.http.routers.svc-extract.entrypoints=websecure"
- "traefik.http.routers.svc-extract.tls=true"
- "traefik.http.routers.svc-extract.middlewares=authentik-forwardauth,rate-limit"
- "traefik.http.services.svc-extract.loadbalancer.server.port=8000"
```
Use the shared dynamic file `traefik-dynamic.yml` with `authentik-forwardauth` and `rate-limit` middlewares.
---
## OUTPUT FORMAT (STRICT)
Implement a **multi-file codebase** as fenced blocks, EXACTLY in this order:
```txt
# FILE: libs/config.py
# factories for Vault/MinIO/Qdrant/Neo4j/Redis/EventBus, Settings base
...
```
```txt
# FILE: libs/security.py
# Vault Transit helpers, header parsing, internal CIDR checks, middleware
...
```
```txt
# FILE: libs/observability.py
# otel init, prometheus, structlog
...
```
```txt
# FILE: libs/events.py
# EventBus abstraction with Kafka and SQS/SNS impls
...
```
```txt
# FILE: libs/schemas.py
# Shared Pydantic models mirroring ontology entities
...
```
```txt
# FILE: apps/svc-ingestion/main.py
# FastAPI app, endpoints, MinIO write, Postgres, publish doc.ingested
...
```
```txt
# FILE: apps/svc-rpa/main.py
# Playwright flows, Prefect tasks, triggers
...
```
```txt
# FILE: apps/svc-ocr/main.py
# OCR pipeline, endpoints
...
```
```txt
# FILE: apps/svc-extract/main.py
# Classifier + extractors with validator loop
...
```
```txt
# FILE: apps/svc-normalize-map/main.py
# normalization, entity resolution, KG mapping, SHACL validation call
...
```
```txt
# FILE: apps/svc-kg/main.py
# KG façade, RDF export, SHACL validate, lineage traversal
...
```
```txt
# FILE: apps/svc-rag-indexer/main.py
# chunk/de-id/embed/upsert to Qdrant
...
```
```txt
# FILE: apps/svc-rag-retriever/main.py
# hybrid retrieval + rerank + KG fusion
...
```
```txt
# FILE: apps/svc-reason/main.py
# deterministic calculators, schedule compute/explain
...
```
```txt
# FILE: apps/svc-forms/main.py
# PDF fill + evidence pack
...
```
```txt
# FILE: apps/svc-hmrc/main.py
# submit stub|sandbox|live with audit + retries
...
```
```txt
# FILE: apps/svc-firm-connectors/main.py
# connectors to practice mgmt & DMS, sync to Postgres
...
```
```txt
# FILE: infra/compose/docker-compose.local.yml
# Traefik, Authentik, Vault, MinIO, Qdrant, Neo4j, Postgres, Redis, Prom+Grafana, Loki, Unleash, all services
...
```
```txt
# FILE: infra/compose/traefik.yml
# static Traefik config
...
```
```txt
# FILE: infra/compose/traefik-dynamic.yml
# forwardAuth middleware + routers/services
...
```
```txt
# FILE: .gitea/workflows/ci.yml
# lint->test->build->scan->push->deploy
...
```
```txt
# FILE: Makefile
# bootstrap, run, test, lint, build, deploy, format, seed
...
```
```txt
# FILE: tests/e2e/test_happy_path.py
# end-to-end: ingest -> ocr -> extract -> map -> compute -> fill -> (stub) submit
...
```
```txt
# FILE: tests/unit/test_calculators.py
# boundary tests for UK SA logic (NIC, HICBC, PA taper, FHL)
...
```
```txt
# FILE: README.md
# how to run locally with docker-compose, Authentik setup, Traefik certs
...
```
## DEFINITION OF DONE
- `docker compose up` brings the full stack up; SSO via Authentik; routes secured via Traefik ForwardAuth.
- Running `pytest` yields 90% coverage; `make e2e` passes the ingestsubmit stub flow.
- All services expose `/healthz|/readyz|/livez|/metrics`; OpenAPI at `/docs`.
- No PII stored in Qdrant; vectors carry `pii_free=true`.
- KG writes are SHACL-validated; violations produce `review.requested` events.
- Evidence lineage is present for every numeric box value.
- Gitea pipeline passes: lint, test, build, scan, push, deploy.
# START
Generate the full codebase and configs in the **exact file blocks and order** specified above.

41
libs/config/__init__.py Normal file
View File

@@ -0,0 +1,41 @@
"""Configuration management and client factories."""
from .factories import (
EventBusFactory,
MinIOClientFactory,
Neo4jDriverFactory,
QdrantClientFactory,
RedisClientFactory,
VaultClientFactory,
)
from .settings import BaseAppSettings
from .utils import (
create_event_bus,
create_minio_client,
create_neo4j_client,
create_qdrant_client,
create_redis_client,
create_vault_client,
get_default_settings,
get_settings,
init_settings,
)
__all__ = [
"BaseAppSettings",
"VaultClientFactory",
"MinIOClientFactory",
"QdrantClientFactory",
"Neo4jDriverFactory",
"RedisClientFactory",
"EventBusFactory",
"get_settings",
"init_settings",
"create_vault_client",
"create_minio_client",
"create_qdrant_client",
"create_neo4j_client",
"create_redis_client",
"create_event_bus",
"get_default_settings",
]

122
libs/config/factories.py Normal file
View File

@@ -0,0 +1,122 @@
"""Client factories for various services."""
from typing import Any
import boto3 # type: ignore
import hvac
import redis.asyncio as redis
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer # type: ignore
from minio import Minio
from neo4j import GraphDatabase
from qdrant_client import QdrantClient
from .settings import BaseAppSettings
class VaultClientFactory: # pylint: disable=too-few-public-methods
"""Factory for creating Vault clients"""
@staticmethod
def create_client(settings: BaseAppSettings) -> hvac.Client:
"""Create authenticated Vault client"""
client = hvac.Client(url=settings.vault_addr)
if settings.vault_token:
# Development mode with token
client.token = settings.vault_token
elif settings.vault_role_id and settings.vault_secret_id:
# Production mode with AppRole
try:
auth_response = client.auth.approle.login(
role_id=settings.vault_role_id, secret_id=settings.vault_secret_id
)
client.token = auth_response["auth"]["client_token"]
except Exception as e:
raise ValueError("Failed to authenticate with Vault") from e
else:
raise ValueError(
"Either vault_token or vault_role_id/vault_secret_id must be provided"
)
if not client.is_authenticated():
raise ValueError("Failed to authenticate with Vault")
return client
class MinIOClientFactory: # pylint: disable=too-few-public-methods
"""Factory for creating MinIO clients"""
@staticmethod
def create_client(settings: BaseAppSettings) -> Minio:
"""Create MinIO client"""
return Minio(
endpoint=settings.minio_endpoint,
access_key=settings.minio_access_key,
secret_key=settings.minio_secret_key,
secure=settings.minio_secure,
)
class QdrantClientFactory: # pylint: disable=too-few-public-methods
"""Factory for creating Qdrant clients"""
@staticmethod
def create_client(settings: BaseAppSettings) -> QdrantClient:
"""Create Qdrant client"""
return QdrantClient(url=settings.qdrant_url, api_key=settings.qdrant_api_key)
class Neo4jDriverFactory: # pylint: disable=too-few-public-methods
"""Factory for creating Neo4j drivers"""
@staticmethod
def create_driver(settings: BaseAppSettings) -> Any:
"""Create Neo4j driver"""
return GraphDatabase.driver(
settings.neo4j_uri, auth=(settings.neo4j_user, settings.neo4j_password)
)
class RedisClientFactory: # pylint: disable=too-few-public-methods
"""Factory for creating Redis clients"""
@staticmethod
async def create_client(settings: BaseAppSettings) -> "redis.Redis[str]":
"""Create Redis client"""
return redis.from_url(
settings.redis_url, encoding="utf-8", decode_responses=True
)
class EventBusFactory:
"""Factory for creating event bus clients"""
@staticmethod
def create_kafka_producer(settings: BaseAppSettings) -> AIOKafkaProducer:
"""Create Kafka producer"""
return AIOKafkaProducer(
bootstrap_servers=settings.kafka_bootstrap_servers,
value_serializer=lambda v: v.encode("utf-8") if isinstance(v, str) else v,
)
@staticmethod
def create_kafka_consumer(
settings: BaseAppSettings, topics: list[str]
) -> AIOKafkaConsumer:
"""Create Kafka consumer"""
return AIOKafkaConsumer(
*topics,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda m: m.decode("utf-8") if m else None,
)
@staticmethod
def create_sqs_client(settings: BaseAppSettings) -> Any:
"""Create SQS client"""
return boto3.client("sqs", region_name=settings.aws_region)
@staticmethod
def create_sns_client(settings: BaseAppSettings) -> Any:
"""Create SNS client"""
return boto3.client("sns", region_name=settings.aws_region)

113
libs/config/settings.py Normal file
View File

@@ -0,0 +1,113 @@
"""Base settings class for all services."""
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class BaseAppSettings(BaseSettings):
"""Base settings class for all services"""
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
)
# Service identification
service_name: str = Field(default="default-service", description="Service name")
service_version: str = Field(default="1.0.0", description="Service version")
# Network and security
host: str = Field(default="0.0.0.0", description="Service host")
port: int = Field(default=8000, description="Service port")
internal_cidrs: list[str] = Field(
default=["10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"],
description="Internal network CIDRs",
)
# Development settings
dev_mode: bool = Field(
default=False,
description="Enable development mode (disables auth)",
validation_alias="DEV_MODE",
)
disable_auth: bool = Field(
default=False,
description="Disable authentication middleware",
validation_alias="DISABLE_AUTH",
)
# Vault configuration
vault_addr: str = Field(
default="http://vault:8200", description="Vault server address"
)
vault_role_id: str | None = Field(default=None, description="Vault AppRole role ID")
vault_secret_id: str | None = Field(
default=None, description="Vault AppRole secret ID"
)
vault_token: str | None = Field(default=None, description="Vault token (dev only)")
vault_mount_point: str = Field(
default="transit", description="Vault transit mount point"
)
# Database URLs
postgres_url: str = Field(
default="postgresql://user:pass@postgres:5432/taxagent",
description="PostgreSQL connection URL",
)
neo4j_uri: str = Field(
default="bolt://neo4j:7687", description="Neo4j connection URI"
)
neo4j_user: str = Field(default="neo4j", description="Neo4j username")
neo4j_password: str = Field(default="password", description="Neo4j password")
redis_url: str = Field(
default="redis://redis:6379", description="Redis connection URL"
)
# Object storage
minio_endpoint: str = Field(default="minio:9000", description="MinIO endpoint")
minio_access_key: str = Field(default="minioadmin", description="MinIO access key")
minio_secret_key: str = Field(default="minioadmin", description="MinIO secret key")
minio_secure: bool = Field(default=False, description="Use HTTPS for MinIO")
# Vector database
qdrant_url: str = Field(
default="http://qdrant:6333", description="Qdrant server URL"
)
qdrant_api_key: str | None = Field(default=None, description="Qdrant API key")
# Event bus configuration
event_bus_type: str = Field(
default="nats", description="Event bus type: nats, kafka, sqs, or memory"
)
# NATS configuration
nats_servers: str = Field(
default="nats://localhost:4222",
description="NATS server URLs (comma-separated)",
)
nats_stream_name: str = Field(
default="TAX_AGENT_EVENTS", description="NATS JetStream stream name"
)
nats_consumer_group: str = Field(
default="tax-agent", description="NATS consumer group name"
)
# Kafka configuration (legacy)
kafka_bootstrap_servers: str = Field(
default="localhost:9092", description="Kafka bootstrap servers"
)
# AWS configuration
aws_region: str = Field(default="us-east-1", description="AWS region for SQS/SNS")
# Observability
otel_service_name: str | None = Field(
default=None, description="OpenTelemetry service name"
)
otel_exporter_endpoint: str | None = Field(
default=None, description="OTEL exporter endpoint"
)
log_level: str = Field(default="INFO", description="Log level")
# Performance
max_workers: int = Field(default=4, description="Maximum worker threads")
request_timeout: int = Field(default=30, description="Request timeout in seconds")

108
libs/config/utils.py Normal file
View File

@@ -0,0 +1,108 @@
"""Configuration utility functions and global settings management."""
from typing import Any
import hvac
import redis.asyncio as redis
from minio import Minio
from qdrant_client import QdrantClient
from libs.events.base import EventBus
from .factories import (
MinIOClientFactory,
Neo4jDriverFactory,
QdrantClientFactory,
RedisClientFactory,
VaultClientFactory,
)
from .settings import BaseAppSettings
# Global settings instance
_settings: BaseAppSettings | None = None
def get_settings() -> BaseAppSettings:
"""Get global settings instance"""
global _settings # pylint: disable=global-variable-not-assigned
if _settings is None:
raise RuntimeError("Settings not initialized. Call init_settings() first.")
return _settings
def init_settings(
settings_class: type[BaseAppSettings] = BaseAppSettings, **kwargs: Any
) -> BaseAppSettings:
"""Initialize settings with custom class"""
global _settings # pylint: disable=global-statement
_settings = settings_class(**kwargs)
return _settings
# Convenience functions for backward compatibility
def create_vault_client(settings: BaseAppSettings) -> hvac.Client:
"""Create Vault client"""
return VaultClientFactory.create_client(settings)
def create_minio_client(settings: BaseAppSettings) -> Minio:
"""Create MinIO client"""
return MinIOClientFactory.create_client(settings)
def create_qdrant_client(settings: BaseAppSettings) -> QdrantClient:
"""Create Qdrant client"""
return QdrantClientFactory.create_client(settings)
def create_neo4j_client(settings: BaseAppSettings) -> Any:
"""Create Neo4j driver"""
return Neo4jDriverFactory.create_driver(settings)
async def create_redis_client(settings: BaseAppSettings) -> "redis.Redis[str]":
"""Create Redis client"""
return await RedisClientFactory.create_client(settings)
def create_event_bus(settings: BaseAppSettings) -> EventBus:
"""Create event bus"""
if settings.event_bus_type.lower() == "kafka":
# pylint: disable=import-outside-toplevel
from ..events import KafkaEventBus
return KafkaEventBus(settings.kafka_bootstrap_servers)
if settings.event_bus_type.lower() == "sqs":
# pylint: disable=import-outside-toplevel
from ..events import SQSEventBus
return SQSEventBus(settings.aws_region)
if settings.event_bus_type.lower() == "memory":
# pylint: disable=import-outside-toplevel
from ..events import MemoryEventBus
return MemoryEventBus()
# Default to memory bus for unknown types
# pylint: disable=import-outside-toplevel
from ..events import MemoryEventBus
return MemoryEventBus()
def get_default_settings(**overrides: Any) -> BaseAppSettings:
"""Get default settings with optional overrides"""
defaults = {
"service_name": "default-service",
"vault_addr": "http://vault:8200",
"postgres_url": "postgresql://user:pass@postgres:5432/taxagent",
"neo4j_uri": "bolt://neo4j:7687",
"neo4j_password": "password",
"redis_url": "redis://redis:6379",
"minio_endpoint": "minio:9000",
"minio_access_key": "minioadmin",
"minio_secret_key": "minioadmin",
"qdrant_url": "http://qdrant:6333",
}
defaults.update(overrides)
return BaseAppSettings(**defaults) # type: ignore

View File

@@ -0,0 +1,9 @@
"""Coverage evaluation engine for tax document requirements."""
from .evaluator import CoverageEvaluator
from .utils import check_document_coverage
__all__ = [
"CoverageEvaluator",
"check_document_coverage",
]

418
libs/coverage/evaluator.py Normal file
View File

@@ -0,0 +1,418 @@
"""Core coverage evaluation engine."""
from datetime import datetime
from typing import Any
import structlog
from ..schemas import (
BlockingItem,
Citation,
CompiledCoveragePolicy,
CoverageItem,
CoverageReport,
FoundEvidence,
OverallStatus,
Role,
ScheduleCoverage,
Status,
)
logger = structlog.get_logger()
class CoverageEvaluator:
"""Core coverage evaluation engine"""
def __init__(self, kg_client: Any = None, rag_client: Any = None):
self.kg_client = kg_client
self.rag_client = rag_client
async def check_document_coverage(
self,
taxpayer_id: str,
tax_year: str,
policy: CompiledCoveragePolicy,
) -> CoverageReport:
"""Main coverage evaluation workflow"""
logger.info(
"Starting coverage evaluation",
taxpayer_id=taxpayer_id,
tax_year=tax_year,
policy_version=policy.policy.version,
)
# Step A: Infer required schedules
required_schedules = await self.infer_required_schedules(
taxpayer_id, tax_year, policy
)
# Step B: Evaluate each schedule
schedule_coverage = []
all_blocking_items = []
for schedule_id in required_schedules:
coverage = await self._evaluate_schedule_coverage(
schedule_id, taxpayer_id, tax_year, policy
)
schedule_coverage.append(coverage)
# Collect blocking items
for evidence in coverage.evidence:
if evidence.role == Role.REQUIRED and evidence.status == Status.MISSING:
all_blocking_items.append(
BlockingItem(schedule_id=schedule_id, evidence_id=evidence.id)
)
# Step C: Determine overall status
overall_status = self._determine_overall_status(
schedule_coverage, all_blocking_items
)
return CoverageReport(
tax_year=tax_year,
taxpayer_id=taxpayer_id,
schedules_required=required_schedules,
overall_status=overall_status,
coverage=schedule_coverage,
blocking_items=all_blocking_items,
policy_version=policy.policy.version,
)
async def infer_required_schedules(
self,
taxpayer_id: str,
tax_year: str,
policy: CompiledCoveragePolicy,
) -> list[str]:
"""Determine which schedules are required for this taxpayer"""
required = []
for schedule_id, trigger in policy.policy.triggers.items():
is_required = False
# Check any_of conditions
if trigger.any_of:
for condition in trigger.any_of:
predicate = policy.compiled_predicates.get(condition)
if predicate and predicate(taxpayer_id, tax_year):
is_required = True
break
# Check all_of conditions
if trigger.all_of and not is_required:
all_match = True
for condition in trigger.all_of:
predicate = policy.compiled_predicates.get(condition)
if not predicate or not predicate(taxpayer_id, tax_year):
all_match = False
break
if all_match:
is_required = True
if is_required:
required.append(schedule_id)
logger.debug(
"Schedule required",
schedule_id=schedule_id,
taxpayer_id=taxpayer_id,
)
return required
async def find_evidence_docs(
self,
taxpayer_id: str,
tax_year: str,
evidence_ids: list[str],
policy: CompiledCoveragePolicy,
) -> dict[str, list[FoundEvidence]]:
"""Find evidence documents in the knowledge graph"""
if not self.kg_client:
logger.warning("No KG client available, returning empty evidence")
empty_evidence_list: list[FoundEvidence] = []
return dict.fromkeys(evidence_ids, empty_evidence_list)
# Import here to avoid circular imports
from ..neo import kg_find_evidence
evidence_map: dict[str, list[FoundEvidence]] = {}
thresholds = policy.policy.defaults.confidence_thresholds
for evidence_id in evidence_ids:
try:
found = await kg_find_evidence(
self.kg_client,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
kinds=[evidence_id],
min_ocr=thresholds.get("ocr", 0.6),
date_window=policy.policy.defaults.date_tolerance_days,
)
evidence_map[evidence_id] = found
except Exception as e:
logger.error(
"Failed to find evidence",
evidence_id=evidence_id,
error=str(e),
)
empty_list: list[FoundEvidence] = []
evidence_map[evidence_id] = empty_list
return evidence_map
def classify_status(
self,
found: list[FoundEvidence],
policy: CompiledCoveragePolicy,
tax_year: str,
) -> Status:
"""Classify evidence status based on what was found"""
if not found:
return Status.MISSING
classifier = policy.policy.status_classifier
tax_year_start, tax_year_end = self._parse_tax_year_bounds(
policy.policy.tax_year_boundary.start,
policy.policy.tax_year_boundary.end,
)
# Check for conflicts first
if len(found) > 1:
# Simple conflict detection: different totals for same period
# In production, this would be more sophisticated
return Status.CONFLICTING
evidence = found[0]
# Check if evidence meets verified criteria
if (
evidence.ocr_confidence >= classifier.present_verified.min_ocr
and evidence.extract_confidence >= classifier.present_verified.min_extract
):
# Check date validity
if evidence.date:
# Handle both date-only and datetime strings consistently
if "T" not in evidence.date:
# Date-only string, add time and timezone (middle of day)
evidence_date = datetime.fromisoformat(
evidence.date + "T12:00:00+00:00"
)
else:
# Full datetime string, ensure timezone-aware
evidence_date = datetime.fromisoformat(
evidence.date.replace("Z", "+00:00")
)
if tax_year_start <= evidence_date <= tax_year_end:
return Status.PRESENT_VERIFIED
# Check if evidence meets unverified criteria
if (
evidence.ocr_confidence >= classifier.present_unverified.min_ocr
and evidence.extract_confidence >= classifier.present_unverified.min_extract
):
return Status.PRESENT_UNVERIFIED
# Default to missing if confidence too low
return Status.MISSING
async def build_reason_and_citations(
self,
schedule_id: str,
evidence_item: Any,
status: Status,
taxpayer_id: str,
tax_year: str,
policy: CompiledCoveragePolicy,
) -> tuple[str, list[Citation]]:
"""Build human-readable reason and citations"""
# Build reason text
reason = self._build_reason_text(evidence_item, status, policy)
# Get citations from KG
citations = []
if self.kg_client:
try:
from ..neo import kg_rule_citations
kg_citations = await kg_rule_citations(
self.kg_client, schedule_id, evidence_item.boxes
)
citations.extend(kg_citations)
except Exception as e:
logger.warning("Failed to get KG citations", error=str(e))
# Fallback to RAG citations if needed
if not citations and self.rag_client:
try:
from ..rag import rag_search_for_citations
query = f"{schedule_id} {evidence_item.id} requirements"
filters = {
"jurisdiction": policy.policy.jurisdiction,
"tax_year": tax_year,
"pii_free": True,
}
rag_citations = await rag_search_for_citations(
self.rag_client, query, filters
)
citations.extend(rag_citations)
except Exception as e:
logger.warning("Failed to get RAG citations", error=str(e))
return reason, citations
async def _evaluate_schedule_coverage(
self,
schedule_id: str,
taxpayer_id: str,
tax_year: str,
policy: CompiledCoveragePolicy,
) -> ScheduleCoverage:
"""Evaluate coverage for a single schedule"""
schedule_policy = policy.policy.schedules[schedule_id]
evidence_items = []
# Get all evidence IDs for this schedule
evidence_ids = [e.id for e in schedule_policy.evidence]
# Find evidence in KG
evidence_map = await self.find_evidence_docs(
taxpayer_id, tax_year, evidence_ids, policy
)
# Evaluate each evidence requirement
for evidence_req in schedule_policy.evidence:
# Check if conditionally required evidence applies
if (
evidence_req.role == Role.CONDITIONALLY_REQUIRED
and evidence_req.condition
):
predicate = policy.compiled_predicates.get(evidence_req.condition)
if not predicate or not predicate(taxpayer_id, tax_year):
continue # Skip this evidence as condition not met
found = evidence_map.get(evidence_req.id, [])
status = self.classify_status(found, policy, tax_year)
reason, citations = await self.build_reason_and_citations(
schedule_id, evidence_req, status, taxpayer_id, tax_year, policy
)
evidence_item = CoverageItem(
id=evidence_req.id,
role=evidence_req.role,
status=status,
boxes=evidence_req.boxes,
found=found,
acceptable_alternatives=evidence_req.acceptable_alternatives,
reason=reason,
citations=citations,
)
evidence_items.append(evidence_item)
# Determine schedule status
schedule_status = self._determine_schedule_status(evidence_items)
return ScheduleCoverage(
schedule_id=schedule_id,
status=schedule_status,
evidence=evidence_items,
)
def _determine_overall_status(
self,
schedule_coverage: list[ScheduleCoverage],
blocking_items: list[BlockingItem],
) -> OverallStatus:
"""Determine overall coverage status"""
if blocking_items:
return OverallStatus.BLOCKING
# Check if all schedules are OK
all_ok = all(s.status == OverallStatus.OK for s in schedule_coverage)
if all_ok:
return OverallStatus.OK
return OverallStatus.PARTIAL
def _determine_schedule_status(
self, evidence_items: list[CoverageItem]
) -> OverallStatus:
"""Determine status for a single schedule"""
# Check for blocking issues
has_missing_required = any(
e.role == Role.REQUIRED and e.status == Status.MISSING
for e in evidence_items
)
if has_missing_required:
return OverallStatus.BLOCKING
# Check for partial issues
has_unverified = any(
e.status == Status.PRESENT_UNVERIFIED for e in evidence_items
)
if has_unverified:
return OverallStatus.PARTIAL
return OverallStatus.OK
def _build_reason_text(
self,
evidence_item: Any,
status: Status,
policy: CompiledCoveragePolicy,
) -> str:
"""Build human-readable reason text"""
evidence_id = evidence_item.id
# Get reason from policy if available
if evidence_item.reasons and "short" in evidence_item.reasons:
base_reason = evidence_item.reasons["short"]
else:
base_reason = f"{evidence_id} is required for this schedule."
# Add status-specific details
if status == Status.MISSING:
return f"No {evidence_id} found. {base_reason}"
elif status == Status.PRESENT_UNVERIFIED:
return (
f"{evidence_id} present but confidence below threshold. {base_reason}"
)
elif status == Status.CONFLICTING:
return f"Conflicting {evidence_id} documents found. {base_reason}"
else:
return f"{evidence_id} verified. {base_reason}"
def _parse_tax_year_bounds(
self, start_str: str, end_str: str
) -> tuple[datetime, datetime]:
"""Parse tax year boundary strings to datetime objects"""
# Handle both date-only and datetime strings
if "T" not in start_str:
# Date-only string, add time and timezone
start = datetime.fromisoformat(start_str + "T00:00:00+00:00")
else:
# Full datetime string, ensure timezone-aware
start = datetime.fromisoformat(start_str.replace("Z", "+00:00"))
if "T" not in end_str:
# Date-only string, add time and timezone (end of day)
end = datetime.fromisoformat(end_str + "T23:59:59+00:00")
else:
# Full datetime string, ensure timezone-aware
end = datetime.fromisoformat(end_str.replace("Z", "+00:00"))
return start, end

18
libs/coverage/utils.py Normal file
View File

@@ -0,0 +1,18 @@
"""Utility functions for coverage evaluation."""
from typing import Any
from ..schemas import CompiledCoveragePolicy, CoverageReport
from .evaluator import CoverageEvaluator
async def check_document_coverage(
taxpayer_id: str,
tax_year: str,
policy: CompiledCoveragePolicy,
kg_client: Any = None,
rag_client: Any = None,
) -> CoverageReport:
"""Check document coverage for taxpayer"""
evaluator = CoverageEvaluator(kg_client, rag_client)
return await evaluator.check_document_coverage(taxpayer_id, tax_year, policy)

336
libs/coverage_schema.json Normal file
View File

@@ -0,0 +1,336 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "Coverage Policy Schema",
"type": "object",
"required": [
"version",
"jurisdiction",
"tax_year",
"tax_year_boundary",
"defaults",
"document_kinds",
"triggers",
"schedules",
"status_classifier",
"conflict_resolution",
"question_templates"
],
"properties": {
"version": {
"type": "string",
"pattern": "^\\d+\\.\\d+$"
},
"jurisdiction": {
"type": "string",
"enum": ["UK", "US", "CA", "AU"]
},
"tax_year": {
"type": "string",
"pattern": "^\\d{4}-\\d{2}$"
},
"tax_year_boundary": {
"type": "object",
"required": ["start", "end"],
"properties": {
"start": {
"type": "string",
"format": "date"
},
"end": {
"type": "string",
"format": "date"
}
}
},
"defaults": {
"type": "object",
"required": ["confidence_thresholds"],
"properties": {
"confidence_thresholds": {
"type": "object",
"properties": {
"ocr": {
"type": "number",
"minimum": 0,
"maximum": 1
},
"extract": {
"type": "number",
"minimum": 0,
"maximum": 1
}
}
},
"date_tolerance_days": {
"type": "integer",
"minimum": 0
},
"require_lineage_bbox": {
"type": "boolean"
},
"allow_bank_substantiation": {
"type": "boolean"
}
}
},
"document_kinds": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"minItems": 1,
"uniqueItems": true
},
"guidance_refs": {
"type": "object",
"patternProperties": {
"^[A-Z0-9_]+$": {
"type": "object",
"required": ["doc_id", "kind"],
"properties": {
"doc_id": {
"type": "string",
"minLength": 1
},
"kind": {
"type": "string",
"minLength": 1
}
}
}
}
},
"triggers": {
"type": "object",
"patternProperties": {
"^SA\\d+[A-Z]*$": {
"type": "object",
"properties": {
"any_of": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
}
},
"all_of": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
}
}
},
"anyOf": [
{"required": ["any_of"]},
{"required": ["all_of"]}
]
}
}
},
"schedules": {
"type": "object",
"patternProperties": {
"^SA\\d+[A-Z]*$": {
"type": "object",
"properties": {
"guidance_hint": {
"type": "string"
},
"evidence": {
"type": "array",
"items": {
"type": "object",
"required": ["id", "role"],
"properties": {
"id": {
"type": "string",
"minLength": 1
},
"role": {
"type": "string",
"enum": ["REQUIRED", "CONDITIONALLY_REQUIRED", "OPTIONAL"]
},
"condition": {
"type": "string"
},
"boxes": {
"type": "array",
"items": {
"type": "string",
"pattern": "^SA\\d+[A-Z]*_b\\d+(_\\d+)?$"
},
"minItems": 0
},
"acceptable_alternatives": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
}
},
"validity": {
"type": "object",
"properties": {
"within_tax_year": {
"type": "boolean"
},
"available_by": {
"type": "string",
"format": "date"
}
}
},
"reasons": {
"type": "object",
"properties": {
"short": {
"type": "string"
}
}
}
}
}
},
"cross_checks": {
"type": "array",
"items": {
"type": "object",
"required": ["name", "logic"],
"properties": {
"name": {
"type": "string",
"minLength": 1
},
"logic": {
"type": "string",
"minLength": 1
}
}
}
},
"selection_rule": {
"type": "object"
},
"notes": {
"type": "object"
}
}
}
}
},
"status_classifier": {
"type": "object",
"required": ["present_verified", "present_unverified", "conflicting", "missing"],
"properties": {
"present_verified": {
"$ref": "#/definitions/statusClassifier"
},
"present_unverified": {
"$ref": "#/definitions/statusClassifier"
},
"conflicting": {
"$ref": "#/definitions/statusClassifier"
},
"missing": {
"$ref": "#/definitions/statusClassifier"
}
}
},
"conflict_resolution": {
"type": "object",
"required": ["precedence"],
"properties": {
"precedence": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
},
"minItems": 1
},
"escalation": {
"type": "object"
}
}
},
"question_templates": {
"type": "object",
"required": ["default"],
"properties": {
"default": {
"type": "object",
"required": ["text", "why"],
"properties": {
"text": {
"type": "string",
"minLength": 1
},
"why": {
"type": "string",
"minLength": 1
}
}
},
"reasons": {
"type": "object",
"patternProperties": {
"^[A-Za-z0-9_]+$": {
"type": "string",
"minLength": 1
}
}
}
}
},
"privacy": {
"type": "object",
"properties": {
"vector_pii_free": {
"type": "boolean"
},
"redact_patterns": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
}
}
}
}
},
"definitions": {
"statusClassifier": {
"type": "object",
"properties": {
"min_ocr": {
"type": "number",
"minimum": 0,
"maximum": 1
},
"min_extract": {
"type": "number",
"minimum": 0,
"maximum": 1
},
"date_in_year": {
"type": "boolean"
},
"date_in_year_or_tolerance": {
"type": "boolean"
},
"conflict_rules": {
"type": "array",
"items": {
"type": "string",
"minLength": 1
}
},
"default": {
"type": "boolean"
}
}
}
}
}

282
libs/events/NATS_README.md Normal file
View File

@@ -0,0 +1,282 @@
# NATS.io Event Bus with JetStream
This document describes the NATS.io event bus implementation with JetStream support for the AI Tax Agent project.
## Overview
The `NATSEventBus` class provides a robust, scalable event streaming solution using NATS.io with JetStream for persistent messaging. It implements the same `EventBus` interface as other event bus implementations (Kafka, SQS, Memory) for consistency.
## Features
- **JetStream Integration**: Uses NATS JetStream for persistent, reliable message delivery
- **Automatic Stream Management**: Creates and manages JetStream streams automatically
- **Pull-based Consumers**: Uses pull-based consumers for better flow control
- **Cluster Support**: Supports NATS cluster configurations for high availability
- **Error Handling**: Comprehensive error handling with automatic retries
- **Message Acknowledgment**: Explicit message acknowledgment with configurable retry policies
- **Durable Consumers**: Creates durable consumers for guaranteed message processing
## Configuration
### Basic Configuration
```python
from libs.events import NATSEventBus
# Single server
bus = NATSEventBus(
servers="nats://localhost:4222",
stream_name="TAX_AGENT_EVENTS",
consumer_group="tax-agent"
)
# Multiple servers (cluster)
bus = NATSEventBus(
servers=[
"nats://nats1.example.com:4222",
"nats://nats2.example.com:4222",
"nats://nats3.example.com:4222"
],
stream_name="PRODUCTION_EVENTS",
consumer_group="tax-agent-prod"
)
```
### Factory Configuration
```python
from libs.events import create_event_bus
bus = create_event_bus(
"nats",
servers="nats://localhost:4222",
stream_name="TAX_AGENT_EVENTS",
consumer_group="tax-agent"
)
```
## Usage
### Publishing Events
```python
from libs.events import EventPayload
# Create event payload
payload = EventPayload(
data={"user_id": "123", "action": "login"},
actor="user-service",
tenant_id="tenant-456",
trace_id="trace-789"
)
# Publish event
success = await bus.publish("user.login", payload)
if success:
print("Event published successfully")
```
### Subscribing to Events
```python
async def handle_user_login(topic: str, payload: EventPayload) -> None:
print(f"User {payload.data['user_id']} logged in")
# Process the event...
# Subscribe to topic
await bus.subscribe("user.login", handle_user_login)
```
### Complete Example
```python
import asyncio
from libs.events import NATSEventBus, EventPayload
async def main():
bus = NATSEventBus()
try:
# Start the bus
await bus.start()
# Subscribe to events
await bus.subscribe("user.created", handle_user_created)
# Publish an event
payload = EventPayload(
data={"user_id": "123", "email": "user@example.com"},
actor="registration-service",
tenant_id="tenant-456"
)
await bus.publish("user.created", payload)
# Wait for processing
await asyncio.sleep(1)
finally:
await bus.stop()
asyncio.run(main())
```
## JetStream Configuration
The NATS event bus automatically creates and configures JetStream streams with the following settings:
- **Retention Policy**: Work Queue (messages are removed after acknowledgment)
- **Max Age**: 7 days (messages older than 7 days are automatically deleted)
- **Storage**: File-based storage for persistence
- **Subject Pattern**: `{stream_name}.*` (e.g., `TAX_AGENT_EVENTS.*`)
### Consumer Configuration
- **Durable Consumers**: Each topic subscription creates a durable consumer
- **Ack Policy**: Explicit acknowledgment required
- **Deliver Policy**: New messages only (doesn't replay old messages)
- **Max Deliver**: 3 attempts before message is considered failed
- **Ack Wait**: 30 seconds timeout for acknowledgment
## Error Handling
The NATS event bus includes comprehensive error handling:
### Publishing Errors
- Network failures are logged and return `False`
- Automatic retry logic can be implemented at the application level
### Consumer Errors
- Handler exceptions are caught and logged
- Failed messages are negatively acknowledged (NAK) for retry
- Messages that fail multiple times are moved to a dead letter queue (if configured)
### Connection Errors
- Automatic reconnection is handled by the NATS client
- Consumer tasks are gracefully shut down on connection loss
## Monitoring and Observability
The implementation includes structured logging with the following information:
- Event publishing success/failure
- Consumer subscription status
- Message processing metrics
- Error details and stack traces
### Log Examples
```
INFO: Event published topic=user.created event_id=01HK... stream_seq=123
INFO: Subscribed to topic topic=user.login consumer=tax-agent-user.login
ERROR: Handler failed topic=user.created event_id=01HK... error=...
```
## Performance Considerations
### Throughput
- Pull-based consumers allow for controlled message processing
- Batch fetching (up to 10 messages per fetch) improves throughput
- Async processing enables high concurrency
### Memory Usage
- File-based storage keeps memory usage low
- Configurable message retention prevents unbounded growth
### Network Efficiency
- Binary protocol with minimal overhead
- Connection pooling and reuse
- Efficient subject-based routing
## Deployment
### Docker Compose Example
```yaml
services:
nats:
image: nats:2.10-alpine
ports:
- "4222:4222"
- "8222:8222"
command:
- "--jetstream"
- "--store_dir=/data"
- "--http_port=8222"
volumes:
- nats_data:/data
volumes:
nats_data:
```
### Kubernetes Example
```yaml
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: nats
spec:
serviceName: nats
replicas: 3
selector:
matchLabels:
app: nats
template:
metadata:
labels:
app: nats
spec:
containers:
- name: nats
image: nats:2.10-alpine
args:
- "--cluster_name=nats-cluster"
- "--jetstream"
- "--store_dir=/data"
ports:
- containerPort: 4222
- containerPort: 6222
- containerPort: 8222
volumeMounts:
- name: nats-storage
mountPath: /data
volumeClaimTemplates:
- metadata:
name: nats-storage
spec:
accessModes: ["ReadWriteOnce"]
resources:
requests:
storage: 10Gi
```
## Dependencies
The NATS event bus requires the following Python package:
```
nats-py>=2.6.0
```
This is automatically included in `libs/requirements.txt`.
## Comparison with Other Event Buses
| Feature | NATS | Kafka | SQS |
|---------|------|-------|-----|
| Setup Complexity | Low | Medium | Low |
| Throughput | High | Very High | Medium |
| Latency | Very Low | Low | Medium |
| Persistence | Yes (JetStream) | Yes | Yes |
| Ordering | Per Subject | Per Partition | FIFO Queues |
| Clustering | Built-in | Built-in | Managed |
| Operational Overhead | Low | High | None |
## Best Practices
1. **Use meaningful subject names**: Follow a hierarchical naming convention (e.g., `service.entity.action`)
2. **Handle failures gracefully**: Implement proper error handling in event handlers
3. **Monitor consumer lag**: Track message processing delays
4. **Use appropriate retention**: Configure message retention based on business requirements
5. **Test failure scenarios**: Verify behavior during network partitions and service failures

20
libs/events/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
"""Event-driven architecture with Kafka, SQS, NATS, and Memory support."""
from .base import EventBus, EventPayload
from .factory import create_event_bus
from .kafka_bus import KafkaEventBus
from .memory_bus import MemoryEventBus
from .nats_bus import NATSEventBus
from .sqs_bus import SQSEventBus
from .topics import EventTopics
__all__ = [
"EventPayload",
"EventBus",
"KafkaEventBus",
"MemoryEventBus",
"NATSEventBus",
"SQSEventBus",
"create_event_bus",
"EventTopics",
]

68
libs/events/base.py Normal file
View File

@@ -0,0 +1,68 @@
"""Base event classes and interfaces."""
import json
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from datetime import datetime
from typing import Any
import ulid
# Each payload MUST include: `event_id (ulid)`, `occurred_at (iso)`, `actor`, `tenant_id`, `trace_id`, `schema_version`, and a `data` object (service-specific).
class EventPayload:
"""Standard event payload structure"""
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
data: dict[str, Any],
actor: str,
tenant_id: str,
trace_id: str | None = None,
schema_version: str = "1.0",
):
self.event_id = str(ulid.new())
self.occurred_at = datetime.utcnow().isoformat() + "Z"
self.actor = actor
self.tenant_id = tenant_id
self.trace_id = trace_id
self.schema_version = schema_version
self.data = data
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization"""
return {
"event_id": self.event_id,
"occurred_at": self.occurred_at,
"actor": self.actor,
"tenant_id": self.tenant_id,
"trace_id": self.trace_id,
"schema_version": self.schema_version,
"data": self.data,
}
def to_json(self) -> str:
"""Convert to JSON string"""
return json.dumps(self.to_dict())
class EventBus(ABC):
"""Abstract event bus interface"""
@abstractmethod
async def publish(self, topic: str, payload: EventPayload) -> bool:
"""Publish event to topic"""
@abstractmethod
async def subscribe(
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
) -> None:
"""Subscribe to topic with handler"""
@abstractmethod
async def start(self) -> None:
"""Start the event bus"""
@abstractmethod
async def stop(self) -> None:
"""Stop the event bus"""

View File

@@ -0,0 +1,163 @@
"""Example usage of NATS.io event bus with JetStream."""
import asyncio
import logging
from libs.events import EventPayload, NATSEventBus, create_event_bus
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def example_handler(topic: str, payload: EventPayload) -> None:
"""Example event handler."""
logger.info(
f"Received event on topic '{topic}': "
f"ID={payload.event_id}, "
f"Actor={payload.actor}, "
f"Data={payload.data}"
)
async def main():
"""Main example function."""
# Method 1: Direct instantiation
nats_bus = NATSEventBus(
servers="nats://localhost:4222", # Can be a list for cluster
stream_name="TAX_AGENT_EVENTS",
consumer_group="tax-agent",
)
# Method 2: Using factory
# nats_bus = create_event_bus(
# "nats",
# servers="nats://localhost:4222",
# stream_name="TAX_AGENT_EVENTS",
# consumer_group="tax-agent",
# )
try:
# Start the event bus
await nats_bus.start()
logger.info("NATS event bus started")
# Subscribe to a topic
await nats_bus.subscribe("user.created", example_handler)
await nats_bus.subscribe("user.updated", example_handler)
logger.info("Subscribed to topics")
# Publish some events
for i in range(5):
payload = EventPayload(
data={"user_id": f"user-{i}", "name": f"User {i}"},
actor="system",
tenant_id="tenant-123",
trace_id=f"trace-{i}",
)
success = await nats_bus.publish("user.created", payload)
if success:
logger.info(f"Published event {i}")
else:
logger.error(f"Failed to publish event {i}")
# Wait a bit for messages to be processed
await asyncio.sleep(2)
# Publish an update event
update_payload = EventPayload(
data={"user_id": "user-1", "name": "Updated User 1", "email": "user1@example.com"},
actor="admin",
tenant_id="tenant-123",
)
await nats_bus.publish("user.updated", update_payload)
logger.info("Published update event")
# Wait for processing
await asyncio.sleep(2)
except Exception as e:
logger.error(f"Error in example: {e}")
finally:
# Stop the event bus
await nats_bus.stop()
logger.info("NATS event bus stopped")
async def cluster_example():
"""Example with NATS cluster configuration."""
# Connect to a NATS cluster
cluster_bus = NATSEventBus(
servers=[
"nats://nats1.example.com:4222",
"nats://nats2.example.com:4222",
"nats://nats3.example.com:4222",
],
stream_name="PRODUCTION_EVENTS",
consumer_group="tax-agent-prod",
)
try:
await cluster_bus.start()
logger.info("Connected to NATS cluster")
# Subscribe to multiple topics
topics = ["document.uploaded", "document.processed", "tax.calculated"]
for topic in topics:
await cluster_bus.subscribe(topic, example_handler)
logger.info(f"Subscribed to {len(topics)} topics")
# Keep running for a while
await asyncio.sleep(10)
finally:
await cluster_bus.stop()
async def error_handling_example():
"""Example showing error handling."""
async def failing_handler(topic: str, payload: EventPayload) -> None:
"""Handler that sometimes fails."""
if payload.data.get("should_fail"):
raise ValueError("Simulated handler failure")
logger.info(f"Successfully processed event {payload.event_id}")
bus = NATSEventBus()
try:
await bus.start()
await bus.subscribe("test.events", failing_handler)
# Publish a good event
good_payload = EventPayload(
data={"message": "This will succeed"},
actor="test",
tenant_id="test-tenant",
)
await bus.publish("test.events", good_payload)
# Publish a bad event
bad_payload = EventPayload(
data={"message": "This will fail", "should_fail": True},
actor="test",
tenant_id="test-tenant",
)
await bus.publish("test.events", bad_payload)
await asyncio.sleep(2)
finally:
await bus.stop()
if __name__ == "__main__":
# Run the basic example
asyncio.run(main())
# Uncomment to run other examples:
# asyncio.run(cluster_example())
# asyncio.run(error_handling_example())

23
libs/events/factory.py Normal file
View File

@@ -0,0 +1,23 @@
"""Factory function for creating event bus instances."""
from typing import Any
from .base import EventBus
from .kafka_bus import KafkaEventBus
from .nats_bus import NATSEventBus
from .sqs_bus import SQSEventBus
def create_event_bus(bus_type: str, **kwargs: Any) -> EventBus:
"""Factory function to create event bus"""
if bus_type.lower() == "kafka":
return KafkaEventBus(kwargs.get("bootstrap_servers", "localhost:9092"))
if bus_type.lower() == "sqs":
return SQSEventBus(kwargs.get("region_name", "us-east-1"))
if bus_type.lower() == "nats":
return NATSEventBus(
servers=kwargs.get("servers", "nats://localhost:4222"),
stream_name=kwargs.get("stream_name", "TAX_AGENT_EVENTS"),
consumer_group=kwargs.get("consumer_group", "tax-agent"),
)
raise ValueError(f"Unsupported event bus type: {bus_type}")

140
libs/events/kafka_bus.py Normal file
View File

@@ -0,0 +1,140 @@
"""Kafka implementation of EventBus."""
import asyncio
import json
from collections.abc import Awaitable, Callable
import structlog
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer # type: ignore
from .base import EventBus, EventPayload
logger = structlog.get_logger()
class KafkaEventBus(EventBus):
"""Kafka implementation of EventBus"""
def __init__(self, bootstrap_servers: str):
self.bootstrap_servers = bootstrap_servers.split(",")
self.producer: AIOKafkaProducer | None = None
self.consumers: dict[str, AIOKafkaConsumer] = {}
self.handlers: dict[
str, list[Callable[[str, EventPayload], Awaitable[None]]]
] = {}
self.running = False
async def start(self) -> None:
"""Start Kafka producer"""
if self.running:
return
self.producer = AIOKafkaProducer(
bootstrap_servers=",".join(self.bootstrap_servers),
value_serializer=lambda v: v.encode("utf-8"),
)
await self.producer.start()
self.running = True
logger.info("Kafka event bus started", bootstrap_servers=self.bootstrap_servers)
async def stop(self) -> None:
"""Stop Kafka producer and consumers"""
if not self.running:
return
if self.producer:
await self.producer.stop()
for consumer in self.consumers.values():
await consumer.stop()
self.running = False
logger.info("Kafka event bus stopped")
async def publish(self, topic: str, payload: EventPayload) -> bool:
"""Publish event to Kafka topic"""
if not self.producer:
raise RuntimeError("Event bus not started")
try:
await self.producer.send_and_wait(topic, payload.to_json())
logger.info(
"Event published",
topic=topic,
event_id=payload.event_id,
actor=payload.actor,
tenant_id=payload.tenant_id,
)
return True
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to publish event",
topic=topic,
event_id=payload.event_id,
error=str(e),
)
return False
async def subscribe(
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
) -> None:
"""Subscribe to Kafka topic"""
if topic not in self.handlers:
self.handlers[topic] = []
self.handlers[topic].append(handler)
if topic not in self.consumers:
consumer = AIOKafkaConsumer(
topic,
bootstrap_servers=",".join(self.bootstrap_servers),
value_deserializer=lambda m: m.decode("utf-8"),
group_id=f"tax-agent-{topic}",
auto_offset_reset="latest",
)
self.consumers[topic] = consumer
await consumer.start()
# Start consumer task
asyncio.create_task(self._consume_messages(topic, consumer))
logger.info("Subscribed to topic", topic=topic)
async def _consume_messages(self, topic: str, consumer: AIOKafkaConsumer) -> None:
"""Consume messages from Kafka topic"""
try:
async for message in consumer:
try:
if message.value is not None:
payload_dict = json.loads(message.value)
else:
continue
payload = EventPayload(
data=payload_dict["data"],
actor=payload_dict["actor"],
tenant_id=payload_dict["tenant_id"],
trace_id=payload_dict.get("trace_id"),
schema_version=payload_dict.get("schema_version", "1.0"),
)
payload.event_id = payload_dict["event_id"]
payload.occurred_at = payload_dict["occurred_at"]
# Call all handlers for this topic
for handler in self.handlers.get(topic, []):
try:
await handler(topic, payload)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Handler failed",
topic=topic,
event_id=payload.event_id,
handler=handler.__name__,
error=str(e),
)
except json.JSONDecodeError as e:
logger.error("Failed to decode message", topic=topic, error=str(e))
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Failed to process message", topic=topic, error=str(e))
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Consumer error", topic=topic, error=str(e))

64
libs/events/memory_bus.py Normal file
View File

@@ -0,0 +1,64 @@
"""In-memory event bus for local development and testing."""
import asyncio
import logging
from collections import defaultdict
from collections.abc import Awaitable, Callable
from .base import EventBus, EventPayload
logger = logging.getLogger(__name__)
class MemoryEventBus(EventBus):
"""In-memory event bus implementation for local development"""
def __init__(self) -> None:
self.handlers: dict[
str, list[Callable[[str, EventPayload], Awaitable[None]]]
] = defaultdict(list)
self.running = False
async def publish(self, topic: str, payload: EventPayload) -> bool:
"""Publish event to topic"""
try:
if not self.running:
logger.warning(
"Event bus not running, skipping publish to topic: %s", topic
)
return False
handlers = self.handlers.get(topic, [])
if not handlers:
logger.debug("No handlers for topic: %s", topic)
return True
# Execute all handlers concurrently
tasks = [handler(topic, payload) for handler in handlers]
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug(
"Published event to topic %s with %d handlers", topic, len(handlers)
)
return True
except Exception as e:
logger.error("Failed to publish event to topic %s: %s", topic, e)
return False
async def subscribe(
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
) -> None:
"""Subscribe to topic with handler"""
self.handlers[topic].append(handler)
logger.debug("Subscribed handler to topic: %s", topic)
async def start(self) -> None:
"""Start the event bus"""
self.running = True
logger.info("Memory event bus started")
async def stop(self) -> None:
"""Stop the event bus"""
self.running = False
self.handlers.clear()
logger.info("Memory event bus stopped")

269
libs/events/nats_bus.py Normal file
View File

@@ -0,0 +1,269 @@
"""NATS.io with JetStream implementation of EventBus."""
import asyncio
import json
from collections.abc import Awaitable, Callable
from typing import Any
import nats # type: ignore
import structlog
from nats.aio.client import Client as NATS # type: ignore
from nats.js import JetStreamContext # type: ignore
from .base import EventBus, EventPayload
logger = structlog.get_logger()
class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
"""NATS.io with JetStream implementation of EventBus"""
def __init__(
self,
servers: str | list[str] = "nats://localhost:4222",
stream_name: str = "TAX_AGENT_EVENTS",
consumer_group: str = "tax-agent",
):
if isinstance(servers, str):
self.servers = [servers]
else:
self.servers = servers
self.stream_name = stream_name
self.consumer_group = consumer_group
self.nc: NATS | None = None
self.js: JetStreamContext | None = None
self.handlers: dict[
str, list[Callable[[str, EventPayload], Awaitable[None]]]
] = {}
self.subscriptions: dict[str, Any] = {}
self.running = False
self.consumer_tasks: list[asyncio.Task[None]] = []
async def start(self) -> None:
"""Start NATS connection and JetStream context"""
if self.running:
return
try:
# Connect to NATS
self.nc = await nats.connect(servers=self.servers)
# Get JetStream context
self.js = self.nc.jetstream()
# Ensure stream exists
await self._ensure_stream_exists()
self.running = True
logger.info(
"NATS event bus started",
servers=self.servers,
stream=self.stream_name,
)
except Exception as e:
logger.error("Failed to start NATS event bus", error=str(e))
raise
async def stop(self) -> None:
"""Stop NATS connection and consumers"""
if not self.running:
return
# Cancel consumer tasks
for task in self.consumer_tasks:
task.cancel()
if self.consumer_tasks:
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
# Unsubscribe from all subscriptions
for subscription in self.subscriptions.values():
try:
await subscription.unsubscribe()
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Error unsubscribing", error=str(e))
# Close NATS connection
if self.nc:
await self.nc.close()
self.running = False
logger.info("NATS event bus stopped")
async def publish(self, topic: str, payload: EventPayload) -> bool:
"""Publish event to NATS JetStream"""
if not self.js:
raise RuntimeError("Event bus not started")
try:
# Create subject name from topic
subject = f"{self.stream_name}.{topic}"
# Publish message with headers
headers = {
"event_id": payload.event_id,
"tenant_id": payload.tenant_id,
"actor": payload.actor,
"trace_id": payload.trace_id or "",
"schema_version": payload.schema_version,
}
ack = await self.js.publish(
subject=subject,
payload=payload.to_json().encode(),
headers=headers,
)
logger.info(
"Event published",
topic=topic,
subject=subject,
event_id=payload.event_id,
stream_seq=ack.seq,
)
return True
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to publish event",
topic=topic,
event_id=payload.event_id,
error=str(e),
)
return False
async def subscribe(
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
) -> None:
"""Subscribe to NATS JetStream topic"""
if not self.js:
raise RuntimeError("Event bus not started")
if topic not in self.handlers:
self.handlers[topic] = []
self.handlers[topic].append(handler)
if topic not in self.subscriptions:
try:
# Create subject pattern for topic
subject = f"{self.stream_name}.{topic}"
# Create durable consumer
consumer_name = f"{self.consumer_group}-{topic}"
# Subscribe with pull-based consumer
subscription = await self.js.pull_subscribe(
subject=subject,
durable=consumer_name,
config=nats.js.api.ConsumerConfig(
durable_name=consumer_name,
ack_policy=nats.js.api.AckPolicy.EXPLICIT,
deliver_policy=nats.js.api.DeliverPolicy.NEW,
max_deliver=3,
ack_wait=30, # 30 seconds
),
)
self.subscriptions[topic] = subscription
# Start consumer task
task = asyncio.create_task(self._consume_messages(topic, subscription))
self.consumer_tasks.append(task)
logger.info(
"Subscribed to topic",
topic=topic,
subject=subject,
consumer=consumer_name,
)
except Exception as e:
logger.error("Failed to subscribe to topic", topic=topic, error=str(e))
raise
async def _ensure_stream_exists(self) -> None:
"""Ensure JetStream stream exists"""
if not self.js:
return
try:
# Try to get stream info
await self.js.stream_info(self.stream_name)
logger.debug("Stream already exists", stream=self.stream_name)
except nats.js.errors.NotFoundError:
# Stream doesn't exist, create it
try:
await self.js.add_stream(
name=self.stream_name,
subjects=[f"{self.stream_name}.*"],
retention=nats.js.api.RetentionPolicy.WORK_QUEUE,
max_age=7 * 24 * 60 * 60, # 7 days in seconds
storage=nats.js.api.StorageType.FILE,
)
logger.info("Created JetStream stream", stream=self.stream_name)
except Exception as e:
logger.error(
"Failed to create stream", stream=self.stream_name, error=str(e)
)
raise
async def _consume_messages(self, topic: str, subscription: Any) -> None:
"""Consume messages from NATS JetStream subscription"""
while self.running:
try:
# Fetch messages in batches
messages = await subscription.fetch(batch=10, timeout=20)
for message in messages:
try:
# Parse message payload
payload_dict = json.loads(message.data.decode())
payload = EventPayload(
data=payload_dict["data"],
actor=payload_dict["actor"],
tenant_id=payload_dict["tenant_id"],
trace_id=payload_dict.get("trace_id"),
schema_version=payload_dict.get("schema_version", "1.0"),
)
payload.event_id = payload_dict["event_id"]
payload.occurred_at = payload_dict["occurred_at"]
# Call all handlers for this topic
for handler in self.handlers.get(topic, []):
try:
await handler(topic, payload)
except (
Exception
) as e: # pylint: disable=broad-exception-caught
logger.error(
"Handler failed",
topic=topic,
event_id=payload.event_id,
error=str(e),
)
# Acknowledge message
await message.ack()
except json.JSONDecodeError as e:
logger.error(
"Failed to decode message", topic=topic, error=str(e)
)
await message.nak()
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to process message", topic=topic, error=str(e)
)
await message.nak()
except asyncio.TimeoutError:
# No messages available, continue polling
continue
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Consumer error", topic=topic, error=str(e))
await asyncio.sleep(5) # Wait before retrying

212
libs/events/sqs_bus.py Normal file
View File

@@ -0,0 +1,212 @@
"""AWS SQS/SNS implementation of EventBus."""
import asyncio
import json
from collections.abc import Awaitable, Callable
from typing import Any
import boto3 # type: ignore
import structlog
from botocore.exceptions import ClientError # type: ignore
from .base import EventBus, EventPayload
logger = structlog.get_logger()
class SQSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
"""AWS SQS/SNS implementation of EventBus"""
def __init__(self, region_name: str = "us-east-1"):
self.region_name = region_name
self.sns_client: Any = None
self.sqs_client: Any = None
self.topic_arns: dict[str, str] = {}
self.queue_urls: dict[str, str] = {}
self.handlers: dict[
str, list[Callable[[str, EventPayload], Awaitable[None]]]
] = {}
self.running = False
self.consumer_tasks: list[asyncio.Task[None]] = []
async def start(self) -> None:
"""Start SQS/SNS clients"""
if self.running:
return
self.sns_client = boto3.client("sns", region_name=self.region_name)
self.sqs_client = boto3.client("sqs", region_name=self.region_name)
self.running = True
logger.info("SQS event bus started", region=self.region_name)
async def stop(self) -> None:
"""Stop SQS/SNS clients and consumers"""
if not self.running:
return
# Cancel consumer tasks
for task in self.consumer_tasks:
task.cancel()
if self.consumer_tasks:
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
self.running = False
logger.info("SQS event bus stopped")
async def publish(self, topic: str, payload: EventPayload) -> bool:
"""Publish event to SNS topic"""
if not self.sns_client:
raise RuntimeError("Event bus not started")
try:
# Ensure topic exists
topic_arn = await self._ensure_topic_exists(topic)
# Publish message
response = self.sns_client.publish(
TopicArn=topic_arn,
Message=payload.to_json(),
MessageAttributes={
"event_id": {"DataType": "String", "StringValue": payload.event_id},
"tenant_id": {
"DataType": "String",
"StringValue": payload.tenant_id,
},
"actor": {"DataType": "String", "StringValue": payload.actor},
},
)
logger.info(
"Event published",
topic=topic,
event_id=payload.event_id,
message_id=response["MessageId"],
)
return True
except ClientError as e:
logger.error(
"Failed to publish event",
topic=topic,
event_id=payload.event_id,
error=str(e),
)
return False
async def subscribe(
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
) -> None:
"""Subscribe to SNS topic via SQS queue"""
if topic not in self.handlers:
self.handlers[topic] = []
self.handlers[topic].append(handler)
if topic not in self.queue_urls:
# Create SQS queue for this topic
queue_name = f"tax-agent-{topic}"
queue_url = await self._ensure_queue_exists(queue_name)
self.queue_urls[topic] = queue_url
# Subscribe queue to SNS topic
topic_arn = await self._ensure_topic_exists(topic)
await self._subscribe_queue_to_topic(queue_url, topic_arn)
# Start consumer task
task = asyncio.create_task(self._consume_messages(topic, queue_url))
self.consumer_tasks.append(task)
logger.info("Subscribed to topic", topic=topic, queue_name=queue_name)
async def _ensure_topic_exists(self, topic: str) -> str:
"""Ensure SNS topic exists and return ARN"""
if topic in self.topic_arns:
return self.topic_arns[topic]
try:
response = self.sns_client.create_topic(Name=topic)
topic_arn = response["TopicArn"]
self.topic_arns[topic] = topic_arn
return str(topic_arn)
except ClientError as e:
logger.error("Failed to create topic", topic=topic, error=str(e))
raise
async def _ensure_queue_exists(self, queue_name: str) -> str:
"""Ensure SQS queue exists and return URL"""
try:
response = self.sqs_client.create_queue(QueueName=queue_name)
return str(response["QueueUrl"])
except ClientError as e:
logger.error("Failed to create queue", queue_name=queue_name, error=str(e))
raise
async def _subscribe_queue_to_topic(self, queue_url: str, topic_arn: str) -> None:
"""Subscribe SQS queue to SNS topic"""
try:
# Get queue attributes
queue_attrs = self.sqs_client.get_queue_attributes(
QueueUrl=queue_url, AttributeNames=["QueueArn"]
)
queue_arn = queue_attrs["Attributes"]["QueueArn"]
# Subscribe queue to topic
self.sns_client.subscribe(
TopicArn=topic_arn, Protocol="sqs", Endpoint=queue_arn
)
except ClientError as e:
logger.error("Failed to subscribe queue to topic", error=str(e))
raise
async def _consume_messages(self, topic: str, queue_url: str) -> None:
"""Consume messages from SQS queue"""
# pylint: disable=too-many-nested-blocks
while self.running:
try:
response = self.sqs_client.receive_message(
QueueUrl=queue_url, MaxNumberOfMessages=10, WaitTimeSeconds=20
)
messages = response.get("Messages", [])
for message in messages:
try:
# Parse SNS message
sns_message = json.loads(message["Body"])
payload_dict = json.loads(sns_message["Message"])
payload = EventPayload(
data=payload_dict["data"],
actor=payload_dict["actor"],
tenant_id=payload_dict["tenant_id"],
trace_id=payload_dict.get("trace_id"),
schema_version=payload_dict.get("schema_version", "1.0"),
)
payload.event_id = payload_dict["event_id"]
payload.occurred_at = payload_dict["occurred_at"]
# Call all handlers for this topic
for handler in self.handlers.get(topic, []):
try:
await handler(topic, payload)
# pylint: disable=broad-exception-caught
except Exception as e:
logger.error(
"Handler failed",
topic=topic,
event_id=payload.event_id,
error=str(e),
)
# Delete message from queue
self.sqs_client.delete_message(
QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"]
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to process message", topic=topic, error=str(e)
)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Consumer error", topic=topic, error=str(e))
await asyncio.sleep(5) # Wait before retrying

17
libs/events/topics.py Normal file
View File

@@ -0,0 +1,17 @@
"""Standard event topic names."""
class EventTopics: # pylint: disable=too-few-public-methods
"""Standard event topic names"""
DOC_INGESTED = "doc.ingested"
DOC_OCR_READY = "doc.ocr_ready"
DOC_EXTRACTED = "doc.extracted"
KG_UPSERTED = "kg.upserted"
RAG_INDEXED = "rag.indexed"
CALC_SCHEDULE_READY = "calc.schedule_ready"
FORM_FILLED = "form.filled"
HMRC_SUBMITTED = "hmrc.submitted"
REVIEW_REQUESTED = "review.requested"
REVIEW_COMPLETED = "review.completed"
FIRM_SYNC_COMPLETED = "firm.sync.completed"

10
libs/forms/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""PDF form filling and evidence pack generation."""
from .evidence_pack import UK_TAX_FORMS, EvidencePackGenerator
from .pdf_filler import PDFFormFiller
__all__ = [
"PDFFormFiller",
"EvidencePackGenerator",
"UK_TAX_FORMS",
]

185
libs/forms/evidence_pack.py Normal file
View File

@@ -0,0 +1,185 @@
"""Evidence pack generation with manifests and signatures."""
import io
from typing import Any
import structlog
logger = structlog.get_logger()
class EvidencePackGenerator: # pylint: disable=too-few-public-methods
"""Generate evidence packs with manifests and signatures"""
def __init__(self, storage_client: Any) -> None:
self.storage = storage_client
async def create_evidence_pack( # pylint: disable=too-many-locals
self,
taxpayer_id: str,
tax_year: str,
scope: str,
evidence_items: list[dict[str, Any]],
) -> dict[str, Any]:
"""Create evidence pack with manifest and signatures"""
# pylint: disable=import-outside-toplevel
import hashlib
import json
import zipfile
from datetime import datetime
try:
# Create ZIP buffer
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
manifest: dict[str, Any] = {
"taxpayer_id": taxpayer_id,
"tax_year": tax_year,
"scope": scope,
"created_at": datetime.utcnow().isoformat(),
"evidence_items": [],
"signatures": {},
}
# Add evidence files to ZIP
for item in evidence_items:
doc_id = item["doc_id"]
page = item.get("page")
bbox = item.get("bbox")
text_hash = item.get("text_hash")
# Get document content
doc_content = await self.storage.get_object(
bucket_name="raw-documents",
object_name=f"tenants/{taxpayer_id}/raw/{doc_id}.pdf",
)
if doc_content:
# Add to ZIP
zip_filename = f"documents/{doc_id}.pdf"
zip_file.writestr(zip_filename, doc_content)
# Calculate file hash
file_hash = hashlib.sha256(doc_content).hexdigest()
# Add to manifest
manifest["evidence_items"].append(
{
"doc_id": doc_id,
"filename": zip_filename,
"page": page,
"bbox": bbox,
"text_hash": text_hash,
"file_hash": file_hash,
"file_size": len(doc_content),
}
)
# Sign manifest
manifest_json = json.dumps(manifest, indent=2, sort_keys=True)
manifest_hash = hashlib.sha256(manifest_json.encode()).hexdigest()
manifest["signatures"]["manifest_hash"] = manifest_hash
manifest["signatures"]["algorithm"] = "SHA-256"
# Add manifest to ZIP
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
# Get ZIP content
zip_content = zip_buffer.getvalue()
# Store evidence pack
pack_filename = f"evidence_pack_{taxpayer_id}_{tax_year}_{scope}.zip"
pack_key = f"tenants/{taxpayer_id}/evidence_packs/{pack_filename}"
success = await self.storage.put_object(
bucket_name="evidence-packs",
object_name=pack_key,
data=io.BytesIO(zip_content),
length=len(zip_content),
content_type="application/zip",
)
if success:
return {
"pack_filename": pack_filename,
"pack_key": pack_key,
"pack_size": len(zip_content),
"evidence_count": len(evidence_items),
"manifest_hash": manifest_hash,
"s3_url": f"s3://evidence-packs/{pack_key}",
}
raise RuntimeError("Failed to store evidence pack")
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Failed to create evidence pack", error=str(e))
raise
# Form configuration for UK tax forms
UK_TAX_FORMS = {
"SA100": {
"name": "Self Assessment Tax Return",
"template_path": "forms/templates/SA100.pdf",
"boxes": {
"1": {"description": "Your name", "type": "text"},
"2": {"description": "Your address", "type": "text"},
"3": {"description": "Your UTR", "type": "text"},
"4": {"description": "Your NI number", "type": "text"},
},
},
"SA103": {
"name": "Self-employment (full)",
"template_path": "forms/templates/SA103.pdf",
"boxes": {
"1": {"description": "Business name", "type": "text"},
"2": {"description": "Business description", "type": "text"},
"3": {"description": "Accounting period start", "type": "date"},
"4": {"description": "Accounting period end", "type": "date"},
"20": {"description": "Total turnover", "type": "currency"},
"31": {
"description": "Total allowable business expenses",
"type": "currency",
},
"32": {"description": "Net profit", "type": "currency"},
"33": {"description": "Balancing charges", "type": "currency"},
"34": {"description": "Goods/services for own use", "type": "currency"},
"35": {"description": "Total taxable profits", "type": "currency"},
},
},
"SA105": {
"name": "Property income",
"template_path": "forms/templates/SA105.pdf",
"boxes": {
"20": {"description": "Total rents and other income", "type": "currency"},
"29": {
"description": "Premiums for the grant of a lease",
"type": "currency",
},
"31": {
"description": "Rent, rates, insurance, ground rents etc",
"type": "currency",
},
"32": {"description": "Property management", "type": "currency"},
"33": {
"description": "Services provided, including wages",
"type": "currency",
},
"34": {
"description": "Repairs, maintenance and renewals",
"type": "currency",
},
"35": {
"description": "Finance costs, including interest",
"type": "currency",
},
"36": {"description": "Professional fees", "type": "currency"},
"37": {"description": "Costs of services provided", "type": "currency"},
"38": {
"description": "Other allowable property expenses",
"type": "currency",
},
},
},
}

246
libs/forms/pdf_filler.py Normal file
View File

@@ -0,0 +1,246 @@
"""PDF form filling using pdfrw with reportlab fallback."""
import io
from typing import Any
import structlog
logger = structlog.get_logger()
class PDFFormFiller:
"""PDF form filling using pdfrw with reportlab fallback"""
def __init__(self) -> None:
self.form_templates: dict[str, Any] = {}
def load_template(self, form_id: str, template_path: str) -> bool:
"""Load PDF form template"""
try:
# pylint: disable=import-outside-toplevel
from pdfrw import PdfReader # type: ignore
template = PdfReader(template_path)
if template is None:
logger.error(
"Failed to load PDF template", form_id=form_id, path=template_path
)
return False
self.form_templates[form_id] = {"template": template, "path": template_path}
logger.info("Loaded PDF template", form_id=form_id, path=template_path)
return True
except ImportError:
logger.error("pdfrw not available for PDF form filling")
return False
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Failed to load PDF template", form_id=form_id, error=str(e))
return False
def fill_form(
self,
form_id: str,
field_values: dict[str, str | int | float | bool],
output_path: str | None = None,
) -> bytes | None:
"""Fill PDF form with values"""
if form_id not in self.form_templates:
logger.error("Form template not loaded", form_id=form_id)
return None
try:
return self._fill_with_pdfrw(form_id, field_values, output_path)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning(
"pdfrw filling failed, trying reportlab overlay", error=str(e)
)
return self._fill_with_overlay(form_id, field_values, output_path)
def _fill_with_pdfrw(
self,
form_id: str,
field_values: dict[str, Any],
output_path: str | None = None,
) -> bytes | None:
"""Fill form using pdfrw"""
# pylint: disable=import-outside-toplevel
from pdfrw import PdfDict, PdfReader, PdfWriter
template_info = self.form_templates[form_id]
template = PdfReader(template_info["path"])
# Get form fields
if template.Root.AcroForm is None: # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue] # fmt: skip
logger.warning("PDF has no AcroForm fields", form_id=form_id)
return self._fill_with_overlay(form_id, field_values, output_path)
# Fill form fields
for field in template.Root.AcroForm.Fields: # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue] # fmt: skip
field_name = field.T
if field_name and field_name[1:-1] in field_values: # Remove parentheses
field_value = field_values[field_name[1:-1]]
# Set field value
if isinstance(field_value, bool):
# Checkbox field
if field_value:
field.V = PdfDict.Yes # fmt: skip # pyright: ignore[reportAttributeAccessIssue]
field.AS = PdfDict.Yes # fmt: skip # pyright: ignore[reportAttributeAccessIssue]
else:
field.V = PdfDict.Off # fmt: skip # pyright: ignore[reportAttributeAccessIssue]
field.AS = PdfDict.Off # fmt: skip # pyright: ignore[reportAttributeAccessIssue]
else:
# Text field
field.V = str(field_value)
# Make field read-only
field.Ff = 1 # Read-only flag
# Flatten form (make fields non-editable)
if template.Root.AcroForm: # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue] # fmt: skip
template.Root.AcroForm.NeedAppearances = True # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue] # fmt: skip
# Write to output
if output_path:
writer = PdfWriter(output_path)
writer.write(template)
with open(output_path, "rb") as f:
return f.read()
else:
# Write to bytes
output_buffer = io.BytesIO()
writer = PdfWriter(output_buffer)
writer.write(template)
return output_buffer.getvalue()
def _fill_with_overlay( # pylint: disable=too-many-locals
self,
form_id: str,
field_values: dict[str, Any],
output_path: str | None = None,
) -> bytes | None:
"""Fill form using reportlab overlay method"""
try:
# pylint: disable=import-outside-toplevel
from PyPDF2 import PdfReader, PdfWriter
from reportlab.lib.pagesizes import A4
from reportlab.pdfgen import canvas
template_info = self.form_templates[form_id]
# Read original PDF
original_pdf = PdfReader(template_info["path"])
# Create overlay with form data
overlay_buffer = io.BytesIO()
overlay_canvas = canvas.Canvas(overlay_buffer, pagesize=A4)
# Get field positions (this would be configured per form)
field_positions = self._get_field_positions(form_id)
# Add text to overlay
for field_name, value in field_values.items():
if field_name in field_positions:
pos = field_positions[field_name]
overlay_canvas.drawString(pos["x"], pos["y"], str(value))
overlay_canvas.save()
overlay_buffer.seek(0)
# Read overlay PDF
overlay_pdf = PdfReader(overlay_buffer)
# Merge original and overlay
writer = PdfWriter()
for page_num, _ in enumerate(original_pdf.pages):
original_page = original_pdf.pages[page_num]
if page_num < len(overlay_pdf.pages):
overlay_page = overlay_pdf.pages[page_num]
original_page.merge_page(overlay_page)
writer.add_page(original_page)
# Write result
if output_path:
with open(output_path, "wb") as output_file:
writer.write(output_file)
with open(output_path, "rb") as f:
return f.read()
else:
output_buffer = io.BytesIO()
writer.write(output_buffer)
return output_buffer.getvalue()
except ImportError as e:
logger.error(
"Required libraries not available for overlay method", error=str(e)
)
return None
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Overlay filling failed", form_id=form_id, error=str(e))
return None
def _get_field_positions(self, form_id: str) -> dict[str, dict[str, float]]:
"""Get field positions for overlay method"""
# This would be configured per form type
# For now, return sample positions for SA103
if form_id == "SA103":
return {
"box_1": {"x": 100, "y": 750}, # Business name
"box_2": {"x": 100, "y": 720}, # Business description
"box_20": {"x": 400, "y": 600}, # Total turnover
"box_31": {"x": 400, "y": 570}, # Total expenses
"box_32": {"x": 400, "y": 540}, # Net profit
}
return {}
def get_form_fields(self, form_id: str) -> list[dict[str, Any]]:
"""Get list of available form fields"""
if form_id not in self.form_templates:
return []
try:
# pylint: disable=import-outside-toplevel
from pdfrw import PdfReader
template_info = self.form_templates[form_id]
template = PdfReader(template_info["path"])
if template.Root.AcroForm is None: # fmt: skip # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue]
return []
fields = []
for field in template.Root.AcroForm.Fields: # fmt: skip # pyright: ignore[reportOptionalMemberAccess, reportAttributeAccessIssue]
field_info = {
"name": field.T[1:-1] if field.T else None, # Remove parentheses
"type": self._get_field_type(field),
"required": bool(field.Ff and int(field.Ff) & 2), # Required flag
"readonly": bool(field.Ff and int(field.Ff) & 1), # Read-only flag
}
if field.V:
field_info["default_value"] = str(field.V)
fields.append(field_info)
return fields
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Failed to get form fields", form_id=form_id, error=str(e))
return []
def _get_field_type(self, field: Any) -> str:
"""Determine field type from PDF field"""
if hasattr(field, "FT"):
field_type = str(field.FT)
if "Tx" in field_type:
return "text"
if "Btn" in field_type:
return "checkbox" if field.Ff and int(field.Ff) & 32768 else "button"
if "Ch" in field_type:
return "choice"
return "unknown"

140
libs/neo/__init__.py Normal file
View File

@@ -0,0 +1,140 @@
from typing import TYPE_CHECKING, Any
import structlog
from .client import Neo4jClient
from .queries import TemporalQueries
from .validator import SHACLValidator
if TYPE_CHECKING:
from libs.schemas.coverage.evaluation import Citation, FoundEvidence
logger = structlog.get_logger()
async def kg_boxes_exist(client: Neo4jClient, box_ids: list[str]) -> dict[str, bool]:
"""Check if form boxes exist in the knowledge graph"""
query = """
UNWIND $box_ids AS bid
OPTIONAL MATCH (fb:FormBox {box_id: bid})
RETURN bid, fb IS NOT NULL AS exists
"""
try:
results = await client.run_query(query, {"box_ids": box_ids})
return {result["bid"]: result["exists"] for result in results}
except Exception as e:
logger.error("Failed to check box existence", box_ids=box_ids, error=str(e))
return dict.fromkeys(box_ids, False)
async def kg_find_evidence(
client: Neo4jClient,
taxpayer_id: str,
tax_year: str,
kinds: list[str],
min_ocr: float = 0.6,
date_window: int = 30,
) -> list["FoundEvidence"]:
"""Find evidence documents for taxpayer in tax year"""
query = """
MATCH (p:TaxpayerProfile {taxpayer_id: $tid})-[:OF_TAX_YEAR]->(y:TaxYear {label: $tax_year})
MATCH (ev:Evidence)-[:DERIVED_FROM]->(d:Document)
WHERE (ev)-[:SUPPORTS]->(p) OR (d)-[:BELONGS_TO]->(p)
AND d.kind IN $kinds
AND date(d.date) >= date(y.start_date) AND date(d.date) <= date(y.end_date)
AND coalesce(ev.ocr_confidence, 0.0) >= $min_ocr
RETURN d.doc_id AS doc_id,
d.kind AS kind,
ev.page AS page,
ev.bbox AS bbox,
ev.ocr_confidence AS ocr_confidence,
ev.extract_confidence AS extract_confidence,
d.date AS date
ORDER BY ev.ocr_confidence DESC
LIMIT 100
"""
try:
results = await client.run_query(
query,
{
"tid": taxpayer_id,
"tax_year": tax_year,
"kinds": kinds,
"min_ocr": min_ocr,
},
)
# Convert to FoundEvidence format
from libs.schemas.coverage.evaluation import FoundEvidence
evidence_list = []
for result in results:
evidence = FoundEvidence(
doc_id=result["doc_id"],
kind=result["kind"],
pages=[result["page"]] if result["page"] else [],
bbox=result["bbox"],
ocr_confidence=result["ocr_confidence"] or 0.0,
extract_confidence=result["extract_confidence"] or 0.0,
date=result["date"],
)
evidence_list.append(evidence)
return evidence_list
except Exception as e:
logger.error(
"Failed to find evidence",
taxpayer_id=taxpayer_id,
tax_year=tax_year,
kinds=kinds,
error=str(e),
)
return []
async def kg_rule_citations(
client: Neo4jClient, schedule_id: str, box_ids: list[str]
) -> list["Citation"]:
"""Get rule citations for schedule and form boxes"""
query = """
MATCH (fb:FormBox)-[:GOVERNED_BY]->(r:Rule)-[:CITES]->(doc:Document)
WHERE fb.box_id IN $box_ids
RETURN r.rule_id AS rule_id,
doc.doc_id AS doc_id,
doc.locator AS locator
LIMIT 10
"""
try:
results = await client.run_query(query, {"box_ids": box_ids})
# Convert to Citation format
from libs.schemas.coverage.evaluation import Citation
citations = []
for result in results:
citation = Citation(
rule_id=result["rule_id"],
doc_id=result["doc_id"],
locator=result["locator"],
)
citations.append(citation)
return citations
except Exception as e:
logger.error(
"Failed to get rule citations",
schedule_id=schedule_id,
box_ids=box_ids,
error=str(e),
)
return []
__all__ = ["Neo4jClient", "TemporalQueries", "SHACLValidator"]

350
libs/neo/client.py Normal file
View File

@@ -0,0 +1,350 @@
"""Neo4j session helpers, Cypher runner with retry, SHACL validator invoker."""
import asyncio
from datetime import datetime
from typing import Any
import structlog
from neo4j import Transaction
from neo4j.exceptions import ServiceUnavailable, TransientError
logger = structlog.get_logger()
class Neo4jClient:
"""Neo4j client with session management and retry logic"""
def __init__(self, driver: Any) -> None:
self.driver = driver
async def __aenter__(self) -> "Neo4jClient":
"""Async context manager entry"""
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Async context manager exit"""
await self.close()
async def close(self) -> None:
"""Close the driver"""
await asyncio.get_event_loop().run_in_executor(None, self.driver.close)
async def run_query(
self,
query: str,
parameters: dict[str, Any] | None = None,
database: str = "neo4j",
max_retries: int = 3,
) -> list[dict[str, Any]]:
"""Run Cypher query with retry logic"""
def _run_query() -> list[dict[str, Any]]:
with self.driver.session(database=database) as session:
result = session.run(query, parameters or {})
return [record.data() for record in result]
for attempt in range(max_retries):
try:
return await asyncio.get_event_loop().run_in_executor(None, _run_query)
except (TransientError, ServiceUnavailable) as e:
if attempt == max_retries - 1:
logger.error(
"Query failed after retries",
query=query[:100],
attempt=attempt + 1,
error=str(e),
)
raise
wait_time = 2**attempt # Exponential backoff
logger.warning(
"Query failed, retrying",
query=query[:100],
attempt=attempt + 1,
wait_time=wait_time,
error=str(e),
)
await asyncio.sleep(wait_time)
except Exception as e:
logger.error(
"Query failed with non-retryable error",
query=query[:100],
error=str(e),
)
raise
# This should never be reached due to the raise statements above
return []
async def run_transaction(
self, transaction_func: Any, database: str = "neo4j", max_retries: int = 3
) -> Any:
"""Run transaction with retry logic"""
def _run_transaction() -> Any:
with self.driver.session(database=database) as session:
return session.execute_write(transaction_func)
for attempt in range(max_retries):
try:
return await asyncio.get_event_loop().run_in_executor(
None, _run_transaction
)
except (TransientError, ServiceUnavailable) as e:
if attempt == max_retries - 1:
logger.error(
"Transaction failed after retries",
attempt=attempt + 1,
error=str(e),
)
raise
wait_time = 2**attempt
logger.warning(
"Transaction failed, retrying",
attempt=attempt + 1,
wait_time=wait_time,
error=str(e),
)
await asyncio.sleep(wait_time)
except Exception as e:
logger.error(
"Transaction failed with non-retryable error", error=str(e)
)
raise
async def create_node(
self, label: str, properties: dict[str, Any], database: str = "neo4j"
) -> dict[str, Any]:
"""Create a node with temporal properties"""
# Add temporal properties if not present
if "asserted_at" not in properties:
properties["asserted_at"] = datetime.utcnow()
query = f"""
CREATE (n:{label} $properties)
RETURN n
"""
result = await self.run_query(query, {"properties": properties}, database)
node = result[0]["n"] if result else {}
# Return node ID if available, otherwise return the full node
return node.get("id", node)
async def update_node(
self,
label: str,
node_id: str,
properties: dict[str, Any],
database: str = "neo4j",
) -> dict[str, Any]:
"""Update node with bitemporal versioning"""
def _update_transaction(tx: Transaction) -> Any:
# First, retract the current version
retract_query = f"""
MATCH (n:{label} {{id: $node_id}})
WHERE n.retracted_at IS NULL
SET n.retracted_at = datetime()
RETURN n
"""
tx.run(retract_query, {"node_id": node_id}) # fmt: skip # pyright: ignore[reportArgumentType]
# Create new version
new_properties = properties.copy()
new_properties["id"] = node_id
new_properties["asserted_at"] = datetime.utcnow()
create_query = f"""
CREATE (n:{label} $properties)
RETURN n
"""
result = tx.run(create_query, {"properties": new_properties}) # fmt: skip # pyright: ignore[reportArgumentType]
record = result.single()
return record["n"] if record else None
result = await self.run_transaction(_update_transaction, database)
return result if isinstance(result, dict) else {}
async def create_relationship( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
from_label: str | None = None,
from_id: str | None = None,
to_label: str | None = None,
to_id: str | None = None,
relationship_type: str | None = None,
properties: dict[str, Any] | None = None,
database: str = "neo4j",
# Alternative signature for tests
from_node_id: int | None = None,
to_node_id: int | None = None,
) -> dict[str, Any]:
"""Create relationship between nodes"""
# Handle alternative signature for tests (using node IDs)
if from_node_id is not None and to_node_id is not None:
rel_properties = properties or {}
if "asserted_at" not in rel_properties:
rel_properties["asserted_at"] = datetime.utcnow()
query = f"""
MATCH (from) WHERE id(from) = $from_id
MATCH (to) WHERE id(to) = $to_id
CREATE (from)-[r:{relationship_type} $properties]->(to)
RETURN r
"""
result = await self.run_query(
query,
{
"from_id": from_node_id,
"to_id": to_node_id,
"properties": rel_properties,
},
database,
)
rel = result[0]["r"] if result else {}
return rel.get("id", rel)
# Original signature (using labels and IDs)
rel_properties = properties or {}
if "asserted_at" not in rel_properties:
rel_properties["asserted_at"] = datetime.utcnow()
query = f"""
MATCH (from:{from_label} {{id: $from_id}})
MATCH (to:{to_label} {{id: $to_id}})
WHERE from.retracted_at IS NULL AND to.retracted_at IS NULL
CREATE (from)-[r:{relationship_type} $properties]->(to)
RETURN r
"""
result = await self.run_query(
query,
{"from_id": from_id, "to_id": to_id, "properties": rel_properties},
database,
)
rel = result[0]["r"] if result else {}
# Return relationship ID if available, otherwise return the full relationship
return rel.get("id", rel)
async def get_node_lineage(
self, node_id: str, max_depth: int = 10, database: str = "neo4j"
) -> list[dict[str, Any]]:
"""Get complete lineage for a node"""
query = """
MATCH path = (n {id: $node_id})-[:DERIVED_FROM*1..10]->(evidence:Evidence)
WHERE n.retracted_at IS NULL
RETURN path, evidence
ORDER BY length(path) DESC
LIMIT 100
"""
return await self.run_query(
query, {"node_id": node_id, "max_depth": max_depth}, database
)
async def export_to_rdf( # pylint: disable=redefined-builtin
self,
format: str = "turtle",
database: str = "neo4j",
) -> dict[str, Any]:
"""Export graph data to RDF format"""
query = """
CALL n10s.rdf.export.cypher(
'MATCH (n) WHERE n.retracted_at IS NULL RETURN n',
$format,
{}
) YIELD triplesCount, format
RETURN triplesCount, format
"""
try:
result = await self.run_query(query, {"format": format}, database)
return result[0] if result else {}
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("RDF export failed, using fallback", error=str(e))
fallback_result = await self._export_rdf_fallback(database)
return {"rdf_data": fallback_result, "format": format}
async def _export_rdf_fallback(self, database: str = "neo4j") -> str:
"""Fallback RDF export without n10s plugin"""
# Get all nodes and relationships
nodes_query = """
MATCH (n) WHERE n.retracted_at IS NULL
RETURN labels(n) as labels, properties(n) as props, id(n) as neo_id
"""
rels_query = """
MATCH (a)-[r]->(b)
WHERE a.retracted_at IS NULL AND b.retracted_at IS NULL
RETURN type(r) as type, properties(r) as props,
id(a) as from_id, id(b) as to_id
"""
nodes = await self.run_query(nodes_query, database=database)
relationships = await self.run_query(rels_query, database=database)
# Convert to simple Turtle format
rdf_lines = ["@prefix tax: <https://tax-kg.example.com/> ."]
for node in nodes:
node_uri = f"tax:node_{node['neo_id']}"
for label in node["labels"]:
rdf_lines.append(f"{node_uri} a tax:{label} .")
for prop, value in node["props"].items():
if isinstance(value, str):
rdf_lines.append(f'{node_uri} tax:{prop} "{value}" .')
else:
rdf_lines.append(f"{node_uri} tax:{prop} {value} .")
for rel in relationships:
from_uri = f"tax:node_{rel['from_id']}"
to_uri = f"tax:node_{rel['to_id']}"
rdf_lines.append(f"{from_uri} tax:{rel['type']} {to_uri} .")
return "\n".join(rdf_lines)
async def find_nodes(
self, label: str, properties: dict[str, Any], database: str = "neo4j"
) -> list[dict[str, Any]]:
"""Find nodes matching label and properties"""
where_clause, params = self._build_properties_clause(properties)
query = f"MATCH (n:{label}) WHERE {where_clause} RETURN n"
result = await self.run_query(query, params, database)
return [record["n"] for record in result]
async def execute_query(
self,
query: str,
parameters: dict[str, Any] | None = None,
database: str = "neo4j",
) -> list[dict[str, Any]]:
"""Execute a custom Cypher query"""
return await self.run_query(query, parameters, database)
def _build_properties_clause(
self, properties: dict[str, Any]
) -> tuple[str, dict[str, Any]]:
"""Build WHERE clause and parameters for properties"""
if not properties:
return "true", {}
clauses = []
params = {}
for i, (key, value) in enumerate(properties.items()):
param_name = f"prop_{i}"
clauses.append(f"n.{key} = ${param_name}")
params[param_name] = value
return " AND ".join(clauses), params

78
libs/neo/queries.py Normal file
View File

@@ -0,0 +1,78 @@
"""Neo4j Cypher queries for coverage policy system"""
from datetime import datetime
from typing import Any
import structlog
logger = structlog.get_logger()
class TemporalQueries:
"""Helper class for temporal queries"""
@staticmethod
def get_current_state_query(
label: str, filters: dict[str, Any] | None = None
) -> str:
"""Get query for current state of nodes"""
where_clause = "n.retracted_at IS NULL"
if filters:
filter_conditions = []
for key, value in filters.items():
if isinstance(value, str):
filter_conditions.append(f"n.{key} = '{value}'")
else:
filter_conditions.append(f"n.{key} = {value}")
if filter_conditions:
where_clause += " AND " + " AND ".join(filter_conditions)
return f"""
MATCH (n:{label})
WHERE {where_clause}
RETURN n
ORDER BY n.asserted_at DESC
"""
@staticmethod
def get_historical_state_query(
label: str, as_of_time: datetime, filters: dict[str, Any] | None = None
) -> str:
"""Get query for historical state at specific time"""
where_clause = f"""
n.asserted_at <= datetime('{as_of_time.isoformat()}')
AND (n.retracted_at IS NULL OR n.retracted_at > datetime('{as_of_time.isoformat()}'))
"""
if filters:
filter_conditions = []
for key, value in filters.items():
if isinstance(value, str):
filter_conditions.append(f"n.{key} = '{value}'")
else:
filter_conditions.append(f"n.{key} = {value}")
if filter_conditions:
where_clause += " AND " + " AND ".join(filter_conditions)
return f"""
MATCH (n:{label})
WHERE {where_clause}
RETURN n
ORDER BY n.asserted_at DESC
"""
@staticmethod
def get_audit_trail_query(node_id: str) -> str:
"""Get complete audit trail for a node"""
return f"""
MATCH (n {{id: '{node_id}'}})
RETURN n.asserted_at as asserted_at,
n.retracted_at as retracted_at,
n.source as source,
n.extractor_version as extractor_version,
properties(n) as properties
ORDER BY n.asserted_at ASC
"""

70
libs/neo/validator.py Normal file
View File

@@ -0,0 +1,70 @@
"""SHACL validation using pySHACL"""
import asyncio
from typing import Any
import structlog
logger = structlog.get_logger()
# pyright: ignore[reportAttributeAccessIssue]
class SHACLValidator: # pylint: disable=too-few-public-methods
"""SHACL validation using pySHACL"""
def __init__(self, shapes_file: str) -> None:
self.shapes_file = shapes_file
async def validate_graph(self, rdf_data: str) -> dict[str, Any]:
"""Validate RDF data against SHACL shapes"""
def _validate() -> dict[str, Any]:
try:
# pylint: disable=import-outside-toplevel
from pyshacl import validate
from rdflib import Graph
# Load data graph
data_graph = Graph()
data_graph.parse(data=rdf_data, format="turtle")
# Load shapes graph
shapes_graph = Graph()
shapes_graph.parse(self.shapes_file, format="turtle")
# Run validation
conforms, results_graph, results_text = validate(
data_graph=data_graph,
shacl_graph=shapes_graph,
inference="rdfs",
abort_on_first=False,
allow_infos=True,
allow_warnings=True,
)
return {
"conforms": conforms,
"results_text": results_text,
"violations_count": len(
list(
results_graph.subjects() # pyright: ignore[reportAttributeAccessIssue]
) # fmt: skip # pyright: ignore[reportAttributeAccessIssue]
),
}
except ImportError:
logger.warning("pySHACL not available, skipping validation")
return {
"conforms": True,
"results_text": "SHACL validation skipped (pySHACL not installed)",
"violations_count": 0,
}
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("SHACL validation failed", error=str(e))
return {
"conforms": False,
"results_text": f"Validation error: {str(e)}",
"violations_count": -1,
}
return await asyncio.get_event_loop().run_in_executor(None, _validate)

View File

@@ -0,0 +1,18 @@
"""Observability setup with OpenTelemetry, Prometheus, and structured logging."""
from .logging import configure_logging
from .opentelemetry_setup import init_opentelemetry
from .prometheus import BusinessMetrics, get_business_metrics, init_prometheus_metrics
from .setup import setup_observability
from .utils import get_metrics, get_tracer
__all__ = [
"configure_logging",
"init_opentelemetry",
"init_prometheus_metrics",
"BusinessMetrics",
"get_business_metrics",
"setup_observability",
"get_tracer",
"get_metrics",
]

View File

@@ -0,0 +1,75 @@
"""Structured logging configuration with OpenTelemetry integration."""
import logging
import sys
import time
from typing import Any
import structlog
from opentelemetry import trace
def configure_logging(service_name: str, log_level: str = "INFO") -> None:
"""Configure structured logging with structlog"""
def add_service_name( # pylint: disable=unused-argument
logger: Any,
method_name: str,
event_dict: dict[str, Any], # noqa: ARG001
) -> dict[str, Any]:
event_dict["service"] = service_name
return event_dict
def add_trace_id( # pylint: disable=unused-argument
logger: Any,
method_name: str,
event_dict: dict[str, Any], # noqa: ARG001
) -> dict[str, Any]:
"""Add trace ID to log entries"""
span = trace.get_current_span()
if span and span.get_span_context().is_valid:
event_dict["trace_id"] = format(span.get_span_context().trace_id, "032x")
event_dict["span_id"] = format(span.get_span_context().span_id, "016x")
return event_dict
def add_timestamp( # pylint: disable=unused-argument
logger: Any,
method_name: str,
event_dict: dict[str, Any], # noqa: ARG001
) -> dict[str, Any]:
event_dict["timestamp"] = time.time()
return event_dict
# Configure structlog
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
add_service_name, # type: ignore
add_trace_id, # type: ignore
add_timestamp, # type: ignore
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer(),
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
# Configure standard library logging
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=getattr(logging, log_level.upper()),
)
# Reduce noise from some libraries
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)

View File

@@ -0,0 +1,99 @@
"""OpenTelemetry tracing and metrics initialization."""
import os
from typing import Any
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
from opentelemetry.instrumentation.psycopg2 import Psycopg2Instrumentor
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import (
MetricExporter,
PeriodicExportingMetricReader,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter
def init_opentelemetry(
service_name: str,
service_version: str = "1.0.0",
otlp_endpoint: str | None = None,
) -> tuple[Any, Any]:
"""Initialize OpenTelemetry tracing and metrics"""
# Create resource
resource = Resource.create(
{
"service.name": service_name,
"service.version": service_version,
"service.instance.id": os.getenv("HOSTNAME", "unknown"),
}
)
# Configure tracing
span_exporter: SpanExporter
if otlp_endpoint:
span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint)
span_processor = BatchSpanProcessor(span_exporter)
else:
# Use console exporter for development
try:
# pylint: disable=import-outside-toplevel
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
span_exporter = ConsoleSpanExporter()
except ImportError:
# Fallback to logging exporter
# pylint: disable=import-outside-toplevel
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
span_exporter = ConsoleSpanExporter()
span_processor = BatchSpanProcessor(span_exporter)
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(span_processor)
trace.set_tracer_provider(tracer_provider)
# Configure metrics
metric_exporter: MetricExporter
if otlp_endpoint:
metric_exporter = OTLPMetricExporter(endpoint=otlp_endpoint)
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=30000
)
else:
# Use console exporter for development
try:
# pylint: disable=import-outside-toplevel
from opentelemetry.sdk.metrics.export import ConsoleMetricExporter
metric_exporter = ConsoleMetricExporter()
except ImportError:
# Fallback to logging exporter
from opentelemetry.sdk.metrics.export import ConsoleMetricExporter
metric_exporter = ConsoleMetricExporter()
metric_reader = PeriodicExportingMetricReader(
metric_exporter, export_interval_millis=30000
)
meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
metrics.set_meter_provider(meter_provider)
# Auto-instrument common libraries
try:
FastAPIInstrumentor().instrument()
HTTPXClientInstrumentor().instrument()
Psycopg2Instrumentor().instrument()
RedisInstrumentor().instrument()
except Exception: # pylint: disable=broad-exception-caught
# Ignore instrumentation errors in tests
pass
return trace.get_tracer(service_name), metrics.get_meter(service_name)

View File

@@ -0,0 +1,235 @@
"""Prometheus metrics setup and business metrics."""
from typing import Any
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Info
from prometheus_fastapi_instrumentator import Instrumentator
def init_prometheus_metrics( # pylint: disable=unused-argument
app: Any, service_name: str
) -> Any:
"""Initialize Prometheus metrics for FastAPI app"""
# Create instrumentator
instrumentator = Instrumentator(
should_group_status_codes=False,
should_ignore_untemplated=True,
should_respect_env_var=True,
should_instrument_requests_inprogress=True,
excluded_handlers=["/metrics", "/healthz", "/readyz", "/livez"],
env_var_name="ENABLE_METRICS",
inprogress_name="http_requests_inprogress",
inprogress_labels=True,
)
# Add custom metrics
instrumentator.add(
lambda info: info.modified_duration < 0.1, # type: ignore
lambda info: Counter(
"http_requests_fast_total",
"Number of fast HTTP requests (< 100ms)",
["method", "endpoint"],
)
.labels(method=info.method, endpoint=info.modified_handler)
.inc(),
)
instrumentator.add(
lambda info: info.modified_duration > 1.0, # type: ignore
lambda info: Counter(
"http_requests_slow_total",
"Number of slow HTTP requests (> 1s)",
["method", "endpoint"],
)
.labels(method=info.method, endpoint=info.modified_handler)
.inc(),
)
# Instrument the app
instrumentator.instrument(app)
instrumentator.expose(app, endpoint="/metrics")
return instrumentator
# Global registry for business metrics to avoid duplicates
_business_metrics_registry: dict[str, Any] = {}
# Custom metrics for business logic
class BusinessMetrics: # pylint: disable=too-many-instance-attributes
"""Custom business metrics for the application"""
def __init__(self, service_name: str):
self.service_name = service_name
# Sanitize service name for Prometheus metrics (replace hyphens with underscores)
self.sanitized_name = service_name.replace("-", "_")
# Create a custom registry for this service to avoid conflicts
self.registry = CollectorRegistry()
# Document processing metrics
self.documents_processed = Counter(
"documents_processed_total",
"Total number of documents processed",
["service", "document_type", "status"],
registry=self.registry,
)
# Add active connections metric for tests
self.active_connections = Gauge(
"active_connections",
"Number of active connections",
["service"],
registry=self.registry,
)
# Dynamic counters for forms service
self._dynamic_counters: dict[str, Any] = {}
self.document_processing_duration = Histogram(
f"document_processing_duration_seconds_{self.sanitized_name}",
"Time spent processing documents",
["service", "document_type"],
buckets=[0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, 120.0, 300.0],
registry=self.registry,
)
# Field extraction metrics
self.field_extractions = Counter(
f"field_extractions_total_{self.sanitized_name}",
"Total number of field extractions",
["service", "field_type", "status"],
registry=self.registry,
)
self.extraction_confidence = Histogram(
f"extraction_confidence_score_{self.sanitized_name}",
"Confidence scores for extractions",
["service", "extraction_type"],
buckets=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
registry=self.registry,
)
# Tax calculation metrics
self.tax_calculations = Counter(
f"tax_calculations_total_{self.sanitized_name}",
"Total number of tax calculations",
["service", "calculation_type", "status"],
registry=self.registry,
)
self.calculation_confidence = Histogram(
f"calculation_confidence_score_{self.sanitized_name}",
"Confidence scores for tax calculations",
["service", "calculation_type"],
buckets=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
registry=self.registry,
)
# RAG metrics
self.rag_searches = Counter(
f"rag_searches_total_{self.sanitized_name}",
"Total number of RAG searches",
["service", "collection", "status"],
registry=self.registry,
)
self.rag_search_duration = Histogram(
f"rag_search_duration_seconds_{self.sanitized_name}",
"Time spent on RAG searches",
["service", "collection"],
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0],
registry=self.registry,
)
self.rag_relevance_score = Histogram(
f"rag_relevance_score_{self.sanitized_name}",
"RAG search relevance scores",
["service", "collection"],
buckets=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
registry=self.registry,
)
# Knowledge graph metrics
self.kg_operations = Counter(
f"kg_operations_total_{self.sanitized_name}",
"Total number of KG operations",
["service", "operation", "status"],
registry=self.registry,
)
self.kg_query_duration = Histogram(
f"kg_query_duration_seconds_{self.sanitized_name}",
"Time spent on KG queries",
["service", "query_type"],
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0],
registry=self.registry,
)
# HMRC submission metrics
self.hmrc_submissions = Counter(
f"hmrc_submissions_total_{self.sanitized_name}",
"Total number of HMRC submissions",
["service", "submission_type", "status"],
registry=self.registry,
)
# Service health metrics
self.service_info = Info(
f"service_info_{self.sanitized_name}",
"Service information",
registry=self.registry,
)
try:
self.service_info.info({"service": service_name, "version": "1.0.0"})
except (AttributeError, ValueError):
# Handle prometheus_client version compatibility or registry conflicts
pass
def counter(self, name: str, labelnames: list[str] | None = None) -> Any:
"""Get or create a counter metric with dynamic labels"""
# Use provided labelnames or default ones
if labelnames is None:
labelnames = ["tenant_id", "form_id", "scope", "error_type"]
# Create a unique key based on name and labelnames
label_key = f"{name}_{','.join(sorted(labelnames))}"
if label_key not in self._dynamic_counters:
self._dynamic_counters[label_key] = Counter(
name,
f"Dynamic counter: {name}",
labelnames=labelnames,
registry=self.registry,
)
return self._dynamic_counters[label_key]
def histogram(self, name: str, labelnames: list[str] | None = None) -> Any:
"""Get or create a histogram metric with dynamic labels"""
# Use provided labelnames or default ones
if labelnames is None:
labelnames = ["tenant_id", "kind"]
# Create a unique key based on name and labelnames
label_key = f"{name}_{','.join(sorted(labelnames))}"
histogram_key = f"_histogram_{label_key}"
if not hasattr(self, histogram_key):
histogram = Histogram(
name,
f"Dynamic histogram: {name}",
labelnames=labelnames,
registry=self.registry,
)
setattr(self, histogram_key, histogram)
return getattr(self, histogram_key)
def get_business_metrics(service_name: str) -> BusinessMetrics:
"""Get business metrics instance for service"""
# Use singleton pattern to avoid registry conflicts
if service_name not in _business_metrics_registry:
_business_metrics_registry[service_name] = BusinessMetrics(service_name)
return _business_metrics_registry[service_name] # type: ignore

View File

@@ -0,0 +1,64 @@
"""Complete observability setup orchestration."""
from typing import Any
from .logging import configure_logging
from .opentelemetry_setup import init_opentelemetry
from .prometheus import get_business_metrics, init_prometheus_metrics
def setup_observability(
settings_or_app: Any,
service_name: str | None = None,
service_version: str = "1.0.0",
log_level: str = "INFO",
otlp_endpoint: str | None = None,
) -> dict[str, Any]:
"""Setup complete observability stack for a service"""
# Handle both settings object and individual parameters
if hasattr(settings_or_app, "service_name"):
# Called with settings object
settings = settings_or_app
service_name = settings.service_name
service_version = getattr(settings, "service_version", "1.0.0")
log_level = getattr(settings, "log_level", "INFO")
otlp_endpoint = getattr(settings, "otel_exporter_endpoint", None)
app = None
else:
# Called with app object
app = settings_or_app
if not service_name:
raise ValueError("service_name is required when passing app object")
# Configure logging
configure_logging(service_name or "unknown", log_level)
# Initialize OpenTelemetry
tracer, meter = init_opentelemetry(
service_name or "unknown", service_version, otlp_endpoint
)
# Get business metrics
business_metrics = get_business_metrics(service_name or "unknown")
# If app is provided, set up Prometheus and add to app state
if app:
# Initialize Prometheus metrics
instrumentator = init_prometheus_metrics(app, service_name or "unknown")
# Add to app state
app.state.tracer = tracer
app.state.meter = meter
app.state.metrics = business_metrics
app.state.instrumentator = instrumentator
return {
"tracer": tracer,
"meter": meter,
"metrics": business_metrics,
"instrumentator": instrumentator,
}
# Just return the observability components
return {"tracer": tracer, "meter": meter, "metrics": business_metrics}

View File

@@ -0,0 +1,17 @@
"""Utility functions for observability components."""
from typing import Any
from opentelemetry import trace
from .prometheus import BusinessMetrics, get_business_metrics
def get_tracer(service_name: str = "default") -> Any:
"""Get OpenTelemetry tracer"""
return trace.get_tracer(service_name)
def get_metrics(service_name: str = "default") -> BusinessMetrics:
"""Get business metrics instance"""
return get_business_metrics(service_name)

21
libs/policy/__init__.py Normal file
View File

@@ -0,0 +1,21 @@
"""Coverage policy loading and management with overlays and hot reload."""
from .loader import PolicyLoader
from .utils import (
apply_feature_flags,
compile_predicates,
get_policy_loader,
load_policy,
merge_overlays,
validate_policy,
)
__all__ = [
"PolicyLoader",
"get_policy_loader",
"load_policy",
"merge_overlays",
"apply_feature_flags",
"compile_predicates",
"validate_policy",
]

386
libs/policy/loader.py Normal file
View File

@@ -0,0 +1,386 @@
"""Policy loading and management with overlays and hot reload."""
import hashlib
import json
import re
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from typing import Any
import structlog
import yaml
from jsonschema import ValidationError, validate
from ..schemas import (
CompiledCoveragePolicy,
CoveragePolicy,
PolicyError,
ValidationResult,
)
logger = structlog.get_logger()
class PolicyLoader:
"""Loads and manages coverage policies with overlays and hot reload"""
def __init__(self, config_dir: str = "config"):
self.config_dir = Path(config_dir)
self.schema_path = Path(__file__).parent.parent / "coverage_schema.json"
self._schema_cache: dict[str, Any] | None = None
def load_policy(
self,
baseline_path: str | None = None,
jurisdiction: str = "UK",
tax_year: str = "2024-25",
tenant_id: str | None = None,
) -> CoveragePolicy:
"""Load policy with overlays applied"""
# Default baseline path
if baseline_path is None:
baseline_path = str(self.config_dir / "coverage.yaml")
# Load baseline policy
baseline = self._load_yaml_file(baseline_path)
# Collect overlay files
overlay_files = []
# Jurisdiction-specific overlay
jurisdiction_file = self.config_dir / f"coverage.{jurisdiction}.{tax_year}.yaml"
if jurisdiction_file.exists():
overlay_files.append(str(jurisdiction_file))
# Tenant-specific overlay
if tenant_id:
tenant_file = self.config_dir / "overrides" / f"{tenant_id}.yaml"
if tenant_file.exists():
overlay_files.append(str(tenant_file))
# Load overlays
overlays = [self._load_yaml_file(path) for path in overlay_files]
# Merge all policies
merged = self.merge_overlays(baseline, *overlays)
# Apply feature flags if available
merged = self.apply_feature_flags(merged)
# Validate against schema
self._validate_policy(merged)
# Convert to Pydantic model
try:
policy = CoveragePolicy(**merged)
logger.info(
"Policy loaded successfully",
jurisdiction=jurisdiction,
tax_year=tax_year,
tenant_id=tenant_id,
overlays=len(overlays),
)
return policy
except Exception as e:
raise PolicyError(f"Failed to parse policy: {str(e)}") from e
def merge_overlays(
self, base: dict[str, Any], *overlays: dict[str, Any]
) -> dict[str, Any]:
"""Merge base policy with overlays using deep merge"""
result = base.copy()
for overlay in overlays:
result = self._deep_merge(result, overlay)
return result
def apply_feature_flags(self, policy: dict[str, Any]) -> dict[str, Any]:
"""Apply feature flags to policy (placeholder for Unleash integration)"""
# TODO: Integrate with Unleash feature flags
# For now, just return the policy unchanged
logger.debug("Feature flags not implemented, returning policy unchanged")
return policy
def compile_predicates(self, policy: CoveragePolicy) -> CompiledCoveragePolicy:
"""Compile condition strings into callable predicates"""
compiled_predicates: dict[str, Callable[[str, str], bool]] = {}
# Compile trigger conditions
for schedule_id, trigger in policy.triggers.items():
for condition in trigger.any_of + trigger.all_of:
if condition not in compiled_predicates:
compiled_predicates[condition] = self._compile_condition(condition)
# Compile evidence conditions
for schedule in policy.schedules.values():
for evidence in schedule.evidence:
if evidence.condition and evidence.condition not in compiled_predicates:
compiled_predicates[evidence.condition] = self._compile_condition(
evidence.condition
)
# Calculate hash of source files
source_files = [str(self.config_dir / "coverage.yaml")]
policy_hash = self._calculate_hash(source_files)
return CompiledCoveragePolicy(
policy=policy,
compiled_predicates=compiled_predicates,
compiled_at=datetime.utcnow(),
hash=policy_hash,
source_files=source_files,
)
def validate_policy(self, policy_dict: dict[str, Any]) -> ValidationResult:
"""Validate policy against schema and business rules"""
errors = []
warnings = []
try:
# JSON Schema validation
self._validate_policy(policy_dict)
# Business rule validation
business_errors, business_warnings = self._validate_business_rules(
policy_dict
)
errors.extend(business_errors)
warnings.extend(business_warnings)
except ValidationError as e:
errors.append(f"Schema validation failed: {e.message}")
except Exception as e:
errors.append(f"Validation error: {str(e)}")
return ValidationResult(ok=len(errors) == 0, errors=errors, warnings=warnings)
def _load_yaml_file(self, path: str) -> dict[str, Any]:
"""Load YAML file with error handling"""
try:
with open(path, encoding="utf-8") as f:
return yaml.safe_load(f) or {}
except FileNotFoundError:
raise PolicyError(f"Policy file not found: {path}")
except yaml.YAMLError as e:
raise PolicyError(f"Invalid YAML in {path}: {str(e)}")
def _deep_merge(
self, base: dict[str, Any], overlay: dict[str, Any]
) -> dict[str, Any]:
"""Deep merge two dictionaries"""
result = base.copy()
for key, value in overlay.items():
if (
key in result
and isinstance(result[key], dict)
and isinstance(value, dict)
):
result[key] = self._deep_merge(result[key], value)
else:
result[key] = value
return result
def _validate_policy(self, policy_dict: dict[str, Any]) -> None:
"""Validate policy against JSON schema"""
if self._schema_cache is None:
with open(self.schema_path, encoding="utf-8") as f:
self._schema_cache = json.load(f)
validate(instance=policy_dict, schema=self._schema_cache) # fmt: skip # pyright: ignore[reportArgumentType]
def _validate_business_rules(
self, policy_dict: dict[str, Any]
) -> tuple[list[str], list[str]]:
"""Validate business rules beyond schema"""
errors = []
warnings = []
# Check that all evidence IDs are in document_kinds
document_kinds = set(policy_dict.get("document_kinds", []))
for schedule_id, schedule in policy_dict.get("schedules", {}).items():
for evidence in schedule.get("evidence", []):
evidence_id = evidence.get("id")
if evidence_id not in document_kinds:
# Check if it's in acceptable_alternatives of any evidence
found_in_alternatives = False
for other_schedule in policy_dict.get("schedules", {}).values():
for other_evidence in other_schedule.get("evidence", []):
if evidence_id in other_evidence.get(
"acceptable_alternatives", []
):
found_in_alternatives = True
break
if found_in_alternatives:
break
if not found_in_alternatives:
errors.append(
f"Evidence ID '{evidence_id}' in schedule '{schedule_id}' "
f"not found in document_kinds or acceptable_alternatives"
)
# Check acceptable alternatives
for alt in evidence.get("acceptable_alternatives", []):
if alt not in document_kinds:
warnings.append(
f"Alternative '{alt}' for evidence '{evidence_id}' "
f"not found in document_kinds"
)
# Check that all schedules referenced in triggers exist
triggers = policy_dict.get("triggers", {})
schedules = policy_dict.get("schedules", {})
for schedule_id in triggers:
if schedule_id not in schedules:
errors.append(
f"Trigger for '{schedule_id}' but no schedule definition found"
)
return errors, warnings
def _compile_condition(self, condition: str) -> Callable[[str, str], bool]:
"""Compile a condition string into a callable predicate"""
# Simple condition parser for the DSL
condition = condition.strip()
# Handle exists() conditions
exists_match = re.match(r"exists\((\w+)\[([^\]]+)\]\)", condition)
if exists_match:
entity_type = exists_match.group(1)
filters = exists_match.group(2)
return self._create_exists_predicate(entity_type, filters)
# Handle simple property conditions
if condition in [
"property_joint_ownership",
"candidate_FHL",
"claims_FTCR",
"claims_remittance_basis",
"received_estate_income",
]:
return self._create_property_predicate(condition)
# Handle computed conditions
if condition in [
"turnover_lt_vat_threshold",
"turnover_ge_vat_threshold",
]:
return self._create_computed_predicate(condition)
# Handle taxpayer flags
if condition.startswith("taxpayer_flag:"):
flag_name = condition.split(":", 1)[1].strip()
return self._create_flag_predicate(flag_name)
# Handle filing mode
if condition.startswith("filing_mode:"):
mode = condition.split(":", 1)[1].strip()
return self._create_filing_mode_predicate(mode)
# Default: always false for unknown conditions
logger.warning("Unknown condition, defaulting to False", condition=condition)
return lambda taxpayer_id, tax_year: False
def _create_exists_predicate(
self, entity_type: str, filters: str
) -> Callable[[str, str], bool]:
"""Create predicate for exists() conditions"""
def predicate(taxpayer_id: str, tax_year: str) -> bool:
# This would query the KG for the entity with filters
# For now, return False as placeholder
logger.debug(
"Exists predicate called",
entity_type=entity_type,
filters=filters,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
)
return False
return predicate
def _create_property_predicate(
self, property_name: str
) -> Callable[[str, str], bool]:
"""Create predicate for property conditions"""
def predicate(taxpayer_id: str, tax_year: str) -> bool:
# This would query the KG for the property
logger.debug(
"Property predicate called",
property_name=property_name,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
)
return False
return predicate
def _create_computed_predicate(
self, computation: str
) -> Callable[[str, str], bool]:
"""Create predicate for computed conditions"""
def predicate(taxpayer_id: str, tax_year: str) -> bool:
# This would perform the computation
logger.debug(
"Computed predicate called",
computation=computation,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
)
return False
return predicate
def _create_flag_predicate(self, flag_name: str) -> Callable[[str, str], bool]:
"""Create predicate for taxpayer flags"""
def predicate(taxpayer_id: str, tax_year: str) -> bool:
# This would check taxpayer flags
logger.debug(
"Flag predicate called",
flag_name=flag_name,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
)
return False
return predicate
def _create_filing_mode_predicate(self, mode: str) -> Callable[[str, str], bool]:
"""Create predicate for filing mode"""
def predicate(taxpayer_id: str, tax_year: str) -> bool:
# This would check filing mode preference
logger.debug(
"Filing mode predicate called",
mode=mode,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
)
return False
return predicate
def _calculate_hash(self, file_paths: list[str]) -> str:
"""Calculate hash of policy files"""
hasher = hashlib.sha256()
for path in sorted(file_paths):
try:
with open(path, "rb") as f:
hasher.update(f.read())
except FileNotFoundError:
logger.warning("File not found for hashing", path=path)
return hasher.hexdigest()

50
libs/policy/utils.py Normal file
View File

@@ -0,0 +1,50 @@
"""Utility functions for policy management."""
from typing import Any
from ..schemas import CompiledCoveragePolicy, CoveragePolicy, ValidationResult
from .loader import PolicyLoader
# Global policy loader instance
_policy_loader: PolicyLoader | None = None
def get_policy_loader(config_dir: str = "config") -> PolicyLoader:
"""Get global policy loader instance"""
global _policy_loader
if _policy_loader is None:
_policy_loader = PolicyLoader(config_dir)
return _policy_loader
# Convenience functions
def load_policy(
baseline_path: str | None = None,
jurisdiction: str = "UK",
tax_year: str = "2024-25",
tenant_id: str | None = None,
) -> CoveragePolicy:
"""Load coverage policy with overlays"""
return get_policy_loader().load_policy(
baseline_path, jurisdiction, tax_year, tenant_id
)
def merge_overlays(base: dict[str, Any], *overlays: dict[str, Any]) -> dict[str, Any]:
"""Merge base policy with overlays"""
return get_policy_loader().merge_overlays(base, *overlays)
def apply_feature_flags(policy: dict[str, Any]) -> dict[str, Any]:
"""Apply feature flags to policy"""
return get_policy_loader().apply_feature_flags(policy)
def compile_predicates(policy: CoveragePolicy) -> CompiledCoveragePolicy:
"""Compile policy predicates"""
return get_policy_loader().compile_predicates(policy)
def validate_policy(policy_dict: dict[str, Any]) -> ValidationResult:
"""Validate policy"""
return get_policy_loader().validate_policy(policy_dict)

13
libs/rag/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""Qdrant collections CRUD, hybrid search, rerank wrapper, de-identification utilities."""
from .collection_manager import QdrantCollectionManager
from .pii_detector import PIIDetector
from .retriever import RAGRetriever
from .utils import rag_search_for_citations
__all__ = [
"PIIDetector",
"QdrantCollectionManager",
"RAGRetriever",
"rag_search_for_citations",
]

View File

@@ -0,0 +1,233 @@
"""Manage Qdrant collections for RAG."""
from typing import Any
import structlog
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
Filter,
PointStruct,
SparseVector,
VectorParams,
)
from .pii_detector import PIIDetector
logger = structlog.get_logger()
class QdrantCollectionManager:
"""Manage Qdrant collections for RAG"""
def __init__(self, client: QdrantClient):
self.client = client
self.pii_detector = PIIDetector()
async def ensure_collection(
self,
collection_name: str,
vector_size: int = 384,
distance: Distance = Distance.COSINE,
sparse_vector_config: dict[str, Any] | None = None,
) -> bool:
"""Ensure collection exists with proper configuration"""
try:
# Check if collection exists
collections = self.client.get_collections().collections
if any(c.name == collection_name for c in collections):
logger.debug("Collection already exists", collection=collection_name)
return True
# Create collection with dense vectors
vector_config = VectorParams(size=vector_size, distance=distance)
# Add sparse vector configuration if provided
sparse_vectors_config = None
if sparse_vector_config:
sparse_vectors_config = {"sparse": sparse_vector_config}
self.client.create_collection(
collection_name=collection_name,
vectors_config=vector_config,
sparse_vectors_config=sparse_vectors_config, # type: ignore
)
logger.info("Created collection", collection=collection_name)
return True
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to create collection", collection=collection_name, error=str(e)
)
return False
async def upsert_points(
self, collection_name: str, points: list[PointStruct]
) -> bool:
"""Upsert points to collection"""
try:
# Validate all points are PII-free
for point in points:
if point.payload and not point.payload.get("pii_free", False):
logger.warning("Point not marked as PII-free", point_id=point.id)
return False
self.client.upsert(collection_name=collection_name, points=points)
logger.info(
"Upserted points", collection=collection_name, count=len(points)
)
return True
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to upsert points", collection=collection_name, error=str(e)
)
return False
async def search_dense( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
collection_name: str,
query_vector: list[float],
limit: int = 10,
filter_conditions: Filter | None = None,
score_threshold: float | None = None,
) -> list[dict[str, Any]]:
"""Search using dense vectors"""
try:
search_result = self.client.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=filter_conditions,
limit=limit,
score_threshold=score_threshold,
with_payload=True,
with_vectors=False,
)
return [
{"id": hit.id, "score": hit.score, "payload": hit.payload}
for hit in search_result
]
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Dense search failed", collection=collection_name, error=str(e)
)
return []
async def search_sparse(
self,
collection_name: str,
query_vector: SparseVector,
limit: int = 10,
filter_conditions: Filter | None = None,
) -> list[dict[str, Any]]:
"""Search using sparse vectors"""
try:
search_result = self.client.search(
collection_name=collection_name,
query_vector=query_vector, # type: ignore
query_filter=filter_conditions,
limit=limit,
using="sparse",
with_payload=True,
with_vectors=False,
)
return [
{"id": hit.id, "score": hit.score, "payload": hit.payload}
for hit in search_result
]
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Sparse search failed", collection=collection_name, error=str(e)
)
return []
async def hybrid_search( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
collection_name: str,
dense_vector: list[float],
sparse_vector: SparseVector,
limit: int = 10,
alpha: float = 0.5,
filter_conditions: Filter | None = None,
) -> list[dict[str, Any]]:
"""Perform hybrid search combining dense and sparse results"""
# Get dense results
dense_results = await self.search_dense(
collection_name=collection_name,
query_vector=dense_vector,
limit=limit * 2, # Get more results for fusion
filter_conditions=filter_conditions,
)
# Get sparse results
sparse_results = await self.search_sparse(
collection_name=collection_name,
query_vector=sparse_vector,
limit=limit * 2,
filter_conditions=filter_conditions,
)
# Combine and re-rank results
return self._fuse_results(dense_results, sparse_results, alpha, limit)
def _fuse_results( # pylint: disable=too-many-locals
self,
dense_results: list[dict[str, Any]],
sparse_results: list[dict[str, Any]],
alpha: float,
limit: int,
) -> list[dict[str, Any]]:
"""Fuse dense and sparse search results"""
# Create score maps
dense_scores = {result["id"]: result["score"] for result in dense_results}
sparse_scores = {result["id"]: result["score"] for result in sparse_results}
# Get all unique IDs
all_ids = set(dense_scores.keys()) | set(sparse_scores.keys())
# Calculate hybrid scores
hybrid_results = []
for doc_id in all_ids:
dense_score = dense_scores.get(doc_id, 0.0)
sparse_score = sparse_scores.get(doc_id, 0.0)
# Normalize scores (simple min-max normalization)
if dense_results:
max_dense = max(dense_scores.values())
dense_score = dense_score / max_dense if max_dense > 0 else 0
if sparse_results:
max_sparse = max(sparse_scores.values())
sparse_score = sparse_score / max_sparse if max_sparse > 0 else 0
# Combine scores
hybrid_score = alpha * dense_score + (1 - alpha) * sparse_score
# Get payload from either result
payload = None
for result in dense_results + sparse_results:
if result["id"] == doc_id:
payload = result["payload"]
break
hybrid_results.append(
{
"id": doc_id,
"score": hybrid_score,
"dense_score": dense_score,
"sparse_score": sparse_score,
"payload": payload,
}
)
# Sort by hybrid score and return top results
hybrid_results.sort(key=lambda x: x["score"], reverse=True)
return hybrid_results[:limit]

507
libs/rag/indexer.py Normal file
View File

@@ -0,0 +1,507 @@
# FILE: retrieval/indexer.py
# De-identify -> embed dense/sparse -> upsert to Qdrant with payload
import json
import logging
import re
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
import numpy as np
import spacy
import torch
import yaml
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, SparseVector, VectorParams
from sentence_transformers import SentenceTransformer
from .chunker import DocumentChunker
from .pii_detector import PIIDetector, PIIRedactor
@dataclass
class IndexingResult:
collection_name: str
points_indexed: int
points_updated: int
points_failed: int
processing_time: float
errors: list[str]
class RAGIndexer:
def __init__(self, config_path: str, qdrant_url: str = "http://localhost:6333"):
with open(config_path) as f:
self.config = yaml.safe_load(f)
self.qdrant_client = QdrantClient(url=qdrant_url)
self.chunker = DocumentChunker(config_path)
self.pii_detector = PIIDetector()
self.pii_redactor = PIIRedactor()
# Initialize embedding models
self.dense_model = SentenceTransformer(
self.config.get("embedding_model", "bge-small-en-v1.5")
)
# Initialize sparse model (BM25/SPLADE)
self.sparse_model = self._init_sparse_model()
# Initialize NLP pipeline
self.nlp = spacy.load("en_core_web_sm")
self.logger = logging.getLogger(__name__)
def _init_sparse_model(self):
"""Initialize sparse embedding model (BM25 or SPLADE)"""
sparse_config = self.config.get("sparse_model", {})
model_type = sparse_config.get("type", "bm25")
if model_type == "bm25":
from rank_bm25 import BM25Okapi
return BM25Okapi
elif model_type == "splade":
from transformers import AutoModelForMaskedLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"naver/splade-cocondenser-ensembledistil"
)
model = AutoModelForMaskedLM.from_pretrained(
"naver/splade-cocondenser-ensembledistil"
)
return {"tokenizer": tokenizer, "model": model}
else:
raise ValueError(f"Unsupported sparse model type: {model_type}")
async def index_document(
self, document_path: str, collection_name: str, metadata: dict[str, Any]
) -> IndexingResult:
"""Index a single document into the specified collection"""
start_time = datetime.now()
errors = []
points_indexed = 0
points_updated = 0
points_failed = 0
try:
# Step 1: Chunk the document
chunks = await self.chunker.chunk_document(document_path, metadata)
# Step 2: Process each chunk
points = []
for chunk in chunks:
try:
point = await self._process_chunk(chunk, collection_name, metadata)
if point:
points.append(point)
except Exception as e:
self.logger.error(
f"Failed to process chunk {chunk.get('id', 'unknown')}: {str(e)}"
)
errors.append(f"Chunk processing error: {str(e)}")
points_failed += 1
# Step 3: Upsert to Qdrant
if points:
try:
operation_info = self.qdrant_client.upsert(
collection_name=collection_name, points=points, wait=True
)
points_indexed = len(points)
self.logger.info(
f"Indexed {points_indexed} points to {collection_name}"
)
except Exception as e:
self.logger.error(f"Failed to upsert to Qdrant: {str(e)}")
errors.append(f"Qdrant upsert error: {str(e)}")
points_failed += len(points)
points_indexed = 0
except Exception as e:
self.logger.error(f"Document indexing failed: {str(e)}")
errors.append(f"Document indexing error: {str(e)}")
processing_time = (datetime.now() - start_time).total_seconds()
return IndexingResult(
collection_name=collection_name,
points_indexed=points_indexed,
points_updated=points_updated,
points_failed=points_failed,
processing_time=processing_time,
errors=errors,
)
async def _process_chunk(
self, chunk: dict[str, Any], collection_name: str, base_metadata: dict[str, Any]
) -> PointStruct | None:
"""Process a single chunk: de-identify, embed, create point"""
# Step 1: De-identify PII
content = chunk["content"]
pii_detected = self.pii_detector.detect(content)
if pii_detected:
# Redact PII and create mapping
redacted_content, pii_mapping = self.pii_redactor.redact(
content, pii_detected
)
# Store PII mapping securely (not in vector DB)
await self._store_pii_mapping(chunk["id"], pii_mapping)
# Log PII detection for audit
self.logger.warning(
f"PII detected in chunk {chunk['id']}: {[p['type'] for p in pii_detected]}"
)
else:
redacted_content = content
# Verify no PII remains
if not self._verify_pii_free(redacted_content):
self.logger.error(f"PII verification failed for chunk {chunk['id']}")
return None
# Step 2: Generate embeddings
try:
dense_vector = await self._generate_dense_embedding(redacted_content)
sparse_vector = await self._generate_sparse_embedding(redacted_content)
except Exception as e:
self.logger.error(
f"Embedding generation failed for chunk {chunk['id']}: {str(e)}"
)
return None
# Step 3: Prepare metadata
payload = self._prepare_payload(chunk, base_metadata, redacted_content)
payload["pii_free"] = True # Verified above
# Step 4: Create point
point = PointStruct(
id=chunk["id"],
vector={"dense": dense_vector, "sparse": sparse_vector},
payload=payload,
)
return point
async def _generate_dense_embedding(self, text: str) -> list[float]:
"""Generate dense vector embedding"""
try:
# Use sentence transformer for dense embeddings
embedding = self.dense_model.encode(text, normalize_embeddings=True)
return embedding.tolist()
except Exception as e:
self.logger.error(f"Dense embedding generation failed: {str(e)}")
raise
async def _generate_sparse_embedding(self, text: str) -> SparseVector:
"""Generate sparse vector embedding (BM25 or SPLADE)"""
vector = SparseVector(indices=[], values=[])
try:
sparse_config = self.config.get("sparse_model", {})
model_type = sparse_config.get("type", "bm25")
if model_type == "bm25":
# Simple BM25-style sparse representation
doc = self.nlp(text)
tokens = [
token.lemma_.lower()
for token in doc
if not token.is_stop and not token.is_punct
]
# Create term frequency vector
term_freq = {}
for token in tokens:
term_freq[token] = term_freq.get(token, 0) + 1
# Convert to sparse vector format
vocab_size = sparse_config.get("vocab_size", 30000)
indices = []
values = []
for term, freq in term_freq.items():
# Simple hash-based vocabulary mapping
term_id = hash(term) % vocab_size
indices.append(term_id)
values.append(float(freq))
vector = SparseVector(indices=indices, values=values)
elif model_type == "splade":
# SPLADE sparse embeddings
tokenizer = self.sparse_model["tokenizer"]
model = self.sparse_model["model"]
inputs = tokenizer(
text, return_tensors="pt", truncation=True, max_length=512
)
outputs = model(**inputs)
# Extract sparse representation
logits = outputs.logits.squeeze()
sparse_rep = torch.relu(logits).detach().numpy()
# Convert to sparse format
indices = np.nonzero(sparse_rep)[0].tolist()
values = sparse_rep[indices].tolist()
vector = SparseVector(indices=indices, values=values)
return vector
except Exception as e:
self.logger.error(f"Sparse embedding generation failed: {str(e)}")
# Return empty sparse vector as fallback
return vector
def _prepare_payload(
self, chunk: dict[str, Any], base_metadata: dict[str, Any], content: str
) -> dict[str, Any]:
"""Prepare payload metadata for the chunk"""
# Start with base metadata
payload = base_metadata.copy()
# Add chunk-specific metadata
payload.update(
{
"document_id": chunk.get("document_id"),
"content": content, # De-identified content
"chunk_index": chunk.get("chunk_index", 0),
"total_chunks": chunk.get("total_chunks", 1),
"page_numbers": chunk.get("page_numbers", []),
"section_hierarchy": chunk.get("section_hierarchy", []),
"has_calculations": self._detect_calculations(content),
"has_forms": self._detect_form_references(content),
"confidence_score": chunk.get("confidence_score", 1.0),
"created_at": datetime.now().isoformat(),
"version": self.config.get("version", "1.0"),
}
)
# Extract and add topic tags
topic_tags = self._extract_topic_tags(content)
if topic_tags:
payload["topic_tags"] = topic_tags
# Add content analysis
payload.update(self._analyze_content(content))
return payload
def _detect_calculations(self, text: str) -> bool:
"""Detect if text contains calculations or formulas"""
calculation_patterns = [
r"\d+\s*[+\-*/]\s*\d+",
r"£\d+(?:,\d{3})*(?:\.\d{2})?",
r"\d+(?:\.\d+)?%",
r"total|sum|calculate|compute",
r"rate|threshold|allowance|relief",
]
for pattern in calculation_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def _detect_form_references(self, text: str) -> bool:
"""Detect references to tax forms"""
form_patterns = [
r"SA\d{3}",
r"P\d{2}",
r"CT\d{3}",
r"VAT\d{3}",
r"form\s+\w+",
r"schedule\s+\w+",
]
for pattern in form_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def _extract_topic_tags(self, text: str) -> list[str]:
"""Extract topic tags from content"""
topic_keywords = {
"employment": [
"PAYE",
"payslip",
"P60",
"employment",
"salary",
"wages",
"employer",
],
"self_employment": [
"self-employed",
"business",
"turnover",
"expenses",
"profit",
"loss",
],
"property": ["rental", "property", "landlord", "FHL", "mortgage", "rent"],
"dividends": ["dividend", "shares", "distribution", "corporation tax"],
"capital_gains": ["capital gains", "disposal", "acquisition", "CGT"],
"pensions": ["pension", "retirement", "SIPP", "occupational"],
"savings": ["interest", "savings", "ISA", "bonds"],
"inheritance": ["inheritance", "IHT", "estate", "probate"],
"vat": ["VAT", "value added tax", "registration", "return"],
}
tags = []
text_lower = text.lower()
for topic, keywords in topic_keywords.items():
for keyword in keywords:
if keyword.lower() in text_lower:
tags.append(topic)
break
return list(set(tags)) # Remove duplicates
def _analyze_content(self, text: str) -> dict[str, Any]:
"""Analyze content for additional metadata"""
doc = self.nlp(text)
return {
"word_count": len([token for token in doc if not token.is_space]),
"sentence_count": len(list(doc.sents)),
"entity_count": len(doc.ents),
"complexity_score": self._calculate_complexity(doc),
"language": doc.lang_ if hasattr(doc, "lang_") else "en",
}
def _calculate_complexity(self, doc: dict) -> float:
"""Calculate text complexity score"""
if not doc:
return 0.0
# Simple complexity based on sentence length and vocabulary
avg_sentence_length = sum(len(sent) for sent in doc.sents) / len(
list(doc.sents)
)
unique_words = len(set(token.lemma_.lower() for token in doc if token.is_alpha))
total_words = len([token for token in doc if token.is_alpha])
vocabulary_diversity = unique_words / total_words if total_words > 0 else 0
# Normalize to 0-1 scale
complexity = min(1.0, (avg_sentence_length / 20.0 + vocabulary_diversity) / 2.0)
return complexity
def _verify_pii_free(self, text: str) -> bool:
"""Verify that text contains no PII"""
# Quick verification using patterns
pii_patterns = [
r"\b[A-Z]{2}\d{6}[A-D]\b", # NI number
r"\b\d{10}\b", # UTR
r"\b[A-Z]{2}\d{2}[A-Z]{4}\d{14}\b", # IBAN
r"\b\d{2}-\d{2}-\d{2}\b", # Sort code
r"\b[A-Z]{1,2}\d[A-Z\d]?\s*\d[A-Z]{2}\b", # Postcode
r"\b[\w\.-]+@[\w\.-]+\.\w+\b", # Email
r"\b(?:\+44|0)\d{10,11}\b", # Phone
]
for pattern in pii_patterns:
if re.search(pattern, text):
return False
return True
async def _store_pii_mapping(
self, chunk_id: str, pii_mapping: dict[str, Any]
) -> None:
"""Store PII mapping in secure client data store (not in vector DB)"""
# This would integrate with the secure PostgreSQL client data store
# For now, just log the mapping securely
self.logger.info(
f"PII mapping stored for chunk {chunk_id}: {len(pii_mapping)} items"
)
async def create_collections(self) -> None:
"""Create all Qdrant collections based on configuration"""
collections_config_path = Path(__file__).parent / "qdrant_collections.json"
with open(collections_config_path) as f:
collections_config = json.load(f)
for collection_config in collections_config["collections"]:
collection_name = collection_config["name"]
try:
# Check if collection exists
try:
self.qdrant_client.get_collection(collection_name)
self.logger.info(f"Collection {collection_name} already exists")
continue
except:
pass # Collection doesn't exist, create it
# Create collection
vectors_config = {}
# Dense vector configuration
if "dense" in collection_config:
vectors_config["dense"] = VectorParams(
size=collection_config["dense"]["size"],
distance=Distance.COSINE,
)
# Sparse vector configuration
if collection_config.get("sparse", False):
vectors_config["sparse"] = VectorParams(
size=30000, # Vocabulary size for sparse vectors
distance=Distance.DOT,
on_disk=True,
)
self.qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
**collection_config.get("indexing_config", {}),
)
self.logger.info(f"Created collection: {collection_name}")
except Exception as e:
self.logger.error(
f"Failed to create collection {collection_name}: {str(e)}"
)
raise
async def batch_index(
self, documents: list[dict[str, Any]], collection_name: str
) -> list[IndexingResult]:
"""Index multiple documents in batch"""
results = []
for doc_info in documents:
result = await self.index_document(
doc_info["path"], collection_name, doc_info["metadata"]
)
results.append(result)
return results
def get_collection_stats(self, collection_name: str) -> dict[str, Any]:
"""Get statistics for a collection"""
try:
collection_info = self.qdrant_client.get_collection(collection_name)
return {
"name": collection_name,
"vectors_count": collection_info.vectors_count,
"indexed_vectors_count": collection_info.indexed_vectors_count,
"points_count": collection_info.points_count,
"segments_count": collection_info.segments_count,
"status": collection_info.status,
}
except Exception as e:
self.logger.error(f"Failed to get stats for {collection_name}: {str(e)}")
return {"error": str(e)}

77
libs/rag/pii_detector.py Normal file
View File

@@ -0,0 +1,77 @@
"""PII detection and de-identification utilities."""
import hashlib
import re
from typing import Any
class PIIDetector:
"""PII detection and de-identification utilities"""
# Regex patterns for common PII
PII_PATTERNS = {
"uk_ni_number": r"\b[A-CEGHJ-PR-TW-Z]{2}\d{6}[A-D]\b",
"uk_utr": r"\b\d{10}\b",
"uk_postcode": r"\b[A-Z]{1,2}\d[A-Z0-9]?\s*\d[A-Z]{2}\b",
"uk_sort_code": r"\b\d{2}-\d{2}-\d{2}\b",
"uk_account_number": r"\b\d{8}\b",
"email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
"phone": r"\b(?:\+44|0)\d{10,11}\b",
"iban": r"\bGB\d{2}[A-Z]{4}\d{14}\b",
"amount": r"£\d{1,3}(?:,\d{3})*(?:\.\d{2})?",
"date": r"\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b",
}
def __init__(self) -> None:
self.compiled_patterns = {
name: re.compile(pattern, re.IGNORECASE)
for name, pattern in self.PII_PATTERNS.items()
}
def detect_pii(self, text: str) -> list[dict[str, Any]]:
"""Detect PII in text and return matches with positions"""
matches = []
for pii_type, pattern in self.compiled_patterns.items():
for match in pattern.finditer(text):
matches.append(
{
"type": pii_type,
"value": match.group(),
"start": match.start(),
"end": match.end(),
"placeholder": self._generate_placeholder(
pii_type, match.group()
),
}
)
return sorted(matches, key=lambda x: x["start"])
def de_identify_text(self, text: str) -> tuple[str, dict[str, str]]:
"""De-identify text by replacing PII with placeholders"""
pii_matches = self.detect_pii(text)
pii_mapping = {}
# Replace PII from end to start to maintain positions
de_identified = text
for match in reversed(pii_matches):
placeholder = match["placeholder"]
pii_mapping[placeholder] = match["value"]
de_identified = (
de_identified[: match["start"]]
+ placeholder
+ de_identified[match["end"] :]
)
return de_identified, pii_mapping
def _generate_placeholder(self, pii_type: str, value: str) -> str:
"""Generate consistent placeholder for PII value"""
# Create hash of the value for consistent placeholders
value_hash = hashlib.md5(value.encode()).hexdigest()[:8]
return f"[{pii_type.upper()}_{value_hash}]"
def has_pii(self, text: str) -> bool:
"""Check if text contains any PII"""
return len(self.detect_pii(text)) > 0

235
libs/rag/retriever.py Normal file
View File

@@ -0,0 +1,235 @@
"""High-level RAG retrieval with reranking and KG fusion."""
from typing import Any
import structlog
from qdrant_client import QdrantClient
from qdrant_client.models import (
FieldCondition,
Filter,
MatchValue,
SparseVector,
)
from .collection_manager import QdrantCollectionManager
logger = structlog.get_logger()
class RAGRetriever: # pylint: disable=too-few-public-methods
"""High-level RAG retrieval with reranking and KG fusion"""
def __init__(
self,
qdrant_client: QdrantClient,
neo4j_client: Any = None,
reranker_model: str | None = None,
) -> None:
self.collection_manager = QdrantCollectionManager(qdrant_client)
self.neo4j_client = neo4j_client
self.reranker_model = reranker_model
async def search( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
self,
query: str,
collections: list[str],
dense_vector: list[float],
sparse_vector: SparseVector,
k: int = 10,
alpha: float = 0.5,
beta: float = 0.3, # pylint: disable=unused-argument
gamma: float = 0.2, # pylint: disable=unused-argument
tax_year: str | None = None,
jurisdiction: str | None = None,
) -> dict[str, Any]:
"""Perform comprehensive RAG search with KG fusion"""
# Build filter conditions
filter_conditions = self._build_filter(tax_year, jurisdiction)
# Search each collection
all_chunks = []
for collection in collections:
chunks = await self.collection_manager.hybrid_search(
collection_name=collection,
dense_vector=dense_vector,
sparse_vector=sparse_vector,
limit=k,
alpha=alpha,
filter_conditions=filter_conditions,
)
# Add collection info to chunks
for chunk in chunks:
chunk["collection"] = collection
all_chunks.extend(chunks)
# Re-rank if reranker is available
if self.reranker_model and len(all_chunks) > k:
all_chunks = await self._rerank_chunks(query, all_chunks, k)
# Sort by score and take top k
all_chunks.sort(key=lambda x: x["score"], reverse=True)
top_chunks = all_chunks[:k]
# Get KG hints if Neo4j client is available
kg_hints = []
if self.neo4j_client:
kg_hints = await self._get_kg_hints(query, top_chunks)
# Extract citations
citations = self._extract_citations(top_chunks)
# Calculate calibrated confidence
calibrated_confidence = self._calculate_confidence(top_chunks)
return {
"chunks": top_chunks,
"citations": citations,
"kg_hints": kg_hints,
"calibrated_confidence": calibrated_confidence,
}
def _build_filter(
self, tax_year: str | None = None, jurisdiction: str | None = None
) -> Filter | None:
"""Build Qdrant filter conditions"""
conditions = []
if jurisdiction:
conditions.append(
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
)
if tax_year:
conditions.append(
FieldCondition(key="tax_years", match=MatchValue(value=tax_year))
)
# Always require PII-free content
conditions.append(FieldCondition(key="pii_free", match=MatchValue(value=True)))
if conditions:
return Filter(must=conditions) # type: ignore
return None
async def _rerank_chunks( # pylint: disable=unused-argument
self, query: str, chunks: list[dict[str, Any]], k: int
) -> list[dict[str, Any]]:
"""Rerank chunks using cross-encoder model"""
try:
# This would integrate with a reranking service
# For now, return original chunks
logger.debug("Reranking not implemented, returning original order")
return chunks
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Reranking failed, using original order", error=str(e))
return chunks
async def _get_kg_hints( # pylint: disable=unused-argument
self, query: str, chunks: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Get knowledge graph hints related to the query"""
try:
# Extract potential rule/formula references from chunks
hints = []
for chunk in chunks:
payload = chunk.get("payload", {})
topic_tags = payload.get("topic_tags", [])
# Look for tax rules related to the topics
if topic_tags and self.neo4j_client:
kg_query = """
MATCH (r:Rule)-[:APPLIES_TO]->(topic)
WHERE topic.name IN $topics
AND r.retracted_at IS NULL
RETURN r.rule_id as rule_id,
r.formula as formula_id,
collect(id(topic)) as node_ids
LIMIT 5
"""
kg_results = await self.neo4j_client.run_query(
kg_query, {"topics": topic_tags}
)
for result in kg_results:
hints.append(
{
"rule_id": result["rule_id"],
"formula_id": result["formula_id"],
"node_ids": result["node_ids"],
}
)
return hints[:5] # Limit to top 5 hints
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Failed to get KG hints", error=str(e))
return []
def _extract_citations(self, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Extract citation information from chunks"""
citations = []
seen_docs = set()
for chunk in chunks:
payload = chunk.get("payload", {})
# Extract document reference
doc_id = payload.get("doc_id")
url = payload.get("url")
section_id = payload.get("section_id")
page = payload.get("page")
bbox = payload.get("bbox")
# Create citation key to avoid duplicates
citation_key = doc_id or url
if citation_key and citation_key not in seen_docs:
citation = {}
if doc_id:
citation["doc_id"] = doc_id
if url:
citation["url"] = url
if section_id:
citation["section_id"] = section_id
if page:
citation["page"] = page
if bbox:
citation["bbox"] = bbox
citations.append(citation)
seen_docs.add(citation_key)
return citations
def _calculate_confidence(self, chunks: list[dict[str, Any]]) -> float:
"""Calculate calibrated confidence score"""
if not chunks:
return 0.0
# Simple confidence calculation based on top scores
top_scores = [chunk["score"] for chunk in chunks[:3]]
if not top_scores:
return 0.0
# Average of top 3 scores with diminishing returns
weights = [0.5, 0.3, 0.2]
weighted_score = sum(
score * weight
for score, weight in zip(
top_scores, weights[: len(top_scores)], strict=False
)
)
# Apply calibration (simple temperature scaling)
# In production, this would use learned calibration parameters
temperature = 1.2
calibrated = weighted_score / temperature
return min(max(calibrated, 0.0), 1.0) # type: ignore

44
libs/rag/utils.py Normal file
View File

@@ -0,0 +1,44 @@
"""Coverage-specific RAG utility functions."""
from typing import Any
import structlog
from libs.schemas.coverage.evaluation import Citation
logger = structlog.get_logger()
async def rag_search_for_citations(
rag_client: Any, query: str, filters: dict[str, Any] | None = None
) -> list["Citation"]:
"""Search for citations using RAG with PII-free filtering"""
try:
# Ensure PII-free filter is always applied
search_filters = filters or {}
search_filters["pii_free"] = True
# This would integrate with the actual RAG retrieval system
# For now, return a placeholder implementation
logger.debug(
"RAG citation search called",
query=query,
filters=search_filters,
rag_client_available=rag_client is not None,
)
# Placeholder citations - in production this would call the RAG system
citations = [
Citation(
doc_id=f"RAG-{query.replace(' ', '-')[:20]}",
locator="Retrieved via RAG search",
url=f"https://guidance.example.com/search?q={query}",
)
]
return citations
except (ConnectionError, TimeoutError) as e:
logger.error("RAG citation search failed", query=query, error=str(e))
return []

View File

@@ -0,0 +1,38 @@
# Core framework dependencies (Required by all services)
fastapi>=0.118.0
uvicorn[standard]>=0.37.0
pydantic>=2.11.9
pydantic-settings>=2.11.0
# Database drivers (lightweight)
sqlalchemy>=2.0.43
asyncpg>=0.30.0
psycopg2-binary>=2.9.10
neo4j>=6.0.2
redis[hiredis]>=6.4.0
# Object storage and vector database
minio>=7.2.18
qdrant-client>=1.15.1
# Event streaming (NATS only - removed Kafka)
nats-py>=2.11.0
# Security and secrets management
hvac>=2.3.0
cryptography>=46.0.2
# Observability and monitoring (minimal)
prometheus-client>=0.23.1
prometheus-fastapi-instrumentator>=7.1.0
structlog>=25.4.0
# HTTP client
httpx>=0.28.1
# Utilities
ulid-py>=1.1.0
python-multipart>=0.0.20
python-dateutil>=2.9.0
python-dotenv>=1.1.1
orjson>=3.11.3

30
libs/requirements-dev.txt Normal file
View File

@@ -0,0 +1,30 @@
# Development dependencies (NOT included in Docker images)
# Type checking
mypy>=1.7.0
types-redis>=4.6.0
types-requests>=2.31.0
# Testing utilities
pytest>=7.4.0
pytest-asyncio>=0.21.0
pytest-minio-mock>=0.4
pytest-cov>=4.1.0
hypothesis>=6.88.0
# Code quality
ruff>=0.1.0
black>=23.11.0
isort>=5.12.0
bandit>=1.7.0
safety>=2.3.0
# OpenTelemetry instrumentation (development only)
opentelemetry-api>=1.21.0
opentelemetry-sdk>=1.21.0
opentelemetry-exporter-otlp-proto-grpc>=1.21.0
opentelemetry-instrumentation-fastapi>=0.42b0
opentelemetry-instrumentation-httpx>=0.42b0
opentelemetry-instrumentation-psycopg2>=0.42b0
opentelemetry-instrumentation-redis>=0.42b0

20
libs/requirements-ml.txt Normal file
View File

@@ -0,0 +1,20 @@
# ML and AI libraries (ONLY for services that need them)
# WARNING: These are HEAVY dependencies - only include in services that absolutely need them
# Sentence transformers (includes PyTorch - ~2GB)
sentence-transformers>=5.1.1
# Transformers library (includes PyTorch - ~1GB)
transformers>=4.57.0
# Traditional ML (lighter than deep learning)
scikit-learn>=1.7.2
numpy>=2.3.3
# NLP libraries
spacy>=3.8.7
nltk>=3.9.2
# Text processing
fuzzywuzzy>=0.18.0
python-Levenshtein>=0.27.1

View File

@@ -0,0 +1,5 @@
# PDF processing libraries (only for services that need them)
pdfrw>=0.4
reportlab>=4.4.4
PyPDF2>=3.0.1
pdfplumber>=0.11.7

View File

@@ -0,0 +1,3 @@
# RDF and semantic web libraries (only for KG service)
pyshacl>=0.30.1
rdflib>=7.2.1

10
libs/requirements.txt Normal file
View File

@@ -0,0 +1,10 @@
# DEPRECATED: This file is kept for backward compatibility
# Use the split requirements files instead:
# - requirements-base.txt: Core dependencies (use in all services)
# - requirements-ml.txt: ML/AI dependencies (use only in ML services)
# - requirements-pdf.txt: PDF processing (use only in services that process PDFs)
# - requirements-rdf.txt: RDF/semantic web (use only in KG service)
# - requirements-dev.txt: Development dependencies (NOT in Docker images)
# For backward compatibility, include base requirements
-r requirements-base.txt

175
libs/schemas/__init__.py Normal file
View File

@@ -0,0 +1,175 @@
"""Shared Pydantic models mirroring ontology entities."""
# Import all enums
# Import coverage models
from .coverage.core import (
CompiledCoveragePolicy,
ConflictRules,
CoveragePolicy,
CrossCheck,
Defaults,
EvidenceItem,
GuidanceRef,
Privacy,
QuestionTemplates,
SchedulePolicy,
StatusClassifier,
StatusClassifierConfig,
TaxYearBoundary,
Trigger,
Validity,
)
from .coverage.evaluation import (
BlockingItem,
Citation,
ClarifyContext,
ClarifyResponse,
CoverageGap,
CoverageItem,
CoverageReport,
FoundEvidence,
ScheduleCoverage,
UploadOption,
)
from .coverage.utils import CoverageAudit, PolicyError, PolicyVersion, ValidationResult
# Import all entities
from .entities import (
Account,
BaseEntity,
Calculation,
Document,
Evidence,
ExpenseItem,
FormBox,
IncomeItem,
Party,
Payment,
PropertyAsset,
Rule,
TaxpayerProfile,
)
from .enums import (
DocumentKind,
ExpenseType,
HealthStatus,
IncomeType,
OverallStatus,
PartySubtype,
PropertyUsage,
Role,
Status,
TaxpayerType,
)
# Import error models
from .errors import ErrorResponse, ValidationError, ValidationErrorResponse
# Import health models
from .health import HealthCheck, ServiceHealth
# Import request models
from .requests import (
DocumentUploadRequest,
ExtractionRequest,
FirmSyncRequest,
HMRCSubmissionRequest,
RAGSearchRequest,
ScheduleComputeRequest,
)
# Import response models
from .responses import (
DocumentUploadResponse,
ExtractionResponse,
FirmSyncResponse,
HMRCSubmissionResponse,
RAGSearchResponse,
ScheduleComputeResponse,
)
# Import utility functions
from .utils import get_entity_schemas
__all__ = [
# Enums
"DocumentKind",
"ExpenseType",
"HealthStatus",
"IncomeType",
"OverallStatus",
"PartySubtype",
"PropertyUsage",
"Role",
"Status",
"TaxpayerType",
# Entities
"Account",
"BaseEntity",
"Calculation",
"Document",
"Evidence",
"ExpenseItem",
"FormBox",
"IncomeItem",
"Party",
"Payment",
"PropertyAsset",
"Rule",
"TaxpayerProfile",
# Errors
"ErrorResponse",
"ValidationError",
"ValidationErrorResponse",
# Health
"HealthCheck",
"ServiceHealth",
# Requests
"DocumentUploadRequest",
"ExtractionRequest",
"FirmSyncRequest",
"HMRCSubmissionRequest",
"RAGSearchRequest",
"ScheduleComputeRequest",
# Responses
"DocumentUploadResponse",
"ExtractionResponse",
"FirmSyncResponse",
"HMRCSubmissionResponse",
"RAGSearchResponse",
"ScheduleComputeResponse",
# Utils
"get_entity_schemas",
# Coverage core models
"Validity",
"StatusClassifier",
"StatusClassifierConfig",
"EvidenceItem",
"CrossCheck",
"SchedulePolicy",
"Trigger",
"GuidanceRef",
"QuestionTemplates",
"ConflictRules",
"TaxYearBoundary",
"Defaults",
"Privacy",
"CoveragePolicy",
"CompiledCoveragePolicy",
# Coverage evaluation models
"FoundEvidence",
"Citation",
"CoverageItem",
"ScheduleCoverage",
"BlockingItem",
"CoverageReport",
"CoverageGap",
"ClarifyContext",
"UploadOption",
"ClarifyResponse",
# Coverage utility models
"PolicyError",
"ValidationResult",
"PolicyVersion",
"CoverageAudit",
]

View File

View File

@@ -0,0 +1,146 @@
"""Core coverage policy models."""
from collections.abc import Callable
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from ..enums import Role
class Validity(BaseModel):
"""Validity constraints for evidence"""
within_tax_year: bool = False
available_by: str | None = None
date_tolerance_days: int = 30
class StatusClassifier(BaseModel):
"""Rules for classifying evidence status"""
min_ocr: float = 0.82
min_extract: float = 0.85
date_in_year: bool = True
date_in_year_or_tolerance: bool = True
conflict_rules: list[str] = Field(default_factory=list)
class StatusClassifierConfig(BaseModel):
"""Complete status classifier configuration"""
present_verified: StatusClassifier
present_unverified: StatusClassifier
conflicting: StatusClassifier
missing: StatusClassifier = Field(default_factory=lambda: StatusClassifier())
class EvidenceItem(BaseModel):
"""Evidence requirement definition"""
id: str
role: Role
condition: str | None = None
boxes: list[str] = Field(default_factory=list)
acceptable_alternatives: list[str] = Field(default_factory=list)
validity: Validity = Field(default_factory=Validity)
reasons: dict[str, str] = Field(default_factory=dict)
class CrossCheck(BaseModel):
"""Cross-validation rule"""
name: str
logic: str
class SchedulePolicy(BaseModel):
"""Policy for a specific tax schedule"""
guidance_hint: str | None = None
evidence: list[EvidenceItem] = Field(default_factory=list)
cross_checks: list[CrossCheck] = Field(default_factory=list)
selection_rule: dict[str, str] = Field(default_factory=dict)
notes: dict[str, Any] = Field(default_factory=dict)
class Trigger(BaseModel):
"""Schedule trigger condition"""
any_of: list[str] = Field(default_factory=list)
all_of: list[str] = Field(default_factory=list)
class GuidanceRef(BaseModel):
"""Reference to guidance document"""
doc_id: str
kind: str
class QuestionTemplates(BaseModel):
"""Templates for generating clarifying questions"""
default: dict[str, str] = Field(default_factory=dict)
reasons: dict[str, str] = Field(default_factory=dict)
class ConflictRules(BaseModel):
"""Rules for handling conflicting evidence"""
precedence: list[str] = Field(default_factory=list)
escalation: dict[str, Any] = Field(default_factory=dict)
class TaxYearBoundary(BaseModel):
"""Tax year date boundaries"""
start: str
end: str
class Defaults(BaseModel):
"""Default configuration values"""
confidence_thresholds: dict[str, float] = Field(default_factory=dict)
date_tolerance_days: int = 30
require_lineage_bbox: bool = True
allow_bank_substantiation: bool = True
class Privacy(BaseModel):
"""Privacy and PII handling configuration"""
vector_pii_free: bool = True
redact_patterns: list[str] = Field(default_factory=list)
class CoveragePolicy(BaseModel):
"""Complete coverage policy definition"""
version: str
jurisdiction: str
tax_year: str
tax_year_boundary: TaxYearBoundary
defaults: Defaults
document_kinds: list[str] = Field(default_factory=list)
guidance_refs: dict[str, GuidanceRef] = Field(default_factory=dict)
triggers: dict[str, Trigger] = Field(default_factory=dict)
schedules: dict[str, SchedulePolicy] = Field(default_factory=dict)
status_classifier: StatusClassifierConfig
conflict_resolution: ConflictRules
question_templates: QuestionTemplates
privacy: Privacy
class CompiledCoveragePolicy(BaseModel):
"""Coverage policy with compiled predicates"""
policy: CoveragePolicy
compiled_predicates: dict[str, Callable[[str, str], bool]] = Field(
default_factory=dict
)
compiled_at: datetime
hash: str
source_files: list[str] = Field(default_factory=list)

View File

@@ -0,0 +1,112 @@
"""Coverage evaluation models."""
from datetime import datetime
from pydantic import BaseModel, Field
from ..enums import OverallStatus, Role, Status
class FoundEvidence(BaseModel):
"""Evidence found in the knowledge graph"""
doc_id: str
kind: str
confidence: float = 0.0
pages: list[int] = Field(default_factory=list)
bbox: dict[str, float] | None = None
ocr_confidence: float = 0.0
extract_confidence: float = 0.0
date: str | None = None
class Citation(BaseModel):
"""Citation reference"""
rule_id: str | None = None
doc_id: str | None = None
url: str | None = None
locator: str | None = None
section_id: str | None = None
page: int | None = None
bbox: dict[str, float] | None = None
class CoverageItem(BaseModel):
"""Coverage evaluation for a single evidence item"""
id: str
role: Role
status: Status
boxes: list[str] = Field(default_factory=list)
found: list[FoundEvidence] = Field(default_factory=list)
acceptable_alternatives: list[str] = Field(default_factory=list)
reason: str = ""
citations: list[Citation] = Field(default_factory=list)
class ScheduleCoverage(BaseModel):
"""Coverage evaluation for a schedule"""
schedule_id: str
status: OverallStatus
evidence: list[CoverageItem] = Field(default_factory=list)
class BlockingItem(BaseModel):
"""Item that blocks completion"""
schedule_id: str
evidence_id: str
class CoverageReport(BaseModel):
"""Complete coverage evaluation report"""
tax_year: str
taxpayer_id: str
schedules_required: list[str] = Field(default_factory=list)
overall_status: OverallStatus
coverage: list[ScheduleCoverage] = Field(default_factory=list)
blocking_items: list[BlockingItem] = Field(default_factory=list)
evaluated_at: datetime = Field(default_factory=datetime.utcnow)
policy_version: str = ""
class CoverageGap(BaseModel):
"""Gap in coverage requiring clarification"""
schedule_id: str
evidence_id: str
role: Role
reason: str
boxes: list[str] = Field(default_factory=list)
citations: list[Citation] = Field(default_factory=list)
acceptable_alternatives: list[str] = Field(default_factory=list)
class ClarifyContext(BaseModel):
"""Context for clarifying question"""
tax_year: str
taxpayer_id: str
jurisdiction: str
class UploadOption(BaseModel):
"""Upload option for user"""
label: str
accepted_formats: list[str] = Field(default_factory=list)
upload_endpoint: str
class ClarifyResponse(BaseModel):
"""Response to clarifying question request"""
question_text: str
why_it_is_needed: str
citations: list[Citation] = Field(default_factory=list)
options_to_provide: list[UploadOption] = Field(default_factory=list)
blocking: bool = False
boxes_affected: list[str] = Field(default_factory=list)

View File

@@ -0,0 +1,48 @@
"""Utility models for coverage system."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from ..enums import OverallStatus
class PolicyError(Exception):
"""Policy loading or validation error"""
pass
class ValidationResult(BaseModel):
"""Policy validation result"""
ok: bool
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
class PolicyVersion(BaseModel):
"""Policy version record"""
id: int | None = None
version: str
jurisdiction: str
tax_year: str
tenant_id: str | None = None
source_files: list[str] = Field(default_factory=list)
compiled_at: datetime
hash: str
class CoverageAudit(BaseModel):
"""Coverage audit record"""
id: int | None = None
taxpayer_id: str
tax_year: str
policy_version: str
overall_status: OverallStatus
blocking_items: list[dict[str, Any]] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.utcnow)
trace_id: str | None = None

230
libs/schemas/entities.py Normal file
View File

@@ -0,0 +1,230 @@
"""Core business entities with temporal modeling."""
from datetime import date, datetime
from decimal import Decimal
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from .enums import (
DocumentKind,
ExpenseType,
IncomeType,
PartySubtype,
PropertyUsage,
TaxpayerType,
)
class BaseEntity(BaseModel):
"""Base entity with temporal fields"""
model_config = ConfigDict(
str_strip_whitespace=True, validate_assignment=True, use_enum_values=True
)
# Temporal fields (bitemporal modeling)
valid_from: datetime = Field(
..., description="When the fact became valid in reality"
)
valid_to: datetime | None = Field(
None, description="When the fact ceased to be valid"
)
asserted_at: datetime = Field(
default_factory=datetime.utcnow, description="When recorded in system"
)
retracted_at: datetime | None = Field(
None, description="When retracted from system"
)
source: str = Field(..., description="Source of the information")
extractor_version: str = Field(..., description="Version of extraction system")
class TaxpayerProfile(BaseEntity):
"""Taxpayer profile entity"""
taxpayer_id: str = Field(..., description="Unique taxpayer identifier")
type: TaxpayerType = Field(..., description="Type of taxpayer")
utr: str | None = Field(
None, pattern=r"^\d{10}$", description="Unique Taxpayer Reference"
)
ni_number: str | None = Field(
None,
pattern=r"^[A-CEGHJ-PR-TW-Z]{2}\d{6}[A-D]$",
description="National Insurance Number",
)
residence: str | None = Field(None, description="Tax residence")
class Document(BaseEntity):
"""Document entity"""
doc_id: str = Field(
..., pattern=r"^doc_[a-f0-9]{16}$", description="Document identifier"
)
kind: DocumentKind = Field(..., description="Type of document")
source: str = Field(..., description="Source of document")
mime: str = Field(..., description="MIME type")
checksum: str = Field(
..., pattern=r"^[a-f0-9]{64}$", description="SHA-256 checksum"
)
file_size: int | None = Field(None, ge=0, description="File size in bytes")
pages: int | None = Field(None, ge=1, description="Number of pages")
date_range: dict[str, date] | None = Field(None, description="Document date range")
class Evidence(BaseEntity):
"""Evidence entity linking to document snippets"""
snippet_id: str = Field(..., description="Evidence snippet identifier")
doc_ref: str = Field(..., description="Reference to source document")
page: int = Field(..., ge=1, description="Page number")
bbox: list[float] | None = Field(
None, description="Bounding box coordinates [x1, y1, x2, y2]"
)
text_hash: str = Field(
..., pattern=r"^[a-f0-9]{64}$", description="SHA-256 hash of extracted text"
)
ocr_confidence: float | None = Field(
None, ge=0.0, le=1.0, description="OCR confidence score"
)
class IncomeItem(BaseEntity):
"""Income item entity"""
income_id: str = Field(..., description="Income item identifier")
type: IncomeType = Field(..., description="Type of income")
gross: Decimal = Field(..., ge=0, description="Gross amount")
net: Decimal | None = Field(None, ge=0, description="Net amount")
tax_withheld: Decimal | None = Field(None, ge=0, description="Tax withheld")
currency: str = Field(..., pattern=r"^[A-Z]{3}$", description="Currency code")
period_start: date | None = Field(None, description="Income period start")
period_end: date | None = Field(None, description="Income period end")
description: str | None = Field(None, description="Income description")
class ExpenseItem(BaseEntity):
"""Expense item entity"""
expense_id: str = Field(..., description="Expense item identifier")
type: ExpenseType = Field(..., description="Type of expense")
amount: Decimal = Field(..., ge=0, description="Expense amount")
currency: str = Field(..., pattern=r"^[A-Z]{3}$", description="Currency code")
description: str | None = Field(None, description="Expense description")
category: str | None = Field(None, description="Expense category")
allowable: bool | None = Field(None, description="Whether expense is allowable")
capitalizable_flag: bool | None = Field(
None, description="Whether expense should be capitalized"
)
vat_amount: Decimal | None = Field(None, ge=0, description="VAT amount")
net_amount: Decimal | None = Field(
None, ge=0, description="Net amount excluding VAT"
)
class Party(BaseEntity):
"""Party entity (person or organization)"""
party_id: str = Field(..., description="Party identifier")
name: str = Field(..., min_length=1, description="Party name")
subtype: PartySubtype | None = Field(None, description="Party subtype")
address: str | None = Field(None, description="Party address")
vat_number: str | None = Field(
None, pattern=r"^GB\d{9}$|^GB\d{12}$", description="UK VAT number"
)
utr: str | None = Field(
None, pattern=r"^\d{10}$", description="Unique Taxpayer Reference"
)
reg_no: str | None = Field(None, description="Registration number")
paye_reference: str | None = Field(None, description="PAYE reference")
class Account(BaseEntity):
"""Bank account entity"""
account_id: str = Field(..., description="Account identifier")
iban: str | None = Field(
None, pattern=r"^GB\d{2}[A-Z]{4}\d{14}$", description="UK IBAN"
)
sort_code: str | None = Field(
None, pattern=r"^\d{2}-\d{2}-\d{2}$", description="Sort code"
)
account_no: str | None = Field(
None, pattern=r"^\d{8}$", description="Account number"
)
institution: str | None = Field(None, description="Financial institution")
account_type: str | None = Field(None, description="Account type")
currency: str = Field(default="GBP", description="Account currency")
class PropertyAsset(BaseEntity):
"""Property asset entity"""
property_id: str = Field(..., description="Property identifier")
address: str = Field(..., min_length=10, description="Property address")
postcode: str | None = Field(
None, pattern=r"^[A-Z]{1,2}\d[A-Z0-9]?\s*\d[A-Z]{2}$", description="UK postcode"
)
tenure: str | None = Field(None, description="Property tenure")
ownership_share: float | None = Field(
None, ge=0.0, le=1.0, description="Ownership share"
)
usage: PropertyUsage | None = Field(None, description="Property usage type")
class Payment(BaseEntity):
"""Payment transaction entity"""
payment_id: str = Field(..., description="Payment identifier")
payment_date: date = Field(..., description="Payment date")
amount: Decimal = Field(
..., description="Payment amount (positive for credit, negative for debit)"
)
currency: str = Field(..., pattern=r"^[A-Z]{3}$", description="Currency code")
direction: str = Field(..., description="Payment direction (credit/debit)")
description: str | None = Field(None, description="Payment description")
reference: str | None = Field(None, description="Payment reference")
balance_after: Decimal | None = Field(
None, description="Account balance after payment"
)
class Calculation(BaseEntity):
"""Tax calculation entity"""
calculation_id: str = Field(..., description="Calculation identifier")
schedule: str = Field(..., description="Tax schedule (SA100, SA103, etc.)")
tax_year: str = Field(
..., pattern=r"^\d{4}-\d{2}$", description="Tax year (e.g., 2023-24)"
)
total_income: Decimal | None = Field(None, ge=0, description="Total income")
total_expenses: Decimal | None = Field(None, ge=0, description="Total expenses")
net_profit: Decimal | None = Field(None, description="Net profit/loss")
calculated_at: datetime = Field(
default_factory=datetime.utcnow, description="Calculation timestamp"
)
class FormBox(BaseEntity):
"""Form box entity"""
form: str = Field(..., description="Form identifier (SA100, SA103, etc.)")
box: str = Field(..., description="Box identifier")
value: Decimal | str | bool = Field(..., description="Box value")
description: str | None = Field(None, description="Box description")
confidence: float | None = Field(
None, ge=0.0, le=1.0, description="Confidence score"
)
class Rule(BaseEntity):
"""Tax rule entity"""
rule_id: str = Field(..., description="Rule identifier")
name: str = Field(..., description="Rule name")
description: str | None = Field(None, description="Rule description")
jurisdiction: str = Field(default="UK", description="Tax jurisdiction")
tax_years: list[str] = Field(..., description="Applicable tax years")
formula: str | None = Field(None, description="Rule formula")
conditions: dict[str, Any] | None = Field(None, description="Rule conditions")

102
libs/schemas/enums.py Normal file
View File

@@ -0,0 +1,102 @@
"""Enumeration types for the tax system."""
from enum import Enum
class TaxpayerType(str, Enum):
"""Taxpayer types"""
INDIVIDUAL = "Individual"
PARTNERSHIP = "Partnership"
COMPANY = "Company"
class DocumentKind(str, Enum):
"""Document types"""
BANK_STATEMENT = "bank_statement"
INVOICE = "invoice"
RECEIPT = "receipt"
P_AND_L = "p_and_l"
BALANCE_SHEET = "balance_sheet"
PAYSLIP = "payslip"
DIVIDEND_VOUCHER = "dividend_voucher"
PROPERTY_STATEMENT = "property_statement"
PRIOR_RETURN = "prior_return"
LETTER = "letter"
CERTIFICATE = "certificate"
class IncomeType(str, Enum):
"""Income types"""
EMPLOYMENT = "employment"
SELF_EMPLOYMENT = "self_employment"
PROPERTY = "property"
DIVIDEND = "dividend"
INTEREST = "interest"
OTHER = "other"
class ExpenseType(str, Enum):
"""Expense types"""
BUSINESS = "business"
PROPERTY = "property"
CAPITAL = "capital"
PERSONAL = "personal"
class PartySubtype(str, Enum):
"""Party subtypes"""
EMPLOYER = "Employer"
PAYER = "Payer"
BANK = "Bank"
LANDLORD = "Landlord"
TENANT = "Tenant"
SUPPLIER = "Supplier"
CLIENT = "Client"
class PropertyUsage(str, Enum):
"""Property usage types"""
RESIDENTIAL = "residential"
FURNISHED_HOLIDAY_LETTING = "furnished_holiday_letting"
COMMERCIAL = "commercial"
MIXED = "mixed"
class HealthStatus(str, Enum):
"""Health status values"""
HEALTHY = "healthy"
UNHEALTHY = "unhealthy"
DEGRADED = "degraded"
# Coverage evaluation enums
class Role(str, Enum):
"""Evidence role in coverage evaluation"""
REQUIRED = "REQUIRED"
CONDITIONALLY_REQUIRED = "CONDITIONALLY_REQUIRED"
OPTIONAL = "OPTIONAL"
class Status(str, Enum):
"""Evidence status classification"""
PRESENT_VERIFIED = "present_verified"
PRESENT_UNVERIFIED = "present_unverified"
MISSING = "missing"
CONFLICTING = "conflicting"
class OverallStatus(str, Enum):
"""Overall coverage status"""
OK = "ok"
PARTIAL = "partial"
BLOCKING = "blocking"

30
libs/schemas/errors.py Normal file
View File

@@ -0,0 +1,30 @@
"""Error response models."""
from typing import Any
from pydantic import BaseModel, Field
class ErrorResponse(BaseModel):
"""RFC7807 Problem+JSON error response"""
type: str = Field(..., description="Error type URI")
title: str = Field(..., description="Error title")
status: int = Field(..., description="HTTP status code")
detail: str = Field(..., description="Error detail")
instance: str = Field(..., description="Error instance URI")
trace_id: str | None = Field(None, description="Trace identifier")
class ValidationError(BaseModel):
"""Validation error details"""
field: str = Field(..., description="Field name")
message: str = Field(..., description="Error message")
value: Any = Field(..., description="Invalid value")
class ValidationErrorResponse(ErrorResponse):
"""Validation error response with field details"""
errors: list[ValidationError] = Field(..., description="Validation errors")

32
libs/schemas/health.py Normal file
View File

@@ -0,0 +1,32 @@
"""Health check models."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from .enums import HealthStatus
class HealthCheck(BaseModel):
"""Health check response"""
status: HealthStatus = Field(..., description="Overall health status")
timestamp: datetime = Field(
default_factory=datetime.utcnow, description="Check timestamp"
)
version: str = Field(..., description="Service version")
checks: dict[str, dict[str, Any]] = Field(
default_factory=dict, description="Individual checks"
)
class ServiceHealth(BaseModel):
"""Individual service health status"""
name: str = Field(..., description="Service name")
status: HealthStatus = Field(..., description="Service health status")
response_time_ms: float | None = Field(
None, description="Response time in milliseconds"
)
error: str | None = Field(None, description="Error message if unhealthy")

65
libs/schemas/requests.py Normal file
View File

@@ -0,0 +1,65 @@
"""API request models."""
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from .enums import DocumentKind
class DocumentUploadRequest(BaseModel):
"""Request model for document upload"""
tenant_id: str = Field(..., description="Tenant identifier")
kind: DocumentKind = Field(..., description="Document type")
source: str = Field(..., description="Document source")
class ExtractionRequest(BaseModel):
"""Request model for document extraction"""
strategy: str = Field(default="hybrid", description="Extraction strategy")
class RAGSearchRequest(BaseModel):
"""Request model for RAG search"""
query: str = Field(..., min_length=1, description="Search query")
tax_year: str | None = Field(None, description="Tax year filter")
jurisdiction: str | None = Field(None, description="Jurisdiction filter")
k: int = Field(default=10, ge=1, le=100, description="Number of results")
class ScheduleComputeRequest(BaseModel):
"""Request model for schedule computation"""
tax_year: str = Field(..., pattern=r"^\d{4}-\d{2}$", description="Tax year")
taxpayer_id: str = Field(..., description="Taxpayer identifier")
schedule_id: str = Field(..., description="Schedule identifier")
class HMRCSubmissionRequest(BaseModel):
"""Request model for HMRC submission"""
tax_year: str = Field(..., pattern=r"^\d{4}-\d{2}$", description="Tax year")
taxpayer_id: str = Field(..., description="Taxpayer identifier")
dry_run: bool = Field(default=True, description="Dry run flag")
class FirmSyncRequest(BaseModel):
"""Request to sync firm data"""
model_config = ConfigDict(extra="forbid")
firm_id: str = Field(..., description="Firm identifier")
system: str = Field(..., description="Practice management system to sync with")
sync_type: str = Field(
default="full", description="Type of sync: full, incremental"
)
force_refresh: bool = Field(
default=False, description="Force refresh of cached data"
)
connection_config: dict[str, Any] = Field(
...,
description="Configuration for connecting to the practice management system",
)

69
libs/schemas/responses.py Normal file
View File

@@ -0,0 +1,69 @@
"""API response models."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class DocumentUploadResponse(BaseModel):
"""Response model for document upload"""
doc_id: str = Field(..., description="Document identifier")
s3_url: str = Field(..., description="S3 URL")
checksum: str = Field(..., description="Document checksum")
class ExtractionResponse(BaseModel):
"""Response model for document extraction"""
extraction_id: str = Field(..., description="Extraction identifier")
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
extracted_fields: dict[str, Any] = Field(..., description="Extracted fields")
provenance: list[dict[str, Any]] = Field(..., description="Provenance information")
class RAGSearchResponse(BaseModel):
"""Response model for RAG search"""
chunks: list[dict[str, Any]] = Field(..., description="Retrieved chunks")
citations: list[dict[str, Any]] = Field(..., description="Source citations")
kg_hints: list[dict[str, Any]] = Field(..., description="Knowledge graph hints")
calibrated_confidence: float = Field(
..., ge=0.0, le=1.0, description="Calibrated confidence"
)
class ScheduleComputeResponse(BaseModel):
"""Response model for schedule computation"""
calculation_id: str = Field(..., description="Calculation identifier")
schedule: str = Field(..., description="Schedule identifier")
form_boxes: dict[str, dict[str, Any]] = Field(
..., description="Computed form boxes"
)
evidence_trail: list[dict[str, Any]] = Field(..., description="Evidence trail")
class HMRCSubmissionResponse(BaseModel):
"""Response model for HMRC submission"""
submission_id: str = Field(..., description="Submission identifier")
status: str = Field(..., description="Submission status")
hmrc_reference: str | None = Field(None, description="HMRC reference")
submission_timestamp: datetime = Field(..., description="Submission timestamp")
validation_results: dict[str, Any] = Field(..., description="Validation results")
class FirmSyncResponse(BaseModel):
"""Response from firm sync operation"""
model_config = ConfigDict(extra="forbid")
firm_id: str = Field(..., description="Firm identifier")
status: str = Field(..., description="Sync status: success, error, partial")
message: str = Field(..., description="Status message")
synced_entities: int = Field(default=0, description="Number of entities synced")
errors: list[str] = Field(
default_factory=list, description="List of errors encountered"
)

69
libs/schemas/utils.py Normal file
View File

@@ -0,0 +1,69 @@
"""Utility functions for schema export."""
from typing import Any
from .entities import (
Account,
Calculation,
Document,
Evidence,
ExpenseItem,
FormBox,
IncomeItem,
Party,
Payment,
PropertyAsset,
Rule,
TaxpayerProfile,
)
from .requests import (
DocumentUploadRequest,
ExtractionRequest,
FirmSyncRequest,
HMRCSubmissionRequest,
RAGSearchRequest,
ScheduleComputeRequest,
)
from .responses import (
DocumentUploadResponse,
ExtractionResponse,
FirmSyncResponse,
HMRCSubmissionResponse,
RAGSearchResponse,
ScheduleComputeResponse,
)
def get_entity_schemas() -> dict[str, dict[str, Any]]:
"""Export JSON schemas for all models"""
schemas = {}
# Core entities
schemas["TaxpayerProfile"] = TaxpayerProfile.model_json_schema()
schemas["Document"] = Document.model_json_schema()
schemas["Evidence"] = Evidence.model_json_schema()
schemas["IncomeItem"] = IncomeItem.model_json_schema()
schemas["ExpenseItem"] = ExpenseItem.model_json_schema()
schemas["Party"] = Party.model_json_schema()
schemas["Account"] = Account.model_json_schema()
schemas["PropertyAsset"] = PropertyAsset.model_json_schema()
schemas["Payment"] = Payment.model_json_schema()
schemas["Calculation"] = Calculation.model_json_schema()
schemas["FormBox"] = FormBox.model_json_schema()
schemas["Rule"] = Rule.model_json_schema()
# Request/Response models
schemas["DocumentUploadRequest"] = DocumentUploadRequest.model_json_schema()
schemas["DocumentUploadResponse"] = DocumentUploadResponse.model_json_schema()
schemas["ExtractionRequest"] = ExtractionRequest.model_json_schema()
schemas["ExtractionResponse"] = ExtractionResponse.model_json_schema()
schemas["RAGSearchRequest"] = RAGSearchRequest.model_json_schema()
schemas["RAGSearchResponse"] = RAGSearchResponse.model_json_schema()
schemas["ScheduleComputeRequest"] = ScheduleComputeRequest.model_json_schema()
schemas["ScheduleComputeResponse"] = ScheduleComputeResponse.model_json_schema()
schemas["HMRCSubmissionRequest"] = HMRCSubmissionRequest.model_json_schema()
schemas["HMRCSubmissionResponse"] = HMRCSubmissionResponse.model_json_schema()
schemas["FirmSyncRequest"] = FirmSyncRequest.model_json_schema()
schemas["FirmSyncResponse"] = FirmSyncResponse.model_json_schema()
return schemas

26
libs/security/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
"""Security utilities for authentication, authorization, and encryption."""
from .auth import AuthenticationHeaders
from .dependencies import (
get_current_tenant,
get_current_user,
get_tenant_id,
require_admin_role,
require_reviewer_role,
)
from .middleware import TrustedProxyMiddleware, create_trusted_proxy_middleware
from .utils import is_internal_request
from .vault import VaultTransitHelper
__all__ = [
"VaultTransitHelper",
"AuthenticationHeaders",
"TrustedProxyMiddleware",
"is_internal_request",
"require_admin_role",
"require_reviewer_role",
"get_current_tenant",
"get_current_user",
"get_tenant_id",
"create_trusted_proxy_middleware",
]

61
libs/security/auth.py Normal file
View File

@@ -0,0 +1,61 @@
"""Authentication headers parsing and validation."""
import structlog
from fastapi import HTTPException, Request, status
logger = structlog.get_logger()
class AuthenticationHeaders:
"""Parse and validate authentication headers from Traefik + Authentik"""
def __init__(self, request: Request):
self.request = request
self.headers = request.headers
@property
def authenticated_user(self) -> str | None:
"""Get authenticated user from headers"""
return self.headers.get("X-Authenticated-User")
@property
def authenticated_email(self) -> str | None:
"""Get authenticated email from headers"""
return self.headers.get("X-Authenticated-Email")
@property
def authenticated_groups(self) -> list[str]:
"""Get authenticated groups from headers"""
groups_header = self.headers.get("X-Authenticated-Groups", "")
return [g.strip() for g in groups_header.split(",") if g.strip()]
@property
def authorization_token(self) -> str | None:
"""Get JWT token from Authorization header"""
auth_header = self.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
return auth_header[7:]
return None
def has_role(self, role: str) -> bool:
"""Check if user has specific role"""
return role in self.authenticated_groups
def has_any_role(self, roles: list[str]) -> bool:
"""Check if user has any of the specified roles"""
return any(role in self.authenticated_groups for role in roles)
def require_role(self, role: str) -> None:
"""Require specific role or raise HTTPException"""
if not self.has_role(role):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=f"Role '{role}' required"
)
def require_any_role(self, roles: list[str]) -> None:
"""Require any of the specified roles or raise HTTPException"""
if not self.has_any_role(roles):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"One of roles {roles} required",
)

View File

@@ -0,0 +1,79 @@
"""FastAPI dependency functions for authentication and authorization."""
from collections.abc import Callable
from typing import Any
from fastapi import HTTPException, Request, status
def require_admin_role(request: Request) -> None:
"""Dependency to require admin role"""
auth = getattr(request.state, "auth", None)
if not auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required"
)
auth.require_role("admin")
def require_reviewer_role(request: Request) -> None:
"""Dependency to require reviewer role"""
auth = getattr(request.state, "auth", None)
if not auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required"
)
auth.require_any_role(["admin", "reviewer"])
def get_current_tenant(request: Request) -> str | None:
"""Extract tenant ID from user context or headers"""
# This could be extracted from JWT claims, user groups, or custom headers
# For now, we'll use a simple mapping from user to tenant
user = getattr(request.state, "user", None)
if not user:
return None
# Simple tenant extraction - in production this would be more sophisticated
# Could be from JWT claims, database lookup, or group membership
roles = getattr(request.state, "roles", [])
for role in roles:
if role.startswith("tenant:"):
return str(role.split(":", 1)[1])
# Default tenant for development
return "default"
# Dependency functions for FastAPI
def get_current_user() -> Callable[[Request], dict[str, Any]]:
"""FastAPI dependency to get current user"""
def _get_current_user(request: Request) -> dict[str, Any]:
user = getattr(request.state, "user", None)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
)
return {
"sub": user,
"email": getattr(request.state, "email", ""),
"roles": getattr(request.state, "roles", []),
}
return _get_current_user
def get_tenant_id() -> Callable[[Request], str]:
"""FastAPI dependency to get tenant ID"""
def _get_tenant_id(request: Request) -> str:
tenant_id = get_current_tenant(request)
if not tenant_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Tenant ID required"
)
return tenant_id
return _get_tenant_id

134
libs/security/middleware.py Normal file
View File

@@ -0,0 +1,134 @@
"""Trusted proxy middleware for authentication validation."""
from collections.abc import Callable
from typing import Any
import structlog
from fastapi import HTTPException, Request, status
from starlette.middleware.base import BaseHTTPMiddleware
from .auth import AuthenticationHeaders
from .utils import is_internal_request
logger = structlog.get_logger()
class TrustedProxyMiddleware(
BaseHTTPMiddleware
): # pylint: disable=too-few-public-methods
"""Middleware to validate requests from trusted proxy (Traefik)"""
def __init__(self, app: Any, internal_cidrs: list[str], disable_auth: bool = False):
super().__init__(app)
self.internal_cidrs = internal_cidrs
self.disable_auth = disable_auth
self.public_endpoints = {
"/healthz",
"/readyz",
"/livez",
"/metrics",
"/docs",
"/openapi.json",
}
async def dispatch(self, request: Request, call_next: Callable[..., Any]) -> Any:
"""Process request through middleware"""
# Get client IP (considering proxy headers)
client_ip = self._get_client_ip(request)
# Check if authentication is disabled (development mode)
if self.disable_auth:
# Set development state
request.state.user = "dev-user"
request.state.email = "dev@example.com"
request.state.roles = ["developers"]
request.state.auth_token = "dev-token"
logger.info(
"Development mode: authentication disabled", path=request.url.path
)
return await call_next(request)
# Check if this is a public endpoint
if request.url.path in self.public_endpoints:
# For metrics endpoint, still require internal network
if request.url.path == "/metrics":
if not is_internal_request(client_ip, self.internal_cidrs):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Metrics endpoint only accessible from internal network",
)
# Set minimal state for public endpoints
request.state.user = None
request.state.email = None
request.state.roles = []
return await call_next(request)
# For protected endpoints, validate authentication headers
auth_headers = AuthenticationHeaders(request)
# Require authentication headers
if not auth_headers.authenticated_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing X-Authenticated-User header",
)
if not auth_headers.authenticated_email:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing X-Authenticated-Email header",
)
if not auth_headers.authorization_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Authorization header",
)
# Set request state
request.state.user = auth_headers.authenticated_user
request.state.email = auth_headers.authenticated_email
request.state.roles = auth_headers.authenticated_groups
request.state.auth_token = auth_headers.authorization_token
# Add authentication helper to request
request.state.auth = auth_headers
logger.info(
"Authenticated request",
user=auth_headers.authenticated_user,
email=auth_headers.authenticated_email,
roles=auth_headers.authenticated_groups,
path=request.url.path,
)
return await call_next(request)
def _get_client_ip(self, request: Request) -> str:
"""Get client IP considering proxy headers"""
# Check X-Forwarded-For header first (set by Traefik)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# Take the first IP in the chain
return forwarded_for.split(",")[0].strip()
# Check X-Real-IP header
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# Fall back to direct client IP
return request.client.host if request.client else "unknown"
def create_trusted_proxy_middleware(
internal_cidrs: list[str],
) -> Callable[[Any], TrustedProxyMiddleware]:
"""Factory function to create TrustedProxyMiddleware"""
def middleware_factory( # pylint: disable=unused-argument
app: Any,
) -> TrustedProxyMiddleware:
return TrustedProxyMiddleware(app, internal_cidrs)
return middleware_factory

20
libs/security/utils.py Normal file
View File

@@ -0,0 +1,20 @@
"""Security utility functions."""
import ipaddress
import structlog
logger = structlog.get_logger()
def is_internal_request(client_ip: str, internal_cidrs: list[str]) -> bool:
"""Check if request comes from internal network"""
try:
client_addr = ipaddress.ip_address(client_ip)
for cidr in internal_cidrs:
if client_addr in ipaddress.ip_network(cidr):
return True
return False
except ValueError:
logger.warning("Invalid client IP address", client_ip=client_ip)
return False

58
libs/security/vault.py Normal file
View File

@@ -0,0 +1,58 @@
"""Vault Transit encryption/decryption helpers."""
import base64
import hvac
import structlog
logger = structlog.get_logger()
class VaultTransitHelper:
"""Helper for Vault Transit encryption/decryption"""
def __init__(self, vault_client: hvac.Client, mount_point: str = "transit"):
self.vault_client = vault_client
self.mount_point = mount_point
def encrypt_field(self, key_name: str, plaintext: str) -> str:
"""Encrypt a field using Vault Transit"""
try:
# Ensure key exists
self._ensure_key_exists(key_name)
# Encrypt the data
response = self.vault_client.secrets.transit.encrypt_data(
mount_point=self.mount_point,
name=key_name,
plaintext=base64.b64encode(plaintext.encode()).decode(),
)
return str(response["data"]["ciphertext"])
except Exception as e:
logger.error("Failed to encrypt field", key_name=key_name, error=str(e))
raise
def decrypt_field(self, key_name: str, ciphertext: str) -> str:
"""Decrypt a field using Vault Transit"""
try:
response = self.vault_client.secrets.transit.decrypt_data(
mount_point=self.mount_point, name=key_name, ciphertext=ciphertext
)
return base64.b64decode(response["data"]["plaintext"]).decode()
except Exception as e:
logger.error("Failed to decrypt field", key_name=key_name, error=str(e))
raise
def _ensure_key_exists(self, key_name: str) -> None:
"""Ensure encryption key exists in Vault"""
try:
self.vault_client.secrets.transit.read_key(
mount_point=self.mount_point, name=key_name
)
# pylint: disable-next=broad-exception-caught
except Exception: # hvac.exceptions.InvalidPath
# Key doesn't exist, create it
self.vault_client.secrets.transit.create_key(
mount_point=self.mount_point, name=key_name, key_type="aes256-gcm96"
)
logger.info("Created new encryption key", key_name=key_name)

9
libs/storage/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
"""Storage client and document management for MinIO/S3."""
from .client import StorageClient
from .document import DocumentStorage
__all__ = [
"StorageClient",
"DocumentStorage",
]

231
libs/storage/client.py Normal file
View File

@@ -0,0 +1,231 @@
"""MinIO/S3 storage client wrapper."""
from datetime import timedelta
from typing import Any, BinaryIO
import structlog
from minio import Minio
from minio.error import S3Error
logger = structlog.get_logger()
class StorageClient:
"""MinIO/S3 storage client wrapper"""
def __init__(self, minio_client: Minio):
self.client = minio_client
async def ensure_bucket(self, bucket_name: str, region: str = "us-east-1") -> bool:
"""Ensure bucket exists, create if not"""
try:
# Check if bucket exists
if self.client.bucket_exists(bucket_name):
logger.debug("Bucket already exists", bucket=bucket_name)
return True
# Create bucket
self.client.make_bucket(bucket_name, location=region)
logger.info("Created bucket", bucket=bucket_name, region=region)
return True
except S3Error as e:
logger.error("Failed to ensure bucket", bucket=bucket_name, error=str(e))
return False
async def put_object( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
bucket_name: str,
object_name: str,
data: BinaryIO,
length: int,
content_type: str = "application/octet-stream",
metadata: dict[str, str] | None = None,
) -> bool:
"""Upload object to bucket"""
try:
# Ensure bucket exists
await self.ensure_bucket(bucket_name)
# Upload object
result = self.client.put_object(
bucket_name=bucket_name,
object_name=object_name,
data=data,
length=length,
content_type=content_type,
metadata=metadata or {}, # fmt: skip # pyright: ignore[reportArgumentType]
)
logger.info(
"Object uploaded",
bucket=bucket_name,
object=object_name,
etag=result.etag,
size=length,
)
return True
except S3Error as e:
logger.error(
"Failed to upload object",
bucket=bucket_name,
object=object_name,
error=str(e),
)
return False
async def get_object(self, bucket_name: str, object_name: str) -> bytes | None:
"""Download object from bucket"""
try:
response = self.client.get_object(bucket_name, object_name)
data = response.read()
response.close()
response.release_conn()
logger.debug(
"Object downloaded",
bucket=bucket_name,
object=object_name,
size=len(data),
)
return data # type: ignore
except S3Error as e:
logger.error(
"Failed to download object",
bucket=bucket_name,
object=object_name,
error=str(e),
)
return None
async def get_object_stream(self, bucket_name: str, object_name: str) -> Any:
"""Get object as stream"""
try:
response = self.client.get_object(bucket_name, object_name)
return response
except S3Error as e:
logger.error(
"Failed to get object stream",
bucket=bucket_name,
object=object_name,
error=str(e),
)
return None
async def object_exists(self, bucket_name: str, object_name: str) -> bool:
"""Check if object exists"""
try:
self.client.stat_object(bucket_name, object_name)
return True
except S3Error:
return False
async def delete_object(self, bucket_name: str, object_name: str) -> bool:
"""Delete object from bucket"""
try:
self.client.remove_object(bucket_name, object_name)
logger.info("Object deleted", bucket=bucket_name, object=object_name)
return True
except S3Error as e:
logger.error(
"Failed to delete object",
bucket=bucket_name,
object=object_name,
error=str(e),
)
return False
async def list_objects(
self, bucket_name: str, prefix: str | None = None, recursive: bool = True
) -> list[str]:
"""List objects in bucket"""
try:
objects = self.client.list_objects(
bucket_name, prefix=prefix, recursive=recursive
)
return [obj.object_name for obj in objects if obj.object_name is not None]
except S3Error as e:
logger.error(
"Failed to list objects",
bucket=bucket_name,
prefix=prefix,
error=str(e),
)
return []
async def get_presigned_url(
self,
bucket_name: str,
object_name: str,
expires: timedelta = timedelta(hours=1),
method: str = "GET",
) -> str | None:
"""Generate presigned URL for object access"""
try:
url = self.client.get_presigned_url(
method=method,
bucket_name=bucket_name,
object_name=object_name,
expires=expires,
)
logger.debug(
"Generated presigned URL",
bucket=bucket_name,
object=object_name,
method=method,
expires=expires,
)
return str(url)
except S3Error as e:
logger.error(
"Failed to generate presigned URL",
bucket=bucket_name,
object=object_name,
error=str(e),
)
return None
async def copy_object(
self, source_bucket: str, source_object: str, dest_bucket: str, dest_object: str
) -> bool:
"""Copy object between buckets/locations"""
try:
# pylint: disable=import-outside-toplevel
from minio.commonconfig import CopySource
# Ensure destination bucket exists
await self.ensure_bucket(dest_bucket)
# Copy object
self.client.copy_object(
bucket_name=dest_bucket,
object_name=dest_object,
source=CopySource(source_bucket, source_object),
)
logger.info(
"Object copied",
source_bucket=source_bucket,
source_object=source_object,
dest_bucket=dest_bucket,
dest_object=dest_object,
)
return True
except S3Error as e:
logger.error(
"Failed to copy object",
source_bucket=source_bucket,
source_object=source_object,
dest_bucket=dest_bucket,
dest_object=dest_object,
error=str(e),
)
return False

145
libs/storage/document.py Normal file
View File

@@ -0,0 +1,145 @@
"""High-level document storage operations."""
import hashlib
import json
from io import BytesIO
from typing import Any
import structlog
from .client import StorageClient
logger = structlog.get_logger()
class DocumentStorage:
"""High-level document storage operations"""
def __init__(self, storage_client: StorageClient):
self.storage = storage_client
async def store_document( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
tenant_id: str,
doc_id: str,
content: bytes,
content_type: str = "application/pdf",
metadata: dict[str, str] | None = None,
bucket_name: str = "raw-documents",
) -> dict[str, Any]:
"""Store document with metadata"""
# Calculate checksum
checksum = hashlib.sha256(content).hexdigest()
# Prepare metadata
doc_metadata = {
"tenant_id": tenant_id,
"doc_id": doc_id,
"checksum": checksum,
"size": str(len(content)),
**(metadata or {}),
}
# Determine bucket and key
object_key = f"tenants/{tenant_id}/raw/{doc_id}.pdf"
# Upload to storage
success = await self.storage.put_object(
bucket_name=bucket_name,
object_name=object_key,
data=BytesIO(content),
length=len(content),
content_type=content_type,
metadata=doc_metadata,
)
if success:
return {
"bucket": bucket_name,
"key": object_key,
"checksum": checksum,
"size": len(content),
"s3_url": f"s3://{bucket_name}/{object_key}",
}
raise RuntimeError("Failed to store document")
async def store_ocr_result(
self, tenant_id: str, doc_id: str, ocr_data: dict[str, Any]
) -> str:
"""Store OCR results as JSON"""
bucket_name = "evidence"
object_key = f"tenants/{tenant_id}/ocr/{doc_id}.json"
# Convert to JSON bytes
json_data = json.dumps(ocr_data, indent=2).encode("utf-8")
# Upload to storage
success = await self.storage.put_object(
bucket_name=bucket_name,
object_name=object_key,
data=BytesIO(json_data),
length=len(json_data),
content_type="application/json",
)
if success:
return f"s3://{bucket_name}/{object_key}"
raise RuntimeError("Failed to store OCR result")
async def store_extraction_result(
self, tenant_id: str, doc_id: str, extraction_data: dict[str, Any]
) -> str:
"""Store extraction results as JSON"""
bucket_name = "evidence"
object_key = f"tenants/{tenant_id}/extractions/{doc_id}.json"
# Convert to JSON bytes
json_data = json.dumps(extraction_data, indent=2).encode("utf-8")
# Upload to storage
success = await self.storage.put_object(
bucket_name=bucket_name,
object_name=object_key,
data=BytesIO(json_data),
length=len(json_data),
content_type="application/json",
)
if success:
return f"s3://{bucket_name}/{object_key}"
raise RuntimeError("Failed to store extraction result")
async def get_document(self, tenant_id: str, doc_id: str) -> bytes | None:
"""Retrieve document content"""
bucket_name = "raw-documents"
object_key = f"tenants/{tenant_id}/raw/{doc_id}.pdf"
return await self.storage.get_object(bucket_name, object_key)
async def get_ocr_result(
self, tenant_id: str, doc_id: str
) -> dict[str, Any] | None:
"""Retrieve OCR results"""
bucket_name = "evidence"
object_key = f"tenants/{tenant_id}/ocr/{doc_id}.json"
data = await self.storage.get_object(bucket_name, object_key)
if data:
return json.loads(data.decode("utf-8")) # type: ignore
return None
async def get_extraction_result(
self, tenant_id: str, doc_id: str
) -> dict[str, Any] | None:
"""Retrieve extraction results"""
bucket_name = "evidence"
object_key = f"tenants/{tenant_id}/extractions/{doc_id}.json"
data = await self.storage.get_object(bucket_name, object_key)
if data:
return json.loads(data.decode("utf-8")) # type: ignore
return None