Files
ai-tax-agent/libs/events/nats_bus.py
harkon fdba81809f
Some checks failed
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
completed local setup with compose
2025-11-26 13:17:17 +00:00

366 lines
13 KiB
Python

"""NATS.io with JetStream implementation of EventBus."""
import asyncio
import json
import time
from collections.abc import Awaitable, Callable
from typing import Any
import nats
import structlog
from nats.aio.client import Client as NATS
from nats.js import JetStreamContext
from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy
from .base import EventBus, EventPayload
from .dlq import DLQHandler
from .metrics import EventMetricsCollector
logger = structlog.get_logger()
class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
"""NATS.io with JetStream implementation of EventBus"""
def __init__(
self,
servers: str | list[str] = "nats://localhost:4222",
stream_name: str = "TAX_AGENT_EVENTS",
consumer_group: str = "tax-agent",
dlq_stream_name: str = "TAX_AGENT_DLQ",
max_retries: int = 3,
):
if isinstance(servers, str):
self.servers = [servers]
else:
self.servers = servers
self.stream_name = stream_name
self.consumer_group = consumer_group
self.dlq_stream_name = dlq_stream_name
self.max_retries = max_retries
self.nc: NATS | None = None
self.js: JetStreamContext | None = None
self.dlq: DLQHandler | None = None
self.handlers: dict[
str, list[Callable[[str, EventPayload], Awaitable[None]]]
] = {}
self.subscriptions: dict[str, Any] = {}
self.running = False
self.consumer_tasks: list[asyncio.Task[None]] = []
async def start(self) -> None:
"""Start NATS connection and JetStream context"""
if self.running:
return
try:
# Connect to NATS
self.nc = await nats.connect(
servers=self.servers,
connect_timeout=10,
reconnect_time_wait=1,
)
# Get JetStream context
self.js = self.nc.jetstream(timeout=10)
# Initialize DLQ handler
self.dlq = DLQHandler(
js=self.js,
dlq_stream_name=self.dlq_stream_name,
max_retries=self.max_retries,
)
# Ensure streams exist
await self._ensure_stream_exists()
await self.dlq.ensure_dlq_stream_exists()
self.running = True
logger.info(
"NATS event bus started",
servers=self.servers,
stream=self.stream_name,
dlq_stream=self.dlq_stream_name,
)
except Exception as e:
logger.error("Failed to start NATS event bus", error=str(e))
raise
async def stop(self) -> None:
"""Stop NATS connection and consumers"""
if not self.running:
return
# Cancel consumer tasks
for task in self.consumer_tasks:
task.cancel()
if self.consumer_tasks:
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
# Unsubscribe from all subscriptions
for subscription in self.subscriptions.values():
try:
await subscription.unsubscribe()
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Error unsubscribing", error=str(e))
# Close NATS connection
if self.nc:
await self.nc.close()
self.running = False
logger.info("NATS event bus stopped")
async def publish(self, topic: str, payload: EventPayload) -> bool:
"""Publish event to NATS JetStream"""
if not self.js:
raise RuntimeError("Event bus not started")
start_time = time.perf_counter()
try:
# Create subject name from topic
subject = f"{self.stream_name}.{topic}"
# Publish message with headers
headers = {
"event_id": payload.event_id,
"tenant_id": payload.tenant_id,
"actor": payload.actor,
"trace_id": payload.trace_id or "",
"schema_version": payload.schema_version,
}
ack = await self.js.publish(
subject=subject,
payload=payload.to_json().encode(),
headers=headers,
)
duration = time.perf_counter() - start_time
EventMetricsCollector.record_publish(
topic=topic,
duration_seconds=duration,
success=True,
)
logger.info(
"Event published",
topic=topic,
subject=subject,
event_id=payload.event_id,
stream_seq=ack.seq,
)
return True
except Exception as e: # pylint: disable=broad-exception-caught
duration = time.perf_counter() - start_time
EventMetricsCollector.record_publish(
topic=topic,
duration_seconds=duration,
success=False,
error_type=type(e).__name__,
)
logger.error(
"Failed to publish event",
topic=topic,
event_id=payload.event_id,
error=str(e),
)
return False
async def subscribe(
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
) -> None:
"""Subscribe to NATS JetStream topic"""
if not self.js:
raise RuntimeError("Event bus not started")
if topic not in self.handlers:
self.handlers[topic] = []
self.handlers[topic].append(handler)
if topic not in self.subscriptions:
try:
# Create subject pattern for topic
subject = f"{self.stream_name}.{topic}"
# Create durable consumer
# Durable names cannot contain dots, so we replace them
safe_topic = topic.replace(".", "-")
consumer_name = f"{self.consumer_group}-{safe_topic}"
# Subscribe with pull-based consumer
# Set max_deliver to max_retries + 1 (initial + retries)
# We handle DLQ manually before NATS gives up
subscription = await self.js.pull_subscribe(
subject=subject,
durable=consumer_name,
config=ConsumerConfig(
durable_name=consumer_name,
ack_policy=AckPolicy.EXPLICIT,
deliver_policy=DeliverPolicy.NEW,
max_deliver=self.max_retries + 2, # Give us room to handle DLQ
ack_wait=30, # 30 seconds
),
)
self.subscriptions[topic] = subscription
# Start consumer task
task = asyncio.create_task(self._consume_messages(topic, subscription))
self.consumer_tasks.append(task)
logger.info(
"Subscribed to topic",
topic=topic,
subject=subject,
consumer=consumer_name,
)
except Exception as e:
logger.error("Failed to subscribe to topic", topic=topic, error=str(e))
raise
async def _ensure_stream_exists(self) -> None:
"""Ensure JetStream stream exists"""
if not self.js:
return
try:
# Try to get stream info
await self.js.stream_info(self.stream_name)
logger.debug("Stream already exists", stream=self.stream_name)
EventMetricsCollector.record_nats_stream_message(self.stream_name)
except Exception:
# Stream doesn't exist, create it
try:
await self.js.add_stream(
name=self.stream_name,
subjects=[f"{self.stream_name}.>"],
)
logger.info("Created JetStream stream", stream=self.stream_name)
except Exception as e:
logger.error(
"Failed to create stream", stream=self.stream_name, error=str(e)
)
raise
async def _consume_messages(self, topic: str, subscription: Any) -> None:
"""Consume messages from NATS JetStream subscription"""
while self.running:
try:
# Fetch messages in batches
messages = await subscription.fetch(batch=10, timeout=5)
for message in messages:
start_time = time.perf_counter()
payload = None
try:
print(f"DEBUG: Received message: {message.data}")
# Parse message payload
payload_dict = json.loads(message.data.decode())
print(f"DEBUG: Parsed payload: {payload_dict}")
payload = EventPayload(
data=payload_dict["data"],
actor=payload_dict["actor"],
tenant_id=payload_dict["tenant_id"],
trace_id=payload_dict.get("trace_id"),
schema_version=payload_dict.get("schema_version", "1.0"),
)
payload.event_id = payload_dict["event_id"]
payload.occurred_at = payload_dict["occurred_at"]
print(f"DEBUG: Reconstructed payload: {payload.event_id}")
# Call all handlers for this topic
for handler in self.handlers.get(topic, []):
print(f"DEBUG: Calling handler for topic {topic}")
await handler(topic, payload)
# Acknowledge message
await message.ack()
print("DEBUG: Message acked")
# Record metrics
duration = time.perf_counter() - start_time
EventMetricsCollector.record_consume(
topic=topic,
consumer_group=self.consumer_group,
duration_seconds=duration,
success=True,
)
except Exception as e: # pylint: disable=broad-exception-caught
duration = time.perf_counter() - start_time
error_type = type(e).__name__
# Record failure metric
EventMetricsCollector.record_consume(
topic=topic,
consumer_group=self.consumer_group,
duration_seconds=duration,
success=False,
error_type=error_type,
)
# Check delivery count for DLQ
try:
metadata = message.metadata
num_delivered = (
metadata.sequence.consumer
) # This might be wrong, check docs
# Actually nats-py MsgMetadata has num_delivered
num_delivered = metadata.num_delivered
except Exception:
num_delivered = 1
if num_delivered >= self.max_retries:
logger.error(
"Max retries exceeded, sending to DLQ",
topic=topic,
event_id=payload.event_id if payload else "unknown",
error=str(e),
num_delivered=num_delivered,
)
if self.dlq and payload:
await self.dlq.send_to_dlq(
topic=topic,
payload=payload,
error=e,
retry_count=num_delivered,
original_message_data=message.data,
)
EventMetricsCollector.record_dlq(topic, error_type)
# Ack to remove from main stream
await message.ack()
else:
# Retry (Nak)
logger.warning(
"Processing failed, retrying",
topic=topic,
event_id=payload.event_id if payload else "unknown",
error=str(e),
attempt=num_delivered,
)
EventMetricsCollector.record_retry(topic, num_delivered)
await message.nak()
except TimeoutError:
# No messages available, continue polling
continue
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Consumer error", topic=topic, error=str(e))
await asyncio.sleep(1) # Wait before retrying