completed local setup with compose
Some checks failed
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled

This commit is contained in:
harkon
2025-11-26 13:17:17 +00:00
parent 8fe5e62fee
commit fdba81809f
87 changed files with 5610 additions and 3376 deletions

View File

@@ -1,7 +1,6 @@
"""Configuration management and client factories."""
from .factories import (
EventBusFactory,
MinIOClientFactory,
Neo4jDriverFactory,
QdrantClientFactory,
@@ -28,7 +27,6 @@ __all__ = [
"QdrantClientFactory",
"Neo4jDriverFactory",
"RedisClientFactory",
"EventBusFactory",
"get_settings",
"init_settings",
"create_vault_client",

View File

@@ -2,10 +2,8 @@
from typing import Any
import boto3 # type: ignore
import hvac
import redis.asyncio as redis
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer # type: ignore
from minio import Minio
from neo4j import GraphDatabase
from qdrant_client import QdrantClient
@@ -87,36 +85,3 @@ class RedisClientFactory: # pylint: disable=too-few-public-methods
return redis.from_url(
settings.redis_url, encoding="utf-8", decode_responses=True
)
class EventBusFactory:
"""Factory for creating event bus clients"""
@staticmethod
def create_kafka_producer(settings: BaseAppSettings) -> AIOKafkaProducer:
"""Create Kafka producer"""
return AIOKafkaProducer(
bootstrap_servers=settings.kafka_bootstrap_servers,
value_serializer=lambda v: v.encode("utf-8") if isinstance(v, str) else v,
)
@staticmethod
def create_kafka_consumer(
settings: BaseAppSettings, topics: list[str]
) -> AIOKafkaConsumer:
"""Create Kafka consumer"""
return AIOKafkaConsumer(
*topics,
bootstrap_servers=settings.kafka_bootstrap_servers,
value_deserializer=lambda m: m.decode("utf-8") if m else None,
)
@staticmethod
def create_sqs_client(settings: BaseAppSettings) -> Any:
"""Create SQS client"""
return boto3.client("sqs", region_name=settings.aws_region)
@staticmethod
def create_sns_client(settings: BaseAppSettings) -> Any:
"""Create SNS client"""
return boto3.client("sns", region_name=settings.aws_region)

View File

@@ -8,7 +8,7 @@ class BaseAppSettings(BaseSettings):
"""Base settings class for all services"""
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore"
)
# Service identification

View File

@@ -67,27 +67,20 @@ async def create_redis_client(settings: BaseAppSettings) -> "redis.Redis[str]":
def create_event_bus(settings: BaseAppSettings) -> EventBus:
"""Create event bus"""
if settings.event_bus_type.lower() == "kafka":
# pylint: disable=import-outside-toplevel
from ..events import KafkaEventBus
return KafkaEventBus(settings.kafka_bootstrap_servers)
if settings.event_bus_type.lower() == "sqs":
# pylint: disable=import-outside-toplevel
from ..events import SQSEventBus
return SQSEventBus(settings.aws_region)
if settings.event_bus_type.lower() == "memory":
# pylint: disable=import-outside-toplevel
from ..events import MemoryEventBus
return MemoryEventBus()
# Default to memory bus for unknown types
# pylint: disable=import-outside-toplevel
from ..events import MemoryEventBus
from libs.events import create_event_bus as _create_event_bus
return MemoryEventBus()
# Extract NATS servers as a list
nats_servers = [s.strip() for s in settings.nats_servers.split(",")]
return _create_event_bus(
settings.event_bus_type,
servers=nats_servers,
stream_name=settings.nats_stream_name,
consumer_group=settings.nats_consumer_group,
bootstrap_servers=settings.kafka_bootstrap_servers,
region_name=settings.aws_region,
)
def get_default_settings(**overrides: Any) -> BaseAppSettings:

View File

@@ -1,20 +1,52 @@
"""Event-driven architecture with Kafka, SQS, NATS, and Memory support."""
from libs.schemas.events import (
EVENT_SCHEMA_MAP,
BaseEventData,
CalculationReadyEventData,
DocumentExtractedEventData,
DocumentIngestedEventData,
DocumentOCRReadyEventData,
FirmSyncCompletedEventData,
FormFilledEventData,
HMRCSubmittedEventData,
KGUpsertedEventData,
KGUpsertReadyEventData,
RAGIndexedEventData,
ReviewCompletedEventData,
ReviewRequestedEventData,
get_schema_for_topic,
validate_event_data,
)
from .base import EventBus, EventPayload
from .factory import create_event_bus
from .kafka_bus import KafkaEventBus
from .memory_bus import MemoryEventBus
from .nats_bus import NATSEventBus
from .sqs_bus import SQSEventBus
from .topics import EventTopics
__all__ = [
"EventPayload",
"EventBus",
"KafkaEventBus",
"MemoryEventBus",
"NATSEventBus",
"SQSEventBus",
"create_event_bus",
"EventTopics",
# Event schemas
"BaseEventData",
"DocumentIngestedEventData",
"DocumentOCRReadyEventData",
"DocumentExtractedEventData",
"KGUpsertReadyEventData",
"KGUpsertedEventData",
"RAGIndexedEventData",
"CalculationReadyEventData",
"FormFilledEventData",
"HMRCSubmittedEventData",
"ReviewRequestedEventData",
"ReviewCompletedEventData",
"FirmSyncCompletedEventData",
"EVENT_SCHEMA_MAP",
"validate_event_data",
"get_schema_for_topic",
]

View File

@@ -3,7 +3,7 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from datetime import datetime
from datetime import UTC, datetime
from typing import Any
import ulid
@@ -22,7 +22,7 @@ class EventPayload:
schema_version: str = "1.0",
):
self.event_id = str(ulid.new())
self.occurred_at = datetime.utcnow().isoformat() + "Z"
self.occurred_at = datetime.now(UTC).isoformat()
self.actor = actor
self.tenant_id = tenant_id
self.trace_id = trace_id

