completed local setup with compose
Some checks failed
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled

This commit is contained in:
harkon
2025-11-26 13:17:17 +00:00
parent 8fe5e62fee
commit fdba81809f
87 changed files with 5610 additions and 3376 deletions

View File

@@ -0,0 +1,76 @@
import asyncio
import httpx
import pytest
from libs.events import EventTopics, NATSEventBus
from libs.schemas.events import DocumentExtractedEventData
# Configuration
INGESTION_URL = "http://localhost:8000"
NATS_URL = "nats://localhost:4222"
TENANT_ID = "tenant_e2e_test"
@pytest.mark.e2e
@pytest.mark.asyncio
async def test_backend_journey():
"""
E2E test for the full backend journey: Ingest -> OCR -> Extract.
"""
# 1. Initialize NATS bus
bus = NATSEventBus(
servers=[NATS_URL],
stream_name="TAX_AGENT_EVENTS",
consumer_group="e2e-test-consumer",
)
await bus.start()
# Future to capture the final event
extraction_future = asyncio.Future()
async def extraction_handler(topic, payload):
if payload.tenant_id == TENANT_ID:
extraction_future.set_result(payload)
# Subscribe to the final event in the chain
await bus.subscribe(EventTopics.DOC_EXTRACTED, extraction_handler)
try:
# 2. Upload a document
async with httpx.AsyncClient() as client:
# Create a dummy PDF file
files = {"file": ("test.pdf", b"%PDF-1.4 mock content", "application/pdf")}
response = await client.post(
f"{INGESTION_URL}/upload",
files=files,
data={"kind": "invoice", "source": "e2e_test"},
headers={"X-Tenant-ID": TENANT_ID, "X-User-ID": "e2e_tester"},
)
assert response.status_code == 200, f"Upload failed: {response.text}"
upload_data = response.json()
doc_id = upload_data["doc_id"]
print(f"Uploaded document: {doc_id}")
# 3. Wait for extraction event (with timeout)
try:
# Give it enough time for the whole chain to process
payload = await asyncio.wait_for(extraction_future, timeout=30.0)
# 4. Verify payload
data = payload.data
assert data["doc_id"] == doc_id
assert data["tenant_id"] == TENANT_ID
assert "extraction_results" in data
# Validate against schema
event_data = DocumentExtractedEventData(**data)
assert event_data.doc_id == doc_id
print("E2E Journey completed successfully!")
except TimeoutError:
pytest.fail("Timed out waiting for extraction event")
finally:
await bus.stop()

View File

@@ -0,0 +1,39 @@
import pytest
from libs.events import EventTopics
from libs.schemas.events import DocumentIngestedEventData, validate_event_data
@pytest.mark.integration
def test_doc_ingested_contract():
"""
Contract test for DOC_INGESTED event.
Verifies that the event data schema matches the expected Pydantic model.
"""
# Sample valid payload data
valid_data = {
"doc_id": "doc_01H1V2W3X4Y5Z6",
"filename": "test.pdf",
"kind": "invoice",
"source": "upload",
"checksum_sha256": "a" * 64,
"size_bytes": 1024,
"mime_type": "application/pdf",
"storage_path": "s3://bucket/key.pdf",
}
# 1. Verify it validates against the Pydantic model directly
model = DocumentIngestedEventData(**valid_data)
assert model.doc_id == valid_data["doc_id"]
# 2. Verify it validates using the shared validation utility
validated_model = validate_event_data(EventTopics.DOC_INGESTED, valid_data)
assert isinstance(validated_model, DocumentIngestedEventData)
assert validated_model.doc_id == valid_data["doc_id"]
# 3. Verify invalid data fails
invalid_data = valid_data.copy()
del invalid_data["doc_id"]
with pytest.raises(ValueError):
validate_event_data(EventTopics.DOC_INGESTED, invalid_data)

View File

@@ -0,0 +1,98 @@
import asyncio
import pytest
from libs.events.base import EventPayload
from libs.events.nats_bus import NATSEventBus
from libs.schemas.events import DocumentIngestedEventData
@pytest.mark.asyncio
async def test_nats_bus_class():
"""Test NATSEventBus class within pytest."""
import time
unique_suffix = int(time.time())
stream_name = f"PYTEST_DEBUG_STREAM_{unique_suffix}"
print(f"\nStarting NATSEventBus with stream {stream_name}...")
bus = NATSEventBus(
servers="nats://localhost:4222",
stream_name=stream_name,
consumer_group="test-debug-group",
)
await bus.start()
print("Bus started.")
# Clean up (just in case)
try:
await bus.js.delete_stream(stream_name)
except Exception:
pass
await bus._ensure_stream_exists()
# Wait for stream to be ready
await asyncio.sleep(2)
try:
info = await bus.js.stream_info(stream_name)
print(f"Stream info: {info.config.subjects}")
except Exception as e:
print(f"Failed to get stream info: {e}")
# Setup subscriber
received_event = asyncio.Future()
async def handler(topic, event):
print(f"Handler received event: {event.event_id}")
if not received_event.done():
received_event.set_result(event)
await bus.subscribe("doc.ingested", handler)
print("Publishing message...")
data = DocumentIngestedEventData(
doc_id="test-doc-123",
filename="test.pdf",
mime_type="application/pdf",
size_bytes=1024,
source="upload",
kind="invoice",
storage_path="s3://test-bucket/test.pdf",
checksum_sha256="a" * 64,
)
payload = EventPayload(
data=data.model_dump(mode="json"),
actor="tester",
tenant_id="tenant-1",
schema_version="1.0",
)
payload.event_id = "evt-debug-1"
success = await bus.publish("doc.ingested", payload)
print(f"Published: {success}")
try:
result = await asyncio.wait_for(received_event, timeout=5.0)
print(f"Received event: {result.event_id}")
assert result.event_id == "evt-debug-1"
assert result.data["doc_id"] == "test-doc-123"
except TimeoutError:
print("Timeout waiting for event")
raise
await bus.stop()
print("Bus stopped.")
# Cleanup stream
try:
nc = await nats.connect("nats://localhost:4222")
js = nc.jetstream()
await js.delete_stream(stream_name)
await nc.close()
except Exception:
pass

View File

