"""NATS.io with JetStream implementation of EventBus.""" import asyncio import json from collections.abc import Awaitable, Callable from typing import Any import nats import structlog from nats.aio.client import Client as NATS from nats.js import JetStreamContext from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy from .base import EventBus, EventPayload logger = structlog.get_logger() class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes """NATS.io with JetStream implementation of EventBus""" def __init__( self, servers: str | list[str] = "nats://localhost:4222", stream_name: str = "TAX_AGENT_EVENTS", consumer_group: str = "tax-agent", ): if isinstance(servers, str): self.servers = [servers] else: self.servers = servers self.stream_name = stream_name self.consumer_group = consumer_group self.nc: NATS | None = None self.js: JetStreamContext | None = None self.handlers: dict[ str, list[Callable[[str, EventPayload], Awaitable[None]]] ] = {} self.subscriptions: dict[str, Any] = {} self.running = False self.consumer_tasks: list[asyncio.Task[None]] = [] async def start(self) -> None: """Start NATS connection and JetStream context""" if self.running: return try: # Connect to NATS self.nc = await nats.connect(servers=self.servers) # Get JetStream context self.js = self.nc.jetstream() # Ensure stream exists await self._ensure_stream_exists() self.running = True logger.info( "NATS event bus started", servers=self.servers, stream=self.stream_name, ) except Exception as e: logger.error("Failed to start NATS event bus", error=str(e)) raise async def stop(self) -> None: """Stop NATS connection and consumers""" if not self.running: return # Cancel consumer tasks for task in self.consumer_tasks: task.cancel() if self.consumer_tasks: await asyncio.gather(*self.consumer_tasks, return_exceptions=True) # Unsubscribe from all subscriptions for subscription in self.subscriptions.values(): try: await subscription.unsubscribe() except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Error unsubscribing", error=str(e)) # Close NATS connection if self.nc: await self.nc.close() self.running = False logger.info("NATS event bus stopped") async def publish(self, topic: str, payload: EventPayload) -> bool: """Publish event to NATS JetStream""" if not self.js: raise RuntimeError("Event bus not started") try: # Create subject name from topic subject = f"{self.stream_name}.{topic}" # Publish message with headers headers = { "event_id": payload.event_id, "tenant_id": payload.tenant_id, "actor": payload.actor, "trace_id": payload.trace_id or "", "schema_version": payload.schema_version, } ack = await self.js.publish( subject=subject, payload=payload.to_json().encode(), headers=headers, ) logger.info( "Event published", topic=topic, subject=subject, event_id=payload.event_id, stream_seq=ack.seq, ) return True except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Failed to publish event", topic=topic, event_id=payload.event_id, error=str(e), ) return False async def subscribe( self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]] ) -> None: """Subscribe to NATS JetStream topic""" if not self.js: raise RuntimeError("Event bus not started") if topic not in self.handlers: self.handlers[topic] = [] self.handlers[topic].append(handler) if topic not in self.subscriptions: try: # Create subject pattern for topic subject = f"{self.stream_name}.{topic}" # Create durable consumer consumer_name = f"{self.consumer_group}-{topic}" # Subscribe with pull-based consumer subscription = await self.js.pull_subscribe( subject=subject, durable=consumer_name, config=ConsumerConfig( durable_name=consumer_name, ack_policy=AckPolicy.EXPLICIT, deliver_policy=DeliverPolicy.NEW, max_deliver=3, ack_wait=30, # 30 seconds ), ) self.subscriptions[topic] = subscription # Start consumer task task = asyncio.create_task(self._consume_messages(topic, subscription)) self.consumer_tasks.append(task) logger.info( "Subscribed to topic", topic=topic, subject=subject, consumer=consumer_name, ) except Exception as e: logger.error("Failed to subscribe to topic", topic=topic, error=str(e)) raise async def _ensure_stream_exists(self) -> None: """Ensure JetStream stream exists""" if not self.js: return try: # Try to get stream info await self.js.stream_info(self.stream_name) logger.debug("Stream already exists", stream=self.stream_name) except Exception: # Stream doesn't exist, create it try: await self.js.add_stream( name=self.stream_name, subjects=[f"{self.stream_name}.*"], ) logger.info("Created JetStream stream", stream=self.stream_name) except Exception as e: logger.error( "Failed to create stream", stream=self.stream_name, error=str(e) ) raise async def _consume_messages(self, topic: str, subscription: Any) -> None: """Consume messages from NATS JetStream subscription""" while self.running: try: # Fetch messages in batches messages = await subscription.fetch(batch=10, timeout=20) for message in messages: try: # Parse message payload payload_dict = json.loads(message.data.decode()) payload = EventPayload( data=payload_dict["data"], actor=payload_dict["actor"], tenant_id=payload_dict["tenant_id"], trace_id=payload_dict.get("trace_id"), schema_version=payload_dict.get("schema_version", "1.0"), ) payload.event_id = payload_dict["event_id"] payload.occurred_at = payload_dict["occurred_at"] # Call all handlers for this topic for handler in self.handlers.get(topic, []): try: await handler(topic, payload) except ( Exception ) as e: # pylint: disable=broad-exception-caught logger.error( "Handler failed", topic=topic, event_id=payload.event_id, error=str(e), ) # Acknowledge message await message.ack() except json.JSONDecodeError as e: logger.error( "Failed to decode message", topic=topic, error=str(e) ) await message.nak() except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Failed to process message", topic=topic, error=str(e) ) await message.nak() except TimeoutError: # No messages available, continue polling continue except Exception as e: # pylint: disable=broad-exception-caught logger.error("Consumer error", topic=topic, error=str(e)) await asyncio.sleep(5) # Wait before retrying