View File

@@ -7,7 +7,7 @@ from collections.abc import Awaitable, Callable
import structlog
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer # type: ignore
from .base import EventBus, EventPayload
from ..base import EventBus, EventPayload
logger = structlog.get_logger()

View File

@@ -9,7 +9,7 @@ import boto3 # type: ignore
import structlog
from botocore.exceptions import ClientError # type: ignore
from .base import EventBus, EventPayload
from ..base import EventBus, EventPayload
logger = structlog.get_logger()

271
libs/events/dlq.py Normal file
View File

@@ -0,0 +1,271 @@
"""Dead Letter Queue (DLQ) handler for failed event processing."""
import asyncio
import json
from datetime import UTC, datetime
from typing import Any
import structlog
from nats.js import JetStreamContext
from .base import EventPayload
logger = structlog.get_logger()
class DLQHandler:
"""
Dead Letter Queue handler for processing failed events.
Captures events that fail processing after max retries and stores them
in a separate NATS stream for manual review and retry.
"""
def __init__(
self,
js: JetStreamContext,
dlq_stream_name: str = "TAX_AGENT_DLQ",
max_retries: int = 3,
backoff_base_ms: int = 1000,
backoff_multiplier: float = 2.0,
backoff_max_ms: int = 30000,
):
"""
Initialize DLQ handler.
Args:
js: NATS JetStream context
dlq_stream_name: Name of the DLQ stream
max_retries: Maximum number of retry attempts
backoff_base_ms: Base backoff time in milliseconds
backoff_multiplier: Exponential backoff multiplier
backoff_max_ms: Maximum backoff time in milliseconds
"""
self.js = js
self.dlq_stream_name = dlq_stream_name
self.max_retries = max_retries
self.backoff_base_ms = backoff_base_ms
self.backoff_multiplier = backoff_multiplier
self.backoff_max_ms = backoff_max_ms
async def ensure_dlq_stream_exists(self) -> None:
"""Ensure DLQ stream exists in JetStream."""
try:
# Try to get stream info
await self.js.stream_info(self.dlq_stream_name)
logger.debug("DLQ stream already exists", stream=self.dlq_stream_name)
except Exception:
# Stream doesn't exist, create it
try:
await self.js.add_stream(
name=self.dlq_stream_name,
subjects=[f"{self.dlq_stream_name}.>"],
# Keep DLQ messages for 30 days
max_age=30 * 24 * 60 * 60, # 30 days in seconds
)
logger.info("Created DLQ stream", stream=self.dlq_stream_name)
except Exception as e:
logger.error(
"Failed to create DLQ stream",
stream=self.dlq_stream_name,
error=str(e),
)
raise
async def send_to_dlq(
self,
topic: str,
payload: EventPayload,
error: Exception,
retry_count: int,
original_message_data: bytes | None = None,
) -> None:
"""
Send failed event to DLQ.
Args:
topic: Original topic name
payload: Event payload
error: Exception that caused the failure
retry_count: Number of retry attempts made
original_message_data: Original message data (optional, for debugging)
"""
try:
# Create DLQ subject
dlq_subject = f"{self.dlq_stream_name}.{topic}"
# Create DLQ payload with metadata
dlq_payload = {
"original_topic": topic,
"original_payload": payload.to_dict(),
"error": {
"type": type(error).__name__,
"message": str(error),
},
"retry_count": retry_count,
"failed_at": datetime.now(UTC).isoformat(),
"tenant_id": payload.tenant_id,
"event_id": payload.event_id,
"trace_id": payload.trace_id,
}
# Add original message data if available
if original_message_data:
try:
dlq_payload["original_message_data"] = original_message_data.decode(
"utf-8"
)
except UnicodeDecodeError:
dlq_payload["original_message_data"] = "<binary data>"
# Publish to DLQ
headers = {
"original_topic": topic,
"tenant_id": payload.tenant_id,
"event_id": payload.event_id,
"error_type": type(error).__name__,
"retry_count": str(retry_count),
}
await self.js.publish(
subject=dlq_subject,
payload=json.dumps(dlq_payload).encode(),
headers=headers,
)
logger.error(
"Event sent to DLQ",
topic=topic,
event_id=payload.event_id,
error=str(error),
retry_count=retry_count,
dlq_subject=dlq_subject,
)
except Exception as dlq_error:
logger.critical(
"Failed to send event to DLQ - EVENT LOST",
topic=topic,
event_id=payload.event_id,
original_error=str(error),
dlq_error=str(dlq_error),
)
def calculate_backoff(self, retry_count: int) -> float:
"""
Calculate exponential backoff delay.
Args:
retry_count: Current retry attempt (0-indexed)
Returns:
Backoff delay in seconds
"""
# Calculate exponential backoff: base * (multiplier ^ retry_count)
backoff_ms = self.backoff_base_ms * (self.backoff_multiplier**retry_count)
# Cap at maximum backoff
backoff_ms = min(backoff_ms, self.backoff_max_ms)
# Convert to seconds
return backoff_ms / 1000.0
async def retry_with_backoff(
self,
func: Any,
*args: Any,
**kwargs: Any,
) -> tuple[bool, Exception | None]:
"""
Retry a function with exponential backoff.
Args:
func: Async function to retry
*args: Position arguments for the function
**kwargs: Keyword arguments for the function
Returns:
Tuple of (success: bool, last_error: Exception | None)
"""
last_error: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
await func(*args, **kwargs)
return (True, None)
except Exception as e: # pylint: disable=broad-exception-caught
last_error = e
if attempt < self.max_retries:
# Calculate and apply backoff
backoff_seconds = self.calculate_backoff(attempt)
logger.warning(
"Retry attempt failed, backing off",
attempt=attempt + 1,
max_retries=self.max_retries,
backoff_seconds=backoff_seconds,
error=str(e),
)
await asyncio.sleep(backoff_seconds)
else:
logger.error(
"All retry attempts exhausted",
attempts=self.max_retries + 1,
error=str(e),
)
return (False, last_error)
class DLQMetrics:
"""Metrics for DLQ operations."""
def __init__(self) -> None:
"""Initialize DLQ metrics."""
self.total_dlq_events = 0
self.dlq_events_by_topic: dict[str, int] = {}
self.dlq_events_by_error_type: dict[str, int] = {}
def record_dlq_event(self, topic: str, error_type: str) -> None:
"""
Record a DLQ event.
Args:
topic: Original topic name
error_type: Type of error that caused DLQ
"""
self.total_dlq_events += 1
# Track by topic
if topic not in self.dlq_events_by_topic:
self.dlq_events_by_topic[topic] = 0
self.dlq_events_by_topic[topic] += 1
# Track by error type
if error_type not in self.dlq_events_by_error_type:
self.dlq_events_by_error_type[error_type] = 0
self.dlq_events_by_error_type[error_type] += 1
def get_metrics(self) -> dict[str, Any]:
"""
Get DLQ metrics.
Returns:
Dictionary of metrics
"""
return {
"total_dlq_events": self.total_dlq_events,
"by_topic": self.dlq_events_by_topic.copy(),
"by_error_type": self.dlq_events_by_error_type.copy(),
}
def reset(self) -> None:
"""Reset all metrics to zero."""
self.total_dlq_events = 0
self.dlq_events_by_topic.clear()
self.dlq_events_by_error_type.clear()