@@ -0,0 +1,240 @@
import asyncio
import json
import pytest
import pytest_asyncio
from libs.events.base import EventPayload
from libs.events.nats_bus import NATSEventBus
from libs.schemas.events import DocumentIngestedEventData
# Check if NATS is available
async def is_nats_available():
import nats
try:
nc = await nats.connect("nats://localhost:4222")
await nc.close()
return True
except Exception:
return False
@pytest_asyncio.fixture
async def nats_bus():
"""Create and start a NATS event bus for testing."""
if not await is_nats_available():
pytest.skip("NATS server not available at localhost:4222")
bus = NATSEventBus(
servers="nats://localhost:4222",
stream_name="TEST_INTEGRATION_STREAM",
consumer_group="test-integration-group",
dlq_stream_name="TEST_INTEGRATION_DLQ",
max_retries=2,
)
await bus.start()
# Clean up streams before test
try:
await bus.js.delete_stream("TEST_INTEGRATION_STREAM")
await bus.js.delete_stream("TEST_INTEGRATION_DLQ")
except Exception:
pass
# Re-create streams
await bus._ensure_stream_exists()
await bus.dlq.ensure_dlq_stream_exists()
# Allow time for streams to propagate
await asyncio.sleep(2)
yield bus
# Clean up after test
try:
await bus.js.delete_stream("TEST_INTEGRATION_STREAM")
await bus.js.delete_stream("TEST_INTEGRATION_DLQ")
except Exception:
pass
await bus.stop()
@pytest.mark.integration
@pytest.mark.asyncio
async def test_publish_subscribe_flow():
"""Test end-to-end publish and subscribe flow."""
# Instantiate bus directly to debug fixture issues
bus = NATSEventBus(
servers="nats://localhost:4222",
stream_name="TEST_INTEGRATION_STREAM_DIRECT",
consumer_group="test-integration-group-direct",
dlq_stream_name="TEST_INTEGRATION_DLQ_DIRECT",
max_retries=2,
)
await bus.start()
try:
await bus.js.delete_stream("TEST_INTEGRATION_STREAM_DIRECT")
except Exception:
pass
await bus._ensure_stream_exists()
try:
# Create event data
data = DocumentIngestedEventData(
doc_id="test-doc-123",
filename="test.pdf",
mime_type="application/pdf",
size_bytes=1024,
source="upload",
kind="invoice",
storage_path="s3://test-bucket/test.pdf",
checksum_sha256="a" * 64,
)
payload = EventPayload(
data=data.model_dump(mode="json"),
actor="test-user",
tenant_id="test-tenant",
trace_id="trace-123",
schema_version="1.0",
)
payload.event_id = "evt-123"
# Setup subscriber
received_event = asyncio.Future()
async def handler(topic, event):
if not received_event.done():
received_event.set_result(event)
await bus.subscribe("doc.ingested", handler)
# Publish event
success = await bus.publish("doc.ingested", payload)
assert success is True
# Wait for reception
try:
result = await asyncio.wait_for(received_event, timeout=5.0)
assert result.event_id == payload.event_id
assert result.data["doc_id"] == "test-doc-123"
except TimeoutError:
pytest.fail("Event not received within timeout")
finally:
await bus.stop()
@pytest.mark.integration
@pytest.mark.asyncio
async def test_dlq_routing(nats_bus):
"""Test that failed events are routed to DLQ after retries."""
# Create event data
data = DocumentIngestedEventData(
doc_id="test-doc-fail",
filename="fail.pdf",
mime_type="application/pdf",
size_bytes=1024,
source="upload",
kind="invoice",
storage_path="s3://test-bucket/fail.pdf",
checksum_sha256="a" * 64,
)
payload = EventPayload(
data=data.model_dump(mode="json"),
actor="test-user",
tenant_id="test-tenant",
trace_id="trace-fail",
schema_version="1.0",
)
# Setup failing handler
failure_count = 0
async def failing_handler(topic, event):
nonlocal failure_count
failure_count += 1
raise ValueError("Simulated processing failure")
await nats_bus.subscribe("doc.fail", failing_handler)
# Publish event
await nats_bus.publish("doc.fail", payload)
# Wait for retries and DLQ routing
await asyncio.sleep(2.0) # Wait for processing
assert failure_count >= 2
# Consume from DLQ to verify
dlq_sub = await nats_bus.js.pull_subscribe(
subject="TEST_INTEGRATION_DLQ.doc.fail", durable="test-dlq-consumer"
)
msgs = await dlq_sub.fetch(batch=1, timeout=5.0)
assert len(msgs) == 1
dlq_msg = msgs[0]
dlq_data = json.loads(dlq_msg.data.decode())
assert dlq_data["original_payload"]["event_id"] == payload.event_id
assert dlq_data["error"]["type"] == "ValueError"
assert dlq_data["error"]["message"] == "Simulated processing failure"
await dlq_msg.ack()
@pytest.mark.integration
@pytest.mark.asyncio
async def test_metrics_recording(nats_bus):
"""Test that metrics are recorded during event processing."""
from libs.events.metrics import event_consumed_total, event_published_total
# Get initial values
initial_published = event_published_total.labels(topic="doc.metrics")._value.get()
initial_consumed = event_consumed_total.labels(
topic="doc.metrics", consumer_group="test-integration-group"
)._value.get()
# Create and publish event
data = DocumentIngestedEventData(
doc_id="test-doc-metrics",
filename="metrics.pdf",
mime_type="application/pdf",
size_bytes=1024,
source="upload",
kind="invoice",
storage_path="s3://test-bucket/metrics.pdf",
checksum_sha256="a" * 64,
)
payload = EventPayload(
data=data.model_dump(mode="json"),
actor="test-user",
tenant_id="test-tenant",
trace_id="trace-metrics",
schema_version="1.0",
)
received_event = asyncio.Future()
async def handler(topic, event):
if not received_event.done():
received_event.set_result(event)
await nats_bus.subscribe("doc.metrics", handler)
await nats_bus.publish("doc.metrics", payload)
await asyncio.wait_for(received_event, timeout=5.0)
# Check metrics increased
final_published = event_published_total.labels(topic="doc.metrics")._value.get()
final_consumed = event_consumed_total.labels(
topic="doc.metrics", consumer_group="test-integration-group"
)._value.get()
assert final_published > initial_published
assert final_consumed > initial_consumed

317
tests/unit/test_dlq.py Normal file
View File

