"""AWS SQS/SNS implementation of EventBus.""" import asyncio import json from collections.abc import Awaitable, Callable from typing import Any import boto3 # type: ignore import structlog from botocore.exceptions import ClientError # type: ignore from ..base import EventBus, EventPayload logger = structlog.get_logger() class SQSEventBus(EventBus): # pylint: disable=too-many-instance-attributes """AWS SQS/SNS implementation of EventBus""" def __init__(self, region_name: str = "us-east-1"): self.region_name = region_name self.sns_client: Any = None self.sqs_client: Any = None self.topic_arns: dict[str, str] = {} self.queue_urls: dict[str, str] = {} self.handlers: dict[ str, list[Callable[[str, EventPayload], Awaitable[None]]] ] = {} self.running = False self.consumer_tasks: list[asyncio.Task[None]] = [] async def start(self) -> None: """Start SQS/SNS clients""" if self.running: return self.sns_client = boto3.client("sns", region_name=self.region_name) self.sqs_client = boto3.client("sqs", region_name=self.region_name) self.running = True logger.info("SQS event bus started", region=self.region_name) async def stop(self) -> None: """Stop SQS/SNS clients and consumers""" if not self.running: return # Cancel consumer tasks for task in self.consumer_tasks: task.cancel() if self.consumer_tasks: await asyncio.gather(*self.consumer_tasks, return_exceptions=True) self.running = False logger.info("SQS event bus stopped") async def publish(self, topic: str, payload: EventPayload) -> bool: """Publish event to SNS topic""" if not self.sns_client: raise RuntimeError("Event bus not started") try: # Ensure topic exists topic_arn = await self._ensure_topic_exists(topic) # Publish message response = self.sns_client.publish( TopicArn=topic_arn, Message=payload.to_json(), MessageAttributes={ "event_id": {"DataType": "String", "StringValue": payload.event_id}, "tenant_id": { "DataType": "String", "StringValue": payload.tenant_id, }, "actor": {"DataType": "String", "StringValue": payload.actor}, }, ) logger.info( "Event published", topic=topic, event_id=payload.event_id, message_id=response["MessageId"], ) return True except ClientError as e: logger.error( "Failed to publish event", topic=topic, event_id=payload.event_id, error=str(e), ) return False async def subscribe( self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]] ) -> None: """Subscribe to SNS topic via SQS queue""" if topic not in self.handlers: self.handlers[topic] = [] self.handlers[topic].append(handler) if topic not in self.queue_urls: # Create SQS queue for this topic queue_name = f"tax-agent-{topic}" queue_url = await self._ensure_queue_exists(queue_name) self.queue_urls[topic] = queue_url # Subscribe queue to SNS topic topic_arn = await self._ensure_topic_exists(topic) await self._subscribe_queue_to_topic(queue_url, topic_arn) # Start consumer task task = asyncio.create_task(self._consume_messages(topic, queue_url)) self.consumer_tasks.append(task) logger.info("Subscribed to topic", topic=topic, queue_name=queue_name) async def _ensure_topic_exists(self, topic: str) -> str: """Ensure SNS topic exists and return ARN""" if topic in self.topic_arns: return self.topic_arns[topic] try: response = self.sns_client.create_topic(Name=topic) topic_arn = response["TopicArn"] self.topic_arns[topic] = topic_arn return str(topic_arn) except ClientError as e: logger.error("Failed to create topic", topic=topic, error=str(e)) raise async def _ensure_queue_exists(self, queue_name: str) -> str: """Ensure SQS queue exists and return URL""" try: response = self.sqs_client.create_queue(QueueName=queue_name) return str(response["QueueUrl"]) except ClientError as e: logger.error("Failed to create queue", queue_name=queue_name, error=str(e)) raise async def _subscribe_queue_to_topic(self, queue_url: str, topic_arn: str) -> None: """Subscribe SQS queue to SNS topic""" try: # Get queue attributes queue_attrs = self.sqs_client.get_queue_attributes( QueueUrl=queue_url, AttributeNames=["QueueArn"] ) queue_arn = queue_attrs["Attributes"]["QueueArn"] # Subscribe queue to topic self.sns_client.subscribe( TopicArn=topic_arn, Protocol="sqs", Endpoint=queue_arn ) except ClientError as e: logger.error("Failed to subscribe queue to topic", error=str(e)) raise async def _consume_messages(self, topic: str, queue_url: str) -> None: """Consume messages from SQS queue""" # pylint: disable=too-many-nested-blocks while self.running: try: response = self.sqs_client.receive_message( QueueUrl=queue_url, MaxNumberOfMessages=10, WaitTimeSeconds=20 ) messages = response.get("Messages", []) for message in messages: try: # Parse SNS message sns_message = json.loads(message["Body"]) payload_dict = json.loads(sns_message["Message"]) payload = EventPayload( data=payload_dict["data"], actor=payload_dict["actor"], tenant_id=payload_dict["tenant_id"], trace_id=payload_dict.get("trace_id"), schema_version=payload_dict.get("schema_version", "1.0"), ) payload.event_id = payload_dict["event_id"] payload.occurred_at = payload_dict["occurred_at"] # Call all handlers for this topic for handler in self.handlers.get(topic, []): try: await handler(topic, payload) # pylint: disable=broad-exception-caught except Exception as e: logger.error( "Handler failed", topic=topic, event_id=payload.event_id, error=str(e), ) # Delete message from queue self.sqs_client.delete_message( QueueUrl=queue_url, ReceiptHandle=message["ReceiptHandle"] ) except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Failed to process message", topic=topic, error=str(e) ) except Exception as e: # pylint: disable=broad-exception-caught logger.error("Consumer error", topic=topic, error=str(e)) await asyncio.sleep(5) # Wait before retrying