View File

@@ -3,16 +3,20 @@
from typing import Any
from .base import EventBus
from .kafka_bus import KafkaEventBus
from .nats_bus import NATSEventBus
from .sqs_bus import SQSEventBus
def create_event_bus(bus_type: str, **kwargs: Any) -> EventBus:
"""Factory function to create event bus"""
if bus_type.lower() == "kafka":
# Lazy import to avoid ModuleNotFoundError when aiokafka is not installed
from .contrib.kafka_bus import KafkaEventBus
return KafkaEventBus(kwargs.get("bootstrap_servers", "localhost:9092"))
if bus_type.lower() == "sqs":
# Lazy import to avoid ModuleNotFoundError when boto3 is not installed
from .contrib.sqs_bus import SQSEventBus
return SQSEventBus(kwargs.get("region_name", "us-east-1"))
if bus_type.lower() == "nats":
return NATSEventBus(

225
libs/events/metrics.py Normal file
View File

@@ -0,0 +1,225 @@
"""Prometheus metrics for event bus monitoring."""
from prometheus_client import Counter, Histogram
from prometheus_client.registry import CollectorRegistry
# Global registry for event metrics
_event_registry = CollectorRegistry()
# Event publishing metrics
event_published_total = Counter(
"event_published_total",
"Total number of events published",
["topic"],
registry=_event_registry,
)
event_publish_errors_total = Counter(
"event_publish_errors_total",
"Total number of event publishing errors",
["topic", "error_type"],
registry=_event_registry,
)
event_publishing_duration_seconds = Histogram(
"event_publishing_duration_seconds",
"Time spent publishing events in seconds",
["topic"],
buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0),
registry=_event_registry,
)
# Event consumption metrics
event_consumed_total = Counter(
"event_consumed_total",
"Total number of events consumed",
["topic", "consumer_group"],
registry=_event_registry,
)
event_processing_duration_seconds = Histogram(
"event_processing_duration_seconds",
"Time spent processing events in seconds",
["topic", "consumer_group"],
buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0),
registry=_event_registry,
)
event_processing_errors_total = Counter(
"event_processing_errors_total",
"Total number of event processing errors",
["topic", "consumer_group", "error_type"],
registry=_event_registry,
)
# DLQ metrics
event_dlq_total = Counter(
"event_dlq_total",
"Total number of events sent to dead letter queue",
["topic", "error_type"],
registry=_event_registry,
)
event_retry_total = Counter(
"event_retry_total",
"Total number of event retry attempts",
["topic", "retry_attempt"],
registry=_event_registry,
)
# Schema validation metrics
event_schema_validation_errors_total = Counter(
"event_schema_validation_errors_total",
"Total number of event schema validation errors",
["topic", "validation_error"],
registry=_event_registry,
)
# NATS JetStream specific metrics
nats_stream_messages_total = Counter(
"nats_stream_messages_total",
"Total messages in NATS stream",
["stream_name"],
registry=_event_registry,
)
nats_consumer_lag_messages = Histogram(
"nats_consumer_lag_messages",
"Number of messages consumer is lagging behind",
["stream_name", "consumer_group"],
buckets=(0, 1, 5, 10, 25, 50, 100, 250, 500, 1000, 5000, 10000),
registry=_event_registry,
)
def get_event_metrics_registry() -> CollectorRegistry:
"""
Get the Prometheus registry for event metrics.
Returns:
CollectorRegistry for event metrics
"""
return _event_registry
class EventMetricsCollector:
"""Helper class for collecting event metrics."""
@staticmethod
def record_publish(
topic: str,
duration_seconds: float,
success: bool = True,
error_type: str | None = None,
) -> None:
"""
Record event publishing metrics.
Args:
topic: Event topic name
duration_seconds: Time taken to publish
success: Whether publishing succeeded
error_type: Type of error if failed
"""
if success:
event_published_total.labels(topic=topic).inc()
else:
event_publish_errors_total.labels(
topic=topic, error_type=error_type or "unknown"
).inc()
event_publishing_duration_seconds.labels(topic=topic).observe(duration_seconds)
@staticmethod
def record_consume(
topic: str,
consumer_group: str,
duration_seconds: float,
success: bool = True,
error_type: str | None = None,
) -> None:
"""
Record event consumption metrics.
Args:
topic: Event topic name
consumer_group: Consumer group name
duration_seconds: Time taken to process event
success: Whether processing succeeded
error_type: Type of error if failed
"""
if success:
event_consumed_total.labels(
topic=topic, consumer_group=consumer_group
).inc()
else:
event_processing_errors_total.labels(
topic=topic,
consumer_group=consumer_group,
error_type=error_type or "unknown",
).inc()
event_processing_duration_seconds.labels(
topic=topic, consumer_group=consumer_group
).observe(duration_seconds)
@staticmethod
def record_dlq(topic: str, error_type: str) -> None:
"""
Record event sent to DLQ.
Args:
topic: Event topic name
error_type: Type of error that caused DLQ
"""
event_dlq_total.labels(topic=topic, error_type=error_type).inc()
@staticmethod
def record_retry(topic: str, retry_attempt: int) -> None:
"""
Record event retry attempt.
Args:
topic: Event topic name
retry_attempt: Retry attempt number (1-indexed)
"""
event_retry_total.labels(topic=topic, retry_attempt=str(retry_attempt)).inc()
@staticmethod
def record_schema_validation_error(topic: str, validation_error: str) -> None:
"""
Record schema validation error.
Args:
topic: Event topic name
validation_error: Type of validation error
"""
event_schema_validation_errors_total.labels(
topic=topic, validation_error=validation_error
).inc()
@staticmethod
def record_nats_stream_message(stream_name: str) -> None:
"""
Record message added to NATS stream.
Args:
stream_name: NATS stream name
"""
nats_stream_messages_total.labels(stream_name=stream_name).inc()
@staticmethod
def record_consumer_lag(
stream_name: str, consumer_group: str, lag_messages: int
) -> None:
"""
Record consumer lag.
Args:
stream_name: NATS stream name
consumer_group: Consumer group name
lag_messages: Number of messages consumer is behind
"""
nats_consumer_lag_messages.labels(
stream_name=stream_name, consumer_group=consumer_group
).observe(lag_messages)

