Files
ai-tax-agent/tests/unit/test_nats_bus.py
harkon fdba81809f
Some checks failed
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
completed local setup with compose
2025-11-26 13:17:17 +00:00

324 lines
11 KiB
Python

"""Tests for NATS event bus implementation."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nats.js.api import ConsumerConfig
from libs.events.base import EventPayload
from libs.events.nats_bus import NATSEventBus
@pytest.fixture
def event_payload():
"""Create a test event payload."""
return EventPayload(
data={"test": "data", "value": 123},
actor="test-user",
tenant_id="test-tenant",
trace_id="test-trace-123",
schema_version="1.0",
)
@pytest.fixture
def nats_bus():
"""Create a NATS event bus instance."""
return NATSEventBus(
servers="nats://localhost:4222",
stream_name="TEST_STREAM",
consumer_group="test-group",
)
class TestNATSEventBus:
"""Test cases for NATS event bus."""
@pytest.mark.asyncio
async def test_initialization(self, nats_bus):
"""Test NATS event bus initialization."""
assert nats_bus.servers == ["nats://localhost:4222"]
assert nats_bus.stream_name == "TEST_STREAM"
assert nats_bus.consumer_group == "test-group"
assert nats_bus.dlq_stream_name == "TAX_AGENT_DLQ"
assert nats_bus.max_retries == 3
assert not nats_bus.running
assert nats_bus.nc is None
assert nats_bus.js is None
assert nats_bus.dlq is None
@pytest.mark.asyncio
async def test_initialization_with_multiple_servers(self):
"""Test NATS event bus initialization with multiple servers."""
servers = ["nats://server1:4222", "nats://server2:4222"]
bus = NATSEventBus(servers=servers)
assert bus.servers == servers
@pytest.mark.asyncio
@patch("libs.events.nats_bus.nats.connect")
@patch("libs.events.nats_bus.DLQHandler")
async def test_start(self, mock_dlq_cls, mock_connect, nats_bus):
"""Test starting the NATS event bus."""
# Mock NATS connection and JetStream
mock_nc = AsyncMock()
mock_js = AsyncMock()
# jetstream() is synchronous, so we mock it as a MagicMock or just set return value
mock_nc.jetstream = MagicMock(return_value=mock_js)
mock_connect.return_value = mock_nc
# Mock DLQ handler
mock_dlq_instance = MagicMock()
mock_dlq_instance.ensure_dlq_stream_exists = AsyncMock()
mock_dlq_cls.return_value = mock_dlq_instance
# Mock stream info to simulate existing stream
mock_js.stream_info.return_value = {"name": "TEST_STREAM"}
await nats_bus.start()
assert nats_bus.running
assert nats_bus.nc == mock_nc
assert nats_bus.js == mock_js
assert nats_bus.dlq == mock_dlq_instance
mock_connect.assert_called_once_with(servers=["nats://localhost:4222"])
mock_dlq_instance.ensure_dlq_stream_exists.assert_called_once()
@pytest.mark.asyncio
@patch("libs.events.nats_bus.nats.connect")
@patch("libs.events.nats_bus.DLQHandler")
async def test_start_creates_stream_if_not_exists(
self, mock_dlq_cls, mock_connect, nats_bus
):
"""Test that start creates stream if it doesn't exist."""
# Mock NATS connection and JetStream
mock_nc = AsyncMock()
mock_js = AsyncMock()
mock_nc.jetstream = MagicMock(return_value=mock_js)
mock_connect.return_value = mock_nc
# Mock DLQ handler
mock_dlq_instance = MagicMock()
mock_dlq_instance.ensure_dlq_stream_exists = AsyncMock()
mock_dlq_cls.return_value = mock_dlq_instance
# Mock stream_info to raise NotFoundError, then add_stream
from nats.js.errors import NotFoundError
mock_js.stream_info.side_effect = NotFoundError
mock_js.add_stream = AsyncMock()
await nats_bus.start()
mock_js.add_stream.assert_called_once()
call_args = mock_js.add_stream.call_args
assert call_args[1]["subjects"] == ["TEST_STREAM.>"]
@pytest.mark.asyncio
async def test_start_already_running(self, nats_bus):
"""Test that start does nothing if already running."""
nats_bus.running = True
original_nc = nats_bus.nc
await nats_bus.start()
assert nats_bus.nc == original_nc
@pytest.mark.asyncio
async def test_stop(self, nats_bus):
"""Test stopping the NATS event bus."""
# Setup mock objects
mock_nc = AsyncMock()
mock_subscription = AsyncMock()
# Create a real task for consumer_tasks
async def dummy_task():
pass
real_task = asyncio.create_task(dummy_task())
nats_bus.running = True
nats_bus.nc = mock_nc
nats_bus.subscriptions = {"test-topic": mock_subscription}
nats_bus.consumer_tasks = [real_task]
await nats_bus.stop()
assert not nats_bus.running
assert real_task.cancelled() or real_task.done()
mock_subscription.unsubscribe.assert_called_once()
mock_nc.close.assert_called_once()
@pytest.mark.asyncio
async def test_stop_not_running(self, nats_bus):
"""Test that stop does nothing if not running."""
assert not nats_bus.running
await nats_bus.stop()
assert not nats_bus.running
@pytest.mark.asyncio
@patch("libs.events.nats_bus.EventMetricsCollector")
async def test_publish(self, mock_metrics, nats_bus, event_payload):
"""Test publishing an event."""
# Setup mock JetStream
mock_js = AsyncMock()
mock_ack = MagicMock()
mock_ack.seq = 123
mock_js.publish.return_value = mock_ack
nats_bus.js = mock_js
result = await nats_bus.publish("test-topic", event_payload)
assert result is True
mock_js.publish.assert_called_once()
call_args = mock_js.publish.call_args
assert call_args[1]["subject"] == "TEST_STREAM.test-topic"
assert call_args[1]["payload"] == event_payload.to_json().encode()
# Verify metrics recorded
mock_metrics.record_publish.assert_called_once()
assert mock_metrics.record_publish.call_args[1]["success"] is True
@pytest.mark.asyncio
async def test_publish_not_started(self, nats_bus, event_payload):
"""Test publishing when event bus is not started."""
with pytest.raises(RuntimeError, match="Event bus not started"):
await nats_bus.publish("test-topic", event_payload)
@pytest.mark.asyncio
@patch("libs.events.nats_bus.EventMetricsCollector")
async def test_publish_failure(self, mock_metrics, nats_bus, event_payload):
"""Test publishing failure."""
# Setup mock JetStream that raises exception
mock_js = AsyncMock()
mock_js.publish.side_effect = Exception("Publish failed")
nats_bus.js = mock_js
result = await nats_bus.publish("test-topic", event_payload)
assert result is False
# Verify metrics recorded failure
mock_metrics.record_publish.assert_called_once()
assert mock_metrics.record_publish.call_args[1]["success"] is False
@pytest.mark.asyncio
async def test_subscribe(self, nats_bus):
"""Test subscribing to a topic."""
# Setup mock JetStream
mock_js = AsyncMock()
mock_subscription = AsyncMock()
mock_js.pull_subscribe.return_value = mock_subscription
nats_bus.js = mock_js
# Mock handler
async def test_handler(topic: str, payload: EventPayload) -> None:
pass
with patch("asyncio.create_task") as mock_create_task:
await nats_bus.subscribe("test-topic", test_handler)
assert "test-topic" in nats_bus.handlers
assert test_handler in nats_bus.handlers["test-topic"]
assert "test-topic" in nats_bus.subscriptions
mock_js.pull_subscribe.assert_called_once()
# Verify ConsumerConfig
call_kwargs = mock_js.pull_subscribe.call_args[1]
config = call_kwargs["config"]
assert isinstance(config, ConsumerConfig)
assert config.max_deliver == 5 # 3 retries + 2 buffer
mock_create_task.assert_called_once()
@pytest.mark.asyncio
async def test_subscribe_not_started(self, nats_bus):
"""Test subscribing when event bus is not started."""
async def test_handler(topic: str, payload: EventPayload) -> None:
pass
with pytest.raises(RuntimeError, match="Event bus not started"):
await nats_bus.subscribe("test-topic", test_handler)
@pytest.mark.asyncio
async def test_subscribe_multiple_handlers(self, nats_bus):
"""Test subscribing multiple handlers to the same topic."""
# Setup mock JetStream
mock_js = AsyncMock()
mock_subscription = AsyncMock()
mock_js.pull_subscribe.return_value = mock_subscription
nats_bus.js = mock_js
# Mock handlers
async def handler1(topic: str, payload: EventPayload) -> None:
pass
async def handler2(topic: str, payload: EventPayload) -> None:
pass
with patch("asyncio.create_task"):
await nats_bus.subscribe("test-topic", handler1)
await nats_bus.subscribe("test-topic", handler2)
assert len(nats_bus.handlers["test-topic"]) == 2
assert handler1 in nats_bus.handlers["test-topic"]
assert handler2 in nats_bus.handlers["test-topic"]
@pytest.mark.asyncio
@patch("libs.events.nats_bus.EventMetricsCollector")
async def test_consume_messages(self, mock_metrics, nats_bus, event_payload):
"""Test consuming messages from NATS."""
# Setup mock subscription and message
mock_subscription = AsyncMock()
mock_message = MagicMock()
mock_message.data.decode.return_value = event_payload.to_json()
mock_message.ack = AsyncMock()
mock_subscription.fetch.return_value = [mock_message]
nats_bus.running = True
# Mock handler
handler_called = False
received_topic = None
received_payload = None
async def test_handler(topic: str, payload: EventPayload) -> None:
nonlocal handler_called, received_topic, received_payload
handler_called = True
received_topic = topic
received_payload = payload
nats_bus.handlers["test-topic"] = [test_handler]
# Run one iteration of message consumption
with patch.object(nats_bus, "running", side_effect=[True, False]):
await nats_bus._consume_messages("test-topic", mock_subscription)
assert handler_called
assert received_topic == "test-topic"
assert received_payload.event_id == event_payload.event_id
mock_message.ack.assert_called_once()
# Verify metrics
mock_metrics.record_consume.assert_called_once()
assert mock_metrics.record_consume.call_args[1]["success"] is True
@pytest.mark.asyncio
async def test_factory_integration(self):
"""Test that the factory can create a NATS event bus."""
from libs.events.factory import create_event_bus
bus = create_event_bus(
"nats",
servers="nats://localhost:4222",
stream_name="TEST_STREAM",
consumer_group="test-group",
)
assert isinstance(bus, NATSEventBus)
assert bus.servers == ["nats://localhost:4222"]
assert bus.stream_name == "TEST_STREAM"
assert bus.consumer_group == "test-group"