Files
ai-tax-agent/pipeline/etl.py
harkon b324ff09ef
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Initial commit
2025-10-11 08:41:36 +01:00

421 lines
15 KiB
Python

# FILE: pipeline/etl.py
import hashlib
import json
import logging
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import cv2
import numpy as np
import pytesseract
import yaml
from pdf2image import convert_from_path
from .llm_client import LLMClient
from .mappers import GraphMapper
from .normalizers import CurrencyNormalizer, DateNormalizer, PartyNormalizer
from .validators import DocumentValidator, FieldValidator
@dataclass
class ExtractionResult:
doc_id: str
classification: str
confidence: float
extracted_data: dict
evidence: list[dict]
errors: list[str]
processing_time: float
class DocumentETL:
def __init__(self, config_path: str):
with open(config_path) as f:
self.config = yaml.safe_load(f)
self.validator = DocumentValidator(self.config)
self.field_validator = FieldValidator(self.config)
self.currency_normalizer = CurrencyNormalizer(self.config)
self.date_normalizer = DateNormalizer(self.config)
self.party_normalizer = PartyNormalizer(self.config)
self.graph_mapper = GraphMapper(self.config)
self.llm_client = LLMClient(self.config)
self.logger = logging.getLogger(__name__)
def process_document(self, file_path: str, taxpayer_id: str) -> ExtractionResult:
"""Main ETL pipeline entry point"""
start_time = datetime.now()
doc_id = self._generate_doc_id(file_path)
try:
# Stage 1: Ingest and preprocess
images, metadata = self._ingest_document(file_path)
# Stage 2: Classify document type
classification, class_confidence = self._classify_document(
images[0], metadata
)
# Stage 3: OCR and layout analysis
ocr_results = self._perform_ocr(images)
# Stage 4: Extract structured data using LLM
extracted_data = self._extract_structured_data(
ocr_results, classification, doc_id
)
# Stage 5: Validate extracted data
validation_errors = self._validate_extraction(
extracted_data, classification
)
# Stage 6: Normalize and standardize
normalized_data = self._normalize_data(extracted_data)
# Stage 7: Map to knowledge graph
graph_nodes, graph_edges = self._map_to_graph(
normalized_data, doc_id, taxpayer_id
)
# Stage 8: Post-processing checks
final_errors = self._post_process_checks(
graph_nodes, graph_edges, validation_errors
)
processing_time = (datetime.now() - start_time).total_seconds()
return ExtractionResult(
doc_id=doc_id,
classification=classification,
confidence=class_confidence,
extracted_data=normalized_data,
evidence=self._create_evidence_records(ocr_results, doc_id),
errors=final_errors,
processing_time=processing_time,
)
except Exception as e:
self.logger.error(f"ETL pipeline failed for {file_path}: {str(e)}")
processing_time = (datetime.now() - start_time).total_seconds()
return ExtractionResult(
doc_id=doc_id,
classification="unknown",
confidence=0.0,
extracted_data={},
evidence=[],
errors=[f"Pipeline failure: {str(e)}"],
processing_time=processing_time,
)
def _generate_doc_id(self, file_path: str) -> str:
"""Generate deterministic document ID"""
with open(file_path, "rb") as f:
content = f.read()
checksum = hashlib.sha256(content).hexdigest()
return f"doc_{checksum[:16]}"
def _ingest_document(self, file_path: str) -> tuple[list[np.ndarray], dict]:
"""Convert document to images and extract metadata"""
file_path = Path(file_path)
if file_path.suffix.lower() == ".pdf":
# Convert PDF to images
pil_images = convert_from_path(str(file_path), dpi=300)
images = [np.array(img) for img in pil_images]
else:
# Handle image files
img = cv2.imread(str(file_path))
if img is None:
raise ValueError(f"Could not read image file: {file_path}")
images = [img]
# Preprocess images
processed_images = []
for img in images:
# Deskew and rotate
processed_img = self._deskew_image(img)
processed_img = self._auto_rotate(processed_img)
processed_images.append(processed_img)
metadata = {
"file_path": str(file_path),
"file_size": file_path.stat().st_size,
"mime_type": self._get_mime_type(file_path),
"pages": len(processed_images),
"created_at": datetime.now().isoformat(),
}
return processed_images, metadata
def _classify_document(
self, image: np.ndarray, metadata: dict
) -> tuple[str, float]:
"""Classify document type using OCR + LLM"""
# Quick OCR for classification
text = pytesseract.image_to_string(image)
# Use LLM for classification
classification_prompt = self._load_prompt("doc_classify")
classification_result = self.llm_client.classify_document(
text[:2000],
classification_prompt, # First 2000 chars for classification
)
return classification_result["type"], classification_result["confidence"]
def _perform_ocr(self, images: list[np.ndarray]) -> list[dict]:
"""Perform OCR with layout analysis"""
ocr_results = []
for page_num, image in enumerate(images, 1):
# Get detailed OCR data with bounding boxes
ocr_data = pytesseract.image_to_data(
image,
output_type=pytesseract.Output.DICT,
config="--psm 6", # Uniform block of text
)
# Extract text blocks with confidence and position
blocks = []
for i in range(len(ocr_data["text"])):
if int(ocr_data["conf"][i]) > 30: # Confidence threshold
blocks.append(
{
"text": ocr_data["text"][i],
"confidence": int(ocr_data["conf"][i]) / 100.0,
"bbox": {
"x": ocr_data["left"][i],
"y": ocr_data["top"][i],
"width": ocr_data["width"][i],
"height": ocr_data["height"][i],
},
"page": page_num,
}
)
# Detect tables using layout analysis
tables = self._detect_tables(image, blocks)
ocr_results.append(
{
"page": page_num,
"blocks": blocks,
"tables": tables,
"full_text": " ".join([b["text"] for b in blocks]),
}
)
return ocr_results
def _extract_structured_data(
self, ocr_results: list[dict], classification: str, doc_id: str
) -> dict:
"""Extract structured data using LLM with schema constraints"""
# Load appropriate extraction prompt
if classification == "bank_statement":
prompt = self._load_prompt("bank_statement_extract")
schema = self._load_schema("bank_statement")
elif classification == "invoice":
prompt = self._load_prompt("invoice_extract")
schema = self._load_schema("invoice")
elif classification == "payslip":
prompt = self._load_prompt("payslip_extract")
schema = self._load_schema("payslip")
else:
prompt = self._load_prompt("kv_extract")
schema = self._load_schema("generic")
# Combine OCR results
combined_text = "\n".join(
[f"Page {r['page']}:\n{r['full_text']}" for r in ocr_results]
)
# Extract with retry logic
max_retries = 3
for attempt in range(max_retries):
try:
extracted = self.llm_client.extract_structured_data(
combined_text,
prompt,
schema,
temperature=0.1 if attempt == 0 else 0.3,
)
# Validate against schema
if self.field_validator.validate_schema(extracted, schema):
return extracted
else:
self.logger.warning(
f"Schema validation failed, attempt {attempt + 1}"
)
except Exception as e:
self.logger.warning(
f"Extraction attempt {attempt + 1} failed: {str(e)}"
)
# Fallback to basic key-value extraction
return self._fallback_extraction(ocr_results)
def _normalize_data(self, extracted_data: dict) -> dict:
"""Normalize extracted data to canonical formats"""
normalized = extracted_data.copy()
# Normalize currencies
for field in ["amount", "gross", "net", "tax_withheld"]:
if field in normalized:
normalized[field] = self.currency_normalizer.normalize(
normalized[field]
)
# Normalize dates
for field in ["date", "period_start", "period_end", "due_date"]:
if field in normalized:
normalized[field] = self.date_normalizer.normalize(normalized[field])
# Normalize party names
for field in ["payer_name", "employer_name", "supplier_name"]:
if field in normalized:
normalized[field] = self.party_normalizer.normalize(normalized[field])
return normalized
def _map_to_graph(
self, normalized_data: dict, doc_id: str, taxpayer_id: str
) -> tuple[list[dict], list[dict]]:
"""Map normalized data to knowledge graph nodes and edges"""
return self.graph_mapper.map_to_graph(normalized_data, doc_id, taxpayer_id)
def _deskew_image(self, image: np.ndarray) -> np.ndarray:
"""Correct skew in scanned documents"""
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
lines = cv2.HoughLines(edges, 1, np.pi / 180, threshold=100)
if lines is not None:
angles = []
for rho, theta in lines[:10]: # Use first 10 lines
angle = theta * 180 / np.pi
if angle < 45:
angles.append(angle)
elif angle > 135:
angles.append(angle - 180)
if angles:
median_angle = np.median(angles)
if abs(median_angle) > 0.5: # Only rotate if significant skew
(h, w) = image.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, median_angle, 1.0)
return cv2.warpAffine(
image,
M,
(w, h),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_REPLICATE,
)
return image
def _auto_rotate(self, image: np.ndarray) -> np.ndarray:
"""Auto-rotate image to correct orientation"""
# Use Tesseract's orientation detection
try:
osd = pytesseract.image_to_osd(image)
rotation = int(
[line for line in osd.split("\n") if "Rotate:" in line][0]
.split(":")[1]
.strip()
)
if rotation != 0:
(h, w) = image.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, rotation, 1.0)
return cv2.warpAffine(image, M, (w, h))
except:
pass # If OSD fails, return original
return image
def _detect_tables(self, image: np.ndarray, blocks: list[dict]) -> list[dict]:
"""Detect and extract table structures"""
# Simple table detection using horizontal/vertical lines
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Detect horizontal lines
horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel)
# Detect vertical lines
vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel)
# Find table regions
table_mask = cv2.addWeighted(horizontal_lines, 0.5, vertical_lines, 0.5, 0.0)
contours, _ = cv2.findContours(
table_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
tables = []
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if w > 200 and h > 100: # Minimum table size
# Extract text blocks within table region
table_blocks = [
block
for block in blocks
if (
block["bbox"]["x"] >= x
and block["bbox"]["y"] >= y
and block["bbox"]["x"] + block["bbox"]["width"] <= x + w
and block["bbox"]["y"] + block["bbox"]["height"] <= y + h
)
]
tables.append(
{
"bbox": {"x": x, "y": y, "width": w, "height": h},
"blocks": table_blocks,
}
)
return tables
def _load_prompt(self, prompt_name: str) -> str:
"""Load LLM prompt template"""
prompt_path = Path(f"prompts/{prompt_name}.txt")
with open(prompt_path) as f:
return f.read()
def _load_schema(self, schema_name: str) -> dict:
"""Load JSON schema for validation"""
schema_path = Path(f"schemas/{schema_name}.schema.json")
with open(schema_path) as f:
return json.load(f)
def _create_evidence_records(
self, ocr_results: list[dict], doc_id: str
) -> list[dict]:
"""Create evidence records with provenance"""
evidence = []
for page_result in ocr_results:
for block in page_result["blocks"]:
evidence.append(
{
"snippet_id": f"{doc_id}_p{page_result['page']}_{len(evidence)}",
"doc_ref": doc_id,
"page": page_result["page"],
"bbox": block["bbox"],
"text_hash": hashlib.sha256(block["text"].encode()).hexdigest(),
"ocr_confidence": block["confidence"],
"extracted_text": block["text"],
}
)
return evidence