Files
ai-tax-agent/libs/events/nats_bus.py
harkon f0f7674b8d
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
clean up base infra
2025-10-11 11:42:43 +01:00

268 lines
9.2 KiB
Python

"""NATS.io with JetStream implementation of EventBus."""
import asyncio
import json
from collections.abc import Awaitable, Callable
from typing import Any
import nats
import structlog
from nats.aio.client import Client as NATS
from nats.js import JetStreamContext
from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy
from .base import EventBus, EventPayload
logger = structlog.get_logger()
class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
"""NATS.io with JetStream implementation of EventBus"""
def __init__(
self,
servers: str | list[str] = "nats://localhost:4222",
stream_name: str = "TAX_AGENT_EVENTS",
consumer_group: str = "tax-agent",
):
if isinstance(servers, str):
self.servers = [servers]
else:
self.servers = servers
self.stream_name = stream_name
self.consumer_group = consumer_group
self.nc: NATS | None = None
self.js: JetStreamContext | None = None
self.handlers: dict[
str, list[Callable[[str, EventPayload], Awaitable[None]]]
] = {}
self.subscriptions: dict[str, Any] = {}
self.running = False
self.consumer_tasks: list[asyncio.Task[None]] = []
async def start(self) -> None:
"""Start NATS connection and JetStream context"""
if self.running:
return
try:
# Connect to NATS
self.nc = await nats.connect(servers=self.servers)
# Get JetStream context
self.js = self.nc.jetstream()
# Ensure stream exists
await self._ensure_stream_exists()
self.running = True
logger.info(
"NATS event bus started",
servers=self.servers,
stream=self.stream_name,
)
except Exception as e:
logger.error("Failed to start NATS event bus", error=str(e))
raise
async def stop(self) -> None:
"""Stop NATS connection and consumers"""
if not self.running:
return
# Cancel consumer tasks
for task in self.consumer_tasks:
task.cancel()
if self.consumer_tasks:
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
# Unsubscribe from all subscriptions
for subscription in self.subscriptions.values():
try:
await subscription.unsubscribe()
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Error unsubscribing", error=str(e))
# Close NATS connection
if self.nc:
await self.nc.close()
self.running = False
logger.info("NATS event bus stopped")
async def publish(self, topic: str, payload: EventPayload) -> bool:
"""Publish event to NATS JetStream"""
if not self.js:
raise RuntimeError("Event bus not started")
try:
# Create subject name from topic
subject = f"{self.stream_name}.{topic}"
# Publish message with headers
headers = {
"event_id": payload.event_id,
"tenant_id": payload.tenant_id,
"actor": payload.actor,
"trace_id": payload.trace_id or "",
"schema_version": payload.schema_version,
}
ack = await self.js.publish(
subject=subject,
payload=payload.to_json().encode(),
headers=headers,
)
logger.info(
"Event published",
topic=topic,
subject=subject,
event_id=payload.event_id,
stream_seq=ack.seq,
)
return True
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to publish event",
topic=topic,
event_id=payload.event_id,
error=str(e),
)
return False
async def subscribe(
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
) -> None:
"""Subscribe to NATS JetStream topic"""
if not self.js:
raise RuntimeError("Event bus not started")
if topic not in self.handlers:
self.handlers[topic] = []
self.handlers[topic].append(handler)
if topic not in self.subscriptions:
try:
# Create subject pattern for topic
subject = f"{self.stream_name}.{topic}"
# Create durable consumer
consumer_name = f"{self.consumer_group}-{topic}"
# Subscribe with pull-based consumer
subscription = await self.js.pull_subscribe(
subject=subject,
durable=consumer_name,
config=ConsumerConfig(
durable_name=consumer_name,
ack_policy=AckPolicy.EXPLICIT,
deliver_policy=DeliverPolicy.NEW,
max_deliver=3,
ack_wait=30, # 30 seconds
),
)
self.subscriptions[topic] = subscription
# Start consumer task
task = asyncio.create_task(self._consume_messages(topic, subscription))
self.consumer_tasks.append(task)
logger.info(
"Subscribed to topic",
topic=topic,
subject=subject,
consumer=consumer_name,
)
except Exception as e:
logger.error("Failed to subscribe to topic", topic=topic, error=str(e))
raise
async def _ensure_stream_exists(self) -> None:
"""Ensure JetStream stream exists"""
if not self.js:
return
try:
# Try to get stream info
await self.js.stream_info(self.stream_name)
logger.debug("Stream already exists", stream=self.stream_name)
except Exception:
# Stream doesn't exist, create it
try:
await self.js.add_stream(
name=self.stream_name,
subjects=[f"{self.stream_name}.*"],
)
logger.info("Created JetStream stream", stream=self.stream_name)
except Exception as e:
logger.error(
"Failed to create stream", stream=self.stream_name, error=str(e)
)
raise
async def _consume_messages(self, topic: str, subscription: Any) -> None:
"""Consume messages from NATS JetStream subscription"""
while self.running:
try:
# Fetch messages in batches
messages = await subscription.fetch(batch=10, timeout=20)
for message in messages:
try:
# Parse message payload
payload_dict = json.loads(message.data.decode())
payload = EventPayload(
data=payload_dict["data"],
actor=payload_dict["actor"],
tenant_id=payload_dict["tenant_id"],
trace_id=payload_dict.get("trace_id"),
schema_version=payload_dict.get("schema_version", "1.0"),
)
payload.event_id = payload_dict["event_id"]
payload.occurred_at = payload_dict["occurred_at"]
# Call all handlers for this topic
for handler in self.handlers.get(topic, []):
try:
await handler(topic, payload)
except (
Exception
) as e: # pylint: disable=broad-exception-caught
logger.error(
"Handler failed",
topic=topic,
event_id=payload.event_id,
error=str(e),
)
# Acknowledge message
await message.ack()
except json.JSONDecodeError as e:
logger.error(
"Failed to decode message", topic=topic, error=str(e)
)
await message.nak()
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(
"Failed to process message", topic=topic, error=str(e)
)
await message.nak()
except TimeoutError:
# No messages available, continue polling
continue
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Consumer error", topic=topic, error=str(e))
await asyncio.sleep(5) # Wait before retrying