"""LLM-based field extraction with confidence scoring and provenance tracking.""" # FILE: apps/svc-extract/main.py # pylint: disable=wrong-import-position,import-error,too-few-public-methods,global-statement # pylint: disable=global-variable-not-assigned,raise-missing-from,unused-argument # pylint: disable=broad-exception-caught,no-else-return,too-many-arguments,too-many-positional-arguments # pylint: disable=too-many-locals,import-outside-toplevel import os # Import shared libraries import sys from datetime import datetime from typing import Any import structlog import ulid from fastapi import BackgroundTasks, Depends, HTTPException, Request from fastapi.responses import JSONResponse sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from libs.app_factory import create_app from libs.calibration import ConfidenceCalibrator from libs.config import BaseAppSettings, create_event_bus, create_minio_client from libs.events import EventBus, EventPayload, EventTopics from libs.observability import get_metrics, get_tracer, setup_observability from libs.schemas import ErrorResponse, ExtractionRequest, ExtractionResponse from libs.security import ( create_trusted_proxy_middleware, get_current_user, get_tenant_id, ) from libs.storage import DocumentStorage, StorageClient logger = structlog.get_logger() class ExtractionSettings(BaseAppSettings): """Settings for extraction service""" service_name: str = "svc-extract" # LLM configuration openai_api_key: str = "" model_name: str = "gpt-4" max_tokens: int = 2000 temperature: float = 0.1 # Extraction configuration confidence_threshold: float = 0.7 max_retries: int = 3 chunk_size: int = 4000 # Prompt templates extraction_prompt_template: str = """ Extract the following fields from this document text: {field_definitions} Document text: {document_text} Return a JSON object with the extracted fields and confidence scores. """ # Create app and settings app, settings = create_app( service_name="svc-extract", title="Tax Agent Extraction Service", description="LLM-based field extraction service", settings_class=ExtractionSettings, ) # Add middleware middleware_factory = create_trusted_proxy_middleware(settings.internal_cidrs) app.add_middleware(middleware_factory) # Global clients storage_client: StorageClient | None = None document_storage: DocumentStorage | None = None event_bus: EventBus | None = None confidence_calibrator: ConfidenceCalibrator | None = None tracer = get_tracer("svc-extract") metrics = get_metrics() @app.on_event("startup") async def startup_event() -> None: """Initialize service dependencies""" global storage_client, document_storage, event_bus, confidence_calibrator logger.info("Starting extraction service") # Setup observability setup_observability(settings) # Initialize MinIO client minio_client = create_minio_client(settings) storage_client = StorageClient(minio_client) document_storage = DocumentStorage(storage_client) # Initialize event bus event_bus = create_event_bus(settings) if not event_bus: raise Exception("Event bus not initialized") await event_bus.start() # Subscribe to OCR completion events await event_bus.subscribe(EventTopics.DOC_OCR_READY, _handle_ocr_ready) # Initialize confidence calibrator confidence_calibrator = ConfidenceCalibrator(method="temperature") logger.info("Extraction service started successfully") @app.on_event("shutdown") async def shutdown_event() -> None: """Cleanup service dependencies""" global event_bus logger.info("Shutting down extraction service") if event_bus: await event_bus.stop() logger.info("Extraction service shutdown complete") @app.get("/healthz") async def health_check() -> dict[str, Any]: """Health check endpoint""" return { "status": "healthy", "service": settings.service_name, "version": settings.service_version, "timestamp": datetime.utcnow().isoformat(), } @app.get("/readyz") async def readiness_check() -> dict[str, Any]: """Readiness check endpoint""" return { "status": "ready", "service": settings.service_name, "version": settings.service_version, "timestamp": datetime.utcnow().isoformat(), } @app.get("/livez") async def liveness_check() -> dict[str, Any]: """Liveness check endpoint""" return { "status": "alive", "service": settings.service_name, "version": settings.service_version, "timestamp": datetime.utcnow().isoformat(), } @app.post("/extract/{doc_id}", response_model=ExtractionResponse) async def extract_fields( doc_id: str, request_data: ExtractionRequest, background_tasks: BackgroundTasks, current_user: dict[str, Any] = Depends(get_current_user()), tenant_id: str = Depends(get_tenant_id()), ) -> ExtractionResponse: """Extract fields from document""" with tracer.start_as_current_span("extract_fields") as span: span.set_attribute("doc_id", doc_id) span.set_attribute("tenant_id", tenant_id) span.set_attribute("strategy", request_data.strategy) try: # Check if OCR results exist ocr_results = ( await document_storage.get_ocr_result(tenant_id, doc_id) if document_storage else None ) if not ocr_results: raise HTTPException(status_code=404, detail="OCR results not found") # Generate extraction ID extraction_id = str(ulid.new()) span.set_attribute("extraction_id", extraction_id) # Start background extraction background_tasks.add_task( _extract_fields_async, doc_id, tenant_id, ocr_results, request_data.strategy, extraction_id, current_user.get("sub", "system"), ) logger.info( "Field extraction started", doc_id=doc_id, extraction_id=extraction_id ) return ExtractionResponse( extraction_id=extraction_id, confidence=0.0, # Will be updated when processing completes extracted_fields={}, provenance=[], ) except HTTPException: raise except Exception as e: logger.error("Failed to start extraction", doc_id=doc_id, error=str(e)) raise HTTPException(status_code=500, detail="Failed to start extraction") @app.get("/results/{doc_id}") async def get_extraction_results( doc_id: str, current_user: dict[str, Any] = Depends(get_current_user()), tenant_id: str = Depends(get_tenant_id()), ) -> ExtractionResponse: """Get extraction results for document""" with tracer.start_as_current_span("get_extraction_results") as span: span.set_attribute("doc_id", doc_id) span.set_attribute("tenant_id", tenant_id) try: # Get extraction results from storage extraction_results = ( await document_storage.get_extraction_result(tenant_id, doc_id) if document_storage else None ) if not extraction_results: raise HTTPException( status_code=404, detail="Extraction results not found" ) # pylint: disable-next=not-a-mapping return ExtractionResponse(**extraction_results) except HTTPException: raise except Exception as e: logger.error( "Failed to get extraction results", doc_id=doc_id, error=str(e) ) raise HTTPException( status_code=500, detail="Failed to get extraction results" ) async def _handle_ocr_ready(topic: str, payload: EventPayload) -> None: """Handle OCR completion events""" try: data = payload.data doc_id = data.get("doc_id") tenant_id = data.get("tenant_id") if not doc_id or not tenant_id: logger.warning("Invalid OCR ready event", data=data) return logger.info("Auto-extracting fields from OCR results", doc_id=doc_id) # Get OCR results ocr_results = data.get("ocr_results") if not ocr_results: ocr_results = ( await document_storage.get_ocr_result(tenant_id, doc_id) if document_storage else None ) if ocr_results: await _extract_fields_async( doc_id=doc_id, tenant_id=tenant_id, ocr_results=ocr_results, strategy="hybrid", extraction_id=str(ulid.new()), actor=payload.actor, ) except Exception as e: logger.error("Failed to handle OCR ready event", error=str(e)) async def _extract_fields_async( doc_id: str, tenant_id: str, ocr_results: dict[str, Any], strategy: str, extraction_id: str, actor: str, ) -> None: """Extract fields asynchronously""" with tracer.start_as_current_span("extract_fields_async") as span: span.set_attribute("doc_id", doc_id) span.set_attribute("extraction_id", extraction_id) span.set_attribute("strategy", strategy) try: # Extract text from OCR results document_text = _extract_text_from_ocr(ocr_results) # Determine field definitions based on document type field_definitions = _get_field_definitions(doc_id, document_text) # Perform extraction if strategy == "llm": extracted_fields, confidence, provenance = await _extract_with_llm( document_text, field_definitions, ocr_results ) elif strategy == "rules": extracted_fields, confidence, provenance = await _extract_with_rules( document_text, field_definitions, ocr_results ) elif strategy == "hybrid": # Combine LLM and rules-based extraction llm_fields, llm_conf, llm_prov = await _extract_with_llm( document_text, field_definitions, ocr_results ) rules_fields, rules_conf, rules_prov = await _extract_with_rules( document_text, field_definitions, ocr_results ) extracted_fields, confidence, provenance = _merge_extractions( llm_fields, llm_conf, llm_prov, rules_fields, rules_conf, rules_prov ) else: raise ValueError(f"Unknown strategy: {strategy}") # Calibrate confidence if confidence_calibrator and confidence_calibrator.is_fitted: calibrated_confidence = confidence_calibrator.calibrate([confidence])[0] else: calibrated_confidence = confidence # Create extraction results extraction_results = { "doc_id": doc_id, "extraction_id": extraction_id, "strategy": strategy, "extracted_at": datetime.utcnow().isoformat(), "confidence": calibrated_confidence, "raw_confidence": confidence, "extracted_fields": extracted_fields, "provenance": provenance, "field_count": len(extracted_fields), } # Store results if document_storage: await document_storage.store_extraction_result( tenant_id, doc_id, extraction_results ) # Update metrics metrics.counter("extractions_completed_total").labels( tenant_id=tenant_id, strategy=strategy ).inc() metrics.histogram("extraction_confidence").labels( strategy=strategy ).observe(calibrated_confidence) # Publish completion event event_payload = EventPayload( data={ "doc_id": doc_id, "tenant_id": tenant_id, "extraction_id": extraction_id, "strategy": strategy, "confidence": calibrated_confidence, "field_count": len(extracted_fields), "extraction_results": extraction_results, }, actor=actor, tenant_id=tenant_id, ) if event_bus: await event_bus.publish(EventTopics.DOC_EXTRACTED, event_payload) logger.info( "Field extraction completed", doc_id=doc_id, fields=len(extracted_fields), confidence=calibrated_confidence, ) except Exception as e: logger.error("Field extraction failed", doc_id=doc_id, error=str(e)) # Update error metrics metrics.counter("extraction_errors_total").labels( tenant_id=tenant_id, strategy=strategy, error_type=type(e).__name__ ).inc() def _extract_text_from_ocr(ocr_results: dict[str, Any]) -> str: """Extract text from OCR results""" text_parts = [] for page in ocr_results.get("pages", []): if "text" in page: text_parts.append(page["text"]) elif "tesseract" in page and "text" in page["tesseract"]: text_parts.append(page["tesseract"]["text"]) return "\n\n".join(text_parts) def _get_field_definitions(doc_id: str, document_text: str) -> dict[str, str]: """Get field definitions based on document type""" # Analyze document text to determine type text_lower = document_text.lower() if "invoice" in text_lower or "bill" in text_lower: return { "invoice_number": "Invoice or bill number", "date": "Invoice date", "supplier_name": "Supplier or vendor name", "total_amount": "Total amount including VAT", "net_amount": "Net amount excluding VAT", "vat_amount": "VAT amount", "description": "Description of goods or services", } elif "bank statement" in text_lower or "account statement" in text_lower: return { "account_number": "Bank account number", "sort_code": "Bank sort code", "statement_period": "Statement period", "opening_balance": "Opening balance", "closing_balance": "Closing balance", "transactions": "List of transactions", } elif "receipt" in text_lower: return { "merchant_name": "Merchant or store name", "date": "Receipt date", "total_amount": "Total amount paid", "payment_method": "Payment method used", "items": "List of items purchased", } else: # Generic fields return { "date": "Any dates mentioned", "amount": "Any monetary amounts", "names": "Any person or company names", "addresses": "Any addresses", "reference_numbers": "Any reference or account numbers", } async def _extract_with_llm( document_text: str, field_definitions: dict[str, str], ocr_results: dict[str, Any] ) -> tuple[dict[str, Any], float, list[dict[str, Any]]]: """Extract fields using LLM""" try: # This would integrate with OpenAI API # For now, return mock extraction logger.warning("LLM extraction not implemented, using mock data") extracted_fields = {} provenance = [] # Mock extraction based on field definitions for field_name, _field_desc in field_definitions.items(): if "amount" in field_name.lower(): extracted_fields[field_name] = "£1,234.56" elif "date" in field_name.lower(): extracted_fields[field_name] = "2024-01-15" elif "name" in field_name.lower(): extracted_fields[field_name] = "Example Company Ltd" else: extracted_fields[field_name] = f"Mock {field_name}" # Add provenance provenance.append( { "field": field_name, "value": extracted_fields[field_name], "confidence": 0.8, "source": "llm", "page": 1, "bbox": [100, 100, 200, 120], } ) return extracted_fields, 0.8, provenance except Exception as e: logger.error("LLM extraction failed", error=str(e)) return {}, 0.0, [] async def _extract_with_rules( document_text: str, field_definitions: dict[str, str], ocr_results: dict[str, Any] ) -> tuple[dict[str, Any], float, list[dict[str, Any]]]: """Extract fields using rules-based approach""" import re extracted_fields = {} provenance = [] # Define extraction patterns patterns = { "amount": r"£\d{1,3}(?:,\d{3})*(?:\.\d{2})?", "date": r"\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b", "invoice_number": r"(?:invoice|inv|bill)\s*#?\s*(\w+)", "account_number": r"\b\d{8}\b", "sort_code": r"\b\d{2}-\d{2}-\d{2}\b", } for field_name, _field_desc in field_definitions.items(): # Find matching pattern pattern_key = None for key in patterns: if key in field_name.lower(): pattern_key = key break if pattern_key: pattern = patterns[pattern_key] matches = re.finditer(pattern, document_text, re.IGNORECASE) for match in matches: value = match.group(1) if match.groups() else match.group(0) extracted_fields[field_name] = value provenance.append( { "field": field_name, "value": value, "confidence": 0.9, "source": "rules", "pattern": pattern, "match_start": match.start(), "match_end": match.end(), } ) break # Take first match confidence = 0.9 if extracted_fields else 0.0 return extracted_fields, confidence, provenance def _merge_extractions( llm_fields: dict[str, Any], llm_conf: float, llm_prov: list[dict[str, Any]], rules_fields: dict[str, Any], rules_conf: float, rules_prov: list[dict[str, Any]], ) -> tuple[dict[str, Any], float, list[dict[str, Any]]]: """Merge LLM and rules-based extractions""" merged_fields = {} merged_provenance = [] # Get all field names all_fields = set(llm_fields.keys()) | set(rules_fields.keys()) for field in all_fields: llm_value = llm_fields.get(field) rules_value = rules_fields.get(field) # Prefer rules-based extraction for structured fields if rules_value and field in ["amount", "date", "account_number", "sort_code"]: merged_fields[field] = rules_value # Find provenance for this field for prov in rules_prov: if prov["field"] == field: merged_provenance.append(prov) break elif llm_value: merged_fields[field] = llm_value # Find provenance for this field for prov in llm_prov: if prov["field"] == field: merged_provenance.append(prov) break # Calculate combined confidence combined_confidence = (llm_conf + rules_conf) / 2 return merged_fields, combined_confidence, merged_provenance @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: """Handle HTTP exceptions with RFC7807 format""" return JSONResponse( status_code=exc.status_code, content=ErrorResponse( type=f"https://httpstatuses.com/{exc.status_code}", title=exc.detail, status=exc.status_code, detail=exc.detail, instance=str(request.url), trace_id=getattr(request.state, "trace_id", None), ).model_dump(), ) if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8003, reload=True, log_config=None)