@@ -0,0 +1,317 @@
"""Tests for Dead Letter Queue (DLQ) handler."""
import json
from unittest.mock import AsyncMock, patch
import pytest
from libs.events.base import EventPayload
from libs.events.dlq import DLQHandler, DLQMetrics
@pytest.fixture
def event_payload():
"""Create a test event payload."""
return EventPayload(
data={"test": "data", "value": 123},
actor="test-user",
tenant_id="test-tenant",
trace_id="test-trace-123",
schema_version="1.0",
)
@pytest.fixture
def mock_js():
"""Create a mock JetStream context."""
js = AsyncMock()
js.stream_info = AsyncMock()
js.add_stream = AsyncMock()
js.publish = AsyncMock()
return js
class TestDLQHandler:
"""Test cases for DLQ handler."""
@pytest.mark.asyncio
async def test_initialization(self, mock_js):
"""Test DLQ handler initialization."""
handler = DLQHandler(
js=mock_js,
dlq_stream_name="TEST_DLQ",
max_retries=5,
backoff_base_ms=500,
)
assert handler.js == mock_js
assert handler.dlq_stream_name == "TEST_DLQ"
assert handler.max_retries == 5
assert handler.backoff_base_ms == 500
@pytest.mark.asyncio
async def test_ensure_dlq_stream_exists_already_exists(self, mock_js):
"""Test ensuring DLQ stream when it already exists."""
mock_js.stream_info.return_value = {"name": "TEST_DLQ"}
handler = DLQHandler(js=mock_js, dlq_stream_name="TEST_DLQ")
await handler.ensure_dlq_stream_exists()
mock_js.stream_info.assert_called_once_with("TEST_DLQ")
mock_js.add_stream.assert_not_called()
@pytest.mark.asyncio
async def test_ensure_dlq_stream_creates_stream(self, mock_js):
"""Test ensuring DLQ stream when it doesn't exist."""
from nats.js.errors import NotFoundError
mock_js.stream_info.side_effect = NotFoundError
mock_js.add_stream = AsyncMock()
handler = DLQHandler(js=mock_js, dlq_stream_name="TEST_DLQ")
await handler.ensure_dlq_stream_exists()
mock_js.add_stream.assert_called_once()
call_kwargs = mock_js.add_stream.call_args[1]
assert call_kwargs["name"] == "TEST_DLQ"
assert call_kwargs["subjects"] == ["TEST_DLQ.*"]
@pytest.mark.asyncio
async def test_send_to_dlq(self, mock_js, event_payload):
"""Test sending event to DLQ."""
handler = DLQHandler(js=mock_js)
error = ValueError("Test error message")
await handler.send_to_dlq(
topic="test-topic",
payload=event_payload,
error=error,
retry_count=3,
)
mock_js.publish.assert_called_once()
call_kwargs = mock_js.publish.call_args[1]
# Verify subject
assert call_kwargs["subject"] == "TAX_AGENT_DLQ.test-topic"
# Verify payload content
payload_data = json.loads(call_kwargs["payload"].decode())
assert payload_data["original_topic"] == "test-topic"
assert payload_data["retry_count"] == 3
assert payload_data["error"]["type"] == "ValueError"
assert payload_data["error"]["message"] == "Test error message"
# Verify headers
headers = call_kwargs["headers"]
assert headers["original_topic"] == "test-topic"
assert headers["event_id"] == event_payload.event_id
assert headers["error_type"] == "ValueError"
@pytest.mark.asyncio
async def test_send_to_dlq_with_original_message(self, mock_js, event_payload):
"""Test sending event to DLQ with original message data."""
handler = DLQHandler(js=mock_js)
original_message = b'{"test": "original"}'
error = RuntimeError("Processing failed")
await handler.send_to_dlq(
topic="test-topic",
payload=event_payload,
error=error,
retry_count=2,
original_message_data=original_message,
)
call_kwargs = mock_js.publish.call_args[1]
payload_data = json.loads(call_kwargs["payload"].decode())
assert "original_message_data" in payload_data
assert payload_data["original_message_data"] == '{"test": "original"}'
@pytest.mark.asyncio
async def test_send_to_dlq_handles_publish_failure(self, mock_js, event_payload):
"""Test DLQ handler when DLQ publish fails."""
mock_js.publish.side_effect = Exception("DLQ publish failed")
handler = DLQHandler(js=mock_js)
# Should not raise, but log critical error
await handler.send_to_dlq(
topic="test-topic",
payload=event_payload,
error=ValueError("Original error"),
retry_count=1,
)
# Verify publish was attempted
mock_js.publish.assert_called_once()
def test_calculate_backoff(self, mock_js):
"""Test exponential backoff calculation."""
handler = DLQHandler(
js=mock_js,
backoff_base_ms=1000,
backoff_multiplier=2.0,
backoff_max_ms=10000,
)
# First retry: 1000ms * 2^0 = 1000ms = 1s
assert handler.calculate_backoff(0) == 1.0
# Second retry: 1000ms * 2^1 = 2000ms = 2s
assert handler.calculate_backoff(1) == 2.0
# Third retry: 1000ms * 2^2 = 4000ms = 4s
assert handler.calculate_backoff(2) == 4.0
# Fourth retry: 1000ms * 2^3 = 8000ms = 8s
assert handler.calculate_backoff(3) == 8.0
# Fifth retry: would be 16000ms but capped at 10000ms = 10s
assert handler.calculate_backoff(4) == 10.0
@pytest.mark.asyncio
async def test_retry_with_backoff_success_first_attempt(self, mock_js):
"""Test successful operation on first attempt."""
handler = DLQHandler(js=mock_js, max_retries=3)
async def successful_func():
return "success"
success, error = await handler.retry_with_backoff(successful_func)
assert success is True
assert error is None
@pytest.mark.asyncio
async def test_retry_with_backoff_success_after_retries(self, mock_js):
"""Test successful operation after retries."""
handler = DLQHandler(
js=mock_js,
max_retries=3,
backoff_base_ms=100, # Short backoff for testing
)
attempt_count = 0
async def flaky_func():
nonlocal attempt_count
attempt_count += 1
if attempt_count < 3:
raise ValueError(f"Fail attempt {attempt_count}")
return "success"
with patch("asyncio.sleep", new=AsyncMock()): # Speed up test
success, error = await handler.retry_with_backoff(flaky_func)
assert success is True
assert error is None
assert attempt_count == 3
@pytest.mark.asyncio
async def test_retry_with_backoff_all_attempts_fail(self, mock_js):
"""Test operation that fails all retry attempts."""
handler = DLQHandler(
js=mock_js,
max_retries=2,
backoff_base_ms=100,
)
async def always_fails():
raise ValueError("Always fails")
with patch("asyncio.sleep", new=AsyncMock()): # Speed up test
success, error = await handler.retry_with_backoff(always_fails)
assert success is False
assert isinstance(error, ValueError)
assert str(error) == "Always fails"
@pytest.mark.asyncio
async def test_retry_with_backoff_applies_delay(self, mock_js):
"""Test that retry applies backoff delay."""
handler = DLQHandler(
js=mock_js,
max_retries=2,
backoff_base_ms=1000,
backoff_multiplier=2.0,
)
attempt_count = 0
async def failing_func():
nonlocal attempt_count
attempt_count += 1
raise ValueError("Fail")
with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep:
await handler.retry_with_backoff(failing_func)
# Should have called sleep twice (after 1st and 2nd failures)
assert mock_sleep.call_count == 2
# Verify backoff delays
calls = mock_sleep.call_args_list
assert calls[0][0][0] == 1.0 # First retry: 1s
assert calls[1][0][0] == 2.0 # Second retry: 2s
class TestDLQMetrics:
"""Test cases for DLQ metrics."""
def test_initialization(self):
"""Test metrics initialization."""
metrics = DLQMetrics()
assert metrics.total_dlq_events == 0
assert len(metrics.dlq_events_by_topic) == 0
assert len(metrics.dlq_events_by_error_type) == 0
def test_record_dlq_event(self):
"""Test recording DLQ events."""
metrics = DLQMetrics()
metrics.record_dlq_event("topic1", "ValueError")
metrics.record_dlq_event("topic1", "ValueError")
metrics.record_dlq_event("topic2", "RuntimeError")
assert metrics.total_dlq_events == 3
assert metrics.dlq_events_by_topic["topic1"] == 2
assert metrics.dlq_events_by_topic["topic2"] == 1
assert metrics.dlq_events_by_error_type["ValueError"] == 2
assert metrics.dlq_events_by_error_type["RuntimeError"] == 1
def test_get_metrics(self):
"""Test getting metrics snapshot."""
metrics = DLQMetrics()
metrics.record_dlq_event("topic1", "ValueError")
metrics.record_dlq_event("topic1", "RuntimeError")
snapshot = metrics.get_metrics()
assert snapshot["total_dlq_events"] == 2
assert snapshot["by_topic"]["topic1"] == 2
assert snapshot["by_error_type"]["ValueError"] == 1
assert snapshot["by_error_type"]["RuntimeError"] == 1
# Verify it's a copy, not a reference
snapshot["total_dlq_events"] = 999
assert metrics.total_dlq_events == 2
def test_reset(self):
"""Test resetting metrics."""
metrics = DLQMetrics()
metrics.record_dlq_event("topic1", "ValueError")
metrics.record_dlq_event("topic2", "RuntimeError")
assert metrics.total_dlq_events == 2
metrics.reset()
assert metrics.total_dlq_events == 0
assert len(metrics.dlq_events_by_topic) == 0
assert len(metrics.dlq_events_by_error_type) == 0

