"""Policy loading and management with overlays and hot reload.""" import hashlib import json import re from collections.abc import Callable from datetime import datetime from pathlib import Path from typing import Any import structlog import yaml from jsonschema import ValidationError, validate from ..schemas import ( CompiledCoveragePolicy, CoveragePolicy, PolicyError, ValidationResult, ) logger = structlog.get_logger() class PolicyLoader: """Loads and manages coverage policies with overlays and hot reload""" def __init__(self, config_dir: str = "config"): self.config_dir = Path(config_dir) self.schema_path = Path(__file__).parent.parent / "coverage_schema.json" self._schema_cache: dict[str, Any] | None = None def load_policy( self, baseline_path: str | None = None, jurisdiction: str = "UK", tax_year: str = "2024-25", tenant_id: str | None = None, ) -> CoveragePolicy: """Load policy with overlays applied""" # Default baseline path if baseline_path is None: baseline_path = str(self.config_dir / "coverage.yaml") # Load baseline policy baseline = self._load_yaml_file(baseline_path) # Collect overlay files overlay_files = [] # Jurisdiction-specific overlay jurisdiction_file = self.config_dir / f"coverage.{jurisdiction}.{tax_year}.yaml" if jurisdiction_file.exists(): overlay_files.append(str(jurisdiction_file)) # Tenant-specific overlay if tenant_id: tenant_file = self.config_dir / "overrides" / f"{tenant_id}.yaml" if tenant_file.exists(): overlay_files.append(str(tenant_file)) # Load overlays overlays = [self._load_yaml_file(path) for path in overlay_files] # Merge all policies merged = self.merge_overlays(baseline, *overlays) # Apply feature flags if available merged = self.apply_feature_flags(merged) # Validate against schema self._validate_policy(merged) # Convert to Pydantic model try: policy = CoveragePolicy(**merged) logger.info( "Policy loaded successfully", jurisdiction=jurisdiction, tax_year=tax_year, tenant_id=tenant_id, overlays=len(overlays), ) return policy except Exception as e: raise PolicyError(f"Failed to parse policy: {str(e)}") from e def merge_overlays( self, base: dict[str, Any], *overlays: dict[str, Any] ) -> dict[str, Any]: """Merge base policy with overlays using deep merge""" result = base.copy() for overlay in overlays: result = self._deep_merge(result, overlay) return result def apply_feature_flags(self, policy: dict[str, Any]) -> dict[str, Any]: """Apply feature flags to policy (placeholder for Unleash integration)""" # TODO: Integrate with Unleash feature flags # For now, just return the policy unchanged logger.debug("Feature flags not implemented, returning policy unchanged") return policy def compile_predicates(self, policy: CoveragePolicy) -> CompiledCoveragePolicy: """Compile condition strings into callable predicates""" compiled_predicates: dict[str, Callable[[str, str], bool]] = {} # Compile trigger conditions for schedule_id, trigger in policy.triggers.items(): for condition in trigger.any_of + trigger.all_of: if condition not in compiled_predicates: compiled_predicates[condition] = self._compile_condition(condition) # Compile evidence conditions for schedule in policy.schedules.values(): for evidence in schedule.evidence: if evidence.condition and evidence.condition not in compiled_predicates: compiled_predicates[evidence.condition] = self._compile_condition( evidence.condition ) # Calculate hash of source files source_files = [str(self.config_dir / "coverage.yaml")] policy_hash = self._calculate_hash(source_files) return CompiledCoveragePolicy( policy=policy, compiled_predicates=compiled_predicates, compiled_at=datetime.utcnow(), hash=policy_hash, source_files=source_files, ) def validate_policy(self, policy_dict: dict[str, Any]) -> ValidationResult: """Validate policy against schema and business rules""" errors = [] warnings = [] try: # JSON Schema validation self._validate_policy(policy_dict) # Business rule validation business_errors, business_warnings = self._validate_business_rules( policy_dict ) errors.extend(business_errors) warnings.extend(business_warnings) except ValidationError as e: errors.append(f"Schema validation failed: {e.message}") except Exception as e: errors.append(f"Validation error: {str(e)}") return ValidationResult(ok=len(errors) == 0, errors=errors, warnings=warnings) def _load_yaml_file(self, path: str) -> dict[str, Any]: """Load YAML file with error handling""" try: with open(path, encoding="utf-8") as f: return yaml.safe_load(f) or {} except FileNotFoundError: raise PolicyError(f"Policy file not found: {path}") except yaml.YAMLError as e: raise PolicyError(f"Invalid YAML in {path}: {str(e)}") def _deep_merge( self, base: dict[str, Any], overlay: dict[str, Any] ) -> dict[str, Any]: """Deep merge two dictionaries""" result = base.copy() for key, value in overlay.items(): if ( key in result and isinstance(result[key], dict) and isinstance(value, dict) ): result[key] = self._deep_merge(result[key], value) else: result[key] = value return result def _validate_policy(self, policy_dict: dict[str, Any]) -> None: """Validate policy against JSON schema""" if self._schema_cache is None: with open(self.schema_path, encoding="utf-8") as f: self._schema_cache = json.load(f) validate(instance=policy_dict, schema=self._schema_cache) # fmt: skip # pyright: ignore[reportArgumentType] def _validate_business_rules( self, policy_dict: dict[str, Any] ) -> tuple[list[str], list[str]]: """Validate business rules beyond schema""" errors = [] warnings = [] # Check that all evidence IDs are in document_kinds document_kinds = set(policy_dict.get("document_kinds", [])) for schedule_id, schedule in policy_dict.get("schedules", {}).items(): for evidence in schedule.get("evidence", []): evidence_id = evidence.get("id") if evidence_id not in document_kinds: # Check if it's in acceptable_alternatives of any evidence found_in_alternatives = False for other_schedule in policy_dict.get("schedules", {}).values(): for other_evidence in other_schedule.get("evidence", []): if evidence_id in other_evidence.get( "acceptable_alternatives", [] ): found_in_alternatives = True break if found_in_alternatives: break if not found_in_alternatives: errors.append( f"Evidence ID '{evidence_id}' in schedule '{schedule_id}' " f"not found in document_kinds or acceptable_alternatives" ) # Check acceptable alternatives for alt in evidence.get("acceptable_alternatives", []): if alt not in document_kinds: warnings.append( f"Alternative '{alt}' for evidence '{evidence_id}' " f"not found in document_kinds" ) # Check that all schedules referenced in triggers exist triggers = policy_dict.get("triggers", {}) schedules = policy_dict.get("schedules", {}) for schedule_id in triggers: if schedule_id not in schedules: errors.append( f"Trigger for '{schedule_id}' but no schedule definition found" ) return errors, warnings def _compile_condition(self, condition: str) -> Callable[[str, str], bool]: """Compile a condition string into a callable predicate""" # Simple condition parser for the DSL condition = condition.strip() # Handle exists() conditions exists_match = re.match(r"exists\((\w+)\[([^\]]+)\]\)", condition) if exists_match: entity_type = exists_match.group(1) filters = exists_match.group(2) return self._create_exists_predicate(entity_type, filters) # Handle simple property conditions if condition in [ "property_joint_ownership", "candidate_FHL", "claims_FTCR", "claims_remittance_basis", "received_estate_income", ]: return self._create_property_predicate(condition) # Handle computed conditions if condition in [ "turnover_lt_vat_threshold", "turnover_ge_vat_threshold", ]: return self._create_computed_predicate(condition) # Handle taxpayer flags if condition.startswith("taxpayer_flag:"): flag_name = condition.split(":", 1)[1].strip() return self._create_flag_predicate(flag_name) # Handle filing mode if condition.startswith("filing_mode:"): mode = condition.split(":", 1)[1].strip() return self._create_filing_mode_predicate(mode) # Default: always false for unknown conditions logger.warning("Unknown condition, defaulting to False", condition=condition) return lambda taxpayer_id, tax_year: False def _create_exists_predicate( self, entity_type: str, filters: str ) -> Callable[[str, str], bool]: """Create predicate for exists() conditions""" def predicate(taxpayer_id: str, tax_year: str) -> bool: # This would query the KG for the entity with filters # For now, return False as placeholder logger.debug( "Exists predicate called", entity_type=entity_type, filters=filters, taxpayer_id=taxpayer_id, tax_year=tax_year, ) return False return predicate def _create_property_predicate( self, property_name: str ) -> Callable[[str, str], bool]: """Create predicate for property conditions""" def predicate(taxpayer_id: str, tax_year: str) -> bool: # This would query the KG for the property logger.debug( "Property predicate called", property_name=property_name, taxpayer_id=taxpayer_id, tax_year=tax_year, ) return False return predicate def _create_computed_predicate( self, computation: str ) -> Callable[[str, str], bool]: """Create predicate for computed conditions""" def predicate(taxpayer_id: str, tax_year: str) -> bool: # This would perform the computation logger.debug( "Computed predicate called", computation=computation, taxpayer_id=taxpayer_id, tax_year=tax_year, ) return False return predicate def _create_flag_predicate(self, flag_name: str) -> Callable[[str, str], bool]: """Create predicate for taxpayer flags""" def predicate(taxpayer_id: str, tax_year: str) -> bool: # This would check taxpayer flags logger.debug( "Flag predicate called", flag_name=flag_name, taxpayer_id=taxpayer_id, tax_year=tax_year, ) return False return predicate def _create_filing_mode_predicate(self, mode: str) -> Callable[[str, str], bool]: """Create predicate for filing mode""" def predicate(taxpayer_id: str, tax_year: str) -> bool: # This would check filing mode preference logger.debug( "Filing mode predicate called", mode=mode, taxpayer_id=taxpayer_id, tax_year=tax_year, ) return False return predicate def _calculate_hash(self, file_paths: list[str]) -> str: """Calculate hash of policy files""" hasher = hashlib.sha256() for path in sorted(file_paths): try: with open(path, "rb") as f: hasher.update(f.read()) except FileNotFoundError: logger.warning("File not found for hashing", path=path) return hasher.hexdigest()