"""NATS.io with JetStream implementation of EventBus.""" import asyncio import json import time from collections.abc import Awaitable, Callable from typing import Any import nats import structlog from nats.aio.client import Client as NATS from nats.js import JetStreamContext from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy from .base import EventBus, EventPayload from .dlq import DLQHandler from .metrics import EventMetricsCollector logger = structlog.get_logger() class NATSEventBus(EventBus): # pylint: disable=too-many-instance-attributes """NATS.io with JetStream implementation of EventBus""" def __init__( self, servers: str | list[str] = "nats://localhost:4222", stream_name: str = "TAX_AGENT_EVENTS", consumer_group: str = "tax-agent", dlq_stream_name: str = "TAX_AGENT_DLQ", max_retries: int = 3, ): if isinstance(servers, str): self.servers = [servers] else: self.servers = servers self.stream_name = stream_name self.consumer_group = consumer_group self.dlq_stream_name = dlq_stream_name self.max_retries = max_retries self.nc: NATS | None = None self.js: JetStreamContext | None = None self.dlq: DLQHandler | None = None self.handlers: dict[ str, list[Callable[[str, EventPayload], Awaitable[None]]] ] = {} self.subscriptions: dict[str, Any] = {} self.running = False self.consumer_tasks: list[asyncio.Task[None]] = [] async def start(self) -> None: """Start NATS connection and JetStream context""" if self.running: return try: # Connect to NATS self.nc = await nats.connect( servers=self.servers, connect_timeout=10, reconnect_time_wait=1, ) # Get JetStream context self.js = self.nc.jetstream(timeout=10) # Initialize DLQ handler self.dlq = DLQHandler( js=self.js, dlq_stream_name=self.dlq_stream_name, max_retries=self.max_retries, ) # Ensure streams exist await self._ensure_stream_exists() await self.dlq.ensure_dlq_stream_exists() self.running = True logger.info( "NATS event bus started", servers=self.servers, stream=self.stream_name, dlq_stream=self.dlq_stream_name, ) except Exception as e: logger.error("Failed to start NATS event bus", error=str(e)) raise async def stop(self) -> None: """Stop NATS connection and consumers""" if not self.running: return # Cancel consumer tasks for task in self.consumer_tasks: task.cancel() if self.consumer_tasks: await asyncio.gather(*self.consumer_tasks, return_exceptions=True) # Unsubscribe from all subscriptions for subscription in self.subscriptions.values(): try: await subscription.unsubscribe() except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Error unsubscribing", error=str(e)) # Close NATS connection if self.nc: await self.nc.close() self.running = False logger.info("NATS event bus stopped") async def publish(self, topic: str, payload: EventPayload) -> bool: """Publish event to NATS JetStream""" if not self.js: raise RuntimeError("Event bus not started") start_time = time.perf_counter() try: # Create subject name from topic subject = f"{self.stream_name}.{topic}" # Publish message with headers headers = { "event_id": payload.event_id, "tenant_id": payload.tenant_id, "actor": payload.actor, "trace_id": payload.trace_id or "", "schema_version": payload.schema_version, } ack = await self.js.publish( subject=subject, payload=payload.to_json().encode(), headers=headers, ) duration = time.perf_counter() - start_time EventMetricsCollector.record_publish( topic=topic, duration_seconds=duration, success=True, ) logger.info( "Event published", topic=topic, subject=subject, event_id=payload.event_id, stream_seq=ack.seq, ) return True except Exception as e: # pylint: disable=broad-exception-caught duration = time.perf_counter() - start_time EventMetricsCollector.record_publish( topic=topic, duration_seconds=duration, success=False, error_type=type(e).__name__, ) logger.error( "Failed to publish event", topic=topic, event_id=payload.event_id, error=str(e), ) return False async def subscribe( self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]] ) -> None: """Subscribe to NATS JetStream topic""" if not self.js: raise RuntimeError("Event bus not started") if topic not in self.handlers: self.handlers[topic] = [] self.handlers[topic].append(handler) if topic not in self.subscriptions: try: # Create subject pattern for topic subject = f"{self.stream_name}.{topic}" # Create durable consumer # Durable names cannot contain dots, so we replace them safe_topic = topic.replace(".", "-") consumer_name = f"{self.consumer_group}-{safe_topic}" # Subscribe with pull-based consumer # Set max_deliver to max_retries + 1 (initial + retries) # We handle DLQ manually before NATS gives up subscription = await self.js.pull_subscribe( subject=subject, durable=consumer_name, config=ConsumerConfig( durable_name=consumer_name, ack_policy=AckPolicy.EXPLICIT, deliver_policy=DeliverPolicy.NEW, max_deliver=self.max_retries + 2, # Give us room to handle DLQ ack_wait=30, # 30 seconds ), ) self.subscriptions[topic] = subscription # Start consumer task task = asyncio.create_task(self._consume_messages(topic, subscription)) self.consumer_tasks.append(task) logger.info( "Subscribed to topic", topic=topic, subject=subject, consumer=consumer_name, ) except Exception as e: logger.error("Failed to subscribe to topic", topic=topic, error=str(e)) raise async def _ensure_stream_exists(self) -> None: """Ensure JetStream stream exists""" if not self.js: return try: # Try to get stream info await self.js.stream_info(self.stream_name) logger.debug("Stream already exists", stream=self.stream_name) EventMetricsCollector.record_nats_stream_message(self.stream_name) except Exception: # Stream doesn't exist, create it try: await self.js.add_stream( name=self.stream_name, subjects=[f"{self.stream_name}.>"], ) logger.info("Created JetStream stream", stream=self.stream_name) except Exception as e: logger.error( "Failed to create stream", stream=self.stream_name, error=str(e) ) raise async def _consume_messages(self, topic: str, subscription: Any) -> None: """Consume messages from NATS JetStream subscription""" while self.running: try: # Fetch messages in batches messages = await subscription.fetch(batch=10, timeout=5) for message in messages: start_time = time.perf_counter() payload = None try: print(f"DEBUG: Received message: {message.data}") # Parse message payload payload_dict = json.loads(message.data.decode()) print(f"DEBUG: Parsed payload: {payload_dict}") payload = EventPayload( data=payload_dict["data"], actor=payload_dict["actor"], tenant_id=payload_dict["tenant_id"], trace_id=payload_dict.get("trace_id"), schema_version=payload_dict.get("schema_version", "1.0"), ) payload.event_id = payload_dict["event_id"] payload.occurred_at = payload_dict["occurred_at"] print(f"DEBUG: Reconstructed payload: {payload.event_id}") # Call all handlers for this topic for handler in self.handlers.get(topic, []): print(f"DEBUG: Calling handler for topic {topic}") await handler(topic, payload) # Acknowledge message await message.ack() print("DEBUG: Message acked") # Record metrics duration = time.perf_counter() - start_time EventMetricsCollector.record_consume( topic=topic, consumer_group=self.consumer_group, duration_seconds=duration, success=True, ) except Exception as e: # pylint: disable=broad-exception-caught duration = time.perf_counter() - start_time error_type = type(e).__name__ # Record failure metric EventMetricsCollector.record_consume( topic=topic, consumer_group=self.consumer_group, duration_seconds=duration, success=False, error_type=error_type, ) # Check delivery count for DLQ try: metadata = message.metadata num_delivered = ( metadata.sequence.consumer ) # This might be wrong, check docs # Actually nats-py MsgMetadata has num_delivered num_delivered = metadata.num_delivered except Exception: num_delivered = 1 if num_delivered >= self.max_retries: logger.error( "Max retries exceeded, sending to DLQ", topic=topic, event_id=payload.event_id if payload else "unknown", error=str(e), num_delivered=num_delivered, ) if self.dlq and payload: await self.dlq.send_to_dlq( topic=topic, payload=payload, error=e, retry_count=num_delivered, original_message_data=message.data, ) EventMetricsCollector.record_dlq(topic, error_type) # Ack to remove from main stream await message.ack() else: # Retry (Nak) logger.warning( "Processing failed, retrying", topic=topic, event_id=payload.event_id if payload else "unknown", error=str(e), attempt=num_delivered, ) EventMetricsCollector.record_retry(topic, num_delivered) await message.nak() except TimeoutError: # No messages available, continue polling continue except Exception as e: # pylint: disable=broad-exception-caught logger.error("Consumer error", topic=topic, error=str(e)) await asyncio.sleep(1) # Wait before retrying