"""NATS.io with JetStream implementation of EventBus.""" import asyncio import json from collections.abc import Awaitable, Callable from typing import Any import nats # type: ignore import structlog from nats.aio.client import Client as NATS # type: ignore from nats.js import JetStreamContext # type: ignore from .base import EventBus, EventPayload logger = structlog.get_logger() class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes """NATS.io with JetStream implementation of EventBus""" def __init__( self, servers: str | list[str] = "nats://localhost:4222", stream_name: str = "TAX_AGENT_EVENTS", consumer_group: str = "tax-agent", ): if isinstance(servers, str): self.servers = [servers] else: self.servers = servers self.stream_name = stream_name self.consumer_group = consumer_group self.nc: NATS | None = None self.js: JetStreamContext | None = None self.handlers: dict[ str, list[Callable[[str, EventPayload], Awaitable[None]]] ] = {} self.subscriptions: dict[str, Any] = {} self.running = False self.consumer_tasks: list[asyncio.Task[None]] = [] async def start(self) -> None: """Start NATS connection and JetStream context""" if self.running: return try: # Connect to NATS self.nc = await nats.connect(servers=self.servers) # Get JetStream context self.js = self.nc.jetstream() # Ensure stream exists await self._ensure_stream_exists() self.running = True logger.info( "NATS event bus started", servers=self.servers, stream=self.stream_name, ) except Exception as e: logger.error("Failed to start NATS event bus", error=str(e)) raise async def stop(self) -> None: """Stop NATS connection and consumers""" if not self.running: return # Cancel consumer tasks for task in self.consumer_tasks: task.cancel() if self.consumer_tasks: await asyncio.gather(*self.consumer_tasks, return_exceptions=True) # Unsubscribe from all subscriptions for subscription in self.subscriptions.values(): try: await subscription.unsubscribe() except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Error unsubscribing", error=str(e)) # Close NATS connection if self.nc: await self.nc.close() self.running = False logger.info("NATS event bus stopped") async def publish(self, topic: str, payload: EventPayload) -> bool: """Publish event to NATS JetStream""" if not self.js: raise RuntimeError("Event bus not started") try: # Create subject name from topic subject = f"{self.stream_name}.{topic}" # Publish message with headers headers = { "event_id": payload.event_id, "tenant_id": payload.tenant_id, "actor": payload.actor, "trace_id": payload.trace_id or "", "schema_version": payload.schema_version, } ack = await self.js.publish( subject=subject, payload=payload.to_json().encode(), headers=headers, ) logger.info( "Event published", topic=topic, subject=subject, event_id=payload.event_id, stream_seq=ack.seq, ) return True except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Failed to publish event", topic=topic, event_id=payload.event_id, error=str(e), ) return False async def subscribe( self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]] ) -> None: """Subscribe to NATS JetStream topic""" if not self.js: raise RuntimeError("Event bus not started") if topic not in self.handlers: self.handlers[topic] = [] self.handlers[topic].append(handler) if topic not in self.subscriptions: try: # Create subject pattern for topic subject = f"{self.stream_name}.{topic}" # Create durable consumer consumer_name = f"{self.consumer_group}-{topic}" # Subscribe with pull-based consumer subscription = await self.js.pull_subscribe( subject=subject, durable=consumer_name, config=nats.js.api.ConsumerConfig( durable_name=consumer_name, ack_policy=nats.js.api.AckPolicy.EXPLICIT, deliver_policy=nats.js.api.DeliverPolicy.NEW, max_deliver=3, ack_wait=30, # 30 seconds ), ) self.subscriptions[topic] = subscription # Start consumer task task = asyncio.create_task(self._consume_messages(topic, subscription)) self.consumer_tasks.append(task) logger.info( "Subscribed to topic", topic=topic, subject=subject, consumer=consumer_name, ) except Exception as e: logger.error("Failed to subscribe to topic", topic=topic, error=str(e)) raise async def _ensure_stream_exists(self) -> None: """Ensure JetStream stream exists""" if not self.js: return try: # Try to get stream info await self.js.stream_info(self.stream_name) logger.debug("Stream already exists", stream=self.stream_name) except nats.js.errors.NotFoundError: # Stream doesn't exist, create it try: await self.js.add_stream( name=self.stream_name, subjects=[f"{self.stream_name}.*"], retention=nats.js.api.RetentionPolicy.WORK_QUEUE, max_age=7 * 24 * 60 * 60, # 7 days in seconds storage=nats.js.api.StorageType.FILE, ) logger.info("Created JetStream stream", stream=self.stream_name) except Exception as e: logger.error( "Failed to create stream", stream=self.stream_name, error=str(e) ) raise async def _consume_messages(self, topic: str, subscription: Any) -> None: """Consume messages from NATS JetStream subscription""" while self.running: try: # Fetch messages in batches messages = await subscription.fetch(batch=10, timeout=20) for message in messages: try: # Parse message payload payload_dict = json.loads(message.data.decode()) payload = EventPayload( data=payload_dict["data"], actor=payload_dict["actor"], tenant_id=payload_dict["tenant_id"], trace_id=payload_dict.get("trace_id"), schema_version=payload_dict.get("schema_version", "1.0"), ) payload.event_id = payload_dict["event_id"] payload.occurred_at = payload_dict["occurred_at"] # Call all handlers for this topic for handler in self.handlers.get(topic, []): try: await handler(topic, payload) except ( Exception ) as e: # pylint: disable=broad-exception-caught logger.error( "Handler failed", topic=topic, event_id=payload.event_id, error=str(e), ) # Acknowledge message await message.ack() except json.JSONDecodeError as e: logger.error( "Failed to decode message", topic=topic, error=str(e) ) await message.nak() except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Failed to process message", topic=topic, error=str(e) ) await message.nak() except asyncio.TimeoutError: # No messages available, continue polling continue except Exception as e: # pylint: disable=broad-exception-caught logger.error("Consumer error", topic=topic, error=str(e)) await asyncio.sleep(5) # Wait before retrying