completed local setup with compose
Some checks failed
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Some checks failed
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
"""Configuration management and client factories."""
|
||||
|
||||
from .factories import (
|
||||
EventBusFactory,
|
||||
MinIOClientFactory,
|
||||
Neo4jDriverFactory,
|
||||
QdrantClientFactory,
|
||||
@@ -28,7 +27,6 @@ __all__ = [
|
||||
"QdrantClientFactory",
|
||||
"Neo4jDriverFactory",
|
||||
"RedisClientFactory",
|
||||
"EventBusFactory",
|
||||
"get_settings",
|
||||
"init_settings",
|
||||
"create_vault_client",
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
import boto3 # type: ignore
|
||||
import hvac
|
||||
import redis.asyncio as redis
|
||||
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer # type: ignore
|
||||
from minio import Minio
|
||||
from neo4j import GraphDatabase
|
||||
from qdrant_client import QdrantClient
|
||||
@@ -87,36 +85,3 @@ class RedisClientFactory: # pylint: disable=too-few-public-methods
|
||||
return redis.from_url(
|
||||
settings.redis_url, encoding="utf-8", decode_responses=True
|
||||
)
|
||||
|
||||
|
||||
class EventBusFactory:
|
||||
"""Factory for creating event bus clients"""
|
||||
|
||||
@staticmethod
|
||||
def create_kafka_producer(settings: BaseAppSettings) -> AIOKafkaProducer:
|
||||
"""Create Kafka producer"""
|
||||
return AIOKafkaProducer(
|
||||
bootstrap_servers=settings.kafka_bootstrap_servers,
|
||||
value_serializer=lambda v: v.encode("utf-8") if isinstance(v, str) else v,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_kafka_consumer(
|
||||
settings: BaseAppSettings, topics: list[str]
|
||||
) -> AIOKafkaConsumer:
|
||||
"""Create Kafka consumer"""
|
||||
return AIOKafkaConsumer(
|
||||
*topics,
|
||||
bootstrap_servers=settings.kafka_bootstrap_servers,
|
||||
value_deserializer=lambda m: m.decode("utf-8") if m else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_sqs_client(settings: BaseAppSettings) -> Any:
|
||||
"""Create SQS client"""
|
||||
return boto3.client("sqs", region_name=settings.aws_region)
|
||||
|
||||
@staticmethod
|
||||
def create_sns_client(settings: BaseAppSettings) -> Any:
|
||||
"""Create SNS client"""
|
||||
return boto3.client("sns", region_name=settings.aws_region)
|
||||
|
||||
@@ -8,7 +8,7 @@ class BaseAppSettings(BaseSettings):
|
||||
"""Base settings class for all services"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=True, extra="ignore"
|
||||
env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore"
|
||||
)
|
||||
|
||||
# Service identification
|
||||
|
||||
@@ -67,27 +67,20 @@ async def create_redis_client(settings: BaseAppSettings) -> "redis.Redis[str]":
|
||||
|
||||
def create_event_bus(settings: BaseAppSettings) -> EventBus:
|
||||
"""Create event bus"""
|
||||
if settings.event_bus_type.lower() == "kafka":
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from ..events import KafkaEventBus
|
||||
|
||||
return KafkaEventBus(settings.kafka_bootstrap_servers)
|
||||
if settings.event_bus_type.lower() == "sqs":
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from ..events import SQSEventBus
|
||||
|
||||
return SQSEventBus(settings.aws_region)
|
||||
if settings.event_bus_type.lower() == "memory":
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from ..events import MemoryEventBus
|
||||
|
||||
return MemoryEventBus()
|
||||
|
||||
# Default to memory bus for unknown types
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from ..events import MemoryEventBus
|
||||
from libs.events import create_event_bus as _create_event_bus
|
||||
|
||||
return MemoryEventBus()
|
||||
# Extract NATS servers as a list
|
||||
nats_servers = [s.strip() for s in settings.nats_servers.split(",")]
|
||||
|
||||
return _create_event_bus(
|
||||
settings.event_bus_type,
|
||||
servers=nats_servers,
|
||||
stream_name=settings.nats_stream_name,
|
||||
consumer_group=settings.nats_consumer_group,
|
||||
bootstrap_servers=settings.kafka_bootstrap_servers,
|
||||
region_name=settings.aws_region,
|
||||
)
|
||||
|
||||
|
||||
def get_default_settings(**overrides: Any) -> BaseAppSettings:
|
||||
|
||||
@@ -1,20 +1,52 @@
|
||||
"""Event-driven architecture with Kafka, SQS, NATS, and Memory support."""
|
||||
|
||||
from libs.schemas.events import (
|
||||
EVENT_SCHEMA_MAP,
|
||||
BaseEventData,
|
||||
CalculationReadyEventData,
|
||||
DocumentExtractedEventData,
|
||||
DocumentIngestedEventData,
|
||||
DocumentOCRReadyEventData,
|
||||
FirmSyncCompletedEventData,
|
||||
FormFilledEventData,
|
||||
HMRCSubmittedEventData,
|
||||
KGUpsertedEventData,
|
||||
KGUpsertReadyEventData,
|
||||
RAGIndexedEventData,
|
||||
ReviewCompletedEventData,
|
||||
ReviewRequestedEventData,
|
||||
get_schema_for_topic,
|
||||
validate_event_data,
|
||||
)
|
||||
|
||||
from .base import EventBus, EventPayload
|
||||
from .factory import create_event_bus
|
||||
from .kafka_bus import KafkaEventBus
|
||||
from .memory_bus import MemoryEventBus
|
||||
from .nats_bus import NATSEventBus
|
||||
from .sqs_bus import SQSEventBus
|
||||
from .topics import EventTopics
|
||||
|
||||
__all__ = [
|
||||
"EventPayload",
|
||||
"EventBus",
|
||||
"KafkaEventBus",
|
||||
"MemoryEventBus",
|
||||
"NATSEventBus",
|
||||
"SQSEventBus",
|
||||
"create_event_bus",
|
||||
"EventTopics",
|
||||
# Event schemas
|
||||
"BaseEventData",
|
||||
"DocumentIngestedEventData",
|
||||
"DocumentOCRReadyEventData",
|
||||
"DocumentExtractedEventData",
|
||||
"KGUpsertReadyEventData",
|
||||
"KGUpsertedEventData",
|
||||
"RAGIndexedEventData",
|
||||
"CalculationReadyEventData",
|
||||
"FormFilledEventData",
|
||||
"HMRCSubmittedEventData",
|
||||
"ReviewRequestedEventData",
|
||||
"ReviewCompletedEventData",
|
||||
"FirmSyncCompletedEventData",
|
||||
"EVENT_SCHEMA_MAP",
|
||||
"validate_event_data",
|
||||
"get_schema_for_topic",
|
||||
]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import ulid
|
||||
@@ -22,7 +22,7 @@ class EventPayload:
|
||||
schema_version: str = "1.0",
|
||||
):
|
||||
self.event_id = str(ulid.new())
|
||||
self.occurred_at = datetime.utcnow().isoformat() + "Z"
|
||||
self.occurred_at = datetime.now(UTC).isoformat()
|
||||
self.actor = actor
|
||||
self.tenant_id = tenant_id
|
||||
self.trace_id = trace_id
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections.abc import Awaitable, Callable
|
||||
import structlog
|
||||
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer # type: ignore
|
||||
|
||||
from .base import EventBus, EventPayload
|
||||
from ..base import EventBus, EventPayload
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -9,7 +9,7 @@ import boto3 # type: ignore
|
||||
import structlog
|
||||
from botocore.exceptions import ClientError # type: ignore
|
||||
|
||||
from .base import EventBus, EventPayload
|
||||
from ..base import EventBus, EventPayload
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
271
libs/events/dlq.py
Normal file
271
libs/events/dlq.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Dead Letter Queue (DLQ) handler for failed event processing."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from nats.js import JetStreamContext
|
||||
|
||||
from .base import EventPayload
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DLQHandler:
|
||||
"""
|
||||
Dead Letter Queue handler for processing failed events.
|
||||
|
||||
Captures events that fail processing after max retries and stores them
|
||||
in a separate NATS stream for manual review and retry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
js: JetStreamContext,
|
||||
dlq_stream_name: str = "TAX_AGENT_DLQ",
|
||||
max_retries: int = 3,
|
||||
backoff_base_ms: int = 1000,
|
||||
backoff_multiplier: float = 2.0,
|
||||
backoff_max_ms: int = 30000,
|
||||
):
|
||||
"""
|
||||
Initialize DLQ handler.
|
||||
|
||||
Args:
|
||||
js: NATS JetStream context
|
||||
dlq_stream_name: Name of the DLQ stream
|
||||
max_retries: Maximum number of retry attempts
|
||||
backoff_base_ms: Base backoff time in milliseconds
|
||||
backoff_multiplier: Exponential backoff multiplier
|
||||
backoff_max_ms: Maximum backoff time in milliseconds
|
||||
"""
|
||||
self.js = js
|
||||
self.dlq_stream_name = dlq_stream_name
|
||||
self.max_retries = max_retries
|
||||
self.backoff_base_ms = backoff_base_ms
|
||||
self.backoff_multiplier = backoff_multiplier
|
||||
self.backoff_max_ms = backoff_max_ms
|
||||
|
||||
async def ensure_dlq_stream_exists(self) -> None:
|
||||
"""Ensure DLQ stream exists in JetStream."""
|
||||
try:
|
||||
# Try to get stream info
|
||||
await self.js.stream_info(self.dlq_stream_name)
|
||||
logger.debug("DLQ stream already exists", stream=self.dlq_stream_name)
|
||||
|
||||
except Exception:
|
||||
# Stream doesn't exist, create it
|
||||
try:
|
||||
await self.js.add_stream(
|
||||
name=self.dlq_stream_name,
|
||||
subjects=[f"{self.dlq_stream_name}.>"],
|
||||
# Keep DLQ messages for 30 days
|
||||
max_age=30 * 24 * 60 * 60, # 30 days in seconds
|
||||
)
|
||||
logger.info("Created DLQ stream", stream=self.dlq_stream_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to create DLQ stream",
|
||||
stream=self.dlq_stream_name,
|
||||
error=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
async def send_to_dlq(
|
||||
self,
|
||||
topic: str,
|
||||
payload: EventPayload,
|
||||
error: Exception,
|
||||
retry_count: int,
|
||||
original_message_data: bytes | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send failed event to DLQ.
|
||||
|
||||
Args:
|
||||
topic: Original topic name
|
||||
payload: Event payload
|
||||
error: Exception that caused the failure
|
||||
retry_count: Number of retry attempts made
|
||||
original_message_data: Original message data (optional, for debugging)
|
||||
"""
|
||||
try:
|
||||
# Create DLQ subject
|
||||
dlq_subject = f"{self.dlq_stream_name}.{topic}"
|
||||
|
||||
# Create DLQ payload with metadata
|
||||
dlq_payload = {
|
||||
"original_topic": topic,
|
||||
"original_payload": payload.to_dict(),
|
||||
"error": {
|
||||
"type": type(error).__name__,
|
||||
"message": str(error),
|
||||
},
|
||||
"retry_count": retry_count,
|
||||
"failed_at": datetime.now(UTC).isoformat(),
|
||||
"tenant_id": payload.tenant_id,
|
||||
"event_id": payload.event_id,
|
||||
"trace_id": payload.trace_id,
|
||||
}
|
||||
|
||||
# Add original message data if available
|
||||
if original_message_data:
|
||||
try:
|
||||
dlq_payload["original_message_data"] = original_message_data.decode(
|
||||
"utf-8"
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
dlq_payload["original_message_data"] = "<binary data>"
|
||||
|
||||
# Publish to DLQ
|
||||
headers = {
|
||||
"original_topic": topic,
|
||||
"tenant_id": payload.tenant_id,
|
||||
"event_id": payload.event_id,
|
||||
"error_type": type(error).__name__,
|
||||
"retry_count": str(retry_count),
|
||||
}
|
||||
|
||||
await self.js.publish(
|
||||
subject=dlq_subject,
|
||||
payload=json.dumps(dlq_payload).encode(),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
logger.error(
|
||||
"Event sent to DLQ",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
error=str(error),
|
||||
retry_count=retry_count,
|
||||
dlq_subject=dlq_subject,
|
||||
)
|
||||
|
||||
except Exception as dlq_error:
|
||||
logger.critical(
|
||||
"Failed to send event to DLQ - EVENT LOST",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
original_error=str(error),
|
||||
dlq_error=str(dlq_error),
|
||||
)
|
||||
|
||||
def calculate_backoff(self, retry_count: int) -> float:
|
||||
"""
|
||||
Calculate exponential backoff delay.
|
||||
|
||||
Args:
|
||||
retry_count: Current retry attempt (0-indexed)
|
||||
|
||||
Returns:
|
||||
Backoff delay in seconds
|
||||
"""
|
||||
# Calculate exponential backoff: base * (multiplier ^ retry_count)
|
||||
backoff_ms = self.backoff_base_ms * (self.backoff_multiplier**retry_count)
|
||||
|
||||
# Cap at maximum backoff
|
||||
backoff_ms = min(backoff_ms, self.backoff_max_ms)
|
||||
|
||||
# Convert to seconds
|
||||
return backoff_ms / 1000.0
|
||||
|
||||
async def retry_with_backoff(
|
||||
self,
|
||||
func: Any,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[bool, Exception | None]:
|
||||
"""
|
||||
Retry a function with exponential backoff.
|
||||
|
||||
Args:
|
||||
func: Async function to retry
|
||||
*args: Position arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, last_error: Exception | None)
|
||||
"""
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
await func(*args, **kwargs)
|
||||
return (True, None)
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
last_error = e
|
||||
|
||||
if attempt < self.max_retries:
|
||||
# Calculate and apply backoff
|
||||
backoff_seconds = self.calculate_backoff(attempt)
|
||||
|
||||
logger.warning(
|
||||
"Retry attempt failed, backing off",
|
||||
attempt=attempt + 1,
|
||||
max_retries=self.max_retries,
|
||||
backoff_seconds=backoff_seconds,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
await asyncio.sleep(backoff_seconds)
|
||||
else:
|
||||
logger.error(
|
||||
"All retry attempts exhausted",
|
||||
attempts=self.max_retries + 1,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
return (False, last_error)
|
||||
|
||||
|
||||
class DLQMetrics:
|
||||
"""Metrics for DLQ operations."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize DLQ metrics."""
|
||||
self.total_dlq_events = 0
|
||||
self.dlq_events_by_topic: dict[str, int] = {}
|
||||
self.dlq_events_by_error_type: dict[str, int] = {}
|
||||
|
||||
def record_dlq_event(self, topic: str, error_type: str) -> None:
|
||||
"""
|
||||
Record a DLQ event.
|
||||
|
||||
Args:
|
||||
topic: Original topic name
|
||||
error_type: Type of error that caused DLQ
|
||||
"""
|
||||
self.total_dlq_events += 1
|
||||
|
||||
# Track by topic
|
||||
if topic not in self.dlq_events_by_topic:
|
||||
self.dlq_events_by_topic[topic] = 0
|
||||
self.dlq_events_by_topic[topic] += 1
|
||||
|
||||
# Track by error type
|
||||
if error_type not in self.dlq_events_by_error_type:
|
||||
self.dlq_events_by_error_type[error_type] = 0
|
||||
self.dlq_events_by_error_type[error_type] += 1
|
||||
|
||||
def get_metrics(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get DLQ metrics.
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics
|
||||
"""
|
||||
return {
|
||||
"total_dlq_events": self.total_dlq_events,
|
||||
"by_topic": self.dlq_events_by_topic.copy(),
|
||||
"by_error_type": self.dlq_events_by_error_type.copy(),
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all metrics to zero."""
|
||||
self.total_dlq_events = 0
|
||||
self.dlq_events_by_topic.clear()
|
||||
self.dlq_events_by_error_type.clear()
|
||||
@@ -3,16 +3,20 @@
|
||||
from typing import Any
|
||||
|
||||
from .base import EventBus
|
||||
from .kafka_bus import KafkaEventBus
|
||||
from .nats_bus import NATSEventBus
|
||||
from .sqs_bus import SQSEventBus
|
||||
|
||||
|
||||
def create_event_bus(bus_type: str, **kwargs: Any) -> EventBus:
|
||||
"""Factory function to create event bus"""
|
||||
if bus_type.lower() == "kafka":
|
||||
# Lazy import to avoid ModuleNotFoundError when aiokafka is not installed
|
||||
from .contrib.kafka_bus import KafkaEventBus
|
||||
|
||||
return KafkaEventBus(kwargs.get("bootstrap_servers", "localhost:9092"))
|
||||
if bus_type.lower() == "sqs":
|
||||
# Lazy import to avoid ModuleNotFoundError when boto3 is not installed
|
||||
from .contrib.sqs_bus import SQSEventBus
|
||||
|
||||
return SQSEventBus(kwargs.get("region_name", "us-east-1"))
|
||||
if bus_type.lower() == "nats":
|
||||
return NATSEventBus(
|
||||
|
||||
225
libs/events/metrics.py
Normal file
225
libs/events/metrics.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Prometheus metrics for event bus monitoring."""
|
||||
|
||||
from prometheus_client import Counter, Histogram
|
||||
from prometheus_client.registry import CollectorRegistry
|
||||
|
||||
# Global registry for event metrics
|
||||
_event_registry = CollectorRegistry()
|
||||
|
||||
# Event publishing metrics
|
||||
event_published_total = Counter(
|
||||
"event_published_total",
|
||||
"Total number of events published",
|
||||
["topic"],
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
event_publish_errors_total = Counter(
|
||||
"event_publish_errors_total",
|
||||
"Total number of event publishing errors",
|
||||
["topic", "error_type"],
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
event_publishing_duration_seconds = Histogram(
|
||||
"event_publishing_duration_seconds",
|
||||
"Time spent publishing events in seconds",
|
||||
["topic"],
|
||||
buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0),
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
# Event consumption metrics
|
||||
event_consumed_total = Counter(
|
||||
"event_consumed_total",
|
||||
"Total number of events consumed",
|
||||
["topic", "consumer_group"],
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
event_processing_duration_seconds = Histogram(
|
||||
"event_processing_duration_seconds",
|
||||
"Time spent processing events in seconds",
|
||||
["topic", "consumer_group"],
|
||||
buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0),
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
event_processing_errors_total = Counter(
|
||||
"event_processing_errors_total",
|
||||
"Total number of event processing errors",
|
||||
["topic", "consumer_group", "error_type"],
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
# DLQ metrics
|
||||
event_dlq_total = Counter(
|
||||
"event_dlq_total",
|
||||
"Total number of events sent to dead letter queue",
|
||||
["topic", "error_type"],
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
event_retry_total = Counter(
|
||||
"event_retry_total",
|
||||
"Total number of event retry attempts",
|
||||
["topic", "retry_attempt"],
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
# Schema validation metrics
|
||||
event_schema_validation_errors_total = Counter(
|
||||
"event_schema_validation_errors_total",
|
||||
"Total number of event schema validation errors",
|
||||
["topic", "validation_error"],
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
# NATS JetStream specific metrics
|
||||
nats_stream_messages_total = Counter(
|
||||
"nats_stream_messages_total",
|
||||
"Total messages in NATS stream",
|
||||
["stream_name"],
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
nats_consumer_lag_messages = Histogram(
|
||||
"nats_consumer_lag_messages",
|
||||
"Number of messages consumer is lagging behind",
|
||||
["stream_name", "consumer_group"],
|
||||
buckets=(0, 1, 5, 10, 25, 50, 100, 250, 500, 1000, 5000, 10000),
|
||||
registry=_event_registry,
|
||||
)
|
||||
|
||||
|
||||
def get_event_metrics_registry() -> CollectorRegistry:
|
||||
"""
|
||||
Get the Prometheus registry for event metrics.
|
||||
|
||||
Returns:
|
||||
CollectorRegistry for event metrics
|
||||
"""
|
||||
return _event_registry
|
||||
|
||||
|
||||
class EventMetricsCollector:
|
||||
"""Helper class for collecting event metrics."""
|
||||
|
||||
@staticmethod
|
||||
def record_publish(
|
||||
topic: str,
|
||||
duration_seconds: float,
|
||||
success: bool = True,
|
||||
error_type: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Record event publishing metrics.
|
||||
|
||||
Args:
|
||||
topic: Event topic name
|
||||
duration_seconds: Time taken to publish
|
||||
success: Whether publishing succeeded
|
||||
error_type: Type of error if failed
|
||||
"""
|
||||
if success:
|
||||
event_published_total.labels(topic=topic).inc()
|
||||
else:
|
||||
event_publish_errors_total.labels(
|
||||
topic=topic, error_type=error_type or "unknown"
|
||||
).inc()
|
||||
|
||||
event_publishing_duration_seconds.labels(topic=topic).observe(duration_seconds)
|
||||
|
||||
@staticmethod
|
||||
def record_consume(
|
||||
topic: str,
|
||||
consumer_group: str,
|
||||
duration_seconds: float,
|
||||
success: bool = True,
|
||||
error_type: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Record event consumption metrics.
|
||||
|
||||
Args:
|
||||
topic: Event topic name
|
||||
consumer_group: Consumer group name
|
||||
duration_seconds: Time taken to process event
|
||||
success: Whether processing succeeded
|
||||
error_type: Type of error if failed
|
||||
"""
|
||||
if success:
|
||||
event_consumed_total.labels(
|
||||
topic=topic, consumer_group=consumer_group
|
||||
).inc()
|
||||
else:
|
||||
event_processing_errors_total.labels(
|
||||
topic=topic,
|
||||
consumer_group=consumer_group,
|
||||
error_type=error_type or "unknown",
|
||||
).inc()
|
||||
|
||||
event_processing_duration_seconds.labels(
|
||||
topic=topic, consumer_group=consumer_group
|
||||
).observe(duration_seconds)
|
||||
|
||||
@staticmethod
|
||||
def record_dlq(topic: str, error_type: str) -> None:
|
||||
"""
|
||||
Record event sent to DLQ.
|
||||
|
||||
Args:
|
||||
topic: Event topic name
|
||||
error_type: Type of error that caused DLQ
|
||||
"""
|
||||
event_dlq_total.labels(topic=topic, error_type=error_type).inc()
|
||||
|
||||
@staticmethod
|
||||
def record_retry(topic: str, retry_attempt: int) -> None:
|
||||
"""
|
||||
Record event retry attempt.
|
||||
|
||||
Args:
|
||||
topic: Event topic name
|
||||
retry_attempt: Retry attempt number (1-indexed)
|
||||
"""
|
||||
event_retry_total.labels(topic=topic, retry_attempt=str(retry_attempt)).inc()
|
||||
|
||||
@staticmethod
|
||||
def record_schema_validation_error(topic: str, validation_error: str) -> None:
|
||||
"""
|
||||
Record schema validation error.
|
||||
|
||||
Args:
|
||||
topic: Event topic name
|
||||
validation_error: Type of validation error
|
||||
"""
|
||||
event_schema_validation_errors_total.labels(
|
||||
topic=topic, validation_error=validation_error
|
||||
).inc()
|
||||
|
||||
@staticmethod
|
||||
def record_nats_stream_message(stream_name: str) -> None:
|
||||
"""
|
||||
Record message added to NATS stream.
|
||||
|
||||
Args:
|
||||
stream_name: NATS stream name
|
||||
"""
|
||||
nats_stream_messages_total.labels(stream_name=stream_name).inc()
|
||||
|
||||
@staticmethod
|
||||
def record_consumer_lag(
|
||||
stream_name: str, consumer_group: str, lag_messages: int
|
||||
) -> None:
|
||||
"""
|
||||
Record consumer lag.
|
||||
|
||||
Args:
|
||||
stream_name: NATS stream name
|
||||
consumer_group: Consumer group name
|
||||
lag_messages: Number of messages consumer is behind
|
||||
"""
|
||||
nats_consumer_lag_messages.labels(
|
||||
stream_name=stream_name, consumer_group=consumer_group
|
||||
).observe(lag_messages)
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
@@ -12,6 +13,8 @@ from nats.js import JetStreamContext
|
||||
from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy
|
||||
|
||||
from .base import EventBus, EventPayload
|
||||
from .dlq import DLQHandler
|
||||
from .metrics import EventMetricsCollector
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -24,6 +27,8 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
servers: str | list[str] = "nats://localhost:4222",
|
||||
stream_name: str = "TAX_AGENT_EVENTS",
|
||||
consumer_group: str = "tax-agent",
|
||||
dlq_stream_name: str = "TAX_AGENT_DLQ",
|
||||
max_retries: int = 3,
|
||||
):
|
||||
if isinstance(servers, str):
|
||||
self.servers = [servers]
|
||||
@@ -32,8 +37,13 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
|
||||
self.stream_name = stream_name
|
||||
self.consumer_group = consumer_group
|
||||
self.dlq_stream_name = dlq_stream_name
|
||||
self.max_retries = max_retries
|
||||
|
||||
self.nc: NATS | None = None
|
||||
self.js: JetStreamContext | None = None
|
||||
self.dlq: DLQHandler | None = None
|
||||
|
||||
self.handlers: dict[
|
||||
str, list[Callable[[str, EventPayload], Awaitable[None]]]
|
||||
] = {}
|
||||
@@ -48,19 +58,32 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
|
||||
try:
|
||||
# Connect to NATS
|
||||
self.nc = await nats.connect(servers=self.servers)
|
||||
self.nc = await nats.connect(
|
||||
servers=self.servers,
|
||||
connect_timeout=10,
|
||||
reconnect_time_wait=1,
|
||||
)
|
||||
|
||||
# Get JetStream context
|
||||
self.js = self.nc.jetstream()
|
||||
self.js = self.nc.jetstream(timeout=10)
|
||||
|
||||
# Ensure stream exists
|
||||
# Initialize DLQ handler
|
||||
self.dlq = DLQHandler(
|
||||
js=self.js,
|
||||
dlq_stream_name=self.dlq_stream_name,
|
||||
max_retries=self.max_retries,
|
||||
)
|
||||
|
||||
# Ensure streams exist
|
||||
await self._ensure_stream_exists()
|
||||
await self.dlq.ensure_dlq_stream_exists()
|
||||
|
||||
self.running = True
|
||||
logger.info(
|
||||
"NATS event bus started",
|
||||
servers=self.servers,
|
||||
stream=self.stream_name,
|
||||
dlq_stream=self.dlq_stream_name,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -98,6 +121,7 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
if not self.js:
|
||||
raise RuntimeError("Event bus not started")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
# Create subject name from topic
|
||||
subject = f"{self.stream_name}.{topic}"
|
||||
@@ -117,6 +141,13 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
duration = time.perf_counter() - start_time
|
||||
EventMetricsCollector.record_publish(
|
||||
topic=topic,
|
||||
duration_seconds=duration,
|
||||
success=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Event published",
|
||||
topic=topic,
|
||||
@@ -127,6 +158,14 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
return True
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
duration = time.perf_counter() - start_time
|
||||
EventMetricsCollector.record_publish(
|
||||
topic=topic,
|
||||
duration_seconds=duration,
|
||||
success=False,
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
logger.error(
|
||||
"Failed to publish event",
|
||||
topic=topic,
|
||||
@@ -152,9 +191,13 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
subject = f"{self.stream_name}.{topic}"
|
||||
|
||||
# Create durable consumer
|
||||
consumer_name = f"{self.consumer_group}-{topic}"
|
||||
# Durable names cannot contain dots, so we replace them
|
||||
safe_topic = topic.replace(".", "-")
|
||||
consumer_name = f"{self.consumer_group}-{safe_topic}"
|
||||
|
||||
# Subscribe with pull-based consumer
|
||||
# Set max_deliver to max_retries + 1 (initial + retries)
|
||||
# We handle DLQ manually before NATS gives up
|
||||
subscription = await self.js.pull_subscribe(
|
||||
subject=subject,
|
||||
durable=consumer_name,
|
||||
@@ -162,7 +205,7 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
durable_name=consumer_name,
|
||||
ack_policy=AckPolicy.EXPLICIT,
|
||||
deliver_policy=DeliverPolicy.NEW,
|
||||
max_deliver=3,
|
||||
max_deliver=self.max_retries + 2, # Give us room to handle DLQ
|
||||
ack_wait=30, # 30 seconds
|
||||
),
|
||||
)
|
||||
@@ -193,13 +236,14 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
# Try to get stream info
|
||||
await self.js.stream_info(self.stream_name)
|
||||
logger.debug("Stream already exists", stream=self.stream_name)
|
||||
EventMetricsCollector.record_nats_stream_message(self.stream_name)
|
||||
|
||||
except Exception:
|
||||
# Stream doesn't exist, create it
|
||||
try:
|
||||
await self.js.add_stream(
|
||||
name=self.stream_name,
|
||||
subjects=[f"{self.stream_name}.*"],
|
||||
subjects=[f"{self.stream_name}.>"],
|
||||
)
|
||||
logger.info("Created JetStream stream", stream=self.stream_name)
|
||||
|
||||
@@ -214,12 +258,17 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
while self.running:
|
||||
try:
|
||||
# Fetch messages in batches
|
||||
messages = await subscription.fetch(batch=10, timeout=20)
|
||||
messages = await subscription.fetch(batch=10, timeout=5)
|
||||
|
||||
for message in messages:
|
||||
start_time = time.perf_counter()
|
||||
payload = None
|
||||
|
||||
try:
|
||||
print(f"DEBUG: Received message: {message.data}")
|
||||
# Parse message payload
|
||||
payload_dict = json.loads(message.data.decode())
|
||||
print(f"DEBUG: Parsed payload: {payload_dict}")
|
||||
|
||||
payload = EventPayload(
|
||||
data=payload_dict["data"],
|
||||
@@ -230,38 +279,87 @@ class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
||||
)
|
||||
payload.event_id = payload_dict["event_id"]
|
||||
payload.occurred_at = payload_dict["occurred_at"]
|
||||
print(f"DEBUG: Reconstructed payload: {payload.event_id}")
|
||||
|
||||
# Call all handlers for this topic
|
||||
for handler in self.handlers.get(topic, []):
|
||||
try:
|
||||
await handler(topic, payload)
|
||||
except (
|
||||
Exception
|
||||
) as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Handler failed",
|
||||
topic=topic,
|
||||
event_id=payload.event_id,
|
||||
error=str(e),
|
||||
)
|
||||
print(f"DEBUG: Calling handler for topic {topic}")
|
||||
await handler(topic, payload)
|
||||
|
||||
# Acknowledge message
|
||||
await message.ack()
|
||||
print("DEBUG: Message acked")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
"Failed to decode message", topic=topic, error=str(e)
|
||||
# Record metrics
|
||||
duration = time.perf_counter() - start_time
|
||||
EventMetricsCollector.record_consume(
|
||||
topic=topic,
|
||||
consumer_group=self.consumer_group,
|
||||
duration_seconds=duration,
|
||||
success=True,
|
||||
)
|
||||
await message.nak()
|
||||
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Failed to process message", topic=topic, error=str(e)
|
||||
duration = time.perf_counter() - start_time
|
||||
error_type = type(e).__name__
|
||||
|
||||
# Record failure metric
|
||||
EventMetricsCollector.record_consume(
|
||||
topic=topic,
|
||||
consumer_group=self.consumer_group,
|
||||
duration_seconds=duration,
|
||||
success=False,
|
||||
error_type=error_type,
|
||||
)
|
||||
await message.nak()
|
||||
|
||||
# Check delivery count for DLQ
|
||||
try:
|
||||
metadata = message.metadata
|
||||
num_delivered = (
|
||||
metadata.sequence.consumer
|
||||
) # This might be wrong, check docs
|
||||
# Actually nats-py MsgMetadata has num_delivered
|
||||
num_delivered = metadata.num_delivered
|
||||
except Exception:
|
||||
num_delivered = 1
|
||||
|
||||
if num_delivered >= self.max_retries:
|
||||
logger.error(
|
||||
"Max retries exceeded, sending to DLQ",
|
||||
topic=topic,
|
||||
event_id=payload.event_id if payload else "unknown",
|
||||
error=str(e),
|
||||
num_delivered=num_delivered,
|
||||
)
|
||||
|
||||
if self.dlq and payload:
|
||||
await self.dlq.send_to_dlq(
|
||||
topic=topic,
|
||||
payload=payload,
|
||||
error=e,
|
||||
retry_count=num_delivered,
|
||||
original_message_data=message.data,
|
||||
)
|
||||
EventMetricsCollector.record_dlq(topic, error_type)
|
||||
|
||||
# Ack to remove from main stream
|
||||
await message.ack()
|
||||
|
||||
else:
|
||||
# Retry (Nak)
|
||||
logger.warning(
|
||||
"Processing failed, retrying",
|
||||
topic=topic,
|
||||
event_id=payload.event_id if payload else "unknown",
|
||||
error=str(e),
|
||||
attempt=num_delivered,
|
||||
)
|
||||
EventMetricsCollector.record_retry(topic, num_delivered)
|
||||
await message.nak()
|
||||
|
||||
except TimeoutError:
|
||||
# No messages available, continue polling
|
||||
continue
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Consumer error", topic=topic, error=str(e))
|
||||
await asyncio.sleep(5) # Wait before retrying
|
||||
await asyncio.sleep(1) # Wait before retrying
|
||||
|
||||
@@ -7,6 +7,7 @@ class EventTopics: # pylint: disable=too-few-public-methods
|
||||
DOC_INGESTED = "doc.ingested"
|
||||
DOC_OCR_READY = "doc.ocr_ready"
|
||||
DOC_EXTRACTED = "doc.extracted"
|
||||
KG_UPSERT_READY = "kg.upsert.ready"
|
||||
KG_UPSERTED = "kg.upserted"
|
||||
RAG_INDEXED = "rag.indexed"
|
||||
CALC_SCHEDULE_READY = "calc.schedule_ready"
|
||||
|
||||
@@ -11,8 +11,8 @@ psycopg2-binary>=2.9.11
|
||||
neo4j>=6.0.2
|
||||
redis[hiredis]>=6.4.0
|
||||
|
||||
# Object storage and vector database
|
||||
minio>=7.2.18
|
||||
boto3>=1.34.0
|
||||
qdrant-client>=1.15.1
|
||||
|
||||
# Event streaming (NATS only - removed Kafka)
|
||||
@@ -36,3 +36,13 @@ python-multipart>=0.0.20
|
||||
python-dateutil>=2.9.0
|
||||
python-dotenv>=1.1.1
|
||||
orjson>=3.11.3
|
||||
jsonschema>=4.20.0
|
||||
|
||||
# OpenTelemetry instrumentation (for observability)
|
||||
opentelemetry-api>=1.21.0
|
||||
opentelemetry-sdk>=1.21.0
|
||||
opentelemetry-exporter-otlp-proto-grpc>=1.21.0
|
||||
opentelemetry-instrumentation-fastapi>=0.42b0
|
||||
opentelemetry-instrumentation-httpx>=0.42b0
|
||||
opentelemetry-instrumentation-psycopg2>=0.42b0
|
||||
opentelemetry-instrumentation-redis>=0.42b0
|
||||
|
||||
@@ -65,6 +65,26 @@ from .enums import (
|
||||
# Import error models
|
||||
from .errors import ErrorResponse, ValidationError, ValidationErrorResponse
|
||||
|
||||
# Import event schemas
|
||||
from .events import (
|
||||
EVENT_SCHEMA_MAP,
|
||||
BaseEventData,
|
||||
CalculationReadyEventData,
|
||||
DocumentExtractedEventData,
|
||||
DocumentIngestedEventData,
|
||||
DocumentOCRReadyEventData,
|
||||
FirmSyncCompletedEventData,
|
||||
FormFilledEventData,
|
||||
HMRCSubmittedEventData,
|
||||
KGUpsertedEventData,
|
||||
KGUpsertReadyEventData,
|
||||
RAGIndexedEventData,
|
||||
ReviewCompletedEventData,
|
||||
ReviewRequestedEventData,
|
||||
get_schema_for_topic,
|
||||
validate_event_data,
|
||||
)
|
||||
|
||||
# Import health models
|
||||
from .health import HealthCheck, ServiceHealth
|
||||
|
||||
@@ -135,7 +155,7 @@ __all__ = [
|
||||
"DocumentUploadResponse",
|
||||
"ExtractionResponse",
|
||||
"FirmSyncResponse",
|
||||
"HMRCSubmissionResponse",
|
||||
"HMRCSubmittedEventData",
|
||||
"RAGSearchResponse",
|
||||
"ScheduleComputeResponse",
|
||||
# Utils
|
||||
@@ -172,4 +192,21 @@ __all__ = [
|
||||
"ValidationResult",
|
||||
"PolicyVersion",
|
||||
"CoverageAudit",
|
||||
# Event schemas
|
||||
"BaseEventData",
|
||||
"DocumentIngestedEventData",
|
||||
"DocumentOCRReadyEventData",
|
||||
"DocumentExtractedEventData",
|
||||
"KGUpsertReadyEventData",
|
||||
"KGUpsertedEventData",
|
||||
"RAGIndexedEventData",
|
||||
"CalculationReadyEventData",
|
||||
"FormFilledEventData",
|
||||
"HMRCSubmittedEventData",
|
||||
"ReviewRequestedEventData",
|
||||
"ReviewCompletedEventData",
|
||||
"FirmSyncCompletedEventData",
|
||||
"EVENT_SCHEMA_MAP",
|
||||
"validate_event_data",
|
||||
"get_schema_for_topic",
|
||||
]
|
||||
|
||||
309
libs/schemas/events.py
Normal file
309
libs/schemas/events.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""Typed event payload schemas for validation and type safety."""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
# Base schema for all events
|
||||
class BaseEventData(BaseModel):
|
||||
"""Base class for all event data payloads."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid", # Prevent unexpected fields
|
||||
frozen=True, # Make immutable
|
||||
)
|
||||
|
||||
|
||||
# Document lifecycle events
|
||||
class DocumentIngestedEventData(BaseEventData):
|
||||
"""Event emitted when a document is successfully ingested."""
|
||||
|
||||
doc_id: str = Field(..., description="Unique document identifier (ULID)")
|
||||
filename: str = Field(..., description="Original filename")
|
||||
mime_type: str = Field(..., description="MIME type of the document")
|
||||
size_bytes: int = Field(..., ge=0, description="File size in bytes")
|
||||
checksum_sha256: str = Field(..., description="SHA-256 checksum for integrity")
|
||||
kind: str = Field(
|
||||
..., description="Document kind (invoice, receipt, bank_statement, etc.)"
|
||||
)
|
||||
source: str = Field(
|
||||
..., description="Ingestion source (manual_upload, rpa, email, api)"
|
||||
)
|
||||
storage_path: str = Field(..., description="MinIO object storage path")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional metadata"
|
||||
)
|
||||
|
||||
@field_validator("checksum_sha256")
|
||||
@classmethod
|
||||
def validate_checksum(cls, v: str) -> str:
|
||||
"""Validate SHA-256 checksum format."""
|
||||
if len(v) != 64 or not all(c in "0123456789abcdef" for c in v.lower()):
|
||||
raise ValueError("Invalid SHA-256 checksum format")
|
||||
return v.lower()
|
||||
|
||||
|
||||
class DocumentOCRReadyEventData(BaseEventData):
|
||||
"""Event emitted when OCR processing is complete."""
|
||||
|
||||
doc_id: str = Field(..., description="Document identifier")
|
||||
ocr_engine: Literal["tesseract", "textract", "azure_ocr"] = Field(
|
||||
..., description="OCR engine used"
|
||||
)
|
||||
page_count: int = Field(..., ge=1, description="Number of pages processed")
|
||||
confidence_avg: float = Field(
|
||||
..., ge=0.0, le=1.0, description="Average OCR confidence score"
|
||||
)
|
||||
text_length: int = Field(..., ge=0, description="Total extracted text length")
|
||||
layout_detected: bool = Field(
|
||||
..., description="Whether document layout was successfully detected"
|
||||
)
|
||||
languages_detected: list[str] = Field(
|
||||
default_factory=list, description="Detected languages (ISO 639-1 codes)"
|
||||
)
|
||||
processing_time_ms: int = Field(
|
||||
..., ge=0, description="Processing time in milliseconds"
|
||||
)
|
||||
storage_path: str = Field(..., description="Path to OCR results in storage")
|
||||
|
||||
|
||||
class DocumentExtractedEventData(BaseEventData):
|
||||
"""Event emitted when field extraction is complete."""
|
||||
|
||||
doc_id: str = Field(..., description="Document identifier")
|
||||
extraction_id: str = Field(..., description="Unique extraction run identifier")
|
||||
strategy: Literal["llm", "rules", "hybrid"] = Field(
|
||||
..., description="Extraction strategy used"
|
||||
)
|
||||
fields_extracted: int = Field(..., ge=0, description="Number of fields extracted")
|
||||
confidence_avg: float = Field(
|
||||
..., ge=0.0, le=1.0, description="Average extraction confidence"
|
||||
)
|
||||
calibrated_confidence: float = Field(
|
||||
..., ge=0.0, le=1.0, description="Calibrated confidence score"
|
||||
)
|
||||
model_name: str | None = Field(None, description="LLM model used (if applicable)")
|
||||
processing_time_ms: int = Field(
|
||||
..., ge=0, description="Processing time in milliseconds"
|
||||
)
|
||||
storage_path: str = Field(..., description="Path to extraction results")
|
||||
|
||||
|
||||
# Knowledge Graph events
|
||||
class KGUpsertReadyEventData(BaseEventData):
|
||||
"""Event emitted when KG upsert data is ready."""
|
||||
|
||||
doc_id: str = Field(..., description="Source document identifier")
|
||||
entity_count: int = Field(..., ge=0, description="Number of entities to upsert")
|
||||
relationship_count: int = Field(
|
||||
..., ge=0, description="Number of relationships to upsert"
|
||||
)
|
||||
tax_year: str = Field(..., description="Tax year (e.g., '2024-25')")
|
||||
taxpayer_id: str = Field(..., description="Taxpayer identifier")
|
||||
normalization_id: str = Field(..., description="Normalization run identifier")
|
||||
storage_path: str = Field(..., description="Path to normalized data")
|
||||
|
||||
|
||||
class KGUpsertedEventData(BaseEventData):
|
||||
"""Event emitted when KG upsert is complete."""
|
||||
|
||||
doc_id: str = Field(..., description="Source document identifier")
|
||||
entities_created: int = Field(..., ge=0, description="Entities created")
|
||||
entities_updated: int = Field(..., ge=0, description="Entities updated")
|
||||
relationships_created: int = Field(..., ge=0, description="Relationships created")
|
||||
relationships_updated: int = Field(..., ge=0, description="Relationships updated")
|
||||
shacl_violations: int = Field(
|
||||
..., ge=0, description="Number of SHACL validation violations"
|
||||
)
|
||||
processing_time_ms: int = Field(
|
||||
..., ge=0, description="Processing time in milliseconds"
|
||||
)
|
||||
success: bool = Field(..., description="Whether upsert was successful")
|
||||
error_message: str | None = Field(None, description="Error message if failed")
|
||||
|
||||
|
||||
# RAG events
|
||||
class RAGIndexedEventData(BaseEventData):
|
||||
"""Event emitted when RAG indexing is complete."""
|
||||
|
||||
doc_id: str = Field(..., description="Source document identifier")
|
||||
collection_name: str = Field(..., description="Qdrant collection name")
|
||||
chunks_indexed: int = Field(..., ge=0, description="Number of chunks indexed")
|
||||
embedding_model: str = Field(..., description="Embedding model used")
|
||||
pii_detected: bool = Field(..., description="Whether PII was detected")
|
||||
pii_redacted: bool = Field(..., description="Whether PII was redacted")
|
||||
processing_time_ms: int = Field(
|
||||
..., ge=0, description="Processing time in milliseconds"
|
||||
)
|
||||
storage_path: str = Field(..., description="Path to chunked data")
|
||||
|
||||
|
||||
# Calculation events
|
||||
class CalculationReadyEventData(BaseEventData):
|
||||
"""Event emitted when tax calculation is complete."""
|
||||
|
||||
taxpayer_id: str = Field(..., description="Taxpayer identifier")
|
||||
tax_year: str = Field(..., description="Tax year (e.g., '2024-25')")
|
||||
schedule_id: str = Field(..., description="Tax schedule identifier (SA102, SA103)")
|
||||
calculation_id: str = Field(..., description="Unique calculation run identifier")
|
||||
boxes_computed: int = Field(..., ge=0, description="Number of form boxes computed")
|
||||
total_income: float | None = Field(None, description="Total income calculated")
|
||||
total_tax: float | None = Field(None, description="Total tax calculated")
|
||||
confidence: float = Field(
|
||||
..., ge=0.0, le=1.0, description="Calculation confidence score"
|
||||
)
|
||||
evidence_count: int = Field(
|
||||
..., ge=0, description="Number of evidence items supporting calculation"
|
||||
)
|
||||
processing_time_ms: int = Field(
|
||||
..., ge=0, description="Processing time in milliseconds"
|
||||
)
|
||||
storage_path: str = Field(..., description="Path to calculation results")
|
||||
|
||||
|
||||
# Form events
|
||||
class FormFilledEventData(BaseEventData):
|
||||
"""Event emitted when PDF form filling is complete."""
|
||||
|
||||
taxpayer_id: str = Field(..., description="Taxpayer identifier")
|
||||
tax_year: str = Field(..., description="Tax year (e.g., '2024-25')")
|
||||
form_id: str = Field(..., description="Form identifier (SA100, SA102, etc.)")
|
||||
fields_filled: int = Field(..., ge=0, description="Number of fields filled")
|
||||
pdf_size_bytes: int = Field(..., ge=0, description="Generated PDF size in bytes")
|
||||
storage_path: str = Field(..., description="Path to filled PDF")
|
||||
evidence_bundle_path: str | None = Field(
|
||||
None, description="Path to evidence bundle ZIP"
|
||||
)
|
||||
checksum_sha256: str = Field(..., description="PDF checksum for integrity")
|
||||
|
||||
|
||||
# HMRC events
|
||||
class HMRCSubmittedEventData(BaseEventData):
|
||||
"""Event emitted when HMRC submission is complete."""
|
||||
|
||||
taxpayer_id: str = Field(..., description="Taxpayer identifier")
|
||||
tax_year: str = Field(..., description="Tax year (e.g., '2024-25')")
|
||||
submission_id: str = Field(..., description="Unique submission identifier")
|
||||
hmrc_reference: str | None = Field(None, description="HMRC submission reference")
|
||||
submission_type: Literal["dry_run", "sandbox", "live"] = Field(
|
||||
..., description="Submission environment type"
|
||||
)
|
||||
success: bool = Field(..., description="Whether submission was successful")
|
||||
status_code: int | None = Field(None, description="HTTP status code")
|
||||
error_message: str | None = Field(None, description="Error message if failed")
|
||||
processing_time_ms: int = Field(
|
||||
..., ge=0, description="Processing time in milliseconds"
|
||||
)
|
||||
|
||||
|
||||
# Review events
|
||||
class ReviewRequestedEventData(BaseEventData):
|
||||
"""Event emitted when human review is requested."""
|
||||
|
||||
doc_id: str = Field(..., description="Document identifier")
|
||||
review_type: Literal["extraction", "calculation", "submission"] = Field(
|
||||
..., description="Type of review needed"
|
||||
)
|
||||
priority: Literal["low", "medium", "high", "urgent"] = Field(
|
||||
..., description="Review priority level"
|
||||
)
|
||||
reason: str = Field(..., description="Reason for review request")
|
||||
assigned_to: str | None = Field(None, description="User assigned to review")
|
||||
due_date: str | None = Field(None, description="Review due date (ISO 8601)")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional review metadata"
|
||||
)
|
||||
|
||||
|
||||
class ReviewCompletedEventData(BaseEventData):
|
||||
"""Event emitted when human review is completed."""
|
||||
|
||||
doc_id: str = Field(..., description="Document identifier")
|
||||
review_id: str = Field(..., description="Review session identifier")
|
||||
reviewer: str = Field(..., description="User who completed review")
|
||||
decision: Literal["approved", "rejected", "needs_revision"] = Field(
|
||||
..., description="Review decision"
|
||||
)
|
||||
changes_made: int = Field(..., ge=0, description="Number of changes made")
|
||||
comments: str | None = Field(None, description="Reviewer comments")
|
||||
review_duration_seconds: int = Field(
|
||||
..., ge=0, description="Time spent in review (seconds)"
|
||||
)
|
||||
|
||||
|
||||
# Firm sync events
|
||||
class FirmSyncCompletedEventData(BaseEventData):
|
||||
"""Event emitted when firm database sync is complete."""
|
||||
|
||||
firm_id: str = Field(..., description="Firm identifier")
|
||||
connector_type: str = Field(
|
||||
..., description="Connector type (iris, sage, xero, etc.)"
|
||||
)
|
||||
sync_id: str = Field(..., description="Unique sync run identifier")
|
||||
records_synced: int = Field(..., ge=0, description="Number of records synced")
|
||||
records_created: int = Field(..., ge=0, description="Records created")
|
||||
records_updated: int = Field(..., ge=0, description="Records updated")
|
||||
records_failed: int = Field(..., ge=0, description="Records that failed to sync")
|
||||
success: bool = Field(..., description="Whether sync was successful")
|
||||
error_message: str | None = Field(None, description="Error message if failed")
|
||||
processing_time_ms: int = Field(
|
||||
..., ge=0, description="Processing time in milliseconds"
|
||||
)
|
||||
|
||||
|
||||
# Schema mapping for topic -> data class
|
||||
EVENT_SCHEMA_MAP: dict[str, type[BaseEventData]] = {
|
||||
"doc.ingested": DocumentIngestedEventData,
|
||||
"doc.ocr_ready": DocumentOCRReadyEventData,
|
||||
"doc.extracted": DocumentExtractedEventData,
|
||||
"kg.upsert.ready": KGUpsertReadyEventData,
|
||||
"kg.upserted": KGUpsertedEventData,
|
||||
"rag.indexed": RAGIndexedEventData,
|
||||
"calc.schedule_ready": CalculationReadyEventData,
|
||||
"form.filled": FormFilledEventData,
|
||||
"hmrc.submitted": HMRCSubmittedEventData,
|
||||
"review.requested": ReviewRequestedEventData,
|
||||
"review.completed": ReviewCompletedEventData,
|
||||
"firm.sync.completed": FirmSyncCompletedEventData,
|
||||
}
|
||||
|
||||
|
||||
def validate_event_data(topic: str, data: dict[str, Any]) -> BaseEventData:
|
||||
"""
|
||||
Validate event data against the schema for the given topic.
|
||||
|
||||
Args:
|
||||
topic: Event topic name
|
||||
data: Raw event data dictionary
|
||||
|
||||
Returns:
|
||||
Validated event data model
|
||||
|
||||
Raises:
|
||||
ValueError: If topic is unknown or validation fails
|
||||
"""
|
||||
if topic not in EVENT_SCHEMA_MAP:
|
||||
raise ValueError(f"Unknown event topic: {topic}")
|
||||
|
||||
schema_class = EVENT_SCHEMA_MAP[topic]
|
||||
return schema_class.model_validate(data)
|
||||
|
||||
|
||||
def get_schema_for_topic(topic: str) -> type[BaseEventData]:
|
||||
"""
|
||||
Get the Pydantic schema class for a given topic.
|
||||
|
||||
Args:
|
||||
topic: Event topic name
|
||||
|
||||
Returns:
|
||||
Schema class for the topic
|
||||
|
||||
Raises:
|
||||
ValueError: If topic is unknown
|
||||
"""
|
||||
if topic not in EVENT_SCHEMA_MAP:
|
||||
raise ValueError(f"Unknown event topic: {topic}")
|
||||
|
||||
return EVENT_SCHEMA_MAP[topic]
|
||||
Reference in New Issue
Block a user