Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
268 lines
9.2 KiB
Python
268 lines
9.2 KiB
Python
"""NATS.io with JetStream implementation of EventBus."""
|
|
|
|
import asyncio
|
|
import json
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any
|
|
|
|
import nats
|
|
import structlog
|
|
from nats.aio.client import Client as NATS
|
|
from nats.js import JetStreamContext
|
|
from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy
|
|
|
|
from .base import EventBus, EventPayload
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes
|
|
"""NATS.io with JetStream implementation of EventBus"""
|
|
|
|
def __init__(
|
|
self,
|
|
servers: str | list[str] = "nats://localhost:4222",
|
|
stream_name: str = "TAX_AGENT_EVENTS",
|
|
consumer_group: str = "tax-agent",
|
|
):
|
|
if isinstance(servers, str):
|
|
self.servers = [servers]
|
|
else:
|
|
self.servers = servers
|
|
|
|
self.stream_name = stream_name
|
|
self.consumer_group = consumer_group
|
|
self.nc: NATS | None = None
|
|
self.js: JetStreamContext | None = None
|
|
self.handlers: dict[
|
|
str, list[Callable[[str, EventPayload], Awaitable[None]]]
|
|
] = {}
|
|
self.subscriptions: dict[str, Any] = {}
|
|
self.running = False
|
|
self.consumer_tasks: list[asyncio.Task[None]] = []
|
|
|
|
async def start(self) -> None:
|
|
"""Start NATS connection and JetStream context"""
|
|
if self.running:
|
|
return
|
|
|
|
try:
|
|
# Connect to NATS
|
|
self.nc = await nats.connect(servers=self.servers)
|
|
|
|
# Get JetStream context
|
|
self.js = self.nc.jetstream()
|
|
|
|
# Ensure stream exists
|
|
await self._ensure_stream_exists()
|
|
|
|
self.running = True
|
|
logger.info(
|
|
"NATS event bus started",
|
|
servers=self.servers,
|
|
stream=self.stream_name,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to start NATS event bus", error=str(e))
|
|
raise
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop NATS connection and consumers"""
|
|
if not self.running:
|
|
return
|
|
|
|
# Cancel consumer tasks
|
|
for task in self.consumer_tasks:
|
|
task.cancel()
|
|
|
|
if self.consumer_tasks:
|
|
await asyncio.gather(*self.consumer_tasks, return_exceptions=True)
|
|
|
|
# Unsubscribe from all subscriptions
|
|
for subscription in self.subscriptions.values():
|
|
try:
|
|
await subscription.unsubscribe()
|
|
except Exception as e: # pylint: disable=broad-exception-caught
|
|
logger.warning("Error unsubscribing", error=str(e))
|
|
|
|
# Close NATS connection
|
|
if self.nc:
|
|
await self.nc.close()
|
|
|
|
self.running = False
|
|
logger.info("NATS event bus stopped")
|
|
|
|
async def publish(self, topic: str, payload: EventPayload) -> bool:
|
|
"""Publish event to NATS JetStream"""
|
|
if not self.js:
|
|
raise RuntimeError("Event bus not started")
|
|
|
|
try:
|
|
# Create subject name from topic
|
|
subject = f"{self.stream_name}.{topic}"
|
|
|
|
# Publish message with headers
|
|
headers = {
|
|
"event_id": payload.event_id,
|
|
"tenant_id": payload.tenant_id,
|
|
"actor": payload.actor,
|
|
"trace_id": payload.trace_id or "",
|
|
"schema_version": payload.schema_version,
|
|
}
|
|
|
|
ack = await self.js.publish(
|
|
subject=subject,
|
|
payload=payload.to_json().encode(),
|
|
headers=headers,
|
|
)
|
|
|
|
logger.info(
|
|
"Event published",
|
|
topic=topic,
|
|
subject=subject,
|
|
event_id=payload.event_id,
|
|
stream_seq=ack.seq,
|
|
)
|
|
return True
|
|
|
|
except Exception as e: # pylint: disable=broad-exception-caught
|
|
logger.error(
|
|
"Failed to publish event",
|
|
topic=topic,
|
|
event_id=payload.event_id,
|
|
error=str(e),
|
|
)
|
|
return False
|
|
|
|
async def subscribe(
|
|
self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]]
|
|
) -> None:
|
|
"""Subscribe to NATS JetStream topic"""
|
|
if not self.js:
|
|
raise RuntimeError("Event bus not started")
|
|
|
|
if topic not in self.handlers:
|
|
self.handlers[topic] = []
|
|
self.handlers[topic].append(handler)
|
|
|
|
if topic not in self.subscriptions:
|
|
try:
|
|
# Create subject pattern for topic
|
|
subject = f"{self.stream_name}.{topic}"
|
|
|
|
# Create durable consumer
|
|
consumer_name = f"{self.consumer_group}-{topic}"
|
|
|
|
# Subscribe with pull-based consumer
|
|
subscription = await self.js.pull_subscribe(
|
|
subject=subject,
|
|
durable=consumer_name,
|
|
config=ConsumerConfig(
|
|
durable_name=consumer_name,
|
|
ack_policy=AckPolicy.EXPLICIT,
|
|
deliver_policy=DeliverPolicy.NEW,
|
|
max_deliver=3,
|
|
ack_wait=30, # 30 seconds
|
|
),
|
|
)
|
|
|
|
self.subscriptions[topic] = subscription
|
|
|
|
# Start consumer task
|
|
task = asyncio.create_task(self._consume_messages(topic, subscription))
|
|
self.consumer_tasks.append(task)
|
|
|
|
logger.info(
|
|
"Subscribed to topic",
|
|
topic=topic,
|
|
subject=subject,
|
|
consumer=consumer_name,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to subscribe to topic", topic=topic, error=str(e))
|
|
raise
|
|
|
|
async def _ensure_stream_exists(self) -> None:
|
|
"""Ensure JetStream stream exists"""
|
|
if not self.js:
|
|
return
|
|
|
|
try:
|
|
# Try to get stream info
|
|
await self.js.stream_info(self.stream_name)
|
|
logger.debug("Stream already exists", stream=self.stream_name)
|
|
|
|
except Exception:
|
|
# Stream doesn't exist, create it
|
|
try:
|
|
await self.js.add_stream(
|
|
name=self.stream_name,
|
|
subjects=[f"{self.stream_name}.*"],
|
|
)
|
|
logger.info("Created JetStream stream", stream=self.stream_name)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to create stream", stream=self.stream_name, error=str(e)
|
|
)
|
|
raise
|
|
|
|
async def _consume_messages(self, topic: str, subscription: Any) -> None:
|
|
"""Consume messages from NATS JetStream subscription"""
|
|
while self.running:
|
|
try:
|
|
# Fetch messages in batches
|
|
messages = await subscription.fetch(batch=10, timeout=20)
|
|
|
|
for message in messages:
|
|
try:
|
|
# Parse message payload
|
|
payload_dict = json.loads(message.data.decode())
|
|
|
|
payload = EventPayload(
|
|
data=payload_dict["data"],
|
|
actor=payload_dict["actor"],
|
|
tenant_id=payload_dict["tenant_id"],
|
|
trace_id=payload_dict.get("trace_id"),
|
|
schema_version=payload_dict.get("schema_version", "1.0"),
|
|
)
|
|
payload.event_id = payload_dict["event_id"]
|
|
payload.occurred_at = payload_dict["occurred_at"]
|
|
|
|
# Call all handlers for this topic
|
|
for handler in self.handlers.get(topic, []):
|
|
try:
|
|
await handler(topic, payload)
|
|
except (
|
|
Exception
|
|
) as e: # pylint: disable=broad-exception-caught
|
|
logger.error(
|
|
"Handler failed",
|
|
topic=topic,
|
|
event_id=payload.event_id,
|
|
error=str(e),
|
|
)
|
|
|
|
# Acknowledge message
|
|
await message.ack()
|
|
|
|
except json.JSONDecodeError as e:
|
|
logger.error(
|
|
"Failed to decode message", topic=topic, error=str(e)
|
|
)
|
|
await message.nak()
|
|
except Exception as e: # pylint: disable=broad-exception-caught
|
|
logger.error(
|
|
"Failed to process message", topic=topic, error=str(e)
|
|
)
|
|
await message.nak()
|
|
|
|
except TimeoutError:
|
|
# No messages available, continue polling
|
|
continue
|
|
except Exception as e: # pylint: disable=broad-exception-caught
|
|
logger.error("Consumer error", topic=topic, error=str(e))
|
|
await asyncio.sleep(5) # Wait before retrying
|