"""Tests for NATS event bus implementation.""" import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest from nats.js.api import ConsumerConfig from libs.events.base import EventPayload from libs.events.nats_bus import NATSEventBus @pytest.fixture def event_payload(): """Create a test event payload.""" return EventPayload( data={"test": "data", "value": 123}, actor="test-user", tenant_id="test-tenant", trace_id="test-trace-123", schema_version="1.0", ) @pytest.fixture def nats_bus(): """Create a NATS event bus instance.""" return NATSEventBus( servers="nats://localhost:4222", stream_name="TEST_STREAM", consumer_group="test-group", ) class TestNATSEventBus: """Test cases for NATS event bus.""" @pytest.mark.asyncio async def test_initialization(self, nats_bus): """Test NATS event bus initialization.""" assert nats_bus.servers == ["nats://localhost:4222"] assert nats_bus.stream_name == "TEST_STREAM" assert nats_bus.consumer_group == "test-group" assert nats_bus.dlq_stream_name == "TAX_AGENT_DLQ" assert nats_bus.max_retries == 3 assert not nats_bus.running assert nats_bus.nc is None assert nats_bus.js is None assert nats_bus.dlq is None @pytest.mark.asyncio async def test_initialization_with_multiple_servers(self): """Test NATS event bus initialization with multiple servers.""" servers = ["nats://server1:4222", "nats://server2:4222"] bus = NATSEventBus(servers=servers) assert bus.servers == servers @pytest.mark.asyncio @patch("libs.events.nats_bus.nats.connect") @patch("libs.events.nats_bus.DLQHandler") async def test_start(self, mock_dlq_cls, mock_connect, nats_bus): """Test starting the NATS event bus.""" # Mock NATS connection and JetStream mock_nc = AsyncMock() mock_js = AsyncMock() # jetstream() is synchronous, so we mock it as a MagicMock or just set return value mock_nc.jetstream = MagicMock(return_value=mock_js) mock_connect.return_value = mock_nc # Mock DLQ handler mock_dlq_instance = MagicMock() mock_dlq_instance.ensure_dlq_stream_exists = AsyncMock() mock_dlq_cls.return_value = mock_dlq_instance # Mock stream info to simulate existing stream mock_js.stream_info.return_value = {"name": "TEST_STREAM"} await nats_bus.start() assert nats_bus.running assert nats_bus.nc == mock_nc assert nats_bus.js == mock_js assert nats_bus.dlq == mock_dlq_instance mock_connect.assert_called_once_with(servers=["nats://localhost:4222"]) mock_dlq_instance.ensure_dlq_stream_exists.assert_called_once() @pytest.mark.asyncio @patch("libs.events.nats_bus.nats.connect") @patch("libs.events.nats_bus.DLQHandler") async def test_start_creates_stream_if_not_exists( self, mock_dlq_cls, mock_connect, nats_bus ): """Test that start creates stream if it doesn't exist.""" # Mock NATS connection and JetStream mock_nc = AsyncMock() mock_js = AsyncMock() mock_nc.jetstream = MagicMock(return_value=mock_js) mock_connect.return_value = mock_nc # Mock DLQ handler mock_dlq_instance = MagicMock() mock_dlq_instance.ensure_dlq_stream_exists = AsyncMock() mock_dlq_cls.return_value = mock_dlq_instance # Mock stream_info to raise NotFoundError, then add_stream from nats.js.errors import NotFoundError mock_js.stream_info.side_effect = NotFoundError mock_js.add_stream = AsyncMock() await nats_bus.start() mock_js.add_stream.assert_called_once() call_args = mock_js.add_stream.call_args assert call_args[1]["subjects"] == ["TEST_STREAM.>"] @pytest.mark.asyncio async def test_start_already_running(self, nats_bus): """Test that start does nothing if already running.""" nats_bus.running = True original_nc = nats_bus.nc await nats_bus.start() assert nats_bus.nc == original_nc @pytest.mark.asyncio async def test_stop(self, nats_bus): """Test stopping the NATS event bus.""" # Setup mock objects mock_nc = AsyncMock() mock_subscription = AsyncMock() # Create a real task for consumer_tasks async def dummy_task(): pass real_task = asyncio.create_task(dummy_task()) nats_bus.running = True nats_bus.nc = mock_nc nats_bus.subscriptions = {"test-topic": mock_subscription} nats_bus.consumer_tasks = [real_task] await nats_bus.stop() assert not nats_bus.running assert real_task.cancelled() or real_task.done() mock_subscription.unsubscribe.assert_called_once() mock_nc.close.assert_called_once() @pytest.mark.asyncio async def test_stop_not_running(self, nats_bus): """Test that stop does nothing if not running.""" assert not nats_bus.running await nats_bus.stop() assert not nats_bus.running @pytest.mark.asyncio @patch("libs.events.nats_bus.EventMetricsCollector") async def test_publish(self, mock_metrics, nats_bus, event_payload): """Test publishing an event.""" # Setup mock JetStream mock_js = AsyncMock() mock_ack = MagicMock() mock_ack.seq = 123 mock_js.publish.return_value = mock_ack nats_bus.js = mock_js result = await nats_bus.publish("test-topic", event_payload) assert result is True mock_js.publish.assert_called_once() call_args = mock_js.publish.call_args assert call_args[1]["subject"] == "TEST_STREAM.test-topic" assert call_args[1]["payload"] == event_payload.to_json().encode() # Verify metrics recorded mock_metrics.record_publish.assert_called_once() assert mock_metrics.record_publish.call_args[1]["success"] is True @pytest.mark.asyncio async def test_publish_not_started(self, nats_bus, event_payload): """Test publishing when event bus is not started.""" with pytest.raises(RuntimeError, match="Event bus not started"): await nats_bus.publish("test-topic", event_payload) @pytest.mark.asyncio @patch("libs.events.nats_bus.EventMetricsCollector") async def test_publish_failure(self, mock_metrics, nats_bus, event_payload): """Test publishing failure.""" # Setup mock JetStream that raises exception mock_js = AsyncMock() mock_js.publish.side_effect = Exception("Publish failed") nats_bus.js = mock_js result = await nats_bus.publish("test-topic", event_payload) assert result is False # Verify metrics recorded failure mock_metrics.record_publish.assert_called_once() assert mock_metrics.record_publish.call_args[1]["success"] is False @pytest.mark.asyncio async def test_subscribe(self, nats_bus): """Test subscribing to a topic.""" # Setup mock JetStream mock_js = AsyncMock() mock_subscription = AsyncMock() mock_js.pull_subscribe.return_value = mock_subscription nats_bus.js = mock_js # Mock handler async def test_handler(topic: str, payload: EventPayload) -> None: pass with patch("asyncio.create_task") as mock_create_task: await nats_bus.subscribe("test-topic", test_handler) assert "test-topic" in nats_bus.handlers assert test_handler in nats_bus.handlers["test-topic"] assert "test-topic" in nats_bus.subscriptions mock_js.pull_subscribe.assert_called_once() # Verify ConsumerConfig call_kwargs = mock_js.pull_subscribe.call_args[1] config = call_kwargs["config"] assert isinstance(config, ConsumerConfig) assert config.max_deliver == 5 # 3 retries + 2 buffer mock_create_task.assert_called_once() @pytest.mark.asyncio async def test_subscribe_not_started(self, nats_bus): """Test subscribing when event bus is not started.""" async def test_handler(topic: str, payload: EventPayload) -> None: pass with pytest.raises(RuntimeError, match="Event bus not started"): await nats_bus.subscribe("test-topic", test_handler) @pytest.mark.asyncio async def test_subscribe_multiple_handlers(self, nats_bus): """Test subscribing multiple handlers to the same topic.""" # Setup mock JetStream mock_js = AsyncMock() mock_subscription = AsyncMock() mock_js.pull_subscribe.return_value = mock_subscription nats_bus.js = mock_js # Mock handlers async def handler1(topic: str, payload: EventPayload) -> None: pass async def handler2(topic: str, payload: EventPayload) -> None: pass with patch("asyncio.create_task"): await nats_bus.subscribe("test-topic", handler1) await nats_bus.subscribe("test-topic", handler2) assert len(nats_bus.handlers["test-topic"]) == 2 assert handler1 in nats_bus.handlers["test-topic"] assert handler2 in nats_bus.handlers["test-topic"] @pytest.mark.asyncio @patch("libs.events.nats_bus.EventMetricsCollector") async def test_consume_messages(self, mock_metrics, nats_bus, event_payload): """Test consuming messages from NATS.""" # Setup mock subscription and message mock_subscription = AsyncMock() mock_message = MagicMock() mock_message.data.decode.return_value = event_payload.to_json() mock_message.ack = AsyncMock() mock_subscription.fetch.return_value = [mock_message] nats_bus.running = True # Mock handler handler_called = False received_topic = None received_payload = None async def test_handler(topic: str, payload: EventPayload) -> None: nonlocal handler_called, received_topic, received_payload handler_called = True received_topic = topic received_payload = payload nats_bus.handlers["test-topic"] = [test_handler] # Run one iteration of message consumption with patch.object(nats_bus, "running", side_effect=[True, False]): await nats_bus._consume_messages("test-topic", mock_subscription) assert handler_called assert received_topic == "test-topic" assert received_payload.event_id == event_payload.event_id mock_message.ack.assert_called_once() # Verify metrics mock_metrics.record_consume.assert_called_once() assert mock_metrics.record_consume.call_args[1]["success"] is True @pytest.mark.asyncio async def test_factory_integration(self): """Test that the factory can create a NATS event bus.""" from libs.events.factory import create_event_bus bus = create_event_bus( "nats", servers="nats://localhost:4222", stream_name="TEST_STREAM", consumer_group="test-group", ) assert isinstance(bus, NATSEventBus) assert bus.servers == ["nats://localhost:4222"] assert bus.stream_name == "TEST_STREAM" assert bus.consumer_group == "test-group"