Initial commit
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
This commit is contained in:
346
tests/unit/coverage/test_policy_load_and_merge.py
Normal file
346
tests/unit/coverage/test_policy_load_and_merge.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Unit tests for policy loading and merging functionality."""
|
||||
|
||||
# FILE: tests/unit/coverage/test_policy_load_and_merge.py
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from libs.policy import PolicyLoader
|
||||
from libs.schemas import CoveragePolicy, PolicyError
|
||||
|
||||
# pylint: disable=wrong-import-position,import-error,too-few-public-methods,global-statement
|
||||
# pylint: disable=raise-missing-from,unused-argument,too-many-arguments,too-many-positional-arguments
|
||||
# pylint: disable=too-many-locals,import-outside-toplevel
|
||||
# mypy: disable-error-code=union-attr
|
||||
# mypy: disable-error-code=no-untyped-def
|
||||
|
||||
|
||||
class TestPolicyLoader:
|
||||
"""Test policy loading and merging functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_config_dir(self):
|
||||
"""Create temporary config directory with test files"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
config_dir = Path(temp_dir)
|
||||
|
||||
# Create baseline policy
|
||||
baseline_policy = {
|
||||
"version": "1.0",
|
||||
"jurisdiction": "UK",
|
||||
"tax_year": "2024-25",
|
||||
"tax_year_boundary": {"start": "2024-04-06", "end": "2025-04-05"},
|
||||
"defaults": {"confidence_thresholds": {"ocr": 0.82, "extract": 0.85}},
|
||||
"document_kinds": ["P60", "P11D"],
|
||||
"triggers": {
|
||||
"SA102": {"any_of": ["exists(IncomeItem[type='Employment'])"]}
|
||||
},
|
||||
"schedules": {
|
||||
"SA102": {
|
||||
"evidence": [
|
||||
{"id": "P60", "role": "REQUIRED", "boxes": ["SA102_b1"]}
|
||||
]
|
||||
}
|
||||
},
|
||||
"status_classifier": {
|
||||
"present_verified": {"min_ocr": 0.82},
|
||||
"present_unverified": {"min_ocr": 0.60},
|
||||
"conflicting": {"conflict_rules": []},
|
||||
"missing": {"default": True},
|
||||
},
|
||||
"conflict_resolution": {"precedence": ["P60"]},
|
||||
"question_templates": {
|
||||
"default": {
|
||||
"text": "Need {evidence}",
|
||||
"why": "Required for {schedule}",
|
||||
}
|
||||
},
|
||||
"privacy": {"vector_pii_free": True, "redact_patterns": []},
|
||||
}
|
||||
|
||||
with open(config_dir / "coverage.yaml", "w") as f:
|
||||
yaml.dump(baseline_policy, f)
|
||||
|
||||
# Create jurisdiction overlay
|
||||
jurisdiction_overlay = {
|
||||
"defaults": {
|
||||
"confidence_thresholds": {"ocr": 0.85} # Override threshold
|
||||
},
|
||||
"document_kinds": ["P60", "P11D", "P45"], # Add P45
|
||||
}
|
||||
|
||||
with open(config_dir / "coverage.UK.2024-25.yaml", "w") as f:
|
||||
yaml.dump(jurisdiction_overlay, f)
|
||||
|
||||
# Create tenant overlay
|
||||
(config_dir / "overrides").mkdir()
|
||||
tenant_overlay = {
|
||||
"defaults": {"date_tolerance_days": 60} # Override tolerance
|
||||
}
|
||||
|
||||
with open(config_dir / "overrides" / "tenant123.yaml", "w") as f:
|
||||
yaml.dump(tenant_overlay, f)
|
||||
|
||||
yield config_dir
|
||||
|
||||
@pytest.fixture
|
||||
def policy_loader(self, temp_config_dir):
|
||||
"""Create policy loader with temp config"""
|
||||
return PolicyLoader(str(temp_config_dir))
|
||||
|
||||
def test_load_baseline_policy(self, policy_loader, temp_config_dir):
|
||||
"""Test loading baseline policy without overlays"""
|
||||
policy = policy_loader.load_policy(
|
||||
baseline_path=str(temp_config_dir / "coverage.yaml"),
|
||||
jurisdiction="US", # No overlay exists
|
||||
tax_year="2023-24", # No overlay exists
|
||||
tenant_id=None,
|
||||
)
|
||||
|
||||
assert isinstance(policy, CoveragePolicy)
|
||||
assert policy.version == "1.0"
|
||||
assert policy.jurisdiction == "UK"
|
||||
assert policy.defaults.confidence_thresholds["ocr"] == 0.82
|
||||
assert len(policy.document_kinds) == 2
|
||||
|
||||
def test_load_policy_with_jurisdiction_overlay(self, policy_loader):
|
||||
"""Test loading policy with jurisdiction overlay applied"""
|
||||
policy = policy_loader.load_policy(jurisdiction="UK", tax_year="2024-25")
|
||||
|
||||
# Should have jurisdiction overlay applied
|
||||
assert policy.defaults.confidence_thresholds["ocr"] == 0.85 # Overridden
|
||||
assert len(policy.document_kinds) == 3 # P45 added
|
||||
assert "P45" in policy.document_kinds
|
||||
|
||||
def test_load_policy_with_tenant_overlay(self, policy_loader):
|
||||
"""Test loading policy with tenant overlay applied"""
|
||||
policy = policy_loader.load_policy(
|
||||
jurisdiction="UK", tax_year="2024-25", tenant_id="tenant123"
|
||||
)
|
||||
|
||||
# Should have both jurisdiction and tenant overlays
|
||||
assert policy.defaults.confidence_thresholds["ocr"] == 0.85 # From jurisdiction
|
||||
assert policy.defaults.date_tolerance_days == 60 # From tenant
|
||||
assert len(policy.document_kinds) == 3 # From jurisdiction
|
||||
|
||||
def test_merge_overlays(self, policy_loader):
|
||||
"""Test overlay merging logic"""
|
||||
base = {"a": 1, "b": {"x": 10, "y": 20}, "c": [1, 2, 3]}
|
||||
|
||||
overlay1 = {
|
||||
"b": {"x": 15, "z": 30}, # Merge into b, override x, add z
|
||||
"d": 4, # Add new key
|
||||
}
|
||||
|
||||
overlay2 = {
|
||||
"b": {"y": 25}, # Override y in b
|
||||
"c": [4, 5, 6], # Replace entire list
|
||||
}
|
||||
|
||||
result = policy_loader.merge_overlays(base, overlay1, overlay2)
|
||||
|
||||
assert result["a"] == 1
|
||||
assert result["b"]["x"] == 15 # From overlay1
|
||||
assert result["b"]["y"] == 25 # From overlay2
|
||||
assert result["b"]["z"] == 30 # From overlay1
|
||||
assert result["c"] == [4, 5, 6] # From overlay2
|
||||
assert result["d"] == 4 # From overlay1
|
||||
|
||||
def test_compile_predicates(self, policy_loader):
|
||||
"""Test predicate compilation"""
|
||||
policy = policy_loader.load_policy()
|
||||
compiled = policy_loader.compile_predicates(policy)
|
||||
|
||||
assert compiled.policy == policy
|
||||
assert len(compiled.compiled_predicates) > 0
|
||||
assert "exists(IncomeItem[type='Employment'])" in compiled.compiled_predicates
|
||||
assert compiled.hash is not None
|
||||
assert len(compiled.source_files) > 0
|
||||
|
||||
def test_predicate_execution(self, policy_loader):
|
||||
"""Test that compiled predicates are callable"""
|
||||
policy = policy_loader.load_policy()
|
||||
compiled = policy_loader.compile_predicates(policy)
|
||||
|
||||
predicate = compiled.compiled_predicates[
|
||||
"exists(IncomeItem[type='Employment'])"
|
||||
]
|
||||
|
||||
# Should be callable and return boolean
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_invalid_yaml_file(self, temp_config_dir):
|
||||
"""Test handling of invalid YAML file"""
|
||||
# Create invalid YAML
|
||||
with open(temp_config_dir / "invalid.yaml", "w") as f:
|
||||
f.write("invalid: yaml: content: [")
|
||||
|
||||
loader = PolicyLoader(str(temp_config_dir))
|
||||
|
||||
with pytest.raises(PolicyError, match="Invalid YAML"):
|
||||
loader._load_yaml_file(str(temp_config_dir / "invalid.yaml"))
|
||||
|
||||
def test_missing_file(self, temp_config_dir):
|
||||
"""Test handling of missing file"""
|
||||
loader = PolicyLoader(str(temp_config_dir))
|
||||
|
||||
with pytest.raises(PolicyError, match="Policy file not found"):
|
||||
loader._load_yaml_file(str(temp_config_dir / "missing.yaml"))
|
||||
|
||||
def test_schema_validation_success(self, policy_loader, temp_config_dir):
|
||||
"""Test successful schema validation"""
|
||||
policy_dict = policy_loader._load_yaml_file(
|
||||
str(temp_config_dir / "coverage.yaml")
|
||||
)
|
||||
|
||||
# Should not raise exception
|
||||
policy_loader._validate_policy(policy_dict)
|
||||
|
||||
def test_schema_validation_failure(self, policy_loader):
|
||||
"""Test schema validation failure"""
|
||||
invalid_policy = {
|
||||
"version": "1.0",
|
||||
# Missing required fields
|
||||
}
|
||||
|
||||
with pytest.raises(Exception): # ValidationError from jsonschema
|
||||
policy_loader._validate_policy(invalid_policy)
|
||||
|
||||
def test_business_rules_validation(self, policy_loader, temp_config_dir):
|
||||
"""Test business rules validation"""
|
||||
policy_dict = policy_loader._load_yaml_file(
|
||||
str(temp_config_dir / "coverage.yaml")
|
||||
)
|
||||
|
||||
result = policy_loader.validate_policy(policy_dict)
|
||||
assert result.ok is True
|
||||
assert len(result.errors) == 0
|
||||
|
||||
def test_business_rules_validation_failure(self, policy_loader):
|
||||
"""Test business rules validation with errors"""
|
||||
invalid_policy = {
|
||||
"version": "1.0",
|
||||
"jurisdiction": "UK",
|
||||
"tax_year": "2024-25",
|
||||
"tax_year_boundary": {"start": "2024-04-06", "end": "2025-04-05"},
|
||||
"defaults": {"confidence_thresholds": {"ocr": 0.82}},
|
||||
"document_kinds": ["P60"],
|
||||
"triggers": {"SA102": {"any_of": ["test"]}},
|
||||
"schedules": {
|
||||
"SA102": {
|
||||
"evidence": [
|
||||
{
|
||||
"id": "P11D", # Not in document_kinds
|
||||
"role": "REQUIRED",
|
||||
"boxes": ["SA102_b1"],
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"status_classifier": {
|
||||
"present_verified": {"min_ocr": 0.82},
|
||||
"present_unverified": {"min_ocr": 0.60},
|
||||
"conflicting": {"conflict_rules": []},
|
||||
"missing": {"default": True},
|
||||
},
|
||||
"conflict_resolution": {"precedence": ["P60"]},
|
||||
"question_templates": {"default": {"text": "test", "why": "test"}},
|
||||
}
|
||||
|
||||
result = policy_loader.validate_policy(invalid_policy)
|
||||
assert result.ok is False
|
||||
assert len(result.errors) > 0
|
||||
assert any("P11D" in error for error in result.errors)
|
||||
|
||||
def test_apply_feature_flags_placeholder(self, policy_loader):
|
||||
"""Test feature flags application (placeholder)"""
|
||||
policy_dict = {"test": "value"}
|
||||
result = policy_loader.apply_feature_flags(policy_dict)
|
||||
|
||||
# Currently just returns unchanged
|
||||
assert result == policy_dict
|
||||
|
||||
@patch("libs.policy.utils.get_policy_loader")
|
||||
def test_convenience_functions(self, mock_get_loader, policy_loader):
|
||||
"""Test convenience functions"""
|
||||
# Create a valid mock policy for testing
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from libs.schemas import (
|
||||
ConflictRules,
|
||||
CoveragePolicy,
|
||||
Defaults,
|
||||
Privacy,
|
||||
QuestionTemplates,
|
||||
StatusClassifier,
|
||||
StatusClassifierConfig,
|
||||
TaxYearBoundary,
|
||||
)
|
||||
|
||||
mock_policy = CoveragePolicy(
|
||||
version="1.0",
|
||||
jurisdiction="UK",
|
||||
tax_year="2024-25",
|
||||
tax_year_boundary=TaxYearBoundary(start="2024-04-06", end="2025-04-05"),
|
||||
defaults=Defaults(
|
||||
confidence_thresholds={"ocr": 0.82, "extract": 0.85},
|
||||
date_tolerance_days=30,
|
||||
),
|
||||
document_kinds=["P60"],
|
||||
status_classifier=StatusClassifierConfig(
|
||||
present_verified=StatusClassifier(min_ocr=0.82, min_extract=0.85),
|
||||
present_unverified=StatusClassifier(min_ocr=0.60, min_extract=0.70),
|
||||
conflicting=StatusClassifier(),
|
||||
missing=StatusClassifier(),
|
||||
),
|
||||
triggers={},
|
||||
conflict_resolution=ConflictRules(precedence=["P60"]),
|
||||
question_templates=QuestionTemplates(
|
||||
default={"text": "test", "why": "test"}
|
||||
),
|
||||
privacy=Privacy(vector_pii_free=True, redact_patterns=[]),
|
||||
)
|
||||
|
||||
# Mock the policy loader to return our test policy
|
||||
from datetime import datetime
|
||||
|
||||
from libs.schemas import CompiledCoveragePolicy
|
||||
|
||||
mock_compiled_policy = CompiledCoveragePolicy(
|
||||
policy=mock_policy,
|
||||
compiled_predicates={},
|
||||
compiled_at=datetime.now(),
|
||||
hash="test-hash",
|
||||
source_files=["test.yaml"],
|
||||
)
|
||||
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load_policy.return_value = mock_policy
|
||||
mock_loader.merge_overlays.side_effect = lambda base, *overlays: {
|
||||
**base,
|
||||
**{k: v for overlay in overlays for k, v in overlay.items()},
|
||||
}
|
||||
mock_loader.compile_predicates.return_value = mock_compiled_policy
|
||||
mock_get_loader.return_value = mock_loader
|
||||
|
||||
from libs.policy import compile_predicates, load_policy, merge_overlays
|
||||
|
||||
# Test load_policy - use the mock policy directly since we're testing the convenience function
|
||||
policy = load_policy()
|
||||
assert isinstance(policy, CoveragePolicy)
|
||||
assert policy.version == "1.0"
|
||||
|
||||
# Test merge_overlays
|
||||
base = {"a": 1}
|
||||
overlay = {"b": 2}
|
||||
merged = merge_overlays(base, overlay)
|
||||
assert merged == {"a": 1, "b": 2}
|
||||
|
||||
# Test compile_predicates
|
||||
compiled = compile_predicates(policy)
|
||||
assert compiled.policy == policy
|
||||
270
tests/unit/coverage/test_predicate_compilation.py
Normal file
270
tests/unit/coverage/test_predicate_compilation.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Unit tests for predicate compilation and DSL parsing."""
|
||||
|
||||
# FILE: tests/unit/coverage/test_predicate_compilation.py
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.policy import PolicyLoader
|
||||
|
||||
# pylint: disable=wrong-import-position,import-error,too-few-public-methods,global-statement
|
||||
# pylint: disable=raise-missing-from,unused-argument,too-many-arguments,too-many-positional-arguments
|
||||
# pylint: disable=too-many-locals,import-outside-toplevel
|
||||
# mypy: disable-error-code=union-attr
|
||||
# mypy: disable-error-code=no-untyped-def
|
||||
|
||||
|
||||
class TestPredicateCompilation:
|
||||
"""Test predicate compilation and DSL parsing"""
|
||||
|
||||
@pytest.fixture
|
||||
def policy_loader(self):
|
||||
"""Create policy loader for testing"""
|
||||
return PolicyLoader()
|
||||
|
||||
def test_compile_exists_condition(self, policy_loader):
|
||||
"""Test compilation of exists() conditions"""
|
||||
condition = "exists(IncomeItem[type='Employment'])"
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_compile_exists_condition_with_filters(self, policy_loader):
|
||||
"""Test exists() with complex filters"""
|
||||
condition = "exists(IncomeItem[type='SelfEmployment' AND turnover_lt_vat_threshold=true])"
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_compile_property_conditions(self, policy_loader):
|
||||
"""Test compilation of property conditions"""
|
||||
conditions = [
|
||||
"property_joint_ownership",
|
||||
"candidate_FHL",
|
||||
"claims_FTCR",
|
||||
"claims_remittance_basis",
|
||||
"received_estate_income",
|
||||
]
|
||||
|
||||
for condition in conditions:
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_compile_computed_conditions(self, policy_loader):
|
||||
"""Test compilation of computed conditions"""
|
||||
conditions = [
|
||||
"turnover_lt_vat_threshold",
|
||||
"turnover_ge_vat_threshold",
|
||||
]
|
||||
|
||||
for condition in conditions:
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_compile_taxpayer_flag_conditions(self, policy_loader):
|
||||
"""Test compilation of taxpayer flag conditions"""
|
||||
condition = "taxpayer_flag:has_employment"
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_compile_filing_mode_conditions(self, policy_loader):
|
||||
"""Test compilation of filing mode conditions"""
|
||||
condition = "filing_mode:paper"
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_compile_unknown_condition(self, policy_loader):
|
||||
"""Test compilation of unknown condition defaults to False"""
|
||||
condition = "unknown_condition_type"
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert result is False # Unknown conditions default to False
|
||||
|
||||
def test_exists_predicate_creation(self, policy_loader):
|
||||
"""Test exists predicate creation with different entity types"""
|
||||
entity_types = [
|
||||
"IncomeItem",
|
||||
"ExpenseItem",
|
||||
"PropertyAsset",
|
||||
"TrustDistribution",
|
||||
]
|
||||
|
||||
for entity_type in entity_types:
|
||||
predicate = policy_loader._create_exists_predicate(
|
||||
entity_type, "type='test'"
|
||||
)
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_property_predicate_creation(self, policy_loader):
|
||||
"""Test property predicate creation"""
|
||||
properties = [
|
||||
"property_joint_ownership",
|
||||
"candidate_FHL",
|
||||
"claims_FTCR",
|
||||
]
|
||||
|
||||
for prop in properties:
|
||||
predicate = policy_loader._create_property_predicate(prop)
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_computed_predicate_creation(self, policy_loader):
|
||||
"""Test computed predicate creation"""
|
||||
computations = [
|
||||
"turnover_lt_vat_threshold",
|
||||
"turnover_ge_vat_threshold",
|
||||
]
|
||||
|
||||
for comp in computations:
|
||||
predicate = policy_loader._create_computed_predicate(comp)
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_flag_predicate_creation(self, policy_loader):
|
||||
"""Test flag predicate creation"""
|
||||
flags = [
|
||||
"has_employment",
|
||||
"is_self_employed_short",
|
||||
"has_property_income",
|
||||
"has_foreign_income",
|
||||
]
|
||||
|
||||
for flag in flags:
|
||||
predicate = policy_loader._create_flag_predicate(flag)
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_filing_mode_predicate_creation(self, policy_loader):
|
||||
"""Test filing mode predicate creation"""
|
||||
modes = ["paper", "online", "agent"]
|
||||
|
||||
for mode in modes:
|
||||
predicate = policy_loader._create_filing_mode_predicate(mode)
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_exists_condition_regex_parsing(self, policy_loader):
|
||||
"""Test regex parsing of exists conditions"""
|
||||
test_cases = [
|
||||
(
|
||||
"exists(IncomeItem[type='Employment'])",
|
||||
"IncomeItem",
|
||||
"type='Employment'",
|
||||
),
|
||||
(
|
||||
"exists(ExpenseItem[category='FinanceCosts'])",
|
||||
"ExpenseItem",
|
||||
"category='FinanceCosts'",
|
||||
),
|
||||
(
|
||||
"exists(PropertyAsset[joint_ownership=true])",
|
||||
"PropertyAsset",
|
||||
"joint_ownership=true",
|
||||
),
|
||||
]
|
||||
|
||||
for condition, expected_entity, expected_filters in test_cases:
|
||||
# Test that the regex matches correctly
|
||||
import re
|
||||
|
||||
exists_match = re.match(r"exists\((\w+)\[([^\]]+)\]\)", condition)
|
||||
assert exists_match is not None
|
||||
assert exists_match.group(1) == expected_entity
|
||||
assert exists_match.group(2) == expected_filters
|
||||
|
||||
def test_condition_whitespace_handling(self, policy_loader):
|
||||
"""Test that conditions handle whitespace correctly"""
|
||||
conditions_with_whitespace = [
|
||||
" exists(IncomeItem[type='Employment']) ",
|
||||
"\tproperty_joint_ownership\t",
|
||||
"\n taxpayer_flag:has_employment \n",
|
||||
]
|
||||
|
||||
for condition in conditions_with_whitespace:
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_complex_exists_filters(self, policy_loader):
|
||||
"""Test exists conditions with complex filter expressions"""
|
||||
complex_conditions = [
|
||||
"exists(IncomeItem[type='SelfEmployment' AND turnover_lt_vat_threshold=true])",
|
||||
"exists(ExpenseItem[category='CapitalAllowances'])",
|
||||
"exists(IncomeItem[type IN ['ForeignInterest','ForeignDividends']])",
|
||||
]
|
||||
|
||||
for condition in complex_conditions:
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
assert callable(predicate)
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_predicate_consistency(self, policy_loader):
|
||||
"""Test that predicates return consistent results for same inputs"""
|
||||
condition = "exists(IncomeItem[type='Employment'])"
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
|
||||
# Call multiple times with same inputs
|
||||
result1 = predicate("T-001", "2024-25")
|
||||
result2 = predicate("T-001", "2024-25")
|
||||
result3 = predicate("T-001", "2024-25")
|
||||
|
||||
# Should be consistent
|
||||
assert result1 == result2 == result3
|
||||
|
||||
def test_predicate_different_inputs(self, policy_loader):
|
||||
"""Test predicates with different input combinations"""
|
||||
condition = "exists(IncomeItem[type='Employment'])"
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
|
||||
# Test with different taxpayer IDs and tax years
|
||||
test_inputs = [
|
||||
("T-001", "2024-25"),
|
||||
("T-002", "2024-25"),
|
||||
("T-001", "2023-24"),
|
||||
("T-999", "2025-26"),
|
||||
]
|
||||
|
||||
for taxpayer_id, tax_year in test_inputs:
|
||||
result = predicate(taxpayer_id, tax_year)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_edge_case_conditions(self, policy_loader):
|
||||
"""Test edge cases in condition parsing"""
|
||||
edge_cases = [
|
||||
"", # Empty string
|
||||
" ", # Whitespace only
|
||||
"exists()", # Empty exists
|
||||
"exists(Entity[])", # Empty filter
|
||||
"taxpayer_flag:", # Empty flag
|
||||
"filing_mode:", # Empty mode
|
||||
]
|
||||
|
||||
for condition in edge_cases:
|
||||
predicate = policy_loader._compile_condition(condition)
|
||||
assert callable(predicate)
|
||||
# Should default to False for malformed conditions
|
||||
result = predicate("T-001", "2024-25")
|
||||
assert result is False
|
||||
272
tests/unit/coverage/test_question_templates.py
Normal file
272
tests/unit/coverage/test_question_templates.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Unit tests for question template generation."""
|
||||
|
||||
# FILE: tests/unit/coverage/test_question_templates.py
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.schemas import Citation, ClarifyContext, CoverageGap, Role, UploadOption
|
||||
|
||||
# pylint: disable=wrong-import-position,import-error,too-few-public-methods,global-statement
|
||||
# pylint: disable=raise-missing-from,unused-argument,too-many-arguments,too-many-positional-arguments
|
||||
# pylint: disable=too-many-locals,import-outside-toplevel
|
||||
# mypy: disable-error-code=union-attr
|
||||
# mypy: disable-error-code=no-untyped-def
|
||||
|
||||
|
||||
class TestQuestionTemplates:
|
||||
"""Test question template generation and formatting"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_gap(self):
|
||||
"""Create sample coverage gap for testing"""
|
||||
return CoverageGap(
|
||||
schedule_id="SA102",
|
||||
evidence_id="P60",
|
||||
role=Role.REQUIRED,
|
||||
reason="P60 provides year-end pay and PAYE tax figures",
|
||||
boxes=["SA102_b1", "SA102_b2"],
|
||||
citations=[
|
||||
Citation(
|
||||
rule_id="UK.SA102.P60.Required",
|
||||
doc_id="SA102-Notes-2025",
|
||||
locator="p.3 §1.1",
|
||||
)
|
||||
],
|
||||
acceptable_alternatives=["P45", "FinalPayslipYTD"],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_context(self):
|
||||
"""Create sample clarify context for testing"""
|
||||
return ClarifyContext(
|
||||
tax_year="2024-25",
|
||||
taxpayer_id="T-001",
|
||||
jurisdiction="UK",
|
||||
)
|
||||
|
||||
def test_question_text_formatting(self, sample_gap, sample_context):
|
||||
"""Test basic question text formatting"""
|
||||
# Mock the _generate_clarifying_question function behavior
|
||||
evidence_name = sample_gap.evidence_id
|
||||
schedule_name = sample_gap.schedule_id
|
||||
boxes_text = ", ".join(sample_gap.boxes)
|
||||
alternatives_text = ", ".join(sample_gap.acceptable_alternatives)
|
||||
|
||||
# Template format
|
||||
template_text = "To complete the {schedule} for {tax_year}, we need {evidence}. These documents support boxes {boxes}. If you don't have this, you can provide {alternatives}."
|
||||
|
||||
question_text = template_text.format(
|
||||
schedule=schedule_name,
|
||||
tax_year=sample_context.tax_year,
|
||||
evidence=evidence_name,
|
||||
boxes=boxes_text,
|
||||
alternatives=alternatives_text,
|
||||
)
|
||||
|
||||
expected = "To complete the SA102 for 2024-25, we need P60. These documents support boxes SA102_b1, SA102_b2. If you don't have this, you can provide P45, FinalPayslipYTD."
|
||||
assert question_text == expected
|
||||
|
||||
def test_why_text_formatting(self, sample_gap):
|
||||
"""Test why explanation formatting"""
|
||||
template_why = "{why}. See guidance: {guidance_doc}."
|
||||
|
||||
why_text = template_why.format(
|
||||
why=sample_gap.reason,
|
||||
guidance_doc="policy guidance",
|
||||
)
|
||||
|
||||
expected = "P60 provides year-end pay and PAYE tax figures. See guidance: policy guidance."
|
||||
assert why_text == expected
|
||||
|
||||
def test_upload_options_generation(self, sample_gap):
|
||||
"""Test upload options generation"""
|
||||
options = []
|
||||
|
||||
# Generate options for alternatives
|
||||
for alt in sample_gap.acceptable_alternatives:
|
||||
options.append(
|
||||
UploadOption(
|
||||
label=f"Upload {alt} (PDF/CSV)",
|
||||
accepted_formats=["pdf", "csv"],
|
||||
upload_endpoint=f"/v1/ingest/upload?tag={alt}",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(options) == 2
|
||||
assert options[0].label == "Upload P45 (PDF/CSV)"
|
||||
assert options[0].accepted_formats == ["pdf", "csv"]
|
||||
assert options[0].upload_endpoint == "/v1/ingest/upload?tag=P45"
|
||||
assert options[1].label == "Upload FinalPayslipYTD (PDF/CSV)"
|
||||
assert options[1].upload_endpoint == "/v1/ingest/upload?tag=FinalPayslipYTD"
|
||||
|
||||
def test_upload_options_no_alternatives(self):
|
||||
"""Test upload options when no alternatives available"""
|
||||
gap_no_alternatives = CoverageGap(
|
||||
schedule_id="SA102",
|
||||
evidence_id="P60",
|
||||
role=Role.REQUIRED,
|
||||
reason="Required document",
|
||||
boxes=["SA102_b1"],
|
||||
acceptable_alternatives=[],
|
||||
)
|
||||
|
||||
options = []
|
||||
|
||||
# When no alternatives, create option for main evidence
|
||||
if not gap_no_alternatives.acceptable_alternatives:
|
||||
options.append(
|
||||
UploadOption(
|
||||
label=f"Upload {gap_no_alternatives.evidence_id} (PDF/CSV)",
|
||||
accepted_formats=["pdf", "csv"],
|
||||
upload_endpoint=f"/v1/ingest/upload?tag={gap_no_alternatives.evidence_id}",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(options) == 1
|
||||
assert options[0].label == "Upload P60 (PDF/CSV)"
|
||||
assert options[0].upload_endpoint == "/v1/ingest/upload?tag=P60"
|
||||
|
||||
def test_blocking_determination(self, sample_gap):
|
||||
"""Test blocking status determination"""
|
||||
# Required evidence should be blocking
|
||||
assert sample_gap.role == Role.REQUIRED
|
||||
blocking = sample_gap.role.value == "REQUIRED"
|
||||
assert blocking is True
|
||||
|
||||
# Optional evidence should not be blocking
|
||||
optional_gap = CoverageGap(
|
||||
schedule_id="SA102",
|
||||
evidence_id="PayslipMonthly",
|
||||
role=Role.OPTIONAL,
|
||||
reason="Optional supporting document",
|
||||
boxes=["SA102_b3"],
|
||||
)
|
||||
|
||||
blocking_optional = optional_gap.role.value == "REQUIRED"
|
||||
assert blocking_optional is False
|
||||
|
||||
def test_boxes_affected_formatting(self, sample_gap):
|
||||
"""Test boxes affected list formatting"""
|
||||
boxes_affected = sample_gap.boxes
|
||||
assert boxes_affected == ["SA102_b1", "SA102_b2"]
|
||||
|
||||
# Test empty boxes
|
||||
gap_no_boxes = CoverageGap(
|
||||
schedule_id="SA102",
|
||||
evidence_id="EmploymentContract",
|
||||
role=Role.OPTIONAL,
|
||||
reason="Used for disambiguation",
|
||||
boxes=[],
|
||||
)
|
||||
|
||||
assert gap_no_boxes.boxes == []
|
||||
|
||||
def test_citations_preservation(self, sample_gap):
|
||||
"""Test that citations are preserved in response"""
|
||||
citations = sample_gap.citations
|
||||
assert len(citations) == 1
|
||||
assert citations[0].rule_id == "UK.SA102.P60.Required"
|
||||
assert citations[0].doc_id == "SA102-Notes-2025"
|
||||
assert citations[0].locator == "p.3 §1.1"
|
||||
|
||||
def test_multiple_alternatives_formatting(self):
|
||||
"""Test formatting with multiple alternatives"""
|
||||
gap_many_alternatives = CoverageGap(
|
||||
schedule_id="SA105",
|
||||
evidence_id="LettingAgentStatements",
|
||||
role=Role.REQUIRED,
|
||||
reason="Evidence of rental income",
|
||||
boxes=["SA105_b5", "SA105_b20"],
|
||||
acceptable_alternatives=[
|
||||
"TenancyLedger",
|
||||
"BankStatements",
|
||||
"RentalAgreements",
|
||||
],
|
||||
)
|
||||
|
||||
alternatives_text = ", ".join(gap_many_alternatives.acceptable_alternatives)
|
||||
expected = "TenancyLedger, BankStatements, RentalAgreements"
|
||||
assert alternatives_text == expected
|
||||
|
||||
def test_empty_boxes_formatting(self):
|
||||
"""Test formatting when no boxes specified"""
|
||||
gap_no_boxes = CoverageGap(
|
||||
schedule_id="SA102",
|
||||
evidence_id="EmploymentContract",
|
||||
role=Role.OPTIONAL,
|
||||
reason="Used for disambiguation",
|
||||
boxes=[],
|
||||
)
|
||||
|
||||
boxes_text = (
|
||||
", ".join(gap_no_boxes.boxes) if gap_no_boxes.boxes else "relevant boxes"
|
||||
)
|
||||
assert boxes_text == "relevant boxes"
|
||||
|
||||
def test_special_characters_in_evidence_names(self):
|
||||
"""Test handling of special characters in evidence names"""
|
||||
gap_special_chars = CoverageGap(
|
||||
schedule_id="SA106",
|
||||
evidence_id="EEA_FHL",
|
||||
role=Role.CONDITIONALLY_REQUIRED,
|
||||
reason="European Economic Area Furnished Holiday Lettings",
|
||||
boxes=["SA106_b14"],
|
||||
)
|
||||
|
||||
# Should handle underscores and other characters
|
||||
assert gap_special_chars.evidence_id == "EEA_FHL"
|
||||
|
||||
# Upload endpoint should handle special characters
|
||||
upload_endpoint = f"/v1/ingest/upload?tag={gap_special_chars.evidence_id}"
|
||||
assert upload_endpoint == "/v1/ingest/upload?tag=EEA_FHL"
|
||||
|
||||
def test_long_reason_text(self):
|
||||
"""Test handling of long reason text"""
|
||||
long_reason = "This is a very long reason that explains in great detail why this particular piece of evidence is absolutely essential for completing the tax return accurately and in compliance with HMRC requirements and regulations."
|
||||
|
||||
gap_long_reason = CoverageGap(
|
||||
schedule_id="SA108",
|
||||
evidence_id="CGT_BrokerAnnualReport",
|
||||
role=Role.REQUIRED,
|
||||
reason=long_reason,
|
||||
boxes=["SA108_b4", "SA108_b5"],
|
||||
)
|
||||
|
||||
# Should preserve full reason text
|
||||
assert gap_long_reason.reason == long_reason
|
||||
assert len(gap_long_reason.reason) > 100
|
||||
|
||||
def test_multiple_upload_formats(self):
|
||||
"""Test generation of upload options with different formats"""
|
||||
evidence_id = "AccountsPAndL"
|
||||
|
||||
# Different evidence types might accept different formats
|
||||
formats_map = {
|
||||
"AccountsPAndL": ["pdf", "xlsx", "csv"],
|
||||
"BankStatements": ["pdf", "csv", "ofx"],
|
||||
"P60": ["pdf", "jpg", "png"],
|
||||
}
|
||||
|
||||
for evidence, formats in formats_map.items():
|
||||
option = UploadOption(
|
||||
label=f"Upload {evidence}",
|
||||
accepted_formats=formats,
|
||||
upload_endpoint=f"/v1/ingest/upload?tag={evidence}",
|
||||
)
|
||||
|
||||
assert option.accepted_formats == formats
|
||||
assert evidence in option.upload_endpoint
|
||||
|
||||
def test_context_variations(self):
|
||||
"""Test question generation with different contexts"""
|
||||
contexts = [
|
||||
ClarifyContext(tax_year="2024-25", taxpayer_id="T-001", jurisdiction="UK"),
|
||||
ClarifyContext(tax_year="2023-24", taxpayer_id="T-002", jurisdiction="UK"),
|
||||
ClarifyContext(tax_year="2024-25", taxpayer_id="T-003", jurisdiction="US"),
|
||||
]
|
||||
|
||||
for context in contexts:
|
||||
# Each context should be valid
|
||||
assert context.tax_year.startswith("20")
|
||||
assert context.taxpayer_id.startswith("T-")
|
||||
assert context.jurisdiction in ["UK", "US", "CA", "AU"]
|
||||
338
tests/unit/coverage/test_status_classifier.py
Normal file
338
tests/unit/coverage/test_status_classifier.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""Unit tests for evidence status classification."""
|
||||
|
||||
# FILE: tests/unit/coverage/test_status_classifier.py
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.coverage.evaluator import CoverageEvaluator
|
||||
from libs.schemas import (
|
||||
CompiledCoveragePolicy,
|
||||
CoveragePolicy,
|
||||
Defaults,
|
||||
FoundEvidence,
|
||||
Status,
|
||||
StatusClassifier,
|
||||
StatusClassifierConfig,
|
||||
TaxYearBoundary,
|
||||
)
|
||||
from libs.schemas.coverage.core import ConflictRules, Privacy, QuestionTemplates
|
||||
|
||||
# pylint: disable=wrong-import-position,import-error,too-few-public-methods,global-statement
|
||||
# pylint: disable=raise-missing-from,unused-argument,too-many-arguments,too-many-positional-arguments
|
||||
# pylint: disable=too-many-locals,import-outside-toplevel
|
||||
# mypy: disable-error-code=union-attr
|
||||
# mypy: disable-error-code=no-untyped-def
|
||||
|
||||
|
||||
class TestStatusClassifier:
|
||||
"""Test evidence status classification logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_policy(self):
|
||||
"""Create mock compiled policy for testing"""
|
||||
policy = CoveragePolicy(
|
||||
version="1.0",
|
||||
jurisdiction="UK",
|
||||
tax_year="2024-25",
|
||||
tax_year_boundary=TaxYearBoundary(start="2024-04-06", end="2025-04-05"),
|
||||
defaults=Defaults(
|
||||
confidence_thresholds={"ocr": 0.82, "extract": 0.85},
|
||||
date_tolerance_days=30,
|
||||
),
|
||||
document_kinds=["P60"],
|
||||
status_classifier=StatusClassifierConfig(
|
||||
present_verified=StatusClassifier(
|
||||
min_ocr=0.82,
|
||||
min_extract=0.85,
|
||||
date_in_year=True,
|
||||
),
|
||||
present_unverified=StatusClassifier(
|
||||
min_ocr=0.60,
|
||||
min_extract=0.70,
|
||||
date_in_year_or_tolerance=True,
|
||||
),
|
||||
conflicting=StatusClassifier(
|
||||
conflict_rules=["Same doc kind, different totals"]
|
||||
),
|
||||
missing=StatusClassifier(),
|
||||
),
|
||||
conflict_resolution=ConflictRules(precedence=["P60"]),
|
||||
question_templates=QuestionTemplates(
|
||||
default={"text": "test", "why": "test"}
|
||||
),
|
||||
privacy=Privacy(vector_pii_free=True, redact_patterns=[]),
|
||||
)
|
||||
|
||||
return CompiledCoveragePolicy(
|
||||
policy=policy,
|
||||
compiled_predicates={},
|
||||
compiled_at=datetime.utcnow(),
|
||||
hash="test-hash",
|
||||
source_files=["test.yaml"],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def evaluator(self):
|
||||
"""Create coverage evaluator for testing"""
|
||||
return CoverageEvaluator()
|
||||
|
||||
def test_classify_missing_evidence(self, evaluator, mock_policy):
|
||||
"""Test classification when no evidence found"""
|
||||
found = []
|
||||
status = evaluator.classify_status(found, mock_policy, "2024-25")
|
||||
assert status == Status.MISSING
|
||||
|
||||
def test_classify_verified_evidence(self, evaluator, mock_policy):
|
||||
"""Test classification of verified evidence"""
|
||||
found = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.85,
|
||||
extract_confidence=0.90,
|
||||
date="2024-05-15T10:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found, mock_policy, "2024-25")
|
||||
assert status == Status.PRESENT_VERIFIED
|
||||
|
||||
def test_classify_unverified_evidence(self, evaluator, mock_policy):
|
||||
"""Test classification of unverified evidence"""
|
||||
found = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.70, # Below verified threshold
|
||||
extract_confidence=0.75, # Below verified threshold
|
||||
date="2024-05-15T10:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found, mock_policy, "2024-25")
|
||||
assert status == Status.PRESENT_UNVERIFIED
|
||||
|
||||
def test_classify_low_confidence_evidence(self, evaluator, mock_policy):
|
||||
"""Test classification of very low confidence evidence"""
|
||||
found = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.50, # Below unverified threshold
|
||||
extract_confidence=0.55, # Below unverified threshold
|
||||
date="2024-05-15T10:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found, mock_policy, "2024-25")
|
||||
assert status == Status.MISSING
|
||||
|
||||
def test_classify_conflicting_evidence(self, evaluator, mock_policy):
|
||||
"""Test classification when multiple conflicting documents found"""
|
||||
found = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.85,
|
||||
extract_confidence=0.90,
|
||||
date="2024-05-15T10:00:00Z",
|
||||
),
|
||||
FoundEvidence(
|
||||
doc_id="DOC-002",
|
||||
kind="P60",
|
||||
ocr_confidence=0.85,
|
||||
extract_confidence=0.90,
|
||||
date="2024-05-20T10:00:00Z",
|
||||
),
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found, mock_policy, "2024-25")
|
||||
assert status == Status.CONFLICTING
|
||||
|
||||
def test_classify_evidence_outside_tax_year(self, evaluator, mock_policy):
|
||||
"""Test classification of evidence outside tax year"""
|
||||
found = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.85,
|
||||
extract_confidence=0.90,
|
||||
date="2023-03-15T10:00:00Z", # Outside tax year
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found, mock_policy, "2024-25")
|
||||
# Evidence outside tax year should be unverified even with high confidence
|
||||
# This is correct business logic - date validation is part of verification
|
||||
assert status == Status.PRESENT_UNVERIFIED
|
||||
|
||||
def test_classify_evidence_no_date(self, evaluator, mock_policy):
|
||||
"""Test classification of evidence without date"""
|
||||
found = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.85,
|
||||
extract_confidence=0.90,
|
||||
date=None,
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found, mock_policy, "2024-25")
|
||||
# Evidence without date cannot be fully verified, even with high confidence
|
||||
# This is correct business logic - date validation is required for verification
|
||||
assert status == Status.PRESENT_UNVERIFIED
|
||||
|
||||
def test_parse_tax_year_bounds(self, evaluator):
|
||||
"""Test parsing of tax year boundary strings"""
|
||||
start_str = "2024-04-06"
|
||||
end_str = "2025-04-05"
|
||||
|
||||
start, end = evaluator._parse_tax_year_bounds(start_str, end_str)
|
||||
|
||||
assert isinstance(start, datetime)
|
||||
assert isinstance(end, datetime)
|
||||
assert start.year == 2024
|
||||
assert start.month == 4
|
||||
assert start.day == 6
|
||||
assert end.year == 2025
|
||||
assert end.month == 4
|
||||
assert end.day == 5
|
||||
|
||||
def test_evidence_within_tax_year(self, evaluator, mock_policy):
|
||||
"""Test evidence date validation within tax year"""
|
||||
# Evidence within tax year
|
||||
found = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.85,
|
||||
extract_confidence=0.90,
|
||||
date="2024-06-15T10:00:00Z", # Within 2024-25 tax year
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found, mock_policy, "2024-25")
|
||||
assert status == Status.PRESENT_VERIFIED
|
||||
|
||||
def test_evidence_boundary_dates(self, evaluator, mock_policy):
|
||||
"""Test evidence on tax year boundary dates"""
|
||||
# Test start boundary
|
||||
found_start = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.85,
|
||||
extract_confidence=0.90,
|
||||
date="2024-04-06T00:00:00Z", # Exact start date
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found_start, mock_policy, "2024-25")
|
||||
assert status == Status.PRESENT_VERIFIED
|
||||
|
||||
# Test end boundary
|
||||
found_end = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-002",
|
||||
kind="P60",
|
||||
ocr_confidence=0.85,
|
||||
extract_confidence=0.90,
|
||||
date="2025-04-05T23:59:59Z", # Exact end date
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found_end, mock_policy, "2024-25")
|
||||
assert status == Status.PRESENT_VERIFIED
|
||||
|
||||
def test_threshold_edge_cases(self, evaluator, mock_policy):
|
||||
"""Test classification at threshold boundaries"""
|
||||
# Exactly at verified threshold
|
||||
found_exact = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.82, # Exactly at threshold
|
||||
extract_confidence=0.85, # Exactly at threshold
|
||||
date="2024-06-15T10:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found_exact, mock_policy, "2024-25")
|
||||
assert status == Status.PRESENT_VERIFIED
|
||||
|
||||
# Just below verified threshold
|
||||
found_below = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-002",
|
||||
kind="P60",
|
||||
ocr_confidence=0.81, # Just below threshold
|
||||
extract_confidence=0.84, # Just below threshold
|
||||
date="2024-06-15T10:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found_below, mock_policy, "2024-25")
|
||||
assert status == Status.PRESENT_UNVERIFIED
|
||||
|
||||
def test_mixed_confidence_levels(self, evaluator, mock_policy):
|
||||
"""Test classification with mixed OCR and extract confidence"""
|
||||
# High OCR, low extract
|
||||
found_mixed1 = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.90, # High
|
||||
extract_confidence=0.70, # Low
|
||||
date="2024-06-15T10:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found_mixed1, mock_policy, "2024-25")
|
||||
assert status == Status.PRESENT_UNVERIFIED # Both must meet threshold
|
||||
|
||||
# Low OCR, high extract
|
||||
found_mixed2 = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-002",
|
||||
kind="P60",
|
||||
ocr_confidence=0.70, # Low
|
||||
extract_confidence=0.90, # High
|
||||
date="2024-06-15T10:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found_mixed2, mock_policy, "2024-25")
|
||||
assert status == Status.PRESENT_UNVERIFIED # Both must meet threshold
|
||||
|
||||
def test_zero_confidence_evidence(self, evaluator, mock_policy):
|
||||
"""Test classification of zero confidence evidence"""
|
||||
found = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=0.0,
|
||||
extract_confidence=0.0,
|
||||
date="2024-06-15T10:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found, mock_policy, "2024-25")
|
||||
assert status == Status.MISSING
|
||||
|
||||
def test_perfect_confidence_evidence(self, evaluator, mock_policy):
|
||||
"""Test classification of perfect confidence evidence"""
|
||||
found = [
|
||||
FoundEvidence(
|
||||
doc_id="DOC-001",
|
||||
kind="P60",
|
||||
ocr_confidence=1.0,
|
||||
extract_confidence=1.0,
|
||||
date="2024-06-15T10:00:00Z",
|
||||
)
|
||||
]
|
||||
|
||||
status = evaluator.classify_status(found, mock_policy, "2024-25")
|
||||
assert status == Status.PRESENT_VERIFIED
|
||||
283
tests/unit/multi-model-calibration.py
Normal file
283
tests/unit/multi-model-calibration.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""Unit tests for multi-model calibration."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.calibration.multi_model import MultiModelCalibrator
|
||||
|
||||
# pylint: disable=wrong-import-position,import-error,too-few-public-methods,global-statement
|
||||
# pylint: disable=raise-missing-from,unused-argument,too-many-arguments,too-many-positional-arguments
|
||||
# pylint: disable=too-many-locals,import-outside-toplevel
|
||||
# mypy: disable-error-code=union-attr
|
||||
# mypy: disable-error-code=no-untyped-def
|
||||
|
||||
|
||||
class TestMultiModelCalibrator:
|
||||
"""Test MultiModelCalibrator"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Create sample training data"""
|
||||
scores = [0.1, 0.3, 0.5, 0.7, 0.9]
|
||||
labels = [False, False, True, True, True]
|
||||
return scores, labels
|
||||
|
||||
def test_init(self):
|
||||
"""Test initialization"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
assert isinstance(calibrator.calibrators, dict)
|
||||
assert len(calibrator.calibrators) == 0
|
||||
|
||||
def test_add_calibrator_default_method(self):
|
||||
"""Test adding calibrator with default method"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
calibrator.add_calibrator("model_a")
|
||||
|
||||
assert "model_a" in calibrator.calibrators
|
||||
assert calibrator.calibrators["model_a"].method == "temperature"
|
||||
|
||||
def test_add_calibrator_custom_method(self):
|
||||
"""Test adding calibrator with custom method"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
calibrator.add_calibrator("model_b", method="platt")
|
||||
|
||||
assert "model_b" in calibrator.calibrators
|
||||
assert calibrator.calibrators["model_b"].method == "platt"
|
||||
|
||||
def test_fit_existing_calibrator(self, sample_data):
|
||||
"""Test fitting existing calibrator"""
|
||||
scores, labels = sample_data
|
||||
calibrator = MultiModelCalibrator()
|
||||
calibrator.add_calibrator("model_a")
|
||||
|
||||
calibrator.fit("model_a", scores, labels)
|
||||
|
||||
assert calibrator.calibrators["model_a"].is_fitted
|
||||
|
||||
def test_fit_auto_add_calibrator(self, sample_data):
|
||||
"""Test fitting automatically adds calibrator if not exists"""
|
||||
scores, labels = sample_data
|
||||
calibrator = MultiModelCalibrator()
|
||||
|
||||
# Should auto-add calibrator
|
||||
calibrator.fit("model_new", scores, labels)
|
||||
|
||||
assert "model_new" in calibrator.calibrators
|
||||
assert calibrator.calibrators["model_new"].is_fitted
|
||||
|
||||
def test_calibrate_existing_model(self, sample_data):
|
||||
"""Test calibrating with existing fitted model"""
|
||||
scores, labels = sample_data
|
||||
calibrator = MultiModelCalibrator()
|
||||
calibrator.fit("model_a", scores, labels)
|
||||
|
||||
test_scores = [0.2, 0.6, 0.8]
|
||||
result = calibrator.calibrate("model_a", test_scores)
|
||||
|
||||
assert len(result) == len(test_scores)
|
||||
assert all(0 <= p <= 1 for p in result)
|
||||
|
||||
def test_calibrate_nonexistent_model_returns_original(self):
|
||||
"""Test calibrating nonexistent model returns original scores"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
scores = [0.1, 0.5, 0.9]
|
||||
|
||||
# Should return original scores and log warning
|
||||
result = calibrator.calibrate("nonexistent", scores)
|
||||
assert result == scores
|
||||
|
||||
def test_calibrate_unfitted_model_returns_original(self, sample_data):
|
||||
"""Test calibrating unfitted model returns original scores"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
calibrator.add_calibrator("model_a") # Add but don't fit
|
||||
|
||||
test_scores = [0.2, 0.6, 0.8]
|
||||
result = calibrator.calibrate("model_a", test_scores)
|
||||
|
||||
# Should return original scores since not fitted
|
||||
assert result == test_scores
|
||||
|
||||
def test_save_models_creates_directory(self, sample_data):
|
||||
"""Test saving models creates directory"""
|
||||
scores, labels = sample_data
|
||||
calibrator = MultiModelCalibrator()
|
||||
calibrator.fit("model_a", scores, labels)
|
||||
calibrator.fit("model_b", scores, labels)
|
||||
|
||||
with (
|
||||
patch("os.makedirs") as mock_makedirs,
|
||||
patch.object(
|
||||
calibrator.calibrators["model_a"], "save_model"
|
||||
) as mock_save_a,
|
||||
patch.object(
|
||||
calibrator.calibrators["model_b"], "save_model"
|
||||
) as mock_save_b,
|
||||
):
|
||||
|
||||
calibrator.save_models("test_dir")
|
||||
|
||||
mock_makedirs.assert_called_once_with("test_dir", exist_ok=True)
|
||||
mock_save_a.assert_called_once()
|
||||
mock_save_b.assert_called_once()
|
||||
|
||||
def test_load_models_from_directory(self):
|
||||
"""Test loading models from directory"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
|
||||
# Mock glob to return some model files
|
||||
mock_files = [
|
||||
"test_dir/model_a_calibrator.pkl",
|
||||
"test_dir/model_b_calibrator.pkl",
|
||||
]
|
||||
|
||||
with (
|
||||
patch("libs.calibration.multi_model.glob.glob", return_value=mock_files),
|
||||
patch(
|
||||
"libs.calibration.multi_model.ConfidenceCalibrator"
|
||||
) as mock_calibrator_class,
|
||||
):
|
||||
|
||||
mock_calibrator_instance = MagicMock()
|
||||
mock_calibrator_class.return_value = mock_calibrator_instance
|
||||
|
||||
calibrator.load_models("test_dir")
|
||||
|
||||
# Should have loaded two models
|
||||
assert len(calibrator.calibrators) == 2
|
||||
assert "model_a" in calibrator.calibrators
|
||||
assert "model_b" in calibrator.calibrators
|
||||
|
||||
# Should have called load_model on each
|
||||
assert mock_calibrator_instance.load_model.call_count == 2
|
||||
|
||||
def test_load_models_empty_directory(self):
|
||||
"""Test loading from empty directory"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
|
||||
with patch("glob.glob", return_value=[]):
|
||||
calibrator.load_models("empty_dir")
|
||||
|
||||
assert len(calibrator.calibrators) == 0
|
||||
|
||||
def test_get_model_names(self, sample_data):
|
||||
"""Test getting model names"""
|
||||
scores, labels = sample_data
|
||||
calibrator = MultiModelCalibrator()
|
||||
calibrator.fit("model_a", scores, labels)
|
||||
calibrator.fit("model_b", scores, labels)
|
||||
|
||||
names = calibrator.get_model_names()
|
||||
|
||||
assert set(names) == {"model_a", "model_b"}
|
||||
|
||||
def test_get_model_names_empty(self):
|
||||
"""Test getting model names when empty"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
names = calibrator.get_model_names()
|
||||
|
||||
assert names == []
|
||||
|
||||
def test_remove_calibrator(self, sample_data):
|
||||
"""Test removing calibrator"""
|
||||
scores, labels = sample_data
|
||||
calibrator = MultiModelCalibrator()
|
||||
calibrator.fit("model_a", scores, labels)
|
||||
calibrator.fit("model_b", scores, labels)
|
||||
|
||||
assert len(calibrator.calibrators) == 2
|
||||
|
||||
calibrator.remove_calibrator("model_a")
|
||||
|
||||
assert len(calibrator.calibrators) == 1
|
||||
assert "model_a" not in calibrator.calibrators
|
||||
assert "model_b" in calibrator.calibrators
|
||||
|
||||
def test_remove_nonexistent_calibrator_raises_error(self):
|
||||
"""Test removing nonexistent calibrator raises error"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
|
||||
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
|
||||
calibrator.remove_calibrator("nonexistent")
|
||||
|
||||
def test_has_model(self, sample_data):
|
||||
"""Test checking if model exists"""
|
||||
scores, labels = sample_data
|
||||
calibrator = MultiModelCalibrator()
|
||||
calibrator.fit("model_a", scores, labels)
|
||||
|
||||
assert calibrator.has_model("model_a")
|
||||
assert not calibrator.has_model("model_b")
|
||||
|
||||
def test_is_fitted(self, sample_data):
|
||||
"""Test checking if model is fitted"""
|
||||
scores, labels = sample_data
|
||||
calibrator = MultiModelCalibrator()
|
||||
calibrator.add_calibrator("model_a") # Add but don't fit
|
||||
calibrator.fit("model_b", scores, labels) # Add and fit
|
||||
|
||||
assert not calibrator.is_fitted("model_a")
|
||||
assert calibrator.is_fitted("model_b")
|
||||
|
||||
def test_is_fitted_nonexistent_model_raises_error(self):
|
||||
"""Test checking fitted status of nonexistent model raises error"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
|
||||
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
|
||||
calibrator.is_fitted("nonexistent")
|
||||
|
||||
def test_multiple_models_workflow(self, sample_data):
|
||||
"""Test complete workflow with multiple models"""
|
||||
scores, labels = sample_data
|
||||
calibrator = MultiModelCalibrator()
|
||||
|
||||
# Add different models with different methods
|
||||
calibrator.add_calibrator("temperature_model", "temperature")
|
||||
calibrator.add_calibrator("platt_model", "platt")
|
||||
calibrator.add_calibrator("isotonic_model", "isotonic")
|
||||
|
||||
# Fit all models
|
||||
calibrator.fit("temperature_model", scores, labels)
|
||||
calibrator.fit("platt_model", scores, labels)
|
||||
calibrator.fit("isotonic_model", scores, labels)
|
||||
|
||||
# Test calibration for all models
|
||||
test_scores = [0.2, 0.6, 0.8]
|
||||
|
||||
temp_result = calibrator.calibrate("temperature_model", test_scores)
|
||||
platt_result = calibrator.calibrate("platt_model", test_scores)
|
||||
isotonic_result = calibrator.calibrate("isotonic_model", test_scores)
|
||||
|
||||
# All should return valid probabilities
|
||||
for result in [temp_result, platt_result, isotonic_result]:
|
||||
assert len(result) == len(test_scores)
|
||||
assert all(0 <= p <= 1 for p in result)
|
||||
|
||||
# Results should be different (unless by coincidence)
|
||||
assert not (temp_result == platt_result == isotonic_result)
|
||||
|
||||
def test_fit_with_different_data_per_model(self):
|
||||
"""Test fitting different models with different data"""
|
||||
calibrator = MultiModelCalibrator()
|
||||
|
||||
# Different data for different models
|
||||
scores_a = [0.1, 0.3, 0.7, 0.9]
|
||||
labels_a = [False, False, True, True]
|
||||
|
||||
scores_b = [0.2, 0.4, 0.6, 0.8]
|
||||
labels_b = [False, True, False, True]
|
||||
|
||||
calibrator.fit("model_a", scores_a, labels_a)
|
||||
calibrator.fit("model_b", scores_b, labels_b)
|
||||
|
||||
assert calibrator.is_fitted("model_a")
|
||||
assert calibrator.is_fitted("model_b")
|
||||
|
||||
# Both should be able to calibrate
|
||||
result_a = calibrator.calibrate("model_a", [0.5])
|
||||
result_b = calibrator.calibrate("model_b", [0.5])
|
||||
|
||||
assert len(result_a) == 1
|
||||
assert len(result_b) == 1
|
||||
assert 0 <= result_a[0] <= 1
|
||||
assert 0 <= result_b[0] <= 1
|
||||
565
tests/unit/test_calculators.py
Normal file
565
tests/unit/test_calculators.py
Normal file
@@ -0,0 +1,565 @@
|
||||
# FILE: tests/unit/test_calculators.py
|
||||
# Unit tests for tax calculation logic
|
||||
|
||||
import os
|
||||
import sys
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
# Add libs to path for testing
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "libs"))
|
||||
|
||||
# Mock the calculation functions since they're in the service
|
||||
# In a real implementation, these would be extracted to shared libs
|
||||
|
||||
# pylint: disable=wrong-import-position,import-error,too-few-public-methods,global-statement
|
||||
# pylint: disable=raise-missing-from,unused-argument,too-many-arguments,too-many-positional-arguments
|
||||
# pylint: disable=too-many-locals,import-outside-toplevel
|
||||
# mypy: disable-error-code=union-attr
|
||||
# mypy: disable-error-code=no-untyped-def
|
||||
|
||||
|
||||
class MockTaxCalculator:
|
||||
"""Mock tax calculator for testing"""
|
||||
|
||||
def __init__(self, tax_year: str = "2023-24"):
|
||||
self.tax_year = tax_year
|
||||
self.precision = 2
|
||||
|
||||
def compute_sa103_self_employment(
|
||||
self, income_items: list[dict[str, Any]], expense_items: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Compute SA103 self-employment schedule"""
|
||||
|
||||
total_turnover = Decimal("0")
|
||||
total_expenses = Decimal("0")
|
||||
evidence_trail = []
|
||||
|
||||
# Sum income
|
||||
for income in income_items:
|
||||
if income.get("type") == "self_employment":
|
||||
amount = Decimal(str(income.get("gross", 0)))
|
||||
total_turnover += amount
|
||||
|
||||
evidence_trail.append(
|
||||
{
|
||||
"box": "20",
|
||||
"source_entity": income.get("income_id"),
|
||||
"amount": float(amount),
|
||||
"description": f"Income: {income.get('description', 'Unknown')}",
|
||||
}
|
||||
)
|
||||
|
||||
# Sum expenses
|
||||
for expense in expense_items:
|
||||
if expense.get("allowable", True):
|
||||
amount = Decimal(str(expense.get("amount", 0)))
|
||||
total_expenses += amount
|
||||
|
||||
evidence_trail.append(
|
||||
{
|
||||
"box": "31",
|
||||
"source_entity": expense.get("expense_id"),
|
||||
"amount": float(amount),
|
||||
"description": f"Expense: {expense.get('description', 'Unknown')}",
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate net profit
|
||||
net_profit = total_turnover - total_expenses
|
||||
|
||||
# Create form boxes
|
||||
form_boxes = {
|
||||
"20": {
|
||||
"value": float(total_turnover),
|
||||
"description": "Total turnover",
|
||||
"confidence": 0.9,
|
||||
},
|
||||
"31": {
|
||||
"value": float(total_expenses),
|
||||
"description": "Total allowable business expenses",
|
||||
"confidence": 0.9,
|
||||
},
|
||||
"32": {
|
||||
"value": float(net_profit),
|
||||
"description": "Net profit",
|
||||
"confidence": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"form_boxes": form_boxes,
|
||||
"evidence_trail": evidence_trail,
|
||||
"total_turnover": float(total_turnover),
|
||||
"total_expenses": float(total_expenses),
|
||||
"net_profit": float(net_profit),
|
||||
}
|
||||
|
||||
def compute_sa105_property(
|
||||
self, income_items: list[dict[str, Any]], expense_items: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Compute SA105 property income schedule"""
|
||||
|
||||
total_rents = Decimal("0")
|
||||
total_property_expenses = Decimal("0")
|
||||
evidence_trail = []
|
||||
|
||||
# Sum property income
|
||||
for income in income_items:
|
||||
if income.get("type") == "property":
|
||||
amount = Decimal(str(income.get("gross", 0)))
|
||||
total_rents += amount
|
||||
|
||||
evidence_trail.append(
|
||||
{
|
||||
"box": "20",
|
||||
"source_entity": income.get("income_id"),
|
||||
"amount": float(amount),
|
||||
"description": f"Property income: {income.get('description', 'Unknown')}",
|
||||
}
|
||||
)
|
||||
|
||||
# Sum property expenses
|
||||
for expense in expense_items:
|
||||
if expense.get("type") == "property" and expense.get("allowable", True):
|
||||
amount = Decimal(str(expense.get("amount", 0)))
|
||||
total_property_expenses += amount
|
||||
|
||||
# Map to appropriate SA105 box based on expense category
|
||||
box = self._map_property_expense_to_box(
|
||||
expense.get("category", "other")
|
||||
)
|
||||
|
||||
evidence_trail.append(
|
||||
{
|
||||
"box": box,
|
||||
"source_entity": expense.get("expense_id"),
|
||||
"amount": float(amount),
|
||||
"description": f"Property expense: {expense.get('description', 'Unknown')}",
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate net property income
|
||||
net_property_income = total_rents - total_property_expenses
|
||||
|
||||
form_boxes = {
|
||||
"20": {
|
||||
"value": float(total_rents),
|
||||
"description": "Total rents and other income",
|
||||
"confidence": 0.9,
|
||||
},
|
||||
"38": {
|
||||
"value": float(total_property_expenses),
|
||||
"description": "Total property expenses",
|
||||
"confidence": 0.9,
|
||||
},
|
||||
"net_income": {
|
||||
"value": float(net_property_income),
|
||||
"description": "Net property income",
|
||||
"confidence": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"form_boxes": form_boxes,
|
||||
"evidence_trail": evidence_trail,
|
||||
"total_rents": float(total_rents),
|
||||
"total_expenses": float(total_property_expenses),
|
||||
"net_income": float(net_property_income),
|
||||
}
|
||||
|
||||
def _map_property_expense_to_box(self, category: str) -> str:
|
||||
"""Map property expense category to SA105 box"""
|
||||
mapping = {
|
||||
"rent_rates_insurance": "31",
|
||||
"property_management": "32",
|
||||
"services_wages": "33",
|
||||
"repairs_maintenance": "34",
|
||||
"finance_costs": "35",
|
||||
"professional_fees": "36",
|
||||
"costs_of_services": "37",
|
||||
"other": "38",
|
||||
}
|
||||
|
||||
return mapping.get(category, "38")
|
||||
|
||||
|
||||
class TestSA103SelfEmployment:
|
||||
"""Test SA103 self-employment calculations"""
|
||||
|
||||
@pytest.fixture
|
||||
def calculator(self):
|
||||
return MockTaxCalculator("2023-24")
|
||||
|
||||
@pytest.fixture
|
||||
def sample_income_items(self):
|
||||
return [
|
||||
{
|
||||
"income_id": "income_1",
|
||||
"type": "self_employment",
|
||||
"gross": 75000,
|
||||
"description": "Consulting income",
|
||||
},
|
||||
{
|
||||
"income_id": "income_2",
|
||||
"type": "self_employment",
|
||||
"gross": 25000,
|
||||
"description": "Training income",
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_expense_items(self):
|
||||
return [
|
||||
{
|
||||
"expense_id": "expense_1",
|
||||
"type": "self_employment",
|
||||
"amount": 5000,
|
||||
"allowable": True,
|
||||
"description": "Office rent",
|
||||
},
|
||||
{
|
||||
"expense_id": "expense_2",
|
||||
"type": "self_employment",
|
||||
"amount": 2000,
|
||||
"allowable": True,
|
||||
"description": "Equipment",
|
||||
},
|
||||
{
|
||||
"expense_id": "expense_3",
|
||||
"type": "self_employment",
|
||||
"amount": 1000,
|
||||
"allowable": False,
|
||||
"description": "Entertainment (not allowable)",
|
||||
},
|
||||
]
|
||||
|
||||
def test_basic_calculation(
|
||||
self, calculator, sample_income_items, sample_expense_items
|
||||
):
|
||||
"""Test basic SA103 calculation"""
|
||||
|
||||
result = calculator.compute_sa103_self_employment(
|
||||
sample_income_items, sample_expense_items
|
||||
)
|
||||
|
||||
# Check totals
|
||||
assert result["total_turnover"] == 100000 # 75000 + 25000
|
||||
assert result["total_expenses"] == 7000 # 5000 + 2000 (excluding non-allowable)
|
||||
assert result["net_profit"] == 93000 # 100000 - 7000
|
||||
|
||||
# Check form boxes
|
||||
form_boxes = result["form_boxes"]
|
||||
assert form_boxes["20"]["value"] == 100000
|
||||
assert form_boxes["31"]["value"] == 7000
|
||||
assert form_boxes["32"]["value"] == 93000
|
||||
|
||||
# Check evidence trail
|
||||
evidence_trail = result["evidence_trail"]
|
||||
assert len(evidence_trail) == 4 # 2 income + 2 allowable expenses
|
||||
|
||||
def test_zero_income(self, calculator):
|
||||
"""Test calculation with zero income"""
|
||||
|
||||
result = calculator.compute_sa103_self_employment([], [])
|
||||
|
||||
assert result["total_turnover"] == 0
|
||||
assert result["total_expenses"] == 0
|
||||
assert result["net_profit"] == 0
|
||||
|
||||
form_boxes = result["form_boxes"]
|
||||
assert form_boxes["20"]["value"] == 0
|
||||
assert form_boxes["31"]["value"] == 0
|
||||
assert form_boxes["32"]["value"] == 0
|
||||
|
||||
def test_loss_scenario(self, calculator):
|
||||
"""Test calculation resulting in a loss"""
|
||||
|
||||
income_items = [
|
||||
{
|
||||
"income_id": "income_1",
|
||||
"type": "self_employment",
|
||||
"gross": 10000,
|
||||
"description": "Low income year",
|
||||
}
|
||||
]
|
||||
|
||||
expense_items = [
|
||||
{
|
||||
"expense_id": "expense_1",
|
||||
"type": "self_employment",
|
||||
"amount": 15000,
|
||||
"allowable": True,
|
||||
"description": "High expenses",
|
||||
}
|
||||
]
|
||||
|
||||
result = calculator.compute_sa103_self_employment(income_items, expense_items)
|
||||
|
||||
assert result["total_turnover"] == 10000
|
||||
assert result["total_expenses"] == 15000
|
||||
assert result["net_profit"] == -5000 # Loss
|
||||
|
||||
form_boxes = result["form_boxes"]
|
||||
assert form_boxes["32"]["value"] == -5000
|
||||
|
||||
def test_non_allowable_expenses_excluded(self, calculator, sample_income_items):
|
||||
"""Test that non-allowable expenses are excluded"""
|
||||
|
||||
expense_items = [
|
||||
{
|
||||
"expense_id": "expense_1",
|
||||
"type": "self_employment",
|
||||
"amount": 5000,
|
||||
"allowable": True,
|
||||
"description": "Allowable expense",
|
||||
},
|
||||
{
|
||||
"expense_id": "expense_2",
|
||||
"type": "self_employment",
|
||||
"amount": 3000,
|
||||
"allowable": False,
|
||||
"description": "Non-allowable expense",
|
||||
},
|
||||
]
|
||||
|
||||
result = calculator.compute_sa103_self_employment(
|
||||
sample_income_items, expense_items
|
||||
)
|
||||
|
||||
# Only allowable expenses should be included
|
||||
assert result["total_expenses"] == 5000
|
||||
|
||||
# Evidence trail should only include allowable expenses
|
||||
expense_evidence = [e for e in result["evidence_trail"] if e["box"] == "31"]
|
||||
assert len(expense_evidence) == 1
|
||||
assert expense_evidence[0]["amount"] == 5000
|
||||
|
||||
|
||||
class TestSA105Property:
|
||||
"""Test SA105 property income calculations"""
|
||||
|
||||
@pytest.fixture
|
||||
def calculator(self):
|
||||
return MockTaxCalculator("2023-24")
|
||||
|
||||
@pytest.fixture
|
||||
def sample_property_income(self):
|
||||
return [
|
||||
{
|
||||
"income_id": "prop_income_1",
|
||||
"type": "property",
|
||||
"gross": 24000,
|
||||
"description": "Rental income - Property 1",
|
||||
},
|
||||
{
|
||||
"income_id": "prop_income_2",
|
||||
"type": "property",
|
||||
"gross": 18000,
|
||||
"description": "Rental income - Property 2",
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_property_expenses(self):
|
||||
return [
|
||||
{
|
||||
"expense_id": "prop_expense_1",
|
||||
"type": "property",
|
||||
"amount": 3000,
|
||||
"allowable": True,
|
||||
"category": "rent_rates_insurance",
|
||||
"description": "Insurance and rates",
|
||||
},
|
||||
{
|
||||
"expense_id": "prop_expense_2",
|
||||
"type": "property",
|
||||
"amount": 2000,
|
||||
"allowable": True,
|
||||
"category": "repairs_maintenance",
|
||||
"description": "Repairs and maintenance",
|
||||
},
|
||||
{
|
||||
"expense_id": "prop_expense_3",
|
||||
"type": "property",
|
||||
"amount": 1500,
|
||||
"allowable": True,
|
||||
"category": "property_management",
|
||||
"description": "Property management fees",
|
||||
},
|
||||
]
|
||||
|
||||
def test_basic_property_calculation(
|
||||
self, calculator, sample_property_income, sample_property_expenses
|
||||
):
|
||||
"""Test basic SA105 property calculation"""
|
||||
|
||||
result = calculator.compute_sa105_property(
|
||||
sample_property_income, sample_property_expenses
|
||||
)
|
||||
|
||||
# Check totals
|
||||
assert result["total_rents"] == 42000 # 24000 + 18000
|
||||
assert result["total_expenses"] == 6500 # 3000 + 2000 + 1500
|
||||
assert result["net_income"] == 35500 # 42000 - 6500
|
||||
|
||||
# Check form boxes
|
||||
form_boxes = result["form_boxes"]
|
||||
assert form_boxes["20"]["value"] == 42000
|
||||
assert form_boxes["38"]["value"] == 6500
|
||||
assert form_boxes["net_income"]["value"] == 35500
|
||||
|
||||
def test_property_expense_mapping(self, calculator):
|
||||
"""Test property expense category mapping to form boxes"""
|
||||
|
||||
# Test different expense categories
|
||||
test_cases = [
|
||||
("rent_rates_insurance", "31"),
|
||||
("property_management", "32"),
|
||||
("services_wages", "33"),
|
||||
("repairs_maintenance", "34"),
|
||||
("finance_costs", "35"),
|
||||
("professional_fees", "36"),
|
||||
("costs_of_services", "37"),
|
||||
("other", "38"),
|
||||
("unknown_category", "38"), # Should default to 38
|
||||
]
|
||||
|
||||
for category, expected_box in test_cases:
|
||||
actual_box = calculator._map_property_expense_to_box(category)
|
||||
assert (
|
||||
actual_box == expected_box
|
||||
), f"Category {category} should map to box {expected_box}"
|
||||
|
||||
def test_property_loss(self, calculator):
|
||||
"""Test property calculation resulting in a loss"""
|
||||
|
||||
income_items = [
|
||||
{
|
||||
"income_id": "prop_income_1",
|
||||
"type": "property",
|
||||
"gross": 12000,
|
||||
"description": "Low rental income",
|
||||
}
|
||||
]
|
||||
|
||||
expense_items = [
|
||||
{
|
||||
"expense_id": "prop_expense_1",
|
||||
"type": "property",
|
||||
"amount": 15000,
|
||||
"allowable": True,
|
||||
"category": "repairs_maintenance",
|
||||
"description": "Major repairs",
|
||||
}
|
||||
]
|
||||
|
||||
result = calculator.compute_sa105_property(income_items, expense_items)
|
||||
|
||||
assert result["total_rents"] == 12000
|
||||
assert result["total_expenses"] == 15000
|
||||
assert result["net_income"] == -3000 # Loss
|
||||
|
||||
form_boxes = result["form_boxes"]
|
||||
assert form_boxes["net_income"]["value"] == -3000
|
||||
|
||||
|
||||
class TestCalculationEdgeCases:
|
||||
"""Test edge cases and error conditions"""
|
||||
|
||||
@pytest.fixture
|
||||
def calculator(self):
|
||||
return MockTaxCalculator("2023-24")
|
||||
|
||||
def test_decimal_precision(self, calculator):
|
||||
"""Test decimal precision handling"""
|
||||
|
||||
income_items = [
|
||||
{
|
||||
"income_id": "income_1",
|
||||
"type": "self_employment",
|
||||
"gross": 33333.33,
|
||||
"description": "Precise income",
|
||||
}
|
||||
]
|
||||
|
||||
expense_items = [
|
||||
{
|
||||
"expense_id": "expense_1",
|
||||
"type": "self_employment",
|
||||
"amount": 11111.11,
|
||||
"allowable": True,
|
||||
"description": "Precise expense",
|
||||
}
|
||||
]
|
||||
|
||||
result = calculator.compute_sa103_self_employment(income_items, expense_items)
|
||||
|
||||
# Check that calculations maintain precision
|
||||
assert result["total_turnover"] == 33333.33
|
||||
assert result["total_expenses"] == 11111.11
|
||||
assert result["net_profit"] == 22222.22
|
||||
|
||||
def test_string_amounts(self, calculator):
|
||||
"""Test handling of string amounts"""
|
||||
|
||||
income_items = [
|
||||
{
|
||||
"income_id": "income_1",
|
||||
"type": "self_employment",
|
||||
"gross": "50000.00", # String amount
|
||||
"description": "String income",
|
||||
}
|
||||
]
|
||||
|
||||
expense_items = [
|
||||
{
|
||||
"expense_id": "expense_1",
|
||||
"type": "self_employment",
|
||||
"amount": "10000.00", # String amount
|
||||
"allowable": True,
|
||||
"description": "String expense",
|
||||
}
|
||||
]
|
||||
|
||||
result = calculator.compute_sa103_self_employment(income_items, expense_items)
|
||||
|
||||
assert result["total_turnover"] == 50000.0
|
||||
assert result["total_expenses"] == 10000.0
|
||||
assert result["net_profit"] == 40000.0
|
||||
|
||||
def test_missing_fields(self, calculator):
|
||||
"""Test handling of missing fields"""
|
||||
|
||||
income_items = [
|
||||
{
|
||||
"income_id": "income_1",
|
||||
"type": "self_employment",
|
||||
# Missing 'gross' field
|
||||
"description": "Income without amount",
|
||||
}
|
||||
]
|
||||
|
||||
expense_items = [
|
||||
{
|
||||
"expense_id": "expense_1",
|
||||
"type": "self_employment",
|
||||
# Missing 'amount' field
|
||||
"allowable": True,
|
||||
"description": "Expense without amount",
|
||||
}
|
||||
]
|
||||
|
||||
result = calculator.compute_sa103_self_employment(income_items, expense_items)
|
||||
|
||||
# Should handle missing fields gracefully
|
||||
assert result["total_turnover"] == 0
|
||||
assert result["total_expenses"] == 0
|
||||
assert result["net_profit"] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
814
tests/unit/test_forms.py
Normal file
814
tests/unit/test_forms.py
Normal file
@@ -0,0 +1,814 @@
|
||||
"""
|
||||
Unit tests for svc-forms service
|
||||
Tests actual business logic: PDF form filling, evidence pack generation,
|
||||
currency formatting, and field mapping
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the project root to the path so we can import from apps
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
# Import the actual service code
|
||||
from apps.svc_forms.main import FormsSettings
|
||||
|
||||
# pylint: disable=wrong-import-position,import-error,too-few-public-methods
|
||||
# pylint: disable=global-statement,raise-missing-from,unused-argument
|
||||
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
# pylint: disable=too-many-locals,import-outside-toplevel
|
||||
# mypy: disable-error-code=union-attr
|
||||
|
||||
|
||||
class TestFormsSettings:
|
||||
"""Test FormsSettings configuration"""
|
||||
|
||||
def test_default_settings(self) -> None:
|
||||
"""Test default FormsSettings values"""
|
||||
settings = FormsSettings()
|
||||
|
||||
# Test service configuration
|
||||
assert settings.service_name == "svc-forms"
|
||||
|
||||
# Test form templates configuration
|
||||
assert settings.forms_template_dir == "forms/templates"
|
||||
assert settings.output_bucket == "filled-forms"
|
||||
assert settings.evidence_packs_bucket == "evidence-packs"
|
||||
|
||||
# Test supported forms
|
||||
expected_forms = ["SA100", "SA103", "SA105", "SA106"]
|
||||
assert settings.supported_forms == expected_forms
|
||||
|
||||
# Test PDF configuration
|
||||
assert settings.pdf_quality == "high"
|
||||
assert settings.flatten_forms is True
|
||||
|
||||
def test_custom_settings(self) -> None:
|
||||
"""Test custom FormsSettings values"""
|
||||
custom_settings = FormsSettings(
|
||||
forms_template_dir="custom/templates",
|
||||
output_bucket="custom-forms",
|
||||
evidence_packs_bucket="custom-evidence",
|
||||
supported_forms=["SA100", "SA103"],
|
||||
pdf_quality="medium",
|
||||
flatten_forms=False,
|
||||
)
|
||||
|
||||
assert custom_settings.forms_template_dir == "custom/templates"
|
||||
assert custom_settings.output_bucket == "custom-forms"
|
||||
assert custom_settings.evidence_packs_bucket == "custom-evidence"
|
||||
assert custom_settings.supported_forms == ["SA100", "SA103"]
|
||||
assert custom_settings.pdf_quality == "medium"
|
||||
assert custom_settings.flatten_forms is False
|
||||
|
||||
|
||||
class TestFormSupport:
|
||||
"""Test form support validation"""
|
||||
|
||||
def test_supported_forms_list(self) -> None:
|
||||
"""Test supported forms list"""
|
||||
settings = FormsSettings()
|
||||
supported_forms = settings.supported_forms
|
||||
|
||||
# Test that key UK tax forms are supported
|
||||
assert "SA100" in supported_forms # Main self-assessment form
|
||||
assert "SA103" in supported_forms # Self-employment
|
||||
assert "SA105" in supported_forms # Property income
|
||||
assert "SA106" in supported_forms # Foreign income
|
||||
|
||||
def test_form_validation(self) -> None:
|
||||
"""Test form ID validation logic"""
|
||||
settings = FormsSettings()
|
||||
valid_forms = settings.supported_forms
|
||||
|
||||
# Test valid form IDs
|
||||
for form_id in valid_forms:
|
||||
assert form_id in valid_forms
|
||||
assert form_id.startswith("SA") # UK self-assessment forms
|
||||
assert len(form_id) >= 5 # Minimum length
|
||||
|
||||
# Test invalid form IDs
|
||||
invalid_forms = ["INVALID", "CT600", "VAT100", ""]
|
||||
for invalid_form in invalid_forms:
|
||||
assert invalid_form not in valid_forms
|
||||
|
||||
|
||||
class TestPDFConfiguration:
|
||||
"""Test PDF configuration and quality settings"""
|
||||
|
||||
def test_pdf_quality_options(self) -> None:
|
||||
"""Test PDF quality configuration"""
|
||||
# Test different quality settings
|
||||
quality_options = ["low", "medium", "high", "maximum"]
|
||||
|
||||
for quality in quality_options:
|
||||
settings = FormsSettings(pdf_quality=quality)
|
||||
assert settings.pdf_quality == quality
|
||||
|
||||
def test_flatten_forms_option(self) -> None:
|
||||
"""Test form flattening configuration"""
|
||||
# Test flattening enabled (default)
|
||||
settings_flat = FormsSettings(flatten_forms=True)
|
||||
assert settings_flat.flatten_forms is True
|
||||
|
||||
# Test flattening disabled
|
||||
settings_editable = FormsSettings(flatten_forms=False)
|
||||
assert settings_editable.flatten_forms is False
|
||||
|
||||
def test_pdf_configuration_validation(self) -> None:
|
||||
"""Test PDF configuration validation"""
|
||||
settings = FormsSettings()
|
||||
|
||||
# Test that quality is a string
|
||||
assert isinstance(settings.pdf_quality, str)
|
||||
assert len(settings.pdf_quality) > 0
|
||||
|
||||
# Test that flatten_forms is boolean
|
||||
assert isinstance(settings.flatten_forms, bool)
|
||||
|
||||
|
||||
class TestFormFieldMapping:
|
||||
"""Test form field mapping concepts"""
|
||||
|
||||
def test_sa100_field_mapping(self) -> None:
|
||||
"""Test SA100 form field mapping structure"""
|
||||
# Test the concept of SA100 field mapping
|
||||
# In a real implementation, this would test actual field mapping logic
|
||||
|
||||
sa100_fields = {
|
||||
# Personal details
|
||||
"1.1": "forename",
|
||||
"1.2": "surname",
|
||||
"1.3": "date_of_birth",
|
||||
"1.4": "national_insurance_number",
|
||||
# Income summary
|
||||
"2.1": "total_income_from_employment",
|
||||
"2.2": "total_income_from_self_employment",
|
||||
"2.3": "total_income_from_property",
|
||||
"2.4": "total_income_from_savings",
|
||||
# Tax calculation
|
||||
"3.1": "total_income_tax_due",
|
||||
"3.2": "total_national_insurance_due",
|
||||
"3.3": "total_tax_and_ni_due",
|
||||
}
|
||||
|
||||
# Test field mapping structure
|
||||
for box_number, field_name in sa100_fields.items():
|
||||
assert isinstance(box_number, str)
|
||||
assert "." in box_number # Box numbers have section.item format
|
||||
assert isinstance(field_name, str)
|
||||
assert len(field_name) > 0
|
||||
|
||||
def test_sa103_field_mapping(self) -> None:
|
||||
"""Test SA103 (self-employment) field mapping structure"""
|
||||
sa103_fields = {
|
||||
# Business details
|
||||
"3.1": "business_name",
|
||||
"3.2": "business_description",
|
||||
"3.3": "business_address",
|
||||
"3.4": "accounting_period_start",
|
||||
"3.5": "accounting_period_end",
|
||||
# Income
|
||||
"3.11": "turnover",
|
||||
"3.12": "other_business_income",
|
||||
# Expenses
|
||||
"3.13": "cost_of_goods_sold",
|
||||
"3.14": "construction_industry_subcontractor_costs",
|
||||
"3.15": "other_direct_costs",
|
||||
"3.16": "employee_costs",
|
||||
"3.17": "premises_costs",
|
||||
"3.18": "repairs_and_renewals",
|
||||
"3.19": "general_administrative_expenses",
|
||||
"3.20": "motor_expenses",
|
||||
"3.21": "travel_and_subsistence",
|
||||
"3.22": "advertising_and_entertainment",
|
||||
"3.23": "legal_and_professional_costs",
|
||||
"3.24": "bad_debts",
|
||||
"3.25": "interest_and_alternative_finance_payments",
|
||||
"3.26": "other_finance_charges",
|
||||
"3.27": "depreciation_and_loss_on_disposal",
|
||||
"3.28": "other_business_expenses",
|
||||
# Profit calculation
|
||||
"3.29": "total_expenses",
|
||||
"3.30": "net_profit_or_loss",
|
||||
}
|
||||
|
||||
# Test field mapping structure
|
||||
for box_number, field_name in sa103_fields.items():
|
||||
assert isinstance(box_number, str)
|
||||
assert box_number.startswith("3.") # SA103 fields start with 3.
|
||||
assert isinstance(field_name, str)
|
||||
assert len(field_name) > 0
|
||||
|
||||
def test_currency_formatting(self) -> None:
|
||||
"""Test currency formatting for form fields"""
|
||||
# Test currency formatting concepts
|
||||
test_amounts = [
|
||||
(1234.56, "1,234.56"),
|
||||
(1000000.00, "1,000,000.00"),
|
||||
(0.50, "0.50"),
|
||||
(0.00, "0.00"),
|
||||
(999.99, "999.99"),
|
||||
]
|
||||
|
||||
for amount, expected_format in test_amounts:
|
||||
# Test that amounts can be formatted correctly
|
||||
formatted = f"{amount:,.2f}"
|
||||
assert formatted == expected_format
|
||||
|
||||
def test_date_formatting(self) -> None:
|
||||
"""Test date formatting for form fields"""
|
||||
# Test date formatting concepts
|
||||
test_dates = [
|
||||
("2024-04-05", "05/04/2024"), # UK date format
|
||||
("2023-12-31", "31/12/2023"),
|
||||
("2024-01-01", "01/01/2024"),
|
||||
]
|
||||
|
||||
for iso_date, expected_format in test_dates:
|
||||
# Test that dates can be formatted correctly for UK forms
|
||||
from datetime import datetime
|
||||
|
||||
date_obj = datetime.fromisoformat(iso_date)
|
||||
formatted = date_obj.strftime("%d/%m/%Y")
|
||||
assert formatted == expected_format
|
||||
|
||||
|
||||
class TestEvidencePackGeneration:
|
||||
"""Test evidence pack generation concepts"""
|
||||
|
||||
def test_evidence_pack_structure(self) -> None:
|
||||
"""Test evidence pack structure"""
|
||||
# Test the concept of evidence pack structure
|
||||
evidence_pack = {
|
||||
"taxpayer_id": "taxpayer_123",
|
||||
"tax_year": "2023-24",
|
||||
"generated_at": "2024-01-15T10:30:00Z",
|
||||
"documents": [
|
||||
{
|
||||
"type": "filled_form",
|
||||
"form_id": "SA100",
|
||||
"filename": "SA100_2023-24_taxpayer_123.pdf",
|
||||
"size_bytes": 245760,
|
||||
},
|
||||
{
|
||||
"type": "supporting_document",
|
||||
"document_type": "bank_statement",
|
||||
"filename": "bank_statement_jan_2024.pdf",
|
||||
"size_bytes": 512000,
|
||||
},
|
||||
{
|
||||
"type": "supporting_document",
|
||||
"document_type": "receipt",
|
||||
"filename": "office_supplies_receipt.pdf",
|
||||
"size_bytes": 128000,
|
||||
},
|
||||
],
|
||||
"total_size_bytes": 885760,
|
||||
"checksum": "sha256:abc123def456...",
|
||||
}
|
||||
|
||||
# Test evidence pack structure
|
||||
assert "taxpayer_id" in evidence_pack
|
||||
assert "tax_year" in evidence_pack
|
||||
assert "generated_at" in evidence_pack
|
||||
assert "documents" in evidence_pack
|
||||
assert "total_size_bytes" in evidence_pack
|
||||
assert "checksum" in evidence_pack
|
||||
|
||||
# Test documents structure
|
||||
for document in evidence_pack["documents"]:
|
||||
assert "type" in document
|
||||
assert "filename" in document
|
||||
assert "size_bytes" in document
|
||||
|
||||
def test_evidence_pack_validation(self) -> None:
|
||||
"""Test evidence pack validation concepts"""
|
||||
# Test validation rules for evidence packs
|
||||
validation_rules = {
|
||||
"max_total_size_mb": 100, # 100MB limit
|
||||
"max_documents": 50, # Maximum 50 documents
|
||||
"allowed_document_types": [
|
||||
"filled_form",
|
||||
"supporting_document",
|
||||
"calculation_summary",
|
||||
"audit_trail",
|
||||
],
|
||||
"required_forms": ["SA100"], # SA100 is always required
|
||||
"supported_file_formats": [".pdf", ".jpg", ".png"],
|
||||
}
|
||||
|
||||
# Test validation rule structure
|
||||
assert isinstance(validation_rules["max_total_size_mb"], int)
|
||||
assert isinstance(validation_rules["max_documents"], int)
|
||||
assert isinstance(validation_rules["allowed_document_types"], list)
|
||||
assert isinstance(validation_rules["required_forms"], list)
|
||||
assert isinstance(validation_rules["supported_file_formats"], list)
|
||||
|
||||
# Test that SA100 is required
|
||||
assert "SA100" in validation_rules["required_forms"]
|
||||
|
||||
# Test that PDF is supported
|
||||
assert ".pdf" in validation_rules["supported_file_formats"]
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Test health check endpoint"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_endpoint(self) -> None:
|
||||
"""Test health check endpoint returns correct data"""
|
||||
from apps.svc_forms.main import health_check
|
||||
|
||||
result = await health_check()
|
||||
|
||||
assert result["status"] == "healthy"
|
||||
assert result["service"] == "svc-forms"
|
||||
assert "timestamp" in result
|
||||
assert "supported_forms" in result
|
||||
assert isinstance(result["supported_forms"], list)
|
||||
|
||||
|
||||
class TestFormFilling:
|
||||
"""Test form filling functionality"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fill_form_async_sa100(self) -> None:
|
||||
"""Test async form filling for SA100"""
|
||||
from apps.svc_forms.main import _fill_form_async
|
||||
|
||||
form_id = "SA100"
|
||||
field_values = {
|
||||
"taxpayer_name": "John Smith",
|
||||
"nino": "AB123456C",
|
||||
"total_income": "50000.00",
|
||||
}
|
||||
tenant_id = "tenant1"
|
||||
filling_id = "FILL123"
|
||||
actor = "user1"
|
||||
|
||||
with (
|
||||
patch("apps.svc_forms.main.pdf_form_filler") as mock_pdf_filler,
|
||||
patch("apps.svc_forms.main.storage_client") as mock_storage,
|
||||
patch("apps.svc_forms.main.event_bus") as mock_event_bus,
|
||||
patch("apps.svc_forms.main.metrics") as mock_metrics,
|
||||
):
|
||||
|
||||
# Mock PDF form filler
|
||||
mock_pdf_filler.fill_form.return_value = b"mock_filled_pdf_content"
|
||||
|
||||
# Mock storage operations (async)
|
||||
mock_storage.put_object = AsyncMock(return_value=True)
|
||||
mock_event_bus.publish = AsyncMock(return_value=None)
|
||||
|
||||
# Mock metrics
|
||||
mock_counter = Mock()
|
||||
mock_counter.labels.return_value = mock_counter
|
||||
mock_counter.inc.return_value = None
|
||||
mock_metrics.counter.return_value = mock_counter
|
||||
|
||||
# Call the function
|
||||
await _fill_form_async(form_id, field_values, tenant_id, filling_id, actor)
|
||||
|
||||
# Verify operations were called
|
||||
mock_pdf_filler.fill_form.assert_called_once_with(form_id, field_values)
|
||||
mock_storage.put_object.assert_called()
|
||||
mock_event_bus.publish.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fill_form_async_error_handling(self) -> None:
|
||||
"""Test error handling in async form filling"""
|
||||
from apps.svc_forms.main import _fill_form_async
|
||||
|
||||
form_id = "SA100"
|
||||
field_values = {"taxpayer_name": "John Smith"}
|
||||
tenant_id = "tenant1"
|
||||
filling_id = "FILL123"
|
||||
actor = "user1"
|
||||
|
||||
with (
|
||||
patch("apps.svc_forms.main.pdf_form_filler") as mock_pdf_filler,
|
||||
patch("apps.svc_forms.main.event_bus") as mock_event_bus,
|
||||
patch("apps.svc_forms.main.metrics") as mock_metrics,
|
||||
):
|
||||
|
||||
# Mock PDF processing to raise an error
|
||||
mock_pdf_filler.fill_form.side_effect = Exception("PDF processing failed")
|
||||
mock_event_bus.publish = AsyncMock(return_value=None)
|
||||
|
||||
# Mock metrics
|
||||
mock_counter = Mock()
|
||||
mock_counter.labels.return_value = mock_counter
|
||||
mock_counter.inc.return_value = None
|
||||
mock_metrics.counter.return_value = mock_counter
|
||||
|
||||
# Call the function - should not raise but log error and update metrics
|
||||
await _fill_form_async(form_id, field_values, tenant_id, filling_id, actor)
|
||||
|
||||
# Verify error metrics were updated
|
||||
mock_metrics.counter.assert_called_with("form_filling_errors_total")
|
||||
mock_counter.labels.assert_called_with(
|
||||
tenant_id=tenant_id, form_id=form_id, error_type="Exception"
|
||||
)
|
||||
mock_counter.inc.assert_called()
|
||||
|
||||
|
||||
class TestEvidencePackCreation:
|
||||
"""Test evidence pack creation functionality"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_evidence_pack_async(self) -> None:
|
||||
"""Test async evidence pack creation"""
|
||||
from apps.svc_forms.main import _create_evidence_pack_async
|
||||
|
||||
taxpayer_id = "TP123456"
|
||||
tax_year = "2023-24"
|
||||
scope = "full_submission"
|
||||
evidence_items = [
|
||||
{
|
||||
"type": "calculation",
|
||||
"calculation_id": "CALC123",
|
||||
"description": "Tax calculation for 2023-24",
|
||||
},
|
||||
{
|
||||
"type": "document",
|
||||
"document_id": "DOC456",
|
||||
"description": "P60 for 2023-24",
|
||||
},
|
||||
]
|
||||
tenant_id = "tenant1"
|
||||
pack_id = "PACK123"
|
||||
actor = "user1"
|
||||
|
||||
with (
|
||||
patch("apps.svc_forms.main.evidence_pack_generator") as mock_evidence_gen,
|
||||
patch("apps.svc_forms.main.storage_client") as mock_storage,
|
||||
patch("apps.svc_forms.main.event_bus") as mock_event_bus,
|
||||
patch("apps.svc_forms.main.metrics") as mock_metrics,
|
||||
):
|
||||
|
||||
# Mock evidence pack generator
|
||||
mock_evidence_gen.create_evidence_pack = AsyncMock(
|
||||
return_value={
|
||||
"pack_size": 1024,
|
||||
"evidence_count": 2,
|
||||
"pack_data": b"mock_pack_data",
|
||||
}
|
||||
)
|
||||
|
||||
# Mock metrics
|
||||
mock_counter = Mock()
|
||||
mock_counter.labels.return_value = mock_counter
|
||||
mock_counter.inc.return_value = None
|
||||
mock_metrics.counter.return_value = mock_counter
|
||||
|
||||
# Call the function
|
||||
await _create_evidence_pack_async(
|
||||
taxpayer_id, tax_year, scope, evidence_items, tenant_id, pack_id, actor
|
||||
)
|
||||
|
||||
# Verify operations were called
|
||||
mock_evidence_gen.create_evidence_pack.assert_called_once_with(
|
||||
taxpayer_id=taxpayer_id,
|
||||
tax_year=tax_year,
|
||||
scope=scope,
|
||||
evidence_items=evidence_items,
|
||||
)
|
||||
mock_metrics.counter.assert_called_with("evidence_packs_created_total")
|
||||
mock_counter.labels.assert_called_with(tenant_id=tenant_id, scope=scope)
|
||||
mock_counter.inc.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_evidence_pack_async_error_handling(self) -> None:
|
||||
"""Test error handling in async evidence pack creation"""
|
||||
from apps.svc_forms.main import _create_evidence_pack_async
|
||||
|
||||
taxpayer_id = "TP123456"
|
||||
tax_year = "2023-24"
|
||||
scope = "full_submission"
|
||||
evidence_items = [{"type": "calculation", "calculation_id": "CALC123"}]
|
||||
tenant_id = "tenant1"
|
||||
pack_id = "PACK123"
|
||||
actor = "user1"
|
||||
|
||||
with (
|
||||
patch("apps.svc_forms.main.evidence_pack_generator") as mock_evidence_gen,
|
||||
patch("apps.svc_forms.main.event_bus") as mock_event_bus,
|
||||
):
|
||||
|
||||
# Mock evidence pack generator to raise an error
|
||||
mock_evidence_gen.create_evidence_pack = AsyncMock(
|
||||
side_effect=Exception("Evidence pack creation failed")
|
||||
)
|
||||
mock_event_bus.publish = AsyncMock(return_value=None)
|
||||
|
||||
# Call the function - should not raise but log error
|
||||
await _create_evidence_pack_async(
|
||||
taxpayer_id, tax_year, scope, evidence_items, tenant_id, pack_id, actor
|
||||
)
|
||||
|
||||
# Verify evidence pack generator was called and failed
|
||||
mock_evidence_gen.create_evidence_pack.assert_called_once_with(
|
||||
taxpayer_id=taxpayer_id,
|
||||
tax_year=tax_year,
|
||||
scope=scope,
|
||||
evidence_items=evidence_items,
|
||||
)
|
||||
|
||||
|
||||
class TestEventHandling:
|
||||
"""Test event handling functionality"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_calculation_ready(self) -> None:
|
||||
"""Test handling calculation ready events"""
|
||||
from apps.svc_forms.main import _handle_calculation_ready
|
||||
from libs.events import EventPayload
|
||||
|
||||
# Create mock event payload
|
||||
payload = EventPayload(
|
||||
actor="user1",
|
||||
tenant_id="tenant1",
|
||||
data={
|
||||
"calculation_id": "CALC123",
|
||||
"schedule": "SA100",
|
||||
"taxpayer_id": "TP123",
|
||||
"tenant_id": "tenant1",
|
||||
"actor": "user1",
|
||||
},
|
||||
)
|
||||
|
||||
with patch("apps.svc_forms.main.BackgroundTasks") as mock_bg_tasks:
|
||||
mock_bg_tasks.return_value = Mock()
|
||||
|
||||
# Call the function
|
||||
await _handle_calculation_ready("calculation_ready", payload)
|
||||
|
||||
# Should not raise an error
|
||||
assert True # If we get here, the function completed successfully
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_calculation_ready_missing_data(self) -> None:
|
||||
"""Test handling calculation ready events with missing data"""
|
||||
from apps.svc_forms.main import _handle_calculation_ready
|
||||
from libs.events import EventPayload
|
||||
|
||||
# Create mock event payload with missing data
|
||||
payload = EventPayload(
|
||||
data={}, # Missing required fields
|
||||
actor="test_user",
|
||||
tenant_id="tenant1",
|
||||
)
|
||||
|
||||
# Call the function - should handle gracefully
|
||||
await _handle_calculation_ready("calculation_ready", payload)
|
||||
|
||||
# Should not raise an error
|
||||
assert True
|
||||
|
||||
|
||||
class TestHealthEndpoints:
|
||||
"""Test health check endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_endpoint(self) -> None:
|
||||
"""Test health check endpoint"""
|
||||
from apps.svc_forms.main import health_check
|
||||
|
||||
result = await health_check()
|
||||
|
||||
assert result["status"] == "healthy"
|
||||
assert result["service"] == "svc-forms"
|
||||
assert "version" in result
|
||||
assert "timestamp" in result
|
||||
assert "supported_forms" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_supported_forms_endpoint(self) -> None:
|
||||
"""Test list supported forms endpoint"""
|
||||
from apps.svc_forms.main import list_supported_forms
|
||||
|
||||
# Mock dependencies
|
||||
current_user = {"user_id": "test_user"}
|
||||
tenant_id = "test_tenant"
|
||||
|
||||
result = await list_supported_forms(current_user, tenant_id)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "supported_forms" in result
|
||||
assert isinstance(result["supported_forms"], list)
|
||||
assert "total_forms" in result
|
||||
|
||||
|
||||
class TestFormValidation:
|
||||
"""Test form validation business logic"""
|
||||
|
||||
def test_supported_form_validation_sa100(self) -> None:
|
||||
"""Test validation of supported SA100 form"""
|
||||
from apps.svc_forms.main import settings
|
||||
|
||||
form_id = "SA100"
|
||||
|
||||
# Test that SA100 is in supported forms
|
||||
assert form_id in settings.supported_forms
|
||||
|
||||
# Test form validation logic
|
||||
is_supported = form_id in settings.supported_forms
|
||||
assert is_supported is True
|
||||
|
||||
def test_supported_form_validation_invalid(self) -> None:
|
||||
"""Test validation of unsupported form"""
|
||||
from apps.svc_forms.main import settings
|
||||
|
||||
form_id = "INVALID_FORM"
|
||||
|
||||
# Test that invalid form is not supported
|
||||
is_supported = form_id in settings.supported_forms
|
||||
assert is_supported is False
|
||||
|
||||
def test_field_values_processing_basic(self) -> None:
|
||||
"""Test basic field values processing"""
|
||||
field_values = {
|
||||
"taxpayer_name": "John Smith",
|
||||
"nino": "AB123456C",
|
||||
"total_income": "50000.00",
|
||||
"box_1": "25000",
|
||||
"box_2": "15000",
|
||||
}
|
||||
|
||||
# Test field count
|
||||
assert len(field_values) == 5
|
||||
|
||||
# Test field types
|
||||
assert isinstance(field_values["taxpayer_name"], str)
|
||||
assert isinstance(field_values["total_income"], str)
|
||||
|
||||
# Test box field processing
|
||||
box_fields = {k: v for k, v in field_values.items() if k.startswith("box_")}
|
||||
assert len(box_fields) == 2
|
||||
assert "box_1" in box_fields
|
||||
assert "box_2" in box_fields
|
||||
|
||||
def test_form_boxes_to_field_values_conversion(self) -> None:
|
||||
"""Test conversion from form boxes to field values"""
|
||||
form_boxes = {
|
||||
"1": {"value": 50000, "description": "Total income"},
|
||||
"2": {"value": 5000, "description": "Tax deducted"},
|
||||
"3": {"value": 2000, "description": "Other income"},
|
||||
}
|
||||
|
||||
# Convert to field values format
|
||||
field_values = {}
|
||||
for box_id, box_data in form_boxes.items():
|
||||
field_values[f"box_{box_id}"] = box_data["value"]
|
||||
|
||||
# Test conversion
|
||||
assert len(field_values) == 3
|
||||
assert field_values["box_1"] == 50000
|
||||
assert field_values["box_2"] == 5000
|
||||
assert field_values["box_3"] == 2000
|
||||
|
||||
|
||||
class TestEvidencePackLogic:
|
||||
"""Test evidence pack business logic"""
|
||||
|
||||
def test_evidence_items_validation_basic(self) -> None:
|
||||
"""Test basic evidence items validation"""
|
||||
evidence_items = [
|
||||
{
|
||||
"type": "calculation",
|
||||
"calculation_id": "CALC123",
|
||||
"description": "Tax calculation for 2023-24",
|
||||
},
|
||||
{
|
||||
"type": "document",
|
||||
"document_id": "DOC456",
|
||||
"description": "P60 for 2023-24",
|
||||
},
|
||||
]
|
||||
|
||||
# Test evidence items structure
|
||||
assert len(evidence_items) == 2
|
||||
|
||||
# Test first item
|
||||
calc_item = evidence_items[0]
|
||||
assert calc_item["type"] == "calculation"
|
||||
assert "calculation_id" in calc_item
|
||||
assert "description" in calc_item
|
||||
|
||||
# Test second item
|
||||
doc_item = evidence_items[1]
|
||||
assert doc_item["type"] == "document"
|
||||
assert "document_id" in doc_item
|
||||
assert "description" in doc_item
|
||||
|
||||
def test_evidence_pack_scope_validation(self) -> None:
|
||||
"""Test evidence pack scope validation"""
|
||||
valid_scopes = ["full_submission", "partial_submission", "supporting_docs"]
|
||||
|
||||
for scope in valid_scopes:
|
||||
# Test that scope is a valid string
|
||||
assert isinstance(scope, str)
|
||||
assert len(scope) > 0
|
||||
|
||||
# Test invalid scope
|
||||
invalid_scope = ""
|
||||
assert len(invalid_scope) == 0
|
||||
|
||||
def test_taxpayer_id_validation(self) -> None:
|
||||
"""Test taxpayer ID validation"""
|
||||
valid_taxpayer_ids = ["TP123456", "TAXPAYER_001", "12345678"]
|
||||
|
||||
for taxpayer_id in valid_taxpayer_ids:
|
||||
# Test basic validation
|
||||
assert isinstance(taxpayer_id, str)
|
||||
assert len(taxpayer_id) > 0
|
||||
assert taxpayer_id.strip() == taxpayer_id # No leading/trailing spaces
|
||||
|
||||
def test_tax_year_format_validation(self) -> None:
|
||||
"""Test tax year format validation"""
|
||||
valid_tax_years = ["2023-24", "2022-23", "2021-22"]
|
||||
|
||||
for tax_year in valid_tax_years:
|
||||
# Test format
|
||||
assert isinstance(tax_year, str)
|
||||
assert len(tax_year) == 7 # Format: YYYY-YY
|
||||
assert "-" in tax_year
|
||||
|
||||
# Test year parts
|
||||
parts = tax_year.split("-")
|
||||
assert len(parts) == 2
|
||||
assert len(parts[0]) == 4 # Full year
|
||||
assert len(parts[1]) == 2 # Short year
|
||||
|
||||
|
||||
class TestFormFillingLogic:
|
||||
"""Test form filling business logic"""
|
||||
|
||||
def test_filling_id_generation_format(self) -> None:
|
||||
"""Test filling ID generation format"""
|
||||
import ulid
|
||||
|
||||
# Generate filling ID like the service does
|
||||
filling_id = str(ulid.new())
|
||||
|
||||
# Test format
|
||||
assert isinstance(filling_id, str)
|
||||
assert len(filling_id) == 26 # ULID length
|
||||
|
||||
# Test uniqueness
|
||||
filling_id2 = str(ulid.new())
|
||||
assert filling_id != filling_id2
|
||||
|
||||
def test_object_key_generation(self) -> None:
|
||||
"""Test S3 object key generation"""
|
||||
tenant_id = "tenant123"
|
||||
filling_id = "01HKQM7XQZX8QZQZQZQZQZQZQZ"
|
||||
|
||||
# Generate object key like the service does
|
||||
object_key = f"tenants/{tenant_id}/filled/{filling_id}.pdf"
|
||||
|
||||
# Test format
|
||||
assert object_key == "tenants/tenant123/filled/01HKQM7XQZX8QZQZQZQZQZQZQZ.pdf"
|
||||
assert object_key.startswith("tenants/")
|
||||
assert object_key.endswith(".pdf")
|
||||
assert tenant_id in object_key
|
||||
assert filling_id in object_key
|
||||
|
||||
def test_form_metadata_generation(self) -> None:
|
||||
"""Test form metadata generation"""
|
||||
from datetime import datetime
|
||||
|
||||
form_id = "SA100"
|
||||
filling_id = "FILL123"
|
||||
tenant_id = "tenant1"
|
||||
calculation_id = "CALC456"
|
||||
|
||||
# Generate metadata like the service does
|
||||
metadata = {
|
||||
"form_id": form_id,
|
||||
"filling_id": filling_id,
|
||||
"tenant_id": tenant_id,
|
||||
"calculation_id": calculation_id or "",
|
||||
"filled_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
# Test metadata structure
|
||||
assert "form_id" in metadata
|
||||
assert "filling_id" in metadata
|
||||
assert "tenant_id" in metadata
|
||||
assert "calculation_id" in metadata
|
||||
assert "filled_at" in metadata
|
||||
|
||||
# Test values
|
||||
assert metadata["form_id"] == form_id
|
||||
assert metadata["filling_id"] == filling_id
|
||||
assert metadata["tenant_id"] == tenant_id
|
||||
assert metadata["calculation_id"] == calculation_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
348
tests/unit/test_kg.py
Normal file
348
tests/unit/test_kg.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
Unit tests for svc-kg service
|
||||
Tests actual business logic: Neo4j operations, SHACL validation,
|
||||
bitemporal data handling, and RDF export
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the project root to the path so we can import from apps
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
# Import the actual service code
|
||||
from apps.svc_kg.main import KGSettings, _is_safe_query, _validate_node
|
||||
|
||||
# pylint: disable=wrong-import-position,import-error,too-few-public-methods
|
||||
# pylint: disable=global-statement,raise-missing-from,unused-argument
|
||||
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
||||
# pylint: disable=too-many-locals,import-outside-toplevel
|
||||
# mypy: disable-error-code=union-attr
|
||||
|
||||
|
||||
class TestKGSettings:
|
||||
"""Test KGSettings configuration"""
|
||||
|
||||
def test_default_settings(self) -> None:
|
||||
"""Test default KGSettings values"""
|
||||
settings = KGSettings()
|
||||
|
||||
# Test service configuration
|
||||
assert settings.service_name == "svc-kg"
|
||||
|
||||
# Test query limits
|
||||
assert settings.max_results == 1000
|
||||
assert settings.max_depth == 10
|
||||
assert settings.query_timeout == 30
|
||||
|
||||
# Test validation configuration
|
||||
assert settings.validate_on_write is True
|
||||
assert settings.shapes_file == "schemas/shapes.ttl"
|
||||
|
||||
def test_custom_settings(self) -> None:
|
||||
"""Test custom KGSettings values"""
|
||||
custom_settings = KGSettings(
|
||||
max_results=500,
|
||||
max_depth=5,
|
||||
query_timeout=60,
|
||||
validate_on_write=False,
|
||||
shapes_file="custom/shapes.ttl",
|
||||
)
|
||||
|
||||
assert custom_settings.max_results == 500
|
||||
assert custom_settings.max_depth == 5
|
||||
assert custom_settings.query_timeout == 60
|
||||
assert custom_settings.validate_on_write is False
|
||||
assert custom_settings.shapes_file == "custom/shapes.ttl"
|
||||
|
||||
|
||||
class TestQuerySafety:
|
||||
"""Test query safety validation"""
|
||||
|
||||
def test_safe_queries(self) -> None:
|
||||
"""Test queries that should be considered safe"""
|
||||
safe_queries = [
|
||||
"MATCH (n:Person) RETURN n",
|
||||
"MATCH (n:Company) WHERE n.name = 'ACME' RETURN n",
|
||||
"MATCH (p:Person)-[:WORKS_FOR]->(c:Company) RETURN p, c",
|
||||
"CREATE (n:Person {name: 'John', age: 30})",
|
||||
"MERGE (n:Company {name: 'ACME'}) RETURN n",
|
||||
"MATCH (n:Person) SET n.updated = timestamp() RETURN n",
|
||||
]
|
||||
|
||||
for query in safe_queries:
|
||||
assert _is_safe_query(query), f"Query should be safe: {query}"
|
||||
|
||||
def test_unsafe_queries(self) -> None:
|
||||
"""Test queries that should be considered unsafe"""
|
||||
unsafe_queries = [
|
||||
"MATCH (n) DELETE n", # Delete all nodes
|
||||
"DROP INDEX ON :Person(name)", # Schema modification
|
||||
"CREATE INDEX ON :Person(name)", # Schema modification
|
||||
"CALL db.schema.visualization()", # System procedure
|
||||
"CALL apoc.export.json.all('file.json', {})", # APOC procedure
|
||||
"LOAD CSV FROM 'file:///etc/passwd' AS line RETURN line", # File access
|
||||
"CALL dbms.procedures()", # System information
|
||||
"MATCH (n) DETACH DELETE n", # Delete all nodes and relationships
|
||||
]
|
||||
|
||||
for query in unsafe_queries:
|
||||
assert not _is_safe_query(query), f"Query should be unsafe: {query}"
|
||||
|
||||
def test_query_safety_case_insensitive(self) -> None:
|
||||
"""Test query safety is case insensitive"""
|
||||
unsafe_queries = [
|
||||
"match (n) delete n",
|
||||
"MATCH (N) DELETE N",
|
||||
"Match (n) Delete n",
|
||||
"drop index on :Person(name)",
|
||||
"DROP INDEX ON :PERSON(NAME)",
|
||||
]
|
||||
|
||||
for query in unsafe_queries:
|
||||
assert not _is_safe_query(query), f"Query should be unsafe: {query}"
|
||||
|
||||
def test_query_safety_with_comments(self) -> None:
|
||||
"""Test query safety with comments"""
|
||||
queries_with_comments = [
|
||||
"// This is a comment\nMATCH (n:Person) RETURN n",
|
||||
"/* Multi-line comment */\nMATCH (n:Person) RETURN n",
|
||||
"MATCH (n:Person) RETURN n // End comment",
|
||||
]
|
||||
|
||||
for query in queries_with_comments:
|
||||
# Comments don't affect safety - depends on actual query
|
||||
result = _is_safe_query(query)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
class TestNodeValidation:
|
||||
"""Test SHACL node validation"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_with_validator(self) -> None:
|
||||
"""Test node validation when SHACL validator is available"""
|
||||
# Mock the SHACL validator
|
||||
with patch("apps.svc_kg.main.shacl_validator") as mock_validator:
|
||||
mock_validator.validate_graph = AsyncMock(
|
||||
return_value={
|
||||
"conforms": True,
|
||||
"violations_count": 0,
|
||||
"results_text": "",
|
||||
}
|
||||
)
|
||||
|
||||
properties = {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||
|
||||
result = await _validate_node("Person", properties)
|
||||
assert result is True
|
||||
|
||||
# Verify validator was called
|
||||
mock_validator.validate_graph.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_validation_failure(self) -> None:
|
||||
"""Test node validation failure"""
|
||||
# Mock the SHACL validator to return validation errors
|
||||
with patch("apps.svc_kg.main.shacl_validator") as mock_validator:
|
||||
mock_validator.validate_graph = AsyncMock(
|
||||
return_value={
|
||||
"conforms": False,
|
||||
"violations_count": 1,
|
||||
"results_text": "Name is required",
|
||||
}
|
||||
)
|
||||
|
||||
properties = {"age": 30} # Missing required name
|
||||
|
||||
result = await _validate_node("Person", properties)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_no_validator(self) -> None:
|
||||
"""Test node validation when no SHACL validator is available"""
|
||||
# Mock no validator available
|
||||
with patch("apps.svc_kg.main.shacl_validator", None):
|
||||
properties = {"name": "John Doe", "age": 30}
|
||||
|
||||
result = await _validate_node("Person", properties)
|
||||
# Should return True when no validator is available
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_node_validator_exception(self) -> None:
|
||||
"""Test node validation when validator raises exception"""
|
||||
# Mock the SHACL validator to raise an exception
|
||||
with patch("apps.svc_kg.main.shacl_validator") as mock_validator:
|
||||
mock_validator.validate_graph = AsyncMock(
|
||||
side_effect=Exception("Validation error")
|
||||
)
|
||||
|
||||
properties = {"name": "John Doe", "age": 30}
|
||||
|
||||
result = await _validate_node("Person", properties)
|
||||
# Should return True when validation fails with exception (to not block operations)
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestBitemporalDataHandling:
|
||||
"""Test bitemporal data handling concepts"""
|
||||
|
||||
def test_bitemporal_properties(self) -> None:
|
||||
"""Test bitemporal property structure"""
|
||||
# Test the concept of bitemporal properties
|
||||
# In a real implementation, this would test actual bitemporal logic
|
||||
|
||||
# Valid time: when the fact was true in reality
|
||||
# Transaction time: when the fact was recorded in the database
|
||||
|
||||
bitemporal_properties = {
|
||||
"name": "John Doe",
|
||||
"valid_from": "2024-01-01T00:00:00Z",
|
||||
"valid_to": "9999-12-31T23:59:59Z", # Current/ongoing
|
||||
"transaction_from": "2024-01-15T10:30:00Z",
|
||||
"transaction_to": "9999-12-31T23:59:59Z", # Current version
|
||||
"retracted_at": None, # Not retracted
|
||||
}
|
||||
|
||||
# Test required bitemporal fields are present
|
||||
assert "valid_from" in bitemporal_properties
|
||||
assert "valid_to" in bitemporal_properties
|
||||
assert "transaction_from" in bitemporal_properties
|
||||
assert "transaction_to" in bitemporal_properties
|
||||
assert "retracted_at" in bitemporal_properties
|
||||
|
||||
# Test that current version has future end times
|
||||
assert bitemporal_properties["valid_to"] == "9999-12-31T23:59:59Z"
|
||||
assert bitemporal_properties["transaction_to"] == "9999-12-31T23:59:59Z"
|
||||
assert bitemporal_properties["retracted_at"] is None
|
||||
|
||||
def test_retracted_properties(self) -> None:
|
||||
"""Test retracted bitemporal properties"""
|
||||
retracted_properties = {
|
||||
"name": "John Doe",
|
||||
"valid_from": "2024-01-01T00:00:00Z",
|
||||
"valid_to": "2024-06-30T23:59:59Z", # No longer valid
|
||||
"transaction_from": "2024-01-15T10:30:00Z",
|
||||
"transaction_to": "2024-07-01T09:00:00Z", # Superseded
|
||||
"retracted_at": "2024-07-01T09:00:00Z", # Retracted
|
||||
}
|
||||
|
||||
# Test retracted properties
|
||||
assert retracted_properties["retracted_at"] is not None
|
||||
assert retracted_properties["valid_to"] != "9999-12-31T23:59:59Z"
|
||||
assert retracted_properties["transaction_to"] != "9999-12-31T23:59:59Z"
|
||||
|
||||
|
||||
class TestRDFExportConcepts:
|
||||
"""Test RDF export format concepts"""
|
||||
|
||||
def test_supported_rdf_formats(self) -> None:
|
||||
"""Test supported RDF formats concepts"""
|
||||
# Test RDF format concepts (not actual implementation)
|
||||
supported_formats = ["turtle", "rdf/xml", "n-triples", "json-ld"]
|
||||
|
||||
# Test that common RDF formats are supported
|
||||
assert "turtle" in supported_formats
|
||||
assert "rdf/xml" in supported_formats
|
||||
assert "n-triples" in supported_formats
|
||||
assert "json-ld" in supported_formats
|
||||
|
||||
def test_rdf_format_validation(self) -> None:
|
||||
"""Test RDF format validation logic concepts"""
|
||||
valid_formats = ["turtle", "rdf/xml", "n-triples", "json-ld"]
|
||||
|
||||
# Test format validation concepts
|
||||
for format_name in valid_formats:
|
||||
assert format_name in valid_formats
|
||||
|
||||
# Test invalid formats
|
||||
invalid_formats = ["invalid", "xml", "json", "yaml"]
|
||||
for invalid_format in invalid_formats:
|
||||
assert invalid_format not in valid_formats
|
||||
|
||||
|
||||
class TestKnowledgeGraphConcepts:
|
||||
"""Test knowledge graph concepts and patterns"""
|
||||
|
||||
def test_entity_relationship_patterns(self) -> None:
|
||||
"""Test common entity-relationship patterns"""
|
||||
# Test typical tax domain entities and relationships
|
||||
|
||||
# Person entity
|
||||
person_properties = {
|
||||
"id": "person_123",
|
||||
"name": "John Doe",
|
||||
"type": "Individual",
|
||||
"utr": "1234567890",
|
||||
"nino": "AB123456C",
|
||||
}
|
||||
|
||||
# Company entity
|
||||
company_properties = {
|
||||
"id": "company_456",
|
||||
"name": "ACME Corp Ltd",
|
||||
"type": "Company",
|
||||
"company_number": "12345678",
|
||||
"utr": "0987654321",
|
||||
}
|
||||
|
||||
# Income entity
|
||||
income_properties = {
|
||||
"id": "income_789",
|
||||
"amount": 50000.0,
|
||||
"currency": "GBP",
|
||||
"tax_year": "2023-24",
|
||||
"type": "employment_income",
|
||||
}
|
||||
|
||||
# Test entity structure
|
||||
for entity in [person_properties, company_properties, income_properties]:
|
||||
assert "id" in entity
|
||||
assert "type" in entity
|
||||
|
||||
# Test relationship concepts
|
||||
relationships = [
|
||||
{"from": "person_123", "to": "company_456", "type": "EMPLOYED_BY"},
|
||||
{"from": "person_123", "to": "income_789", "type": "RECEIVES"},
|
||||
{"from": "income_789", "to": "company_456", "type": "PAID_BY"},
|
||||
]
|
||||
|
||||
for relationship in relationships:
|
||||
assert "from" in relationship
|
||||
assert "to" in relationship
|
||||
assert "type" in relationship
|
||||
|
||||
def test_tax_domain_entities(self) -> None:
|
||||
"""Test tax domain specific entities"""
|
||||
tax_entities = {
|
||||
"TaxpayerProfile": {
|
||||
"required_fields": ["utr", "name", "tax_year"],
|
||||
"optional_fields": ["nino", "address", "phone"],
|
||||
},
|
||||
"IncomeItem": {
|
||||
"required_fields": ["amount", "currency", "tax_year", "source"],
|
||||
"optional_fields": ["description", "date_received"],
|
||||
},
|
||||
"ExpenseItem": {
|
||||
"required_fields": ["amount", "currency", "category", "tax_year"],
|
||||
"optional_fields": ["description", "receipt_reference"],
|
||||
},
|
||||
"TaxCalculation": {
|
||||
"required_fields": ["tax_year", "total_income", "total_tax"],
|
||||
"optional_fields": ["allowances", "reliefs", "schedule"],
|
||||
},
|
||||
}
|
||||
|
||||
# Test that each entity type has required structure
|
||||
for entity_type, schema in tax_entities.items():
|
||||
assert "required_fields" in schema
|
||||
assert "optional_fields" in schema
|
||||
assert len(schema["required_fields"]) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
271
tests/unit/test_nats_bus.py
Normal file
271
tests/unit/test_nats_bus.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Tests for NATS event bus implementation."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.events.base import EventPayload
|
||||
from libs.events.nats_bus import NATSEventBus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_payload():
|
||||
"""Create a test event payload."""
|
||||
return EventPayload(
|
||||
data={"test": "data", "value": 123},
|
||||
actor="test-user",
|
||||
tenant_id="test-tenant",
|
||||
trace_id="test-trace-123",
|
||||
schema_version="1.0",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nats_bus():
|
||||
"""Create a NATS event bus instance."""
|
||||
return NATSEventBus(
|
||||
servers="nats://localhost:4222",
|
||||
stream_name="TEST_STREAM",
|
||||
consumer_group="test-group",
|
||||
)
|
||||
|
||||
|
||||
class TestNATSEventBus:
|
||||
"""Test cases for NATS event bus."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization(self, nats_bus):
|
||||
"""Test NATS event bus initialization."""
|
||||
assert nats_bus.servers == ["nats://localhost:4222"]
|
||||
assert nats_bus.stream_name == "TEST_STREAM"
|
||||
assert nats_bus.consumer_group == "test-group"
|
||||
assert not nats_bus.running
|
||||
assert nats_bus.nc is None
|
||||
assert nats_bus.js is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialization_with_multiple_servers(self):
|
||||
"""Test NATS event bus initialization with multiple servers."""
|
||||
servers = ["nats://server1:4222", "nats://server2:4222"]
|
||||
bus = NATSEventBus(servers=servers)
|
||||
assert bus.servers == servers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("libs.events.nats_bus.nats.connect")
|
||||
async def test_start(self, mock_connect, nats_bus):
|
||||
"""Test starting the NATS event bus."""
|
||||
# Mock NATS connection and JetStream
|
||||
mock_nc = AsyncMock()
|
||||
mock_js = AsyncMock()
|
||||
mock_nc.jetstream.return_value = mock_js
|
||||
mock_connect.return_value = mock_nc
|
||||
|
||||
# Mock stream info to simulate existing stream
|
||||
mock_js.stream_info.return_value = {"name": "TEST_STREAM"}
|
||||
|
||||
await nats_bus.start()
|
||||
|
||||
assert nats_bus.running
|
||||
assert nats_bus.nc == mock_nc
|
||||
assert nats_bus.js == mock_js
|
||||
mock_connect.assert_called_once_with(servers=["nats://localhost:4222"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("libs.events.nats_bus.nats.connect")
|
||||
async def test_start_creates_stream_if_not_exists(self, mock_connect, nats_bus):
|
||||
"""Test that start creates stream if it doesn't exist."""
|
||||
# Mock NATS connection and JetStream
|
||||
mock_nc = AsyncMock()
|
||||
mock_js = AsyncMock()
|
||||
mock_nc.jetstream.return_value = mock_js
|
||||
mock_connect.return_value = mock_nc
|
||||
|
||||
# Mock stream_info to raise NotFoundError, then add_stream
|
||||
from nats.js.errors import NotFoundError
|
||||
mock_js.stream_info.side_effect = NotFoundError
|
||||
mock_js.add_stream = AsyncMock()
|
||||
|
||||
await nats_bus.start()
|
||||
|
||||
mock_js.add_stream.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_already_running(self, nats_bus):
|
||||
"""Test that start does nothing if already running."""
|
||||
nats_bus.running = True
|
||||
original_nc = nats_bus.nc
|
||||
|
||||
await nats_bus.start()
|
||||
|
||||
assert nats_bus.nc == original_nc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop(self, nats_bus):
|
||||
"""Test stopping the NATS event bus."""
|
||||
# Setup mock objects
|
||||
mock_nc = AsyncMock()
|
||||
mock_subscription = AsyncMock()
|
||||
mock_task = AsyncMock()
|
||||
|
||||
nats_bus.running = True
|
||||
nats_bus.nc = mock_nc
|
||||
nats_bus.subscriptions = {"test-topic": mock_subscription}
|
||||
nats_bus.consumer_tasks = [mock_task]
|
||||
|
||||
await nats_bus.stop()
|
||||
|
||||
assert not nats_bus.running
|
||||
mock_task.cancel.assert_called_once()
|
||||
mock_subscription.unsubscribe.assert_called_once()
|
||||
mock_nc.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_not_running(self, nats_bus):
|
||||
"""Test that stop does nothing if not running."""
|
||||
assert not nats_bus.running
|
||||
await nats_bus.stop()
|
||||
assert not nats_bus.running
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish(self, nats_bus, event_payload):
|
||||
"""Test publishing an event."""
|
||||
# Setup mock JetStream
|
||||
mock_js = AsyncMock()
|
||||
mock_ack = MagicMock()
|
||||
mock_ack.seq = 123
|
||||
mock_js.publish.return_value = mock_ack
|
||||
nats_bus.js = mock_js
|
||||
|
||||
result = await nats_bus.publish("test-topic", event_payload)
|
||||
|
||||
assert result is True
|
||||
mock_js.publish.assert_called_once()
|
||||
call_args = mock_js.publish.call_args
|
||||
assert call_args[1]["subject"] == "TEST_STREAM.test-topic"
|
||||
assert call_args[1]["payload"] == event_payload.to_json().encode()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_not_started(self, nats_bus, event_payload):
|
||||
"""Test publishing when event bus is not started."""
|
||||
with pytest.raises(RuntimeError, match="Event bus not started"):
|
||||
await nats_bus.publish("test-topic", event_payload)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_failure(self, nats_bus, event_payload):
|
||||
"""Test publishing failure."""
|
||||
# Setup mock JetStream that raises exception
|
||||
mock_js = AsyncMock()
|
||||
mock_js.publish.side_effect = Exception("Publish failed")
|
||||
nats_bus.js = mock_js
|
||||
|
||||
result = await nats_bus.publish("test-topic", event_payload)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe(self, nats_bus):
|
||||
"""Test subscribing to a topic."""
|
||||
# Setup mock JetStream
|
||||
mock_js = AsyncMock()
|
||||
mock_subscription = AsyncMock()
|
||||
mock_js.pull_subscribe.return_value = mock_subscription
|
||||
nats_bus.js = mock_js
|
||||
|
||||
# Mock handler
|
||||
async def test_handler(topic: str, payload: EventPayload) -> None:
|
||||
pass
|
||||
|
||||
with patch("asyncio.create_task") as mock_create_task:
|
||||
await nats_bus.subscribe("test-topic", test_handler)
|
||||
|
||||
assert "test-topic" in nats_bus.handlers
|
||||
assert test_handler in nats_bus.handlers["test-topic"]
|
||||
assert "test-topic" in nats_bus.subscriptions
|
||||
mock_js.pull_subscribe.assert_called_once()
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_not_started(self, nats_bus):
|
||||
"""Test subscribing when event bus is not started."""
|
||||
async def test_handler(topic: str, payload: EventPayload) -> None:
|
||||
pass
|
||||
|
||||
with pytest.raises(RuntimeError, match="Event bus not started"):
|
||||
await nats_bus.subscribe("test-topic", test_handler)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_multiple_handlers(self, nats_bus):
|
||||
"""Test subscribing multiple handlers to the same topic."""
|
||||
# Setup mock JetStream
|
||||
mock_js = AsyncMock()
|
||||
mock_subscription = AsyncMock()
|
||||
mock_js.pull_subscribe.return_value = mock_subscription
|
||||
nats_bus.js = mock_js
|
||||
|
||||
# Mock handlers
|
||||
async def handler1(topic: str, payload: EventPayload) -> None:
|
||||
pass
|
||||
|
||||
async def handler2(topic: str, payload: EventPayload) -> None:
|
||||
pass
|
||||
|
||||
with patch("asyncio.create_task"):
|
||||
await nats_bus.subscribe("test-topic", handler1)
|
||||
await nats_bus.subscribe("test-topic", handler2)
|
||||
|
||||
assert len(nats_bus.handlers["test-topic"]) == 2
|
||||
assert handler1 in nats_bus.handlers["test-topic"]
|
||||
assert handler2 in nats_bus.handlers["test-topic"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consume_messages(self, nats_bus, event_payload):
|
||||
"""Test consuming messages from NATS."""
|
||||
# Setup mock subscription and message
|
||||
mock_subscription = AsyncMock()
|
||||
mock_message = MagicMock()
|
||||
mock_message.data.decode.return_value = event_payload.to_json()
|
||||
mock_message.ack = AsyncMock()
|
||||
|
||||
mock_subscription.fetch.return_value = [mock_message]
|
||||
nats_bus.running = True
|
||||
|
||||
# Mock handler
|
||||
handler_called = False
|
||||
received_topic = None
|
||||
received_payload = None
|
||||
|
||||
async def test_handler(topic: str, payload: EventPayload) -> None:
|
||||
nonlocal handler_called, received_topic, received_payload
|
||||
handler_called = True
|
||||
received_topic = topic
|
||||
received_payload = payload
|
||||
|
||||
nats_bus.handlers["test-topic"] = [test_handler]
|
||||
|
||||
# Run one iteration of message consumption
|
||||
with patch.object(nats_bus, "running", side_effect=[True, False]):
|
||||
await nats_bus._consume_messages("test-topic", mock_subscription)
|
||||
|
||||
assert handler_called
|
||||
assert received_topic == "test-topic"
|
||||
assert received_payload.event_id == event_payload.event_id
|
||||
mock_message.ack.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_integration(self):
|
||||
"""Test that the factory can create a NATS event bus."""
|
||||
from libs.events.factory import create_event_bus
|
||||
|
||||
bus = create_event_bus(
|
||||
"nats",
|
||||
servers="nats://localhost:4222",
|
||||
stream_name="TEST_STREAM",
|
||||
consumer_group="test-group",
|
||||
)
|
||||
|
||||
assert isinstance(bus, NATSEventBus)
|
||||
assert bus.servers == ["nats://localhost:4222"]
|
||||
assert bus.stream_name == "TEST_STREAM"
|
||||
assert bus.consumer_group == "test-group"
|
||||
622
tests/unit/test_neo.py
Normal file
622
tests/unit/test_neo.py
Normal file
@@ -0,0 +1,622 @@
|
||||
# tests/unit/test_neo.py
|
||||
# Unit tests for libs/neo.py
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.neo import Neo4jClient, SHACLValidator, TemporalQueries
|
||||
|
||||
# pylint: disable=wrong-import-position,import-error,too-few-public-methods,global-statement
|
||||
# pylint: disable=raise-missing-from,unused-argument,too-many-arguments,too-many-positional-arguments
|
||||
# pylint: disable=too-many-locals,import-outside-toplevel
|
||||
# mypy: disable-error-code=union-attr
|
||||
# mypy: disable-error-code=no-untyped-def
|
||||
|
||||
|
||||
class TestNeo4jClient:
|
||||
"""Test Neo4jClient class"""
|
||||
|
||||
def test_neo4j_client_init(self):
|
||||
"""Test Neo4jClient initialization"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
assert client.driver == mock_driver
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(self):
|
||||
"""Test closing the driver"""
|
||||
mock_driver = Mock()
|
||||
mock_driver.close = Mock()
|
||||
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
mock_loop.run_in_executor = AsyncMock()
|
||||
|
||||
await client.close()
|
||||
|
||||
mock_loop.run_in_executor.assert_called_once_with(None, mock_driver.close)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_query_success(self):
|
||||
"""Test successful query execution"""
|
||||
mock_driver = Mock()
|
||||
mock_session = Mock()
|
||||
mock_result = Mock()
|
||||
mock_record = Mock()
|
||||
mock_record.data.return_value = {"name": "test", "value": 123}
|
||||
mock_result.__iter__ = Mock(return_value=iter([mock_record]))
|
||||
|
||||
mock_session.run.return_value = mock_result
|
||||
mock_driver.session.return_value.__enter__ = Mock(return_value=mock_session)
|
||||
mock_driver.session.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
return_value=[{"name": "test", "value": 123}]
|
||||
)
|
||||
|
||||
result = await client.run_query("MATCH (n) RETURN n", {"param": "value"})
|
||||
|
||||
assert result == [{"name": "test", "value": 123}]
|
||||
mock_loop.run_in_executor.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_query_with_retries(self):
|
||||
"""Test query execution with retries on transient errors"""
|
||||
from neo4j.exceptions import TransientError
|
||||
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
with (
|
||||
patch("asyncio.get_event_loop") as mock_get_loop,
|
||||
patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep,
|
||||
):
|
||||
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
# First two calls fail, third succeeds
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=[
|
||||
TransientError("Connection lost"),
|
||||
TransientError("Connection lost"),
|
||||
[{"result": "success"}],
|
||||
]
|
||||
)
|
||||
|
||||
result = await client.run_query("MATCH (n) RETURN n", max_retries=3)
|
||||
|
||||
assert result == [{"result": "success"}]
|
||||
assert mock_loop.run_in_executor.call_count == 3
|
||||
assert mock_sleep.call_count == 2 # Two retries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_query_max_retries_exceeded(self):
|
||||
"""Test query execution when max retries exceeded"""
|
||||
from neo4j.exceptions import TransientError
|
||||
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
with (
|
||||
patch("asyncio.get_event_loop") as mock_get_loop,
|
||||
patch("asyncio.sleep", new_callable=AsyncMock),
|
||||
):
|
||||
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=TransientError("Connection lost")
|
||||
)
|
||||
|
||||
with pytest.raises(TransientError):
|
||||
await client.run_query("MATCH (n) RETURN n", max_retries=2)
|
||||
|
||||
assert mock_loop.run_in_executor.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_query_non_retryable_error(self):
|
||||
"""Test query execution with non-retryable error"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=ValueError("Invalid query")
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await client.run_query("INVALID QUERY")
|
||||
|
||||
assert mock_loop.run_in_executor.call_count == 1 # No retries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_transaction_success(self):
|
||||
"""Test successful transaction execution"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
def mock_transaction_func(tx):
|
||||
return {"created": "node"}
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
mock_loop.run_in_executor = AsyncMock(return_value={"created": "node"})
|
||||
|
||||
result = await client.run_transaction(mock_transaction_func)
|
||||
|
||||
assert result == {"created": "node"}
|
||||
mock_loop.run_in_executor.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_node(self):
|
||||
"""Test node creation with temporal properties"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
properties = {"name": "Test Node", "value": 123}
|
||||
|
||||
with patch.object(client, "run_query") as mock_run_query:
|
||||
mock_run_query.return_value = [
|
||||
{
|
||||
"n": {
|
||||
"name": "Test Node",
|
||||
"value": 123,
|
||||
"asserted_at": "2023-01-01T00:00:00",
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
result = await client.create_node("TestLabel", properties)
|
||||
|
||||
assert result == {
|
||||
"name": "Test Node",
|
||||
"value": 123,
|
||||
"asserted_at": "2023-01-01T00:00:00",
|
||||
}
|
||||
mock_run_query.assert_called_once()
|
||||
|
||||
# Check that asserted_at was added to properties
|
||||
call_args = mock_run_query.call_args
|
||||
assert "asserted_at" in call_args[0][1]["properties"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_node_with_existing_asserted_at(self):
|
||||
"""Test node creation when asserted_at already exists"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
existing_time = datetime(2023, 1, 1, 12, 0, 0)
|
||||
properties = {"name": "Test Node", "asserted_at": existing_time}
|
||||
|
||||
with patch.object(client, "run_query") as mock_run_query:
|
||||
mock_run_query.return_value = [{"n": properties}]
|
||||
|
||||
result = await client.create_node("TestLabel", properties)
|
||||
|
||||
# Should not modify existing asserted_at
|
||||
call_args = mock_run_query.call_args
|
||||
assert call_args[0][1]["properties"]["asserted_at"] == existing_time
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_node(self):
|
||||
"""Test node update with bitemporal versioning"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
properties = {"name": "Updated Node", "value": 456}
|
||||
|
||||
with patch.object(client, "run_transaction") as mock_run_transaction:
|
||||
mock_run_transaction.return_value = {"name": "Updated Node", "value": 456}
|
||||
|
||||
result = await client.update_node("TestLabel", "node123", properties)
|
||||
|
||||
assert result == {"name": "Updated Node", "value": 456}
|
||||
mock_run_transaction.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_relationship(self):
|
||||
"""Test relationship creation"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
rel_properties = {"strength": 0.8, "type": "RELATED_TO"}
|
||||
|
||||
with patch.object(client, "run_query") as mock_run_query:
|
||||
mock_run_query.return_value = [{"r": rel_properties}]
|
||||
|
||||
result = await client.create_relationship(
|
||||
"Person", "person1", "Company", "company1", "WORKS_FOR", rel_properties
|
||||
)
|
||||
|
||||
assert result == rel_properties
|
||||
mock_run_query.assert_called_once()
|
||||
|
||||
# Check query parameters
|
||||
call_args = mock_run_query.call_args
|
||||
params = call_args[0][1]
|
||||
assert params["from_id"] == "person1"
|
||||
assert params["to_id"] == "company1"
|
||||
assert "asserted_at" in params["properties"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_node_lineage(self):
|
||||
"""Test getting node lineage"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
lineage_data = [
|
||||
{"path": "path1", "evidence": {"id": "evidence1"}},
|
||||
{"path": "path2", "evidence": {"id": "evidence2"}},
|
||||
]
|
||||
|
||||
with patch.object(client, "run_query") as mock_run_query:
|
||||
mock_run_query.return_value = lineage_data
|
||||
|
||||
result = await client.get_node_lineage("node123", max_depth=5)
|
||||
|
||||
assert result == lineage_data
|
||||
mock_run_query.assert_called_once()
|
||||
|
||||
# Check query parameters
|
||||
call_args = mock_run_query.call_args
|
||||
params = call_args[0][1]
|
||||
assert params["node_id"] == "node123"
|
||||
assert params["max_depth"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_to_rdf_success(self):
|
||||
"""Test successful RDF export"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
export_result = [{"triplesCount": 100, "format": "turtle"}]
|
||||
|
||||
with patch.object(client, "run_query") as mock_run_query:
|
||||
mock_run_query.return_value = export_result
|
||||
|
||||
result = await client.export_to_rdf("turtle")
|
||||
|
||||
assert result == {"triplesCount": 100, "format": "turtle"}
|
||||
mock_run_query.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_to_rdf_fallback(self):
|
||||
"""Test RDF export with fallback"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
with (
|
||||
patch.object(client, "run_query") as mock_run_query,
|
||||
patch.object(client, "_export_rdf_fallback") as mock_fallback,
|
||||
):
|
||||
|
||||
mock_run_query.side_effect = Exception("n10s plugin not available")
|
||||
mock_fallback.return_value = "fallback_rdf_data"
|
||||
|
||||
result = await client.export_to_rdf("turtle")
|
||||
|
||||
assert result == {"rdf_data": "fallback_rdf_data", "format": "turtle"}
|
||||
mock_fallback.assert_called_once_with("neo4j")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_rdf_fallback(self):
|
||||
"""Test fallback RDF export method"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
nodes_data = [
|
||||
{"labels": ["Person"], "props": {"name": "John"}, "neo_id": 1},
|
||||
{"labels": ["Company"], "props": {"name": "Acme"}, "neo_id": 2},
|
||||
]
|
||||
|
||||
rels_data = [{"type": "WORKS_FOR", "props": {}, "from_id": 1, "to_id": 2}]
|
||||
|
||||
with patch.object(client, "run_query") as mock_run_query:
|
||||
mock_run_query.side_effect = [nodes_data, rels_data]
|
||||
|
||||
result = await client._export_rdf_fallback()
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert (
|
||||
"Person" in result or "Company" in result
|
||||
) # Should contain some RDF data
|
||||
assert mock_run_query.call_count == 2
|
||||
|
||||
|
||||
class TestSHACLValidator:
|
||||
"""Test SHACLValidator class"""
|
||||
|
||||
def test_shacl_validator_init(self):
|
||||
"""Test SHACLValidator initialization"""
|
||||
validator = SHACLValidator("/path/to/shapes.ttl")
|
||||
|
||||
assert validator.shapes_file == "/path/to/shapes.ttl"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_success(self):
|
||||
"""Test successful SHACL validation"""
|
||||
validator = SHACLValidator("/path/to/shapes.ttl")
|
||||
|
||||
rdf_data = """
|
||||
@prefix ex: <http://example.org/> .
|
||||
ex:person1 a ex:Person ;
|
||||
ex:name "John Doe" ;
|
||||
ex:age 30 .
|
||||
"""
|
||||
|
||||
def mock_validate():
|
||||
# Mock pySHACL validation
|
||||
with (
|
||||
patch("pyshacl.validate") as mock_pyshacl,
|
||||
patch("rdflib.Graph") as mock_graph_class,
|
||||
):
|
||||
|
||||
mock_data_graph = Mock()
|
||||
mock_shapes_graph = Mock()
|
||||
mock_results_graph = Mock()
|
||||
mock_results_graph.subjects.return_value = [] # No violations
|
||||
|
||||
mock_graph_class.side_effect = [mock_data_graph, mock_shapes_graph]
|
||||
mock_pyshacl.return_value = (
|
||||
True,
|
||||
mock_results_graph,
|
||||
"Validation passed",
|
||||
)
|
||||
|
||||
return validator._SHACLValidator__validate_sync(rdf_data)
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
return_value={
|
||||
"conforms": True,
|
||||
"results_text": "Validation passed",
|
||||
"violations_count": 0,
|
||||
}
|
||||
)
|
||||
|
||||
result = await validator.validate_graph(rdf_data)
|
||||
|
||||
assert result["conforms"] is True
|
||||
assert result["violations_count"] == 0
|
||||
assert "passed" in result["results_text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_with_violations(self):
|
||||
"""Test SHACL validation with violations"""
|
||||
validator = SHACLValidator("/path/to/shapes.ttl")
|
||||
|
||||
rdf_data = """
|
||||
@prefix ex: <http://example.org/> .
|
||||
ex:person1 a ex:Person ;
|
||||
ex:name "John Doe" .
|
||||
"""
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
return_value={
|
||||
"conforms": False,
|
||||
"results_text": "Missing required property: age",
|
||||
"violations_count": 1,
|
||||
}
|
||||
)
|
||||
|
||||
result = await validator.validate_graph(rdf_data)
|
||||
|
||||
assert result["conforms"] is False
|
||||
assert result["violations_count"] == 1
|
||||
assert "Missing" in result["results_text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_import_error(self):
|
||||
"""Test SHACL validation when pySHACL not available"""
|
||||
validator = SHACLValidator("/path/to/shapes.ttl")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
return_value={
|
||||
"conforms": True,
|
||||
"results_text": "SHACL validation skipped (pySHACL not installed)",
|
||||
"violations_count": 0,
|
||||
}
|
||||
)
|
||||
|
||||
result = await validator.validate_graph(
|
||||
"@prefix ex: <http://example.org/> ."
|
||||
)
|
||||
|
||||
assert result["conforms"] is True
|
||||
assert result["violations_count"] == 0
|
||||
assert "skipped" in result["results_text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_graph_validation_error(self):
|
||||
"""Test SHACL validation with validation error"""
|
||||
validator = SHACLValidator("/path/to/shapes.ttl")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
return_value={
|
||||
"conforms": False,
|
||||
"results_text": "Validation error: Invalid RDF syntax",
|
||||
"violations_count": -1,
|
||||
}
|
||||
)
|
||||
|
||||
result = await validator.validate_graph("invalid rdf data")
|
||||
|
||||
assert result["conforms"] is False
|
||||
assert result["violations_count"] == -1
|
||||
assert "error" in result["results_text"]
|
||||
|
||||
|
||||
class TestTemporalQueries:
|
||||
"""Test TemporalQueries class"""
|
||||
|
||||
def test_get_current_state_query_no_filters(self):
|
||||
"""Test current state query without filters"""
|
||||
query = TemporalQueries.get_current_state_query("Person")
|
||||
|
||||
assert "MATCH (n:Person)" in query
|
||||
assert "n.retracted_at IS NULL" in query
|
||||
assert "ORDER BY n.asserted_at DESC" in query
|
||||
|
||||
def test_get_current_state_query_with_filters(self):
|
||||
"""Test current state query with filters"""
|
||||
filters = {"name": "John Doe", "age": 30, "active": True}
|
||||
query = TemporalQueries.get_current_state_query("Person", filters)
|
||||
|
||||
assert "MATCH (n:Person)" in query
|
||||
assert "n.retracted_at IS NULL" in query
|
||||
assert "n.name = 'John Doe'" in query
|
||||
assert "n.age = 30" in query
|
||||
assert "n.active = True" in query
|
||||
|
||||
def test_get_historical_state_query_no_filters(self):
|
||||
"""Test historical state query without filters"""
|
||||
as_of_time = datetime(2023, 6, 15, 12, 0, 0)
|
||||
query = TemporalQueries.get_historical_state_query("Person", as_of_time)
|
||||
|
||||
assert "MATCH (n:Person)" in query
|
||||
assert "n.asserted_at <= datetime('2023-06-15T12:00:00')" in query
|
||||
assert (
|
||||
"n.retracted_at IS NULL OR n.retracted_at > datetime('2023-06-15T12:00:00')"
|
||||
in query
|
||||
)
|
||||
assert "ORDER BY n.asserted_at DESC" in query
|
||||
|
||||
def test_get_historical_state_query_with_filters(self):
|
||||
"""Test historical state query with filters"""
|
||||
as_of_time = datetime(2023, 6, 15, 12, 0, 0)
|
||||
filters = {"department": "Engineering", "level": 5}
|
||||
query = TemporalQueries.get_historical_state_query(
|
||||
"Employee", as_of_time, filters
|
||||
)
|
||||
|
||||
assert "MATCH (n:Employee)" in query
|
||||
assert "n.asserted_at <= datetime('2023-06-15T12:00:00')" in query
|
||||
assert "n.department = 'Engineering'" in query
|
||||
assert "n.level = 5" in query
|
||||
|
||||
def test_get_audit_trail_query(self):
|
||||
"""Test audit trail query"""
|
||||
query = TemporalQueries.get_audit_trail_query("node123")
|
||||
|
||||
assert "MATCH (n {id: 'node123'})" in query
|
||||
assert "n.asserted_at as asserted_at" in query
|
||||
assert "n.retracted_at as retracted_at" in query
|
||||
assert "n.source as source" in query
|
||||
assert "n.extractor_version as extractor_version" in query
|
||||
assert "properties(n) as properties" in query
|
||||
assert "ORDER BY n.asserted_at ASC" in query
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Test integration scenarios"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_neo4j_workflow(self):
|
||||
"""Test complete Neo4j workflow"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
# Mock all the operations
|
||||
with (
|
||||
patch.object(client, "create_node") as mock_create,
|
||||
patch.object(client, "create_relationship") as mock_create_rel,
|
||||
patch.object(client, "get_node_lineage") as mock_lineage,
|
||||
):
|
||||
|
||||
mock_create.return_value = {"id": "person1", "name": "John Doe"}
|
||||
mock_create_rel.return_value = {"type": "WORKS_FOR", "strength": 0.8}
|
||||
mock_lineage.return_value = [{"path": "lineage_path"}]
|
||||
|
||||
# Create nodes
|
||||
person = await client.create_node("Person", {"name": "John Doe"})
|
||||
company = await client.create_node("Company", {"name": "Acme Corp"})
|
||||
|
||||
# Create relationship
|
||||
relationship = await client.create_relationship(
|
||||
"Person",
|
||||
"person1",
|
||||
"Company",
|
||||
"company1",
|
||||
"WORKS_FOR",
|
||||
{"strength": 0.8},
|
||||
)
|
||||
|
||||
# Get lineage
|
||||
lineage = await client.get_node_lineage("person1")
|
||||
|
||||
assert person["name"] == "John Doe"
|
||||
assert relationship["type"] == "WORKS_FOR"
|
||||
assert len(lineage) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_queries_integration(self):
|
||||
"""Test temporal queries integration"""
|
||||
mock_driver = Mock()
|
||||
client = Neo4jClient(mock_driver)
|
||||
|
||||
# Test current state query
|
||||
current_query = TemporalQueries.get_current_state_query(
|
||||
"Person", {"active": True}
|
||||
)
|
||||
assert "Person" in current_query
|
||||
assert "active = True" in current_query
|
||||
|
||||
# Test historical state query
|
||||
historical_time = datetime(2023, 1, 1, 0, 0, 0)
|
||||
historical_query = TemporalQueries.get_historical_state_query(
|
||||
"Person", historical_time
|
||||
)
|
||||
assert "2023-01-01T00:00:00" in historical_query
|
||||
|
||||
# Test audit trail query
|
||||
audit_query = TemporalQueries.get_audit_trail_query("person123")
|
||||
assert "person123" in audit_query
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shacl_validation_integration(self):
|
||||
"""Test SHACL validation integration"""
|
||||
validator = SHACLValidator("/path/to/shapes.ttl")
|
||||
|
||||
# Mock the validation process
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = Mock()
|
||||
mock_get_loop.return_value = mock_loop
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
return_value={
|
||||
"conforms": True,
|
||||
"results_text": "All constraints satisfied",
|
||||
"violations_count": 0,
|
||||
}
|
||||
)
|
||||
|
||||
rdf_data = "@prefix ex: <http://example.org/> . ex:person1 a ex:Person ."
|
||||
result = await validator.validate_graph(rdf_data)
|
||||
|
||||
assert result["conforms"] is True
|
||||
assert result["violations_count"] == 0
|
||||
Reference in New Issue
Block a user