"""Kafka implementation of EventBus.""" import asyncio import json from collections.abc import Awaitable, Callable import structlog from aiokafka import AIOKafkaConsumer, AIOKafkaProducer # type: ignore from .base import EventBus, EventPayload logger = structlog.get_logger() class KafkaEventBus(EventBus): """Kafka implementation of EventBus""" def __init__(self, bootstrap_servers: str): self.bootstrap_servers = bootstrap_servers.split(",") self.producer: AIOKafkaProducer | None = None self.consumers: dict[str, AIOKafkaConsumer] = {} self.handlers: dict[ str, list[Callable[[str, EventPayload], Awaitable[None]]] ] = {} self.running = False async def start(self) -> None: """Start Kafka producer""" if self.running: return self.producer = AIOKafkaProducer( bootstrap_servers=",".join(self.bootstrap_servers), value_serializer=lambda v: v.encode("utf-8"), ) await self.producer.start() self.running = True logger.info("Kafka event bus started", bootstrap_servers=self.bootstrap_servers) async def stop(self) -> None: """Stop Kafka producer and consumers""" if not self.running: return if self.producer: await self.producer.stop() for consumer in self.consumers.values(): await consumer.stop() self.running = False logger.info("Kafka event bus stopped") async def publish(self, topic: str, payload: EventPayload) -> bool: """Publish event to Kafka topic""" if not self.producer: raise RuntimeError("Event bus not started") try: await self.producer.send_and_wait(topic, payload.to_json()) logger.info( "Event published", topic=topic, event_id=payload.event_id, actor=payload.actor, tenant_id=payload.tenant_id, ) return True except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Failed to publish event", topic=topic, event_id=payload.event_id, error=str(e), ) return False async def subscribe( self, topic: str, handler: Callable[[str, EventPayload], Awaitable[None]] ) -> None: """Subscribe to Kafka topic""" if topic not in self.handlers: self.handlers[topic] = [] self.handlers[topic].append(handler) if topic not in self.consumers: consumer = AIOKafkaConsumer( topic, bootstrap_servers=",".join(self.bootstrap_servers), value_deserializer=lambda m: m.decode("utf-8"), group_id=f"tax-agent-{topic}", auto_offset_reset="latest", ) self.consumers[topic] = consumer await consumer.start() # Start consumer task asyncio.create_task(self._consume_messages(topic, consumer)) logger.info("Subscribed to topic", topic=topic) async def _consume_messages(self, topic: str, consumer: AIOKafkaConsumer) -> None: """Consume messages from Kafka topic""" try: async for message in consumer: try: if message.value is not None: payload_dict = json.loads(message.value) else: continue payload = EventPayload( data=payload_dict["data"], actor=payload_dict["actor"], tenant_id=payload_dict["tenant_id"], trace_id=payload_dict.get("trace_id"), schema_version=payload_dict.get("schema_version", "1.0"), ) payload.event_id = payload_dict["event_id"] payload.occurred_at = payload_dict["occurred_at"] # Call all handlers for this topic for handler in self.handlers.get(topic, []): try: await handler(topic, payload) except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Handler failed", topic=topic, event_id=payload.event_id, handler=handler.__name__, error=str(e), ) except json.JSONDecodeError as e: logger.error("Failed to decode message", topic=topic, error=str(e)) except Exception as e: # pylint: disable=broad-exception-caught logger.error("Failed to process message", topic=topic, error=str(e)) except Exception as e: # pylint: disable=broad-exception-caught logger.error("Consumer error", topic=topic, error=str(e))