View File

@@ -0,0 +1,274 @@
"""Tests for event metrics."""
from unittest.mock import MagicMock, patch
from libs.events.metrics import (
EventMetricsCollector,
event_consumed_total,
event_dlq_total,
event_processing_duration_seconds,
event_processing_errors_total,
event_publish_errors_total,
event_published_total,
event_publishing_duration_seconds,
event_retry_total,
event_schema_validation_errors_total,
get_event_metrics_registry,
nats_consumer_lag_messages,
nats_stream_messages_total,
)
class TestEventMetrics:
"""Test cases for event metrics."""
def test_get_event_metrics_registry(self) -> None:
"""Test getting the metrics registry."""
registry = get_event_metrics_registry()
assert registry is not None
def test_metrics_exist(self) -> None:
"""Test that all expected metrics are defined."""
# Publishing metrics
assert event_published_total is not None
assert event_publish_errors_total is not None
assert event_publishing_duration_seconds is not None
# Consumption metrics
assert event_consumed_total is not None
assert event_processing_duration_seconds is not None
assert event_processing_errors_total is not None
# DLQ metrics
assert event_dlq_total is not None
assert event_retry_total is not None
# Schema validation metrics
assert event_schema_validation_errors_total is not None
# NATS metrics
assert nats_stream_messages_total is not None
assert nats_consumer_lag_messages is not None
class TestEventMetricsCollector:
"""Test cases for EventMetricsCollector."""
def test_record_publish_success(self) -> None:
"""Test recording successful publish."""
with patch.object(event_published_total, "labels") as mock_labels:
mock_counter = MagicMock()
mock_labels.return_value = mock_counter
EventMetricsCollector.record_publish(
topic="test.topic",
duration_seconds=0.05,
success=True,
)
mock_labels.assert_called_once_with(topic="test.topic")
mock_counter.inc.assert_called_once()
def test_record_publish_failure(self) -> None:
"""Test recording failed publish."""
with patch.object(event_publish_errors_total, "labels") as mock_labels:
mock_counter = MagicMock()
mock_labels.return_value = mock_counter
EventMetricsCollector.record_publish(
topic="test.topic",
duration_seconds=0.1,
success=False,
error_type="ConnectionError",
)
mock_labels.assert_called_once_with(
topic="test.topic", error_type="ConnectionError"
)
mock_counter.inc.assert_called_once()
def test_record_publish_duration(self) -> None:
"""Test recording publish duration."""
with patch.object(event_publishing_duration_seconds, "labels") as mock_labels:
mock_histogram = MagicMock()
mock_labels.return_value = mock_histogram
duration = 0.123
EventMetricsCollector.record_publish(
topic="test.topic",
duration_seconds=duration,
success=True,
)
mock_labels.assert_called_once_with(topic="test.topic")
mock_histogram.observe.assert_called_once_with(duration)
def test_record_consume_success(self) -> None:
"""Test recording successful event consumption."""
with patch.object(event_consumed_total, "labels") as mock_labels:
mock_counter = MagicMock()
mock_labels.return_value = mock_counter
EventMetricsCollector.record_consume(
topic="test.topic",
consumer_group="test-group",
duration_seconds=0.5,
success=True,
)
mock_labels.assert_called_once_with(
topic="test.topic", consumer_group="test-group"
)
mock_counter.inc.assert_called_once()
def test_record_consume_failure(self) -> None:
"""Test recording failed event consumption."""
with patch.object(event_processing_errors_total, "labels") as mock_labels:
mock_counter = MagicMock()
mock_labels.return_value = mock_counter
EventMetricsCollector.record_consume(
topic="test.topic",
consumer_group="test-group",
duration_seconds=1.0,
success=False,
error_type="ValidationError",
)
mock_labels.assert_called_once_with(
topic="test.topic",
consumer_group="test-group",
error_type="ValidationError",
)
mock_counter.inc.assert_called_once()
def test_record_consume_duration(self) -> None:
"""Test recording consumption duration."""
with patch.object(event_processing_duration_seconds, "labels") as mock_labels:
mock_histogram = MagicMock()
mock_labels.return_value = mock_histogram
duration = 2.5
EventMetricsCollector.record_consume(
topic="test.topic",
consumer_group="test-group",
duration_seconds=duration,
success=True,
)
mock_labels.assert_called_once_with(
topic="test.topic", consumer_group="test-group"
)
mock_histogram.observe.assert_called_once_with(duration)
def test_record_dlq(self) -> None:
"""Test recording DLQ event."""
with patch.object(event_dlq_total, "labels") as mock_labels:
mock_counter = MagicMock()
mock_labels.return_value = mock_counter
EventMetricsCollector.record_dlq(
topic="test.topic", error_type="TimeoutError"
)
mock_labels.assert_called_once_with(
topic="test.topic", error_type="TimeoutError"
)
mock_counter.inc.assert_called_once()
def test_record_retry(self) -> None:
"""Test recording retry attempt."""
with patch.object(event_retry_total, "labels") as mock_labels:
mock_counter = MagicMock()
mock_labels.return_value = mock_counter
EventMetricsCollector.record_retry(topic="test.topic", retry_attempt=2)
mock_labels.assert_called_once_with(topic="test.topic", retry_attempt="2")
mock_counter.inc.assert_called_once()
def test_record_schema_validation_error(self) -> None:
"""Test recording schema validation error."""
with patch.object(
event_schema_validation_errors_total, "labels"
) as mock_labels:
mock_counter = MagicMock()
mock_labels.return_value = mock_counter
EventMetricsCollector.record_schema_validation_error(
topic="test.topic", validation_error="missing_required_field"
)
mock_labels.assert_called_once_with(
topic="test.topic", validation_error="missing_required_field"
)
mock_counter.inc.assert_called_once()
def test_record_nats_stream_message(self) -> None:
"""Test recording NATS stream message."""
with patch.object(nats_stream_messages_total, "labels") as mock_labels:
mock_counter = MagicMock()
mock_labels.return_value = mock_counter
EventMetricsCollector.record_nats_stream_message(
stream_name="TAX_AGENT_EVENTS"
)
mock_labels.assert_called_once_with(stream_name="TAX_AGENT_EVENTS")
mock_counter.inc.assert_called_once()
def test_record_consumer_lag(self) -> None:
"""Test recording consumer lag."""
with patch.object(nats_consumer_lag_messages, "labels") as mock_labels:
mock_histogram = MagicMock()
mock_labels.return_value = mock_histogram
EventMetricsCollector.record_consumer_lag(
stream_name="TAX_AGENT_EVENTS",
consumer_group="tax-agent",
lag_messages=150,
)
mock_labels.assert_called_once_with(
stream_name="TAX_AGENT_EVENTS", consumer_group="tax-agent"
)
mock_histogram.observe.assert_called_once_with(150)
def test_record_publish_with_default_error_type(self) -> None:
"""Test recording publish failure with default error type."""
with patch.object(event_publish_errors_total, "labels") as mock_labels:
mock_counter = MagicMock()
mock_labels.return_value = mock_counter
EventMetricsCollector.record_publish(
topic="test.topic",
duration_seconds=0.1,
success=False,
error_type=None, # No error type provided
)
mock_labels.assert_called_once_with(
topic="test.topic", error_type="unknown" # Should default to "unknown"
)
mock_counter.inc.assert_called_once()
def test_record_consume_with_default_error_type(self) -> None:
"""Test recording consume failure with default error type."""
with patch.object(event_processing_errors_total, "labels") as mock_labels:
mock_counter = MagicMock()
mock_labels.return_value = mock_counter
EventMetricsCollector.record_consume(
topic="test.topic",
consumer_group="test-group",
duration_seconds=1.0,
success=False,
error_type=None, # No error type provided
)
mock_labels.assert_called_once_with(
topic="test.topic",
consumer_group="test-group",
error_type="unknown", # Should default to "unknown"
)
mock_counter.inc.assert_called_once()

