# FILE: pipeline/etl.py import hashlib import json import logging from dataclasses import dataclass from datetime import datetime from pathlib import Path import cv2 import numpy as np import pytesseract import yaml from pdf2image import convert_from_path from .llm_client import LLMClient from .mappers import GraphMapper from .normalizers import CurrencyNormalizer, DateNormalizer, PartyNormalizer from .validators import DocumentValidator, FieldValidator @dataclass class ExtractionResult: doc_id: str classification: str confidence: float extracted_data: dict evidence: list[dict] errors: list[str] processing_time: float class DocumentETL: def __init__(self, config_path: str): with open(config_path) as f: self.config = yaml.safe_load(f) self.validator = DocumentValidator(self.config) self.field_validator = FieldValidator(self.config) self.currency_normalizer = CurrencyNormalizer(self.config) self.date_normalizer = DateNormalizer(self.config) self.party_normalizer = PartyNormalizer(self.config) self.graph_mapper = GraphMapper(self.config) self.llm_client = LLMClient(self.config) self.logger = logging.getLogger(__name__) def process_document(self, file_path: str, taxpayer_id: str) -> ExtractionResult: """Main ETL pipeline entry point""" start_time = datetime.now() doc_id = self._generate_doc_id(file_path) try: # Stage 1: Ingest and preprocess images, metadata = self._ingest_document(file_path) # Stage 2: Classify document type classification, class_confidence = self._classify_document( images[0], metadata ) # Stage 3: OCR and layout analysis ocr_results = self._perform_ocr(images) # Stage 4: Extract structured data using LLM extracted_data = self._extract_structured_data( ocr_results, classification, doc_id ) # Stage 5: Validate extracted data validation_errors = self._validate_extraction( extracted_data, classification ) # Stage 6: Normalize and standardize normalized_data = self._normalize_data(extracted_data) # Stage 7: Map to knowledge graph graph_nodes, graph_edges = self._map_to_graph( normalized_data, doc_id, taxpayer_id ) # Stage 8: Post-processing checks final_errors = self._post_process_checks( graph_nodes, graph_edges, validation_errors ) processing_time = (datetime.now() - start_time).total_seconds() return ExtractionResult( doc_id=doc_id, classification=classification, confidence=class_confidence, extracted_data=normalized_data, evidence=self._create_evidence_records(ocr_results, doc_id), errors=final_errors, processing_time=processing_time, ) except Exception as e: self.logger.error(f"ETL pipeline failed for {file_path}: {str(e)}") processing_time = (datetime.now() - start_time).total_seconds() return ExtractionResult( doc_id=doc_id, classification="unknown", confidence=0.0, extracted_data={}, evidence=[], errors=[f"Pipeline failure: {str(e)}"], processing_time=processing_time, ) def _generate_doc_id(self, file_path: str) -> str: """Generate deterministic document ID""" with open(file_path, "rb") as f: content = f.read() checksum = hashlib.sha256(content).hexdigest() return f"doc_{checksum[:16]}" def _ingest_document(self, file_path: str) -> tuple[list[np.ndarray], dict]: """Convert document to images and extract metadata""" file_path = Path(file_path) if file_path.suffix.lower() == ".pdf": # Convert PDF to images pil_images = convert_from_path(str(file_path), dpi=300) images = [np.array(img) for img in pil_images] else: # Handle image files img = cv2.imread(str(file_path)) if img is None: raise ValueError(f"Could not read image file: {file_path}") images = [img] # Preprocess images processed_images = [] for img in images: # Deskew and rotate processed_img = self._deskew_image(img) processed_img = self._auto_rotate(processed_img) processed_images.append(processed_img) metadata = { "file_path": str(file_path), "file_size": file_path.stat().st_size, "mime_type": self._get_mime_type(file_path), "pages": len(processed_images), "created_at": datetime.now().isoformat(), } return processed_images, metadata def _classify_document( self, image: np.ndarray, metadata: dict ) -> tuple[str, float]: """Classify document type using OCR + LLM""" # Quick OCR for classification text = pytesseract.image_to_string(image) # Use LLM for classification classification_prompt = self._load_prompt("doc_classify") classification_result = self.llm_client.classify_document( text[:2000], classification_prompt, # First 2000 chars for classification ) return classification_result["type"], classification_result["confidence"] def _perform_ocr(self, images: list[np.ndarray]) -> list[dict]: """Perform OCR with layout analysis""" ocr_results = [] for page_num, image in enumerate(images, 1): # Get detailed OCR data with bounding boxes ocr_data = pytesseract.image_to_data( image, output_type=pytesseract.Output.DICT, config="--psm 6", # Uniform block of text ) # Extract text blocks with confidence and position blocks = [] for i in range(len(ocr_data["text"])): if int(ocr_data["conf"][i]) > 30: # Confidence threshold blocks.append( { "text": ocr_data["text"][i], "confidence": int(ocr_data["conf"][i]) / 100.0, "bbox": { "x": ocr_data["left"][i], "y": ocr_data["top"][i], "width": ocr_data["width"][i], "height": ocr_data["height"][i], }, "page": page_num, } ) # Detect tables using layout analysis tables = self._detect_tables(image, blocks) ocr_results.append( { "page": page_num, "blocks": blocks, "tables": tables, "full_text": " ".join([b["text"] for b in blocks]), } ) return ocr_results def _extract_structured_data( self, ocr_results: list[dict], classification: str, doc_id: str ) -> dict: """Extract structured data using LLM with schema constraints""" # Load appropriate extraction prompt if classification == "bank_statement": prompt = self._load_prompt("bank_statement_extract") schema = self._load_schema("bank_statement") elif classification == "invoice": prompt = self._load_prompt("invoice_extract") schema = self._load_schema("invoice") elif classification == "payslip": prompt = self._load_prompt("payslip_extract") schema = self._load_schema("payslip") else: prompt = self._load_prompt("kv_extract") schema = self._load_schema("generic") # Combine OCR results combined_text = "\n".join( [f"Page {r['page']}:\n{r['full_text']}" for r in ocr_results] ) # Extract with retry logic max_retries = 3 for attempt in range(max_retries): try: extracted = self.llm_client.extract_structured_data( combined_text, prompt, schema, temperature=0.1 if attempt == 0 else 0.3, ) # Validate against schema if self.field_validator.validate_schema(extracted, schema): return extracted else: self.logger.warning( f"Schema validation failed, attempt {attempt + 1}" ) except Exception as e: self.logger.warning( f"Extraction attempt {attempt + 1} failed: {str(e)}" ) # Fallback to basic key-value extraction return self._fallback_extraction(ocr_results) def _normalize_data(self, extracted_data: dict) -> dict: """Normalize extracted data to canonical formats""" normalized = extracted_data.copy() # Normalize currencies for field in ["amount", "gross", "net", "tax_withheld"]: if field in normalized: normalized[field] = self.currency_normalizer.normalize( normalized[field] ) # Normalize dates for field in ["date", "period_start", "period_end", "due_date"]: if field in normalized: normalized[field] = self.date_normalizer.normalize(normalized[field]) # Normalize party names for field in ["payer_name", "employer_name", "supplier_name"]: if field in normalized: normalized[field] = self.party_normalizer.normalize(normalized[field]) return normalized def _map_to_graph( self, normalized_data: dict, doc_id: str, taxpayer_id: str ) -> tuple[list[dict], list[dict]]: """Map normalized data to knowledge graph nodes and edges""" return self.graph_mapper.map_to_graph(normalized_data, doc_id, taxpayer_id) def _deskew_image(self, image: np.ndarray) -> np.ndarray: """Correct skew in scanned documents""" gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) edges = cv2.Canny(gray, 50, 150, apertureSize=3) lines = cv2.HoughLines(edges, 1, np.pi / 180, threshold=100) if lines is not None: angles = [] for rho, theta in lines[:10]: # Use first 10 lines angle = theta * 180 / np.pi if angle < 45: angles.append(angle) elif angle > 135: angles.append(angle - 180) if angles: median_angle = np.median(angles) if abs(median_angle) > 0.5: # Only rotate if significant skew (h, w) = image.shape[:2] center = (w // 2, h // 2) M = cv2.getRotationMatrix2D(center, median_angle, 1.0) return cv2.warpAffine( image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE, ) return image def _auto_rotate(self, image: np.ndarray) -> np.ndarray: """Auto-rotate image to correct orientation""" # Use Tesseract's orientation detection try: osd = pytesseract.image_to_osd(image) rotation = int( [line for line in osd.split("\n") if "Rotate:" in line][0] .split(":")[1] .strip() ) if rotation != 0: (h, w) = image.shape[:2] center = (w // 2, h // 2) M = cv2.getRotationMatrix2D(center, rotation, 1.0) return cv2.warpAffine(image, M, (w, h)) except: pass # If OSD fails, return original return image def _detect_tables(self, image: np.ndarray, blocks: list[dict]) -> list[dict]: """Detect and extract table structures""" # Simple table detection using horizontal/vertical lines gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # Detect horizontal lines horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1)) horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel) # Detect vertical lines vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40)) vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel) # Find table regions table_mask = cv2.addWeighted(horizontal_lines, 0.5, vertical_lines, 0.5, 0.0) contours, _ = cv2.findContours( table_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) tables = [] for contour in contours: x, y, w, h = cv2.boundingRect(contour) if w > 200 and h > 100: # Minimum table size # Extract text blocks within table region table_blocks = [ block for block in blocks if ( block["bbox"]["x"] >= x and block["bbox"]["y"] >= y and block["bbox"]["x"] + block["bbox"]["width"] <= x + w and block["bbox"]["y"] + block["bbox"]["height"] <= y + h ) ] tables.append( { "bbox": {"x": x, "y": y, "width": w, "height": h}, "blocks": table_blocks, } ) return tables def _load_prompt(self, prompt_name: str) -> str: """Load LLM prompt template""" prompt_path = Path(f"prompts/{prompt_name}.txt") with open(prompt_path) as f: return f.read() def _load_schema(self, schema_name: str) -> dict: """Load JSON schema for validation""" schema_path = Path(f"schemas/{schema_name}.schema.json") with open(schema_path) as f: return json.load(f) def _create_evidence_records( self, ocr_results: list[dict], doc_id: str ) -> list[dict]: """Create evidence records with provenance""" evidence = [] for page_result in ocr_results: for block in page_result["blocks"]: evidence.append( { "snippet_id": f"{doc_id}_p{page_result['page']}_{len(evidence)}", "doc_ref": doc_id, "page": page_result["page"], "bbox": block["bbox"], "text_hash": hashlib.sha256(block["text"].encode()).hexdigest(), "ocr_confidence": block["confidence"], "extracted_text": block["text"], } ) return evidence