Files
ai-tax-agent/libs/policy/loader.py
harkon b324ff09ef
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Initial commit
2025-10-11 08:41:36 +01:00

387 lines
14 KiB
Python

"""Policy loading and management with overlays and hot reload."""
import hashlib
import json
import re
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from typing import Any
import structlog
import yaml
from jsonschema import ValidationError, validate
from ..schemas import (
CompiledCoveragePolicy,
CoveragePolicy,
PolicyError,
ValidationResult,
)
logger = structlog.get_logger()
class PolicyLoader:
"""Loads and manages coverage policies with overlays and hot reload"""
def __init__(self, config_dir: str = "config"):
self.config_dir = Path(config_dir)
self.schema_path = Path(__file__).parent.parent / "coverage_schema.json"
self._schema_cache: dict[str, Any] | None = None
def load_policy(
self,
baseline_path: str | None = None,
jurisdiction: str = "UK",
tax_year: str = "2024-25",
tenant_id: str | None = None,
) -> CoveragePolicy:
"""Load policy with overlays applied"""
# Default baseline path
if baseline_path is None:
baseline_path = str(self.config_dir / "coverage.yaml")
# Load baseline policy
baseline = self._load_yaml_file(baseline_path)
# Collect overlay files
overlay_files = []
# Jurisdiction-specific overlay
jurisdiction_file = self.config_dir / f"coverage.{jurisdiction}.{tax_year}.yaml"
if jurisdiction_file.exists():
overlay_files.append(str(jurisdiction_file))
# Tenant-specific overlay
if tenant_id:
tenant_file = self.config_dir / "overrides" / f"{tenant_id}.yaml"
if tenant_file.exists():
overlay_files.append(str(tenant_file))
# Load overlays
overlays = [self._load_yaml_file(path) for path in overlay_files]
# Merge all policies
merged = self.merge_overlays(baseline, *overlays)
# Apply feature flags if available
merged = self.apply_feature_flags(merged)
# Validate against schema
self._validate_policy(merged)
# Convert to Pydantic model
try:
policy = CoveragePolicy(**merged)
logger.info(
"Policy loaded successfully",
jurisdiction=jurisdiction,
tax_year=tax_year,
tenant_id=tenant_id,
overlays=len(overlays),
)
return policy
except Exception as e:
raise PolicyError(f"Failed to parse policy: {str(e)}") from e
def merge_overlays(
self, base: dict[str, Any], *overlays: dict[str, Any]
) -> dict[str, Any]:
"""Merge base policy with overlays using deep merge"""
result = base.copy()
for overlay in overlays:
result = self._deep_merge(result, overlay)
return result
def apply_feature_flags(self, policy: dict[str, Any]) -> dict[str, Any]:
"""Apply feature flags to policy (placeholder for Unleash integration)"""
# TODO: Integrate with Unleash feature flags
# For now, just return the policy unchanged
logger.debug("Feature flags not implemented, returning policy unchanged")
return policy
def compile_predicates(self, policy: CoveragePolicy) -> CompiledCoveragePolicy:
"""Compile condition strings into callable predicates"""
compiled_predicates: dict[str, Callable[[str, str], bool]] = {}
# Compile trigger conditions
for schedule_id, trigger in policy.triggers.items():
for condition in trigger.any_of + trigger.all_of:
if condition not in compiled_predicates:
compiled_predicates[condition] = self._compile_condition(condition)
# Compile evidence conditions
for schedule in policy.schedules.values():
for evidence in schedule.evidence:
if evidence.condition and evidence.condition not in compiled_predicates:
compiled_predicates[evidence.condition] = self._compile_condition(
evidence.condition
)
# Calculate hash of source files
source_files = [str(self.config_dir / "coverage.yaml")]
policy_hash = self._calculate_hash(source_files)
return CompiledCoveragePolicy(
policy=policy,
compiled_predicates=compiled_predicates,
compiled_at=datetime.utcnow(),
hash=policy_hash,
source_files=source_files,
)
def validate_policy(self, policy_dict: dict[str, Any]) -> ValidationResult:
"""Validate policy against schema and business rules"""
errors = []
warnings = []
try:
# JSON Schema validation
self._validate_policy(policy_dict)
# Business rule validation
business_errors, business_warnings = self._validate_business_rules(
policy_dict
)
errors.extend(business_errors)
warnings.extend(business_warnings)
except ValidationError as e:
errors.append(f"Schema validation failed: {e.message}")
except Exception as e:
errors.append(f"Validation error: {str(e)}")
return ValidationResult(ok=len(errors) == 0, errors=errors, warnings=warnings)
def _load_yaml_file(self, path: str) -> dict[str, Any]:
"""Load YAML file with error handling"""
try:
with open(path, encoding="utf-8") as f:
return yaml.safe_load(f) or {}
except FileNotFoundError:
raise PolicyError(f"Policy file not found: {path}")
except yaml.YAMLError as e:
raise PolicyError(f"Invalid YAML in {path}: {str(e)}")
def _deep_merge(
self, base: dict[str, Any], overlay: dict[str, Any]
) -> dict[str, Any]:
"""Deep merge two dictionaries"""
result = base.copy()
for key, value in overlay.items():
if (
key in result
and isinstance(result[key], dict)
and isinstance(value, dict)
):
result[key] = self._deep_merge(result[key], value)
else:
result[key] = value
return result
def _validate_policy(self, policy_dict: dict[str, Any]) -> None:
"""Validate policy against JSON schema"""
if self._schema_cache is None:
with open(self.schema_path, encoding="utf-8") as f:
self._schema_cache = json.load(f)
validate(instance=policy_dict, schema=self._schema_cache) # fmt: skip # pyright: ignore[reportArgumentType]
def _validate_business_rules(
self, policy_dict: dict[str, Any]
) -> tuple[list[str], list[str]]:
"""Validate business rules beyond schema"""
errors = []
warnings = []
# Check that all evidence IDs are in document_kinds
document_kinds = set(policy_dict.get("document_kinds", []))
for schedule_id, schedule in policy_dict.get("schedules", {}).items():
for evidence in schedule.get("evidence", []):
evidence_id = evidence.get("id")
if evidence_id not in document_kinds:
# Check if it's in acceptable_alternatives of any evidence
found_in_alternatives = False
for other_schedule in policy_dict.get("schedules", {}).values():
for other_evidence in other_schedule.get("evidence", []):
if evidence_id in other_evidence.get(
"acceptable_alternatives", []
):
found_in_alternatives = True
break
if found_in_alternatives:
break
if not found_in_alternatives:
errors.append(
f"Evidence ID '{evidence_id}' in schedule '{schedule_id}' "
f"not found in document_kinds or acceptable_alternatives"
)
# Check acceptable alternatives
for alt in evidence.get("acceptable_alternatives", []):
if alt not in document_kinds:
warnings.append(
f"Alternative '{alt}' for evidence '{evidence_id}' "
f"not found in document_kinds"
)
# Check that all schedules referenced in triggers exist
triggers = policy_dict.get("triggers", {})
schedules = policy_dict.get("schedules", {})
for schedule_id in triggers:
if schedule_id not in schedules:
errors.append(
f"Trigger for '{schedule_id}' but no schedule definition found"
)
return errors, warnings
def _compile_condition(self, condition: str) -> Callable[[str, str], bool]:
"""Compile a condition string into a callable predicate"""
# Simple condition parser for the DSL
condition = condition.strip()
# Handle exists() conditions
exists_match = re.match(r"exists\((\w+)\[([^\]]+)\]\)", condition)
if exists_match:
entity_type = exists_match.group(1)
filters = exists_match.group(2)
return self._create_exists_predicate(entity_type, filters)
# Handle simple property conditions
if condition in [
"property_joint_ownership",
"candidate_FHL",
"claims_FTCR",
"claims_remittance_basis",
"received_estate_income",
]:
return self._create_property_predicate(condition)
# Handle computed conditions
if condition in [
"turnover_lt_vat_threshold",
"turnover_ge_vat_threshold",
]:
return self._create_computed_predicate(condition)
# Handle taxpayer flags
if condition.startswith("taxpayer_flag:"):
flag_name = condition.split(":", 1)[1].strip()
return self._create_flag_predicate(flag_name)
# Handle filing mode
if condition.startswith("filing_mode:"):
mode = condition.split(":", 1)[1].strip()
return self._create_filing_mode_predicate(mode)
# Default: always false for unknown conditions
logger.warning("Unknown condition, defaulting to False", condition=condition)
return lambda taxpayer_id, tax_year: False
def _create_exists_predicate(
self, entity_type: str, filters: str
) -> Callable[[str, str], bool]:
"""Create predicate for exists() conditions"""
def predicate(taxpayer_id: str, tax_year: str) -> bool:
# This would query the KG for the entity with filters
# For now, return False as placeholder
logger.debug(
"Exists predicate called",
entity_type=entity_type,
filters=filters,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
)
return False
return predicate
def _create_property_predicate(
self, property_name: str
) -> Callable[[str, str], bool]:
"""Create predicate for property conditions"""
def predicate(taxpayer_id: str, tax_year: str) -> bool:
# This would query the KG for the property
logger.debug(
"Property predicate called",
property_name=property_name,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
)
return False
return predicate
def _create_computed_predicate(
self, computation: str
) -> Callable[[str, str], bool]:
"""Create predicate for computed conditions"""
def predicate(taxpayer_id: str, tax_year: str) -> bool:
# This would perform the computation
logger.debug(
"Computed predicate called",
computation=computation,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
)
return False
return predicate
def _create_flag_predicate(self, flag_name: str) -> Callable[[str, str], bool]:
"""Create predicate for taxpayer flags"""
def predicate(taxpayer_id: str, tax_year: str) -> bool:
# This would check taxpayer flags
logger.debug(
"Flag predicate called",
flag_name=flag_name,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
)
return False
return predicate
def _create_filing_mode_predicate(self, mode: str) -> Callable[[str, str], bool]:
"""Create predicate for filing mode"""
def predicate(taxpayer_id: str, tax_year: str) -> bool:
# This would check filing mode preference
logger.debug(
"Filing mode predicate called",
mode=mode,
taxpayer_id=taxpayer_id,
tax_year=tax_year,
)
return False
return predicate
def _calculate_hash(self, file_paths: list[str]) -> str:
"""Calculate hash of policy files"""
hasher = hashlib.sha256()
for path in sorted(file_paths):
try:
with open(path, "rb") as f:
hasher.update(f.read())
except FileNotFoundError:
logger.warning("File not found for hashing", path=path)
return hasher.hexdigest()