View File

@@ -0,0 +1,500 @@
"""Tests for event schema validation."""
import pytest
from pydantic import ValidationError
from libs.events.topics import EventTopics
from libs.schemas.events import (
EVENT_SCHEMA_MAP,
CalculationReadyEventData,
DocumentExtractedEventData,
DocumentIngestedEventData,
DocumentOCRReadyEventData,
FirmSyncCompletedEventData,
FormFilledEventData,
HMRCSubmittedEventData,
KGUpsertedEventData,
KGUpsertReadyEventData,
RAGIndexedEventData,
ReviewCompletedEventData,
ReviewRequestedEventData,
get_schema_for_topic,
validate_event_data,
)
class TestDocumentIngestedEventData:
"""Test DocumentIngestedEventData schema."""
def test_valid_event(self) -> None:
"""Test creating a valid document ingested event."""
data = DocumentIngestedEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
filename="invoice_2024.pdf",
mime_type="application/pdf",
size_bytes=102400,
checksum_sha256="a" * 64,
kind="invoice",
source="manual_upload",
storage_path="raw-documents/2024/invoice_2024.pdf",
)
assert data.doc_id == "01H8Y9Z5M3K7N2P4Q6R8T0V1W3"
assert data.size_bytes == 102400
assert len(data.checksum_sha256) == 64
def test_invalid_checksum(self) -> None:
"""Test invalid SHA-256 checksum."""
with pytest.raises(ValidationError) as exc_info:
DocumentIngestedEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
filename="test.pdf",
mime_type="application/pdf",
size_bytes=1024,
checksum_sha256="invalid", # Too short
kind="invoice",
source="manual_upload",
storage_path="path/to/file",
)
assert "Invalid SHA-256 checksum format" in str(exc_info.value)
def test_negative_size(self) -> None:
"""Test negative file size validation."""
with pytest.raises(ValidationError):
DocumentIngestedEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
filename="test.pdf",
mime_type="application/pdf",
size_bytes=-1, # Negative size
checksum_sha256="a" * 64,
kind="invoice",
source="manual_upload",
storage_path="path/to/file",
)
def test_immutable(self) -> None:
"""Test that event data is immutable."""
data = DocumentIngestedEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
filename="test.pdf",
mime_type="application/pdf",
size_bytes=1024,
checksum_sha256="a" * 64,
kind="invoice",
source="manual_upload",
storage_path="path/to/file",
)
with pytest.raises(ValidationError):
data.filename = "changed.pdf" # Should raise because frozen=True
class TestDocumentOCRReadyEventData:
"""Test DocumentOCRReadyEventData schema."""
def test_valid_event(self) -> None:
"""Test creating a valid OCR ready event."""
data = DocumentOCRReadyEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
ocr_engine="tesseract",
page_count=3,
confidence_avg=0.95,
text_length=5000,
layout_detected=True,
languages_detected=["en"],
processing_time_ms=1500,
storage_path="ocr-results/doc_123.json",
)
assert data.ocr_engine == "tesseract"
assert data.confidence_avg == 0.95
assert 0.0 <= data.confidence_avg <= 1.0
def test_invalid_confidence(self) -> None:
"""Test invalid confidence score."""
with pytest.raises(ValidationError):
DocumentOCRReadyEventData(
doc_id="123",
ocr_engine="tesseract",
page_count=1,
confidence_avg=1.5, # > 1.0
text_length=100,
layout_detected=True,
processing_time_ms=1000,
storage_path="path",
)
def test_invalid_ocr_engine(self) -> None:
"""Test invalid OCR engine value."""
with pytest.raises(ValidationError):
DocumentOCRReadyEventData(
doc_id="123",
ocr_engine="invalid_engine", # Not in allowed values
page_count=1,
confidence_avg=0.9,
text_length=100,
layout_detected=True,
processing_time_ms=1000,
storage_path="path",
)
class TestDocumentExtractedEventData:
"""Test DocumentExtractedEventData schema."""
def test_valid_event(self) -> None:
"""Test creating a valid extraction event."""
data = DocumentExtractedEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
extraction_id="extr_123",
strategy="hybrid",
fields_extracted=15,
confidence_avg=0.88,
calibrated_confidence=0.91,
model_name="gpt-4",
processing_time_ms=3000,
storage_path="extractions/extr_123.json",
)
assert data.strategy == "hybrid"
assert data.model_name == "gpt-4"
def test_valid_without_model(self) -> None:
"""Test extraction event without model (rules-based)."""
data = DocumentExtractedEventData(
doc_id="123",
extraction_id="extr_456",
strategy="rules",
fields_extracted=10,
confidence_avg=0.95,
calibrated_confidence=0.93,
model_name=None, # No model for rules-based
processing_time_ms=500,
storage_path="path",
)
assert data.model_name is None
assert data.strategy == "rules"
class TestKGEvents:
"""Test Knowledge Graph event schemas."""
def test_kg_upsert_ready(self) -> None:
"""Test KG upsert ready event."""
data = KGUpsertReadyEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
entity_count=25,
relationship_count=40,
tax_year="2024-25",
taxpayer_id="TP-001",
normalization_id="norm_123",
storage_path="normalized/norm_123.json",
)
assert data.entity_count == 25
assert data.tax_year == "2024-25"
def test_kg_upserted(self) -> None:
"""Test KG upserted event."""
data = KGUpsertedEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
entities_created=10,
entities_updated=5,
relationships_created=20,
relationships_updated=10,
shacl_violations=0,
processing_time_ms=2000,
success=True,
error_message=None,
)
assert data.success is True
assert data.shacl_violations == 0
def test_kg_upserted_with_violations(self) -> None:
"""Test KG upserted event with SHACL violations."""
data = KGUpsertedEventData(
doc_id="123",
entities_created=5,
entities_updated=0,
relationships_created=8,
relationships_updated=0,
shacl_violations=3,
processing_time_ms=1500,
success=False,
error_message="SHACL validation failed: Missing required property",
)
assert data.success is False
assert data.shacl_violations == 3
assert data.error_message is not None
class TestRAGIndexedEventData:
"""Test RAG indexed event schema."""
def test_valid_event(self) -> None:
"""Test creating a valid RAG indexed event."""
data = RAGIndexedEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
collection_name="firm_knowledge",
chunks_indexed=45,
embedding_model="bge-small-en-v1.5",
pii_detected=True,
pii_redacted=True,
processing_time_ms=5000,
storage_path="chunks/doc_123.json",
)
assert data.pii_detected is True
assert data.pii_redacted is True
assert data.chunks_indexed == 45
class TestCalculationReadyEventData:
"""Test calculation ready event schema."""
def test_valid_event(self) -> None:
"""Test creating a valid calculation event."""
data = CalculationReadyEventData(
taxpayer_id="TP-001",
tax_year="2024-25",
schedule_id="SA103",
calculation_id="calc_789",
boxes_computed=50,
total_income=85000.50,
total_tax=18500.25,
confidence=0.92,
evidence_count=15,
processing_time_ms=2500,
storage_path="calculations/calc_789.json",
)
assert data.schedule_id == "SA103"
assert data.total_income == 85000.50
assert data.total_tax == 18500.25
def test_valid_without_totals(self) -> None:
"""Test calculation event without totals (partial calculation)."""
data = CalculationReadyEventData(
taxpayer_id="TP-001",
tax_year="2024-25",
schedule_id="SA102",
calculation_id="calc_456",
boxes_computed=20,
total_income=None,
total_tax=None,
confidence=0.85,
evidence_count=10,
processing_time_ms=1000,
storage_path="calculations/calc_456.json",
)
assert data.total_income is None
assert data.total_tax is None
class TestFormFilledEventData:
"""Test form filled event schema."""
def test_valid_event(self) -> None:
"""Test creating a valid form filled event."""
data = FormFilledEventData(
taxpayer_id="TP-001",
tax_year="2024-25",
form_id="SA100",
fields_filled=75,
pdf_size_bytes=524288,
storage_path="forms/SA100_filled.pdf",
evidence_bundle_path="evidence/bundle_123.zip",
checksum_sha256="b" * 64,
)
assert data.form_id == "SA100"
assert data.evidence_bundle_path is not None
class TestHMRCSubmittedEventData:
"""Test HMRC submitted event schema."""
def test_successful_submission(self) -> None:
"""Test successful HMRC submission."""
data = HMRCSubmittedEventData(
taxpayer_id="TP-001",
tax_year="2024-25",
submission_id="sub_999",
hmrc_reference="HMRC-REF-12345",
submission_type="sandbox",
success=True,
status_code=200,
error_message=None,
processing_time_ms=3000,
)
assert data.success is True
assert data.hmrc_reference is not None
def test_failed_submission(self) -> None:
"""Test failed HMRC submission."""
data = HMRCSubmittedEventData(
taxpayer_id="TP-001",
tax_year="2024-25",
submission_id="sub_888",
hmrc_reference=None,
submission_type="live",
success=False,
status_code=400,
error_message="Invalid UTR number",
processing_time_ms=1500,
)
assert data.success is False
assert data.error_message is not None
def test_invalid_submission_type(self) -> None:
"""Test invalid submission type."""
with pytest.raises(ValidationError):
HMRCSubmittedEventData(
taxpayer_id="TP-001",
tax_year="2024-25",
submission_id="sub_777",
hmrc_reference=None,
submission_type="invalid", # Not in allowed values
success=False,
status_code=None,
error_message=None,
processing_time_ms=1000,
)
class TestReviewEvents:
"""Test review event schemas."""
def test_review_requested(self) -> None:
"""Test review requested event."""
data = ReviewRequestedEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
review_type="extraction",
priority="high",
reason="Low confidence extraction (0.65)",
assigned_to="reviewer@example.com",
due_date="2024-12-01T10:00:00Z",
metadata={"extraction_id": "extr_123"},
)
assert data.priority == "high"
assert data.review_type == "extraction"
def test_review_completed(self) -> None:
"""Test review completed event."""
data = ReviewCompletedEventData(
doc_id="01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
review_id="rev_456",
reviewer="reviewer@example.com",
decision="approved",
changes_made=3,
comments="Fixed vendor name and amount",
review_duration_seconds=180,
)
assert data.decision == "approved"
assert data.changes_made == 3
class TestFirmSyncCompletedEventData:
"""Test firm sync completed event schema."""
def test_successful_sync(self) -> None:
"""Test successful firm sync."""
data = FirmSyncCompletedEventData(
firm_id="FIRM-001",
connector_type="xero",
sync_id="sync_123",
records_synced=150,
records_created=50,
records_updated=100,
records_failed=0,
success=True,
error_message=None,
processing_time_ms=10000,
)
assert data.success is True
assert data.records_failed == 0
def test_partial_sync_failure(self) -> None:
"""Test sync with some failures."""
data = FirmSyncCompletedEventData(
firm_id="FIRM-002",
connector_type="sage",
sync_id="sync_456",
records_synced=90,
records_created=30,
records_updated=60,
records_failed=10,
success=True, # Overall success despite some failures
error_message="10 records failed validation",
processing_time_ms=15000,
)
assert data.records_failed == 10
assert data.error_message is not None
class TestSchemaMapping:
"""Test schema mapping and validation utilities."""
def test_all_topics_have_schemas(self) -> None:
"""Test that all topics in EventTopics have corresponding schemas."""
topic_values = {
getattr(EventTopics, attr)
for attr in dir(EventTopics)
if not attr.startswith("_")
}
schema_topics = set(EVENT_SCHEMA_MAP.keys())
# All event topics should have schemas
missing_schemas = topic_values - schema_topics
assert not missing_schemas, f"Missing schemas for topics: {missing_schemas}"
def test_validate_event_data(self) -> None:
"""Test validate_event_data function."""
valid_data = {
"doc_id": "01H8Y9Z5M3K7N2P4Q6R8T0V1W3",
"filename": "test.pdf",
"mime_type": "application/pdf",
"size_bytes": 1024,
"checksum_sha256": "a" * 64,
"kind": "invoice",
"source": "manual_upload",
"storage_path": "path/to/file",
}
result = validate_event_data("doc.ingested", valid_data)
assert isinstance(result, DocumentIngestedEventData)
assert result.doc_id == "01H8Y9Z5M3K7N2P4Q6R8T0V1W3"
def test_validate_unknown_topic(self) -> None:
"""Test validation with unknown topic."""
with pytest.raises(ValueError, match="Unknown event topic"):
validate_event_data("unknown.topic", {})
def test_validate_invalid_data(self) -> None:
"""Test validation with invalid data."""
invalid_data = {
"doc_id": "123",
"filename": "test.pdf",
# Missing required fields
}
with pytest.raises(ValidationError):
validate_event_data("doc.ingested", invalid_data)
def test_get_schema_for_topic(self) -> None:
"""Test get_schema_for_topic function."""
schema = get_schema_for_topic("doc.ingested")
assert schema == DocumentIngestedEventData
def test_get_schema_unknown_topic(self) -> None:
"""Test get_schema_for_topic with unknown topic."""
with pytest.raises(ValueError, match="Unknown event topic"):
get_schema_for_topic("unknown.topic")
def test_schema_prevents_extra_fields(self) -> None:
"""Test that schemas prevent extra fields (extra='forbid')."""
with pytest.raises(ValidationError) as exc_info:
DocumentIngestedEventData(
doc_id="123",
filename="test.pdf",
mime_type="application/pdf",
size_bytes=1024,
checksum_sha256="a" * 64,
kind="invoice",
source="manual_upload",
storage_path="path",
unexpected_field="should_fail", # Extra field
)
assert "Extra inputs are not permitted" in str(exc_info.value)