View File

@@ -2,6 +2,7 @@
import asyncio
import json
import time
from collections.abc import Awaitable, Callable
from typing import Any
@@ -12,6 +13,8 @@ from nats.js import JetStreamContext
from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy
from .base import EventBus, EventPayload
from .dlq import DLQHandler
from .metrics import EventMetricsCollector
logger = structlog.get_logger()
@@ -24,6 +27,8 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
servers: str | list[str] = "nats://localhost:4222",
stream_name: str = "TAX_AGENT_EVENTS",
consumer_group: str = "tax-agent",
dlq_stream_name: str = "TAX_AGENT_DLQ",
max_retries: int = 3,
):
if isinstance(servers, str):
self.servers = [servers]
@@ -32,8 +37,13 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
self.stream_name = stream_name
self.consumer_group = consumer_group
self.dlq_stream_name = dlq_stream_name
self.max_retries = max_retries
self.nc: NATS | None = None
self.js: JetStreamContext | None = None
self.dlq: DLQHandler | None = None
self.handlers: dict[
str, list[Callable[[str, EventPayload], Awaitable[None]]]
] = {}
@@ -48,19 +58,32 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
try:
# Connect to NATS
self.nc = await nats.connect(servers=self.servers)
self.nc = await nats.connect(
servers=self.servers,
connect_timeout=10,
reconnect_time_wait=1,
)
# Get JetStream context
self.js = self.nc.jetstream()
self.js = self.nc.jetstream(timeout=10)
# Ensure stream exists
# Initialize DLQ handler
self.dlq = DLQHandler(
js=self.js,
dlq_stream_name=self.dlq_stream_name,
max_retries=self.max_retries,
)
# Ensure streams exist
await self._ensure_stream_exists()
await self.dlq.ensure_dlq_stream_exists()
self.running = True
logger.info(
"NATS event bus started",
servers=self.servers,
stream=self.stream_name,
dlq_stream=self.dlq_stream_name,
)
except Exception as e:
@@ -98,6 +121,7 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
if not self.js:
raise RuntimeError("Event bus not started")
start_time = time.perf_counter()
try:
# Create subject name from topic
subject = f"{self.stream_name}.{topic}"
@@ -117,6 +141,13 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
headers=headers,
)
duration = time.perf_counter() - start_time
EventMetricsCollector.record_publish(
topic=topic,
duration_seconds=duration,
success=True,
)
logger.info(
"Event published",
topic=topic,
@@ -127,6 +158,14 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
return True
except Exception as e: # pylint: disable=broad-exception-caught
duration = time.perf_counter() - start_time
EventMetricsCollector.record_publish(
topic=topic,
duration_seconds=duration,
success=False,
error_type=type(e).__name__,
)
logger.error(
"Failed to publish event",
topic=topic,
@@ -152,9 +191,13 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
subject = f"{self.stream_name}.{topic}"
# Create durable consumer
consumer_name = f"{self.consumer_group}-{topic}"
# Durable names cannot contain dots, so we replace them
safe_topic = topic.replace(".", "-")
consumer_name = f"{self.consumer_group}-{safe_topic}"
# Subscribe with pull-based consumer
# Set max_deliver to max_retries + 1 (initial + retries)
# We handle DLQ manually before NATS gives up
subscription = await self.js.pull_subscribe(
subject=subject,
durable=consumer_name,
@@ -162,7 +205,7 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
durable_name=consumer_name,
ack_policy=AckPolicy.EXPLICIT,
deliver_policy=DeliverPolicy.NEW,
max_deliver=3,
max_deliver=self.max_retries + 2, # Give us room to handle DLQ
ack_wait=30, # 30 seconds
),
)
@@ -193,13 +236,14 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
# Try to get stream info
await self.js.stream_info(self.stream_name)
logger.debug("Stream already exists", stream=self.stream_name)
EventMetricsCollector.record_nats_stream_message(self.stream_name)
except Exception:
# Stream doesn't exist, create it
try:
await self.js.add_stream(
name=self.stream_name,
subjects=[f"{self.stream_name}.*"],
subjects=[f"{self.stream_name}.>"],
)
logger.info("Created JetStream stream", stream=self.stream_name)
@@ -214,12 +258,17 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
while self.running:
try:
# Fetch messages in batches
messages = await subscription.fetch(batch=10, timeout=20)
messages = await subscription.fetch(batch=10, timeout=5)
for message in messages:
start_time = time.perf_counter()
payload = None
try:
print(f"DEBUG: Received message: {message.data}")
# Parse message payload
payload_dict = json.loads(message.data.decode())
print(f"DEBUG: Parsed payload: {payload_dict}")
payload = EventPayload(
data=payload_dict["data"],
@@ -230,38 +279,87 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
)
payload.event_id = payload_dict["event_id"]
payload.occurred_at = payload_dict["occurred_at"]
print(f"DEBUG: Reconstructed payload: {payload.event_id}")
# Call all handlers for this topic
for handler in self.handlers.get(topic, []):
try:
await handler(topic, payload)
except (
Exception
) as e: # pylint: disable=broad-exception-caught
logger.error(
"Handler failed",
topic=topic,
event_id=payload.event_id,
error=str(e),
)
print(f"DEBUG: Calling handler for topic {topic}")
await handler(topic, payload)
# Acknowledge message
await message.ack()
print("DEBUG: Message acked")
except json.JSONDecodeError as e:
logger.error(
"Failed to decode message", topic=topic, error=str(e)
# Record metrics
duration = time.perf_counter() - start_time
EventMetricsCollector.record_consume(
topic=topic,
consumer_group=self.consumer_group,
duration_seconds=duration,
success=True,
)
await message.nak()
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to process message", topic=topic, error=str(e)
duration = time.perf_counter() - start_time
error_type = type(e).__name__
# Record failure metric
EventMetricsCollector.record_consume(
topic=topic,
consumer_group=self.consumer_group,
duration_seconds=duration,
success=False,
error_type=error_type,
)
await message.nak()
# Check delivery count for DLQ
try:
metadata = message.metadata
num_delivered = (
metadata.sequence.consumer
) # This might be wrong, check docs
# Actually nats-py MsgMetadata has num_delivered
num_delivered = metadata.num_delivered
except Exception:
num_delivered = 1
if num_delivered >= self.max_retries:
logger.error(
"Max retries exceeded, sending to DLQ",
topic=topic,
event_id=payload.event_id if payload else "unknown",
error=str(e),
num_delivered=num_delivered,
)
if self.dlq and payload:
await self.dlq.send_to_dlq(
topic=topic,
payload=payload,
error=e,
retry_count=num_delivered,
original_message_data=message.data,
)
EventMetricsCollector.record_dlq(topic, error_type)
# Ack to remove from main stream
await message.ack()
else:
# Retry (Nak)
logger.warning(
"Processing failed, retrying",
topic=topic,
event_id=payload.event_id if payload else "unknown",
error=str(e),
attempt=num_delivered,
)
EventMetricsCollector.record_retry(topic, num_delivered)
await message.nak()
except TimeoutError:
# No messages available, continue polling
continue
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Consumer error", topic=topic, error=str(e))
await asyncio.sleep(5) # Wait before retrying
await asyncio.sleep(1) # Wait before retrying

View File

@@ -7,6 +7,7 @@ class EventTopics: # pylint: disable=too-few-public-methods
DOC_INGESTED = "doc.ingested"
DOC_OCR_READY = "doc.ocr_ready"
DOC_EXTRACTED = "doc.extracted"
KG_UPSERT_READY = "kg.upsert.ready"
KG_UPSERTED = "kg.upserted"
RAG_INDEXED = "rag.indexed"
CALC_SCHEDULE_READY = "calc.schedule_ready"

View File

@@ -11,8 +11,8 @@ psycopg2-binary>=2.9.11
neo4j>=6.0.2
redis[hiredis]>=6.4.0
# Object storage and vector database
minio>=7.2.18
boto3>=1.34.0
qdrant-client>=1.15.1
# Event streaming (NATS only - removed Kafka)
@@ -36,3 +36,13 @@ python-multipart>=0.0.20
python-dateutil>=2.9.0
python-dotenv>=1.1.1
orjson>=3.11.3
jsonschema>=4.20.0
# OpenTelemetry instrumentation (for observability)
opentelemetry-api>=1.21.0
opentelemetry-sdk>=1.21.0
opentelemetry-exporter-otlp-proto-grpc>=1.21.0
opentelemetry-instrumentation-fastapi>=0.42b0
opentelemetry-instrumentation-httpx>=0.42b0
opentelemetry-instrumentation-psycopg2>=0.42b0
opentelemetry-instrumentation-redis>=0.42b0

View File

@@ -65,6 +65,26 @@ from .enums import (
# Import error models
from .errors import ErrorResponse, ValidationError, ValidationErrorResponse
# Import event schemas
from .events import (
EVENT_SCHEMA_MAP,
BaseEventData,
CalculationReadyEventData,
DocumentExtractedEventData,
DocumentIngestedEventData,
DocumentOCRReadyEventData,
FirmSyncCompletedEventData,
FormFilledEventData,
HMRCSubmittedEventData,
KGUpsertedEventData,
KGUpsertReadyEventData,
RAGIndexedEventData,
ReviewCompletedEventData,
ReviewRequestedEventData,
get_schema_for_topic,
validate_event_data,
)
# Import health models
from .health import HealthCheck, ServiceHealth
@@ -135,7 +155,7 @@ __all__ = [
"DocumentUploadResponse",
"ExtractionResponse",
"FirmSyncResponse",
"HMRCSubmissionResponse",
"HMRCSubmittedEventData",
"RAGSearchResponse",
"ScheduleComputeResponse",
# Utils
@@ -172,4 +192,21 @@ __all__ = [
"ValidationResult",
"PolicyVersion",
"CoverageAudit",
# Event schemas
"BaseEventData",
"DocumentIngestedEventData",
"DocumentOCRReadyEventData",
"DocumentExtractedEventData",
"KGUpsertReadyEventData",
"KGUpsertedEventData",
"RAGIndexedEventData",
"CalculationReadyEventData",
"FormFilledEventData",
"HMRCSubmittedEventData",
"ReviewRequestedEventData",
"ReviewCompletedEventData",
"FirmSyncCompletedEventData",
"EVENT_SCHEMA_MAP",
"validate_event_data",
"get_schema_for_topic",
]

309
libs/schemas/events.py Normal file
View File

@@ -0,0 +1,309 @@
"""Typed event payload schemas for validation and type safety."""
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator
# Base schema for all events
class BaseEventData(BaseModel):
"""Base class for all event data payloads."""
model_config = ConfigDict(
extra="forbid", # Prevent unexpected fields
frozen=True, # Make immutable
)
# Document lifecycle events
class DocumentIngestedEventData(BaseEventData):
"""Event emitted when a document is successfully ingested."""
doc_id: str = Field(..., description="Unique document identifier (ULID)")
filename: str = Field(..., description="Original filename")
mime_type: str = Field(..., description="MIME type of the document")
size_bytes: int = Field(..., ge=0, description="File size in bytes")
checksum_sha256: str = Field(..., description="SHA-256 checksum for integrity")
kind: str = Field(
..., description="Document kind (invoice, receipt, bank_statement, etc.)"
)
source: str = Field(
..., description="Ingestion source (manual_upload, rpa, email, api)"
)
storage_path: str = Field(..., description="MinIO object storage path")
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
@field_validator("checksum_sha256")
@classmethod
def validate_checksum(cls, v: str) -> str:
"""Validate SHA-256 checksum format."""
if len(v) != 64 or not all(c in "0123456789abcdef" for c in v.lower()):
raise ValueError("Invalid SHA-256 checksum format")
return v.lower()
class DocumentOCRReadyEventData(BaseEventData):
"""Event emitted when OCR processing is complete."""
doc_id: str = Field(..., description="Document identifier")
ocr_engine: Literal["tesseract", "textract", "azure_ocr"] = Field(
..., description="OCR engine used"
)
page_count: int = Field(..., ge=1, description="Number of pages processed")
confidence_avg: float = Field(
..., ge=0.0, le=1.0, description="Average OCR confidence score"
)
text_length: int = Field(..., ge=0, description="Total extracted text length")
layout_detected: bool = Field(
..., description="Whether document layout was successfully detected"
)
languages_detected: list[str] = Field(
default_factory=list, description="Detected languages (ISO 639-1 codes)"
)
processing_time_ms: int = Field(
..., ge=0, description="Processing time in milliseconds"
)
storage_path: str = Field(..., description="Path to OCR results in storage")
class DocumentExtractedEventData(BaseEventData):
"""Event emitted when field extraction is complete."""
doc_id: str = Field(..., description="Document identifier")
extraction_id: str = Field(..., description="Unique extraction run identifier")
strategy: Literal["llm", "rules", "hybrid"] = Field(
..., description="Extraction strategy used"
)
fields_extracted: int = Field(..., ge=0, description="Number of fields extracted")
confidence_avg: float = Field(
..., ge=0.0, le=1.0, description="Average extraction confidence"
)
calibrated_confidence: float = Field(
..., ge=0.0, le=1.0, description="Calibrated confidence score"
)
model_name: str | None = Field(None, description="LLM model used (if applicable)")
processing_time_ms: int = Field(
..., ge=0, description="Processing time in milliseconds"
)
storage_path: str = Field(..., description="Path to extraction results")
# Knowledge Graph events
class KGUpsertReadyEventData(BaseEventData):
"""Event emitted when KG upsert data is ready."""
doc_id: str = Field(..., description="Source document identifier")
entity_count: int = Field(..., ge=0, description="Number of entities to upsert")
relationship_count: int = Field(
..., ge=0, description="Number of relationships to upsert"
)
tax_year: str = Field(..., description="Tax year (e.g., '2024-25')")
taxpayer_id: str = Field(..., description="Taxpayer identifier")
normalization_id: str = Field(..., description="Normalization run identifier")
storage_path: str = Field(..., description="Path to normalized data")
class KGUpsertedEventData(BaseEventData):
"""Event emitted when KG upsert is complete."""
doc_id: str = Field(..., description="Source document identifier")
entities_created: int = Field(..., ge=0, description="Entities created")
entities_updated: int = Field(..., ge=0, description="Entities updated")
relationships_created: int = Field(..., ge=0, description="Relationships created")
relationships_updated: int = Field(..., ge=0, description="Relationships updated")
shacl_violations: int = Field(
..., ge=0, description="Number of SHACL validation violations"
)
processing_time_ms: int = Field(
..., ge=0, description="Processing time in milliseconds"
)
success: bool = Field(..., description="Whether upsert was successful")
error_message: str | None = Field(None, description="Error message if failed")
# RAG events
class RAGIndexedEventData(BaseEventData):
"""Event emitted when RAG indexing is complete."""
doc_id: str = Field(..., description="Source document identifier")
collection_name: str = Field(..., description="Qdrant collection name")
chunks_indexed: int = Field(..., ge=0, description="Number of chunks indexed")
embedding_model: str = Field(..., description="Embedding model used")
pii_detected: bool = Field(..., description="Whether PII was detected")
pii_redacted: bool = Field(..., description="Whether PII was redacted")
processing_time_ms: int = Field(
..., ge=0, description="Processing time in milliseconds"
)
storage_path: str = Field(..., description="Path to chunked data")
# Calculation events
class CalculationReadyEventData(BaseEventData):
"""Event emitted when tax calculation is complete."""
taxpayer_id: str = Field(..., description="Taxpayer identifier")
tax_year: str = Field(..., description="Tax year (e.g., '2024-25')")
schedule_id: str = Field(..., description="Tax schedule identifier (SA102, SA103)")
calculation_id: str = Field(..., description="Unique calculation run identifier")
boxes_computed: int = Field(..., ge=0, description="Number of form boxes computed")
total_income: float | None = Field(None, description="Total income calculated")
total_tax: float | None = Field(None, description="Total tax calculated")
confidence: float = Field(
..., ge=0.0, le=1.0, description="Calculation confidence score"
)
evidence_count: int = Field(
..., ge=0, description="Number of evidence items supporting calculation"
)
processing_time_ms: int = Field(
..., ge=0, description="Processing time in milliseconds"
)
storage_path: str = Field(..., description="Path to calculation results")
# Form events
class FormFilledEventData(BaseEventData):
"""Event emitted when PDF form filling is complete."""
taxpayer_id: str = Field(..., description="Taxpayer identifier")
tax_year: str = Field(..., description="Tax year (e.g., '2024-25')")
form_id: str = Field(..., description="Form identifier (SA100, SA102, etc.)")
fields_filled: int = Field(..., ge=0, description="Number of fields filled")
pdf_size_bytes: int = Field(..., ge=0, description="Generated PDF size in bytes")
storage_path: str = Field(..., description="Path to filled PDF")
evidence_bundle_path: str | None = Field(
None, description="Path to evidence bundle ZIP"
)
checksum_sha256: str = Field(..., description="PDF checksum for integrity")
# HMRC events
class HMRCSubmittedEventData(BaseEventData):
"""Event emitted when HMRC submission is complete."""
taxpayer_id: str = Field(..., description="Taxpayer identifier")
tax_year: str = Field(..., description="Tax year (e.g., '2024-25')")
submission_id: str = Field(..., description="Unique submission identifier")
hmrc_reference: str | None = Field(None, description="HMRC submission reference")
submission_type: Literal["dry_run", "sandbox", "live"] = Field(
..., description="Submission environment type"
)
success: bool = Field(..., description="Whether submission was successful")
status_code: int | None = Field(None, description="HTTP status code")
error_message: str | None = Field(None, description="Error message if failed")
processing_time_ms: int = Field(
..., ge=0, description="Processing time in milliseconds"
)
# Review events
class ReviewRequestedEventData(BaseEventData):
"""Event emitted when human review is requested."""
doc_id: str = Field(..., description="Document identifier")
review_type: Literal["extraction", "calculation", "submission"] = Field(
..., description="Type of review needed"
)
priority: Literal["low", "medium", "high", "urgent"] = Field(
..., description="Review priority level"
)
reason: str = Field(..., description="Reason for review request")
assigned_to: str | None = Field(None, description="User assigned to review")
due_date: str | None = Field(None, description="Review due date (ISO 8601)")
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional review metadata"
)
class ReviewCompletedEventData(BaseEventData):
"""Event emitted when human review is completed."""
doc_id: str = Field(..., description="Document identifier")
review_id: str = Field(..., description="Review session identifier")
reviewer: str = Field(..., description="User who completed review")
decision: Literal["approved", "rejected", "needs_revision"] = Field(
..., description="Review decision"
)
changes_made: int = Field(..., ge=0, description="Number of changes made")
comments: str | None = Field(None, description="Reviewer comments")
review_duration_seconds: int = Field(
..., ge=0, description="Time spent in review (seconds)"
)
# Firm sync events
class FirmSyncCompletedEventData(BaseEventData):
"""Event emitted when firm database sync is complete."""
firm_id: str = Field(..., description="Firm identifier")
connector_type: str = Field(
..., description="Connector type (iris, sage, xero, etc.)"
)
sync_id: str = Field(..., description="Unique sync run identifier")
records_synced: int = Field(..., ge=0, description="Number of records synced")
records_created: int = Field(..., ge=0, description="Records created")
records_updated: int = Field(..., ge=0, description="Records updated")
records_failed: int = Field(..., ge=0, description="Records that failed to sync")
success: bool = Field(..., description="Whether sync was successful")
error_message: str | None = Field(None, description="Error message if failed")
processing_time_ms: int = Field(
..., ge=0, description="Processing time in milliseconds"
)
# Schema mapping for topic -> data class
EVENT_SCHEMA_MAP: dict[str, type[BaseEventData]] = {
"doc.ingested": DocumentIngestedEventData,
"doc.ocr_ready": DocumentOCRReadyEventData,
"doc.extracted": DocumentExtractedEventData,
"kg.upsert.ready": KGUpsertReadyEventData,
"kg.upserted": KGUpsertedEventData,
"rag.indexed": RAGIndexedEventData,
"calc.schedule_ready": CalculationReadyEventData,
"form.filled": FormFilledEventData,
"hmrc.submitted": HMRCSubmittedEventData,
"review.requested": ReviewRequestedEventData,
"review.completed": ReviewCompletedEventData,
"firm.sync.completed": FirmSyncCompletedEventData,
}
def validate_event_data(topic: str, data: dict[str, Any]) -> BaseEventData:
"""
Validate event data against the schema for the given topic.
Args:
topic: Event topic name
data: Raw event data dictionary
Returns:
Validated event data model
Raises:
ValueError: If topic is unknown or validation fails
"""
if topic not in EVENT_SCHEMA_MAP:
raise ValueError(f"Unknown event topic: {topic}")
schema_class = EVENT_SCHEMA_MAP[topic]
return schema_class.model_validate(data)
def get_schema_for_topic(topic: str) -> type[BaseEventData]:
"""
Get the Pydantic schema class for a given topic.
Args:
topic: Event topic name
Returns:
Schema class for the topic
Raises:
ValueError: If topic is unknown
"""
if topic not in EVENT_SCHEMA_MAP:
raise ValueError(f"Unknown event topic: {topic}")
return EVENT_SCHEMA_MAP[topic]