View File

@@ -1,10 +1,10 @@
"""Tests for NATS event bus implementation."""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nats.js.api import ConsumerConfig
from libs.events.base import EventPayload
from libs.events.nats_bus import NATSEventBus
@@ -41,9 +41,12 @@ class TestNATSEventBus:
assert nats_bus.servers == ["nats://localhost:4222"]
assert nats_bus.stream_name == "TEST_STREAM"
assert nats_bus.consumer_group == "test-group"
assert nats_bus.dlq_stream_name == "TAX_AGENT_DLQ"
assert nats_bus.max_retries == 3
assert not nats_bus.running
assert nats_bus.nc is None
assert nats_bus.js is None
assert nats_bus.dlq is None
@pytest.mark.asyncio
async def test_initialization_with_multiple_servers(self):
@@ -54,14 +57,21 @@ class TestNATSEventBus:
@pytest.mark.asyncio
@patch("libs.events.nats_bus.nats.connect")
async def test_start(self, mock_connect, nats_bus):
@patch("libs.events.nats_bus.DLQHandler")
async def test_start(self, mock_dlq_cls, mock_connect, nats_bus):
"""Test starting the NATS event bus."""
# Mock NATS connection and JetStream
mock_nc = AsyncMock()
mock_js = AsyncMock()
mock_nc.jetstream.return_value = mock_js
# jetstream() is synchronous, so we mock it as a MagicMock or just set return value
mock_nc.jetstream = MagicMock(return_value=mock_js)
mock_connect.return_value = mock_nc
# Mock DLQ handler
mock_dlq_instance = MagicMock()
mock_dlq_instance.ensure_dlq_stream_exists = AsyncMock()
mock_dlq_cls.return_value = mock_dlq_instance
# Mock stream info to simulate existing stream
mock_js.stream_info.return_value = {"name": "TEST_STREAM"}
@@ -70,26 +80,40 @@ class TestNATSEventBus:
assert nats_bus.running
assert nats_bus.nc == mock_nc
assert nats_bus.js == mock_js
assert nats_bus.dlq == mock_dlq_instance
mock_connect.assert_called_once_with(servers=["nats://localhost:4222"])
mock_dlq_instance.ensure_dlq_stream_exists.assert_called_once()
@pytest.mark.asyncio
@patch("libs.events.nats_bus.nats.connect")
async def test_start_creates_stream_if_not_exists(self, mock_connect, nats_bus):
@patch("libs.events.nats_bus.DLQHandler")
async def test_start_creates_stream_if_not_exists(
self, mock_dlq_cls, mock_connect, nats_bus
):
"""Test that start creates stream if it doesn't exist."""
# Mock NATS connection and JetStream
mock_nc = AsyncMock()
mock_js = AsyncMock()
mock_nc.jetstream.return_value = mock_js
mock_nc.jetstream = MagicMock(return_value=mock_js)
mock_connect.return_value = mock_nc
# Mock DLQ handler
mock_dlq_instance = MagicMock()
mock_dlq_instance.ensure_dlq_stream_exists = AsyncMock()
mock_dlq_cls.return_value = mock_dlq_instance
# Mock stream_info to raise NotFoundError, then add_stream
from nats.js.errors import NotFoundError
mock_js.stream_info.side_effect = NotFoundError
mock_js.add_stream = AsyncMock()
await nats_bus.start()
mock_js.add_stream.assert_called_once()
call_args = mock_js.add_stream.call_args
assert call_args[1]["subjects"] == ["TEST_STREAM.>"]
@pytest.mark.asyncio
async def test_start_already_running(self, nats_bus):
@@ -107,17 +131,22 @@ class TestNATSEventBus:
# Setup mock objects
mock_nc = AsyncMock()
mock_subscription = AsyncMock()
mock_task = AsyncMock()
# Create a real task for consumer_tasks
async def dummy_task():
pass
real_task = asyncio.create_task(dummy_task())
nats_bus.running = True
nats_bus.nc = mock_nc
nats_bus.subscriptions = {"test-topic": mock_subscription}
nats_bus.consumer_tasks = [mock_task]
nats_bus.consumer_tasks = [real_task]
await nats_bus.stop()
assert not nats_bus.running
mock_task.cancel.assert_called_once()
assert real_task.cancelled() or real_task.done()
mock_subscription.unsubscribe.assert_called_once()
mock_nc.close.assert_called_once()
@@ -129,7 +158,8 @@ class TestNATSEventBus:
assert not nats_bus.running
@pytest.mark.asyncio
async def test_publish(self, nats_bus, event_payload):
@patch("libs.events.nats_bus.EventMetricsCollector")
async def test_publish(self, mock_metrics, nats_bus, event_payload):
"""Test publishing an event."""
# Setup mock JetStream
mock_js = AsyncMock()
@@ -146,6 +176,10 @@ class TestNATSEventBus:
assert call_args[1]["subject"] == "TEST_STREAM.test-topic"
assert call_args[1]["payload"] == event_payload.to_json().encode()
# Verify metrics recorded
mock_metrics.record_publish.assert_called_once()
assert mock_metrics.record_publish.call_args[1]["success"] is True
@pytest.mark.asyncio
async def test_publish_not_started(self, nats_bus, event_payload):
"""Test publishing when event bus is not started."""
@@ -153,7 +187,8 @@ class TestNATSEventBus:
await nats_bus.publish("test-topic", event_payload)
@pytest.mark.asyncio
async def test_publish_failure(self, nats_bus, event_payload):
@patch("libs.events.nats_bus.EventMetricsCollector")
async def test_publish_failure(self, mock_metrics, nats_bus, event_payload):
"""Test publishing failure."""
# Setup mock JetStream that raises exception
mock_js = AsyncMock()
@@ -164,6 +199,10 @@ class TestNATSEventBus:
assert result is False
# Verify metrics recorded failure
mock_metrics.record_publish.assert_called_once()
assert mock_metrics.record_publish.call_args[1]["success"] is False
@pytest.mark.asyncio
async def test_subscribe(self, nats_bus):
"""Test subscribing to a topic."""
@@ -184,11 +223,19 @@ class TestNATSEventBus:
assert test_handler in nats_bus.handlers["test-topic"]
assert "test-topic" in nats_bus.subscriptions
mock_js.pull_subscribe.assert_called_once()
# Verify ConsumerConfig
call_kwargs = mock_js.pull_subscribe.call_args[1]
config = call_kwargs["config"]
assert isinstance(config, ConsumerConfig)
assert config.max_deliver == 5 # 3 retries + 2 buffer
mock_create_task.assert_called_once()
@pytest.mark.asyncio
async def test_subscribe_not_started(self, nats_bus):
"""Test subscribing when event bus is not started."""
async def test_handler(topic: str, payload: EventPayload) -> None:
pass
@@ -220,7 +267,8 @@ class TestNATSEventBus:
assert handler2 in nats_bus.handlers["test-topic"]
@pytest.mark.asyncio
async def test_consume_messages(self, nats_bus, event_payload):
@patch("libs.events.nats_bus.EventMetricsCollector")
async def test_consume_messages(self, mock_metrics, nats_bus, event_payload):
"""Test consuming messages from NATS."""
# Setup mock subscription and message
mock_subscription = AsyncMock()
@@ -253,6 +301,10 @@ class TestNATSEventBus:
assert received_payload.event_id == event_payload.event_id
mock_message.ack.assert_called_once()
# Verify metrics
mock_metrics.record_consume.assert_called_once()
assert mock_metrics.record_consume.call_args[1]["success"] is True
@pytest.mark.asyncio
async def test_factory_integration(self):
"""Test that the factory can create a NATS event bus."""