Files
ai-tax-agent/apps/svc_extract/main.py
harkon b324ff09ef
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
Initial commit
2025-10-11 08:41:36 +01:00

626 lines
20 KiB
Python

"""LLM-based field extraction with confidence scoring and provenance tracking."""
# FILE: apps/svc-extract/main.py
# pylint: disable=wrong-import-position,import-error,too-few-public-methods,global-statement
# pylint: disable=global-variable-not-assigned,raise-missing-from,unused-argument
# pylint: disable=broad-exception-caught,no-else-return,too-many-arguments,too-many-positional-arguments
# pylint: disable=too-many-locals,import-outside-toplevel
import os
# Import shared libraries
import sys
from datetime import datetime
from typing import Any
import structlog
import ulid
from fastapi import BackgroundTasks, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from libs.app_factory import create_app
from libs.calibration import ConfidenceCalibrator
from libs.config import BaseAppSettings, create_event_bus, create_minio_client
from libs.events import EventBus, EventPayload, EventTopics
from libs.observability import get_metrics, get_tracer, setup_observability
from libs.schemas import ErrorResponse, ExtractionRequest, ExtractionResponse
from libs.security import (
create_trusted_proxy_middleware,
get_current_user,
get_tenant_id,
)
from libs.storage import DocumentStorage, StorageClient
logger = structlog.get_logger()
class ExtractionSettings(BaseAppSettings):
"""Settings for extraction service"""
service_name: str = "svc-extract"
# LLM configuration
openai_api_key: str = ""
model_name: str = "gpt-4"
max_tokens: int = 2000
temperature: float = 0.1
# Extraction configuration
confidence_threshold: float = 0.7
max_retries: int = 3
chunk_size: int = 4000
# Prompt templates
extraction_prompt_template: str = """
Extract the following fields from this document text:
{field_definitions}
Document text:
{document_text}
Return a JSON object with the extracted fields and confidence scores.
"""
# Create app and settings
app, settings = create_app(
service_name="svc-extract",
title="Tax Agent Extraction Service",
description="LLM-based field extraction service",
settings_class=ExtractionSettings,
)
# Add middleware
middleware_factory = create_trusted_proxy_middleware(settings.internal_cidrs)
app.add_middleware(middleware_factory)
# Global clients
storage_client: StorageClient | None = None
document_storage: DocumentStorage | None = None
event_bus: EventBus | None = None
confidence_calibrator: ConfidenceCalibrator | None = None
tracer = get_tracer("svc-extract")
metrics = get_metrics()
@app.on_event("startup")
async def startup_event() -> None:
"""Initialize service dependencies"""
global storage_client, document_storage, event_bus, confidence_calibrator
logger.info("Starting extraction service")
# Setup observability
setup_observability(settings)
# Initialize MinIO client
minio_client = create_minio_client(settings)
storage_client = StorageClient(minio_client)
document_storage = DocumentStorage(storage_client)
# Initialize event bus
event_bus = create_event_bus(settings)
if not event_bus:
raise Exception("Event bus not initialized")
await event_bus.start()
# Subscribe to OCR completion events
await event_bus.subscribe(EventTopics.DOC_OCR_READY, _handle_ocr_ready)
# Initialize confidence calibrator
confidence_calibrator = ConfidenceCalibrator(method="temperature")
logger.info("Extraction service started successfully")
@app.on_event("shutdown")
async def shutdown_event() -> None:
"""Cleanup service dependencies"""
global event_bus
logger.info("Shutting down extraction service")
if event_bus:
await event_bus.stop()
logger.info("Extraction service shutdown complete")
@app.get("/healthz")
async def health_check() -> dict[str, Any]:
"""Health check endpoint"""
return {
"status": "healthy",
"service": settings.service_name,
"version": settings.service_version,
"timestamp": datetime.utcnow().isoformat(),
}
@app.get("/readyz")
async def readiness_check() -> dict[str, Any]:
"""Readiness check endpoint"""
return {
"status": "ready",
"service": settings.service_name,
"version": settings.service_version,
"timestamp": datetime.utcnow().isoformat(),
}
@app.get("/livez")
async def liveness_check() -> dict[str, Any]:
"""Liveness check endpoint"""
return {
"status": "alive",
"service": settings.service_name,
"version": settings.service_version,
"timestamp": datetime.utcnow().isoformat(),
}
@app.post("/extract/{doc_id}", response_model=ExtractionResponse)
async def extract_fields(
doc_id: str,
request_data: ExtractionRequest,
background_tasks: BackgroundTasks,
current_user: dict[str, Any] = Depends(get_current_user()),
tenant_id: str = Depends(get_tenant_id()),
) -> ExtractionResponse:
"""Extract fields from document"""
with tracer.start_as_current_span("extract_fields") as span:
span.set_attribute("doc_id", doc_id)
span.set_attribute("tenant_id", tenant_id)
span.set_attribute("strategy", request_data.strategy)
try:
# Check if OCR results exist
ocr_results = (
await document_storage.get_ocr_result(tenant_id, doc_id)
if document_storage
else None
)
if not ocr_results:
raise HTTPException(status_code=404, detail="OCR results not found")
# Generate extraction ID
extraction_id = str(ulid.new())
span.set_attribute("extraction_id", extraction_id)
# Start background extraction
background_tasks.add_task(
_extract_fields_async,
doc_id,
tenant_id,
ocr_results,
request_data.strategy,
extraction_id,
current_user.get("sub", "system"),
)
logger.info(
"Field extraction started", doc_id=doc_id, extraction_id=extraction_id
)
return ExtractionResponse(
extraction_id=extraction_id,
confidence=0.0, # Will be updated when processing completes
extracted_fields={},
provenance=[],
)
except HTTPException:
raise
except Exception as e:
logger.error("Failed to start extraction", doc_id=doc_id, error=str(e))
raise HTTPException(status_code=500, detail="Failed to start extraction")
@app.get("/results/{doc_id}")
async def get_extraction_results(
doc_id: str,
current_user: dict[str, Any] = Depends(get_current_user()),
tenant_id: str = Depends(get_tenant_id()),
) -> ExtractionResponse:
"""Get extraction results for document"""
with tracer.start_as_current_span("get_extraction_results") as span:
span.set_attribute("doc_id", doc_id)
span.set_attribute("tenant_id", tenant_id)
try:
# Get extraction results from storage
extraction_results = (
await document_storage.get_extraction_result(tenant_id, doc_id)
if document_storage
else None
)
if not extraction_results:
raise HTTPException(
status_code=404, detail="Extraction results not found"
)
# pylint: disable-next=not-a-mapping
return ExtractionResponse(**extraction_results)
except HTTPException:
raise
except Exception as e:
logger.error(
"Failed to get extraction results", doc_id=doc_id, error=str(e)
)
raise HTTPException(
status_code=500, detail="Failed to get extraction results"
)
async def _handle_ocr_ready(topic: str, payload: EventPayload) -> None:
"""Handle OCR completion events"""
try:
data = payload.data
doc_id = data.get("doc_id")
tenant_id = data.get("tenant_id")
if not doc_id or not tenant_id:
logger.warning("Invalid OCR ready event", data=data)
return
logger.info("Auto-extracting fields from OCR results", doc_id=doc_id)
# Get OCR results
ocr_results = data.get("ocr_results")
if not ocr_results:
ocr_results = (
await document_storage.get_ocr_result(tenant_id, doc_id)
if document_storage
else None
)
if ocr_results:
await _extract_fields_async(
doc_id=doc_id,
tenant_id=tenant_id,
ocr_results=ocr_results,
strategy="hybrid",
extraction_id=str(ulid.new()),
actor=payload.actor,
)
except Exception as e:
logger.error("Failed to handle OCR ready event", error=str(e))
async def _extract_fields_async(
doc_id: str,
tenant_id: str,
ocr_results: dict[str, Any],
strategy: str,
extraction_id: str,
actor: str,
) -> None:
"""Extract fields asynchronously"""
with tracer.start_as_current_span("extract_fields_async") as span:
span.set_attribute("doc_id", doc_id)
span.set_attribute("extraction_id", extraction_id)
span.set_attribute("strategy", strategy)
try:
# Extract text from OCR results
document_text = _extract_text_from_ocr(ocr_results)
# Determine field definitions based on document type
field_definitions = _get_field_definitions(doc_id, document_text)
# Perform extraction
if strategy == "llm":
extracted_fields, confidence, provenance = await _extract_with_llm(
document_text, field_definitions, ocr_results
)
elif strategy == "rules":
extracted_fields, confidence, provenance = await _extract_with_rules(
document_text, field_definitions, ocr_results
)
elif strategy == "hybrid":
# Combine LLM and rules-based extraction
llm_fields, llm_conf, llm_prov = await _extract_with_llm(
document_text, field_definitions, ocr_results
)
rules_fields, rules_conf, rules_prov = await _extract_with_rules(
document_text, field_definitions, ocr_results
)
extracted_fields, confidence, provenance = _merge_extractions(
llm_fields, llm_conf, llm_prov, rules_fields, rules_conf, rules_prov
)
else:
raise ValueError(f"Unknown strategy: {strategy}")
# Calibrate confidence
if confidence_calibrator and confidence_calibrator.is_fitted:
calibrated_confidence = confidence_calibrator.calibrate([confidence])[0]
else:
calibrated_confidence = confidence
# Create extraction results
extraction_results = {
"doc_id": doc_id,
"extraction_id": extraction_id,
"strategy": strategy,
"extracted_at": datetime.utcnow().isoformat(),
"confidence": calibrated_confidence,
"raw_confidence": confidence,
"extracted_fields": extracted_fields,
"provenance": provenance,
"field_count": len(extracted_fields),
}
# Store results
if document_storage:
await document_storage.store_extraction_result(
tenant_id, doc_id, extraction_results
)
# Update metrics
metrics.counter("extractions_completed_total").labels(
tenant_id=tenant_id, strategy=strategy
).inc()
metrics.histogram("extraction_confidence").labels(
strategy=strategy
).observe(calibrated_confidence)
# Publish completion event
event_payload = EventPayload(
data={
"doc_id": doc_id,
"tenant_id": tenant_id,
"extraction_id": extraction_id,
"strategy": strategy,
"confidence": calibrated_confidence,
"field_count": len(extracted_fields),
"extraction_results": extraction_results,
},
actor=actor,
tenant_id=tenant_id,
)
if event_bus:
await event_bus.publish(EventTopics.DOC_EXTRACTED, event_payload)
logger.info(
"Field extraction completed",
doc_id=doc_id,
fields=len(extracted_fields),
confidence=calibrated_confidence,
)
except Exception as e:
logger.error("Field extraction failed", doc_id=doc_id, error=str(e))
# Update error metrics
metrics.counter("extraction_errors_total").labels(
tenant_id=tenant_id, strategy=strategy, error_type=type(e).__name__
).inc()
def _extract_text_from_ocr(ocr_results: dict[str, Any]) -> str:
"""Extract text from OCR results"""
text_parts = []
for page in ocr_results.get("pages", []):
if "text" in page:
text_parts.append(page["text"])
elif "tesseract" in page and "text" in page["tesseract"]:
text_parts.append(page["tesseract"]["text"])
return "\n\n".join(text_parts)
def _get_field_definitions(doc_id: str, document_text: str) -> dict[str, str]:
"""Get field definitions based on document type"""
# Analyze document text to determine type
text_lower = document_text.lower()
if "invoice" in text_lower or "bill" in text_lower:
return {
"invoice_number": "Invoice or bill number",
"date": "Invoice date",
"supplier_name": "Supplier or vendor name",
"total_amount": "Total amount including VAT",
"net_amount": "Net amount excluding VAT",
"vat_amount": "VAT amount",
"description": "Description of goods or services",
}
elif "bank statement" in text_lower or "account statement" in text_lower:
return {
"account_number": "Bank account number",
"sort_code": "Bank sort code",
"statement_period": "Statement period",
"opening_balance": "Opening balance",
"closing_balance": "Closing balance",
"transactions": "List of transactions",
}
elif "receipt" in text_lower:
return {
"merchant_name": "Merchant or store name",
"date": "Receipt date",
"total_amount": "Total amount paid",
"payment_method": "Payment method used",
"items": "List of items purchased",
}
else:
# Generic fields
return {
"date": "Any dates mentioned",
"amount": "Any monetary amounts",
"names": "Any person or company names",
"addresses": "Any addresses",
"reference_numbers": "Any reference or account numbers",
}
async def _extract_with_llm(
document_text: str, field_definitions: dict[str, str], ocr_results: dict[str, Any]
) -> tuple[dict[str, Any], float, list[dict[str, Any]]]:
"""Extract fields using LLM"""
try:
# This would integrate with OpenAI API
# For now, return mock extraction
logger.warning("LLM extraction not implemented, using mock data")
extracted_fields = {}
provenance = []
# Mock extraction based on field definitions
for field_name, _field_desc in field_definitions.items():
if "amount" in field_name.lower():
extracted_fields[field_name] = "£1,234.56"
elif "date" in field_name.lower():
extracted_fields[field_name] = "2024-01-15"
elif "name" in field_name.lower():
extracted_fields[field_name] = "Example Company Ltd"
else:
extracted_fields[field_name] = f"Mock {field_name}"
# Add provenance
provenance.append(
{
"field": field_name,
"value": extracted_fields[field_name],
"confidence": 0.8,
"source": "llm",
"page": 1,
"bbox": [100, 100, 200, 120],
}
)
return extracted_fields, 0.8, provenance
except Exception as e:
logger.error("LLM extraction failed", error=str(e))
return {}, 0.0, []
async def _extract_with_rules(
document_text: str, field_definitions: dict[str, str], ocr_results: dict[str, Any]
) -> tuple[dict[str, Any], float, list[dict[str, Any]]]:
"""Extract fields using rules-based approach"""
import re
extracted_fields = {}
provenance = []
# Define extraction patterns
patterns = {
"amount": r"£\d{1,3}(?:,\d{3})*(?:\.\d{2})?",
"date": r"\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b",
"invoice_number": r"(?:invoice|inv|bill)\s*#?\s*(\w+)",
"account_number": r"\b\d{8}\b",
"sort_code": r"\b\d{2}-\d{2}-\d{2}\b",
}
for field_name, _field_desc in field_definitions.items():
# Find matching pattern
pattern_key = None
for key in patterns:
if key in field_name.lower():
pattern_key = key
break
if pattern_key:
pattern = patterns[pattern_key]
matches = re.finditer(pattern, document_text, re.IGNORECASE)
for match in matches:
value = match.group(1) if match.groups() else match.group(0)
extracted_fields[field_name] = value
provenance.append(
{
"field": field_name,
"value": value,
"confidence": 0.9,
"source": "rules",
"pattern": pattern,
"match_start": match.start(),
"match_end": match.end(),
}
)
break # Take first match
confidence = 0.9 if extracted_fields else 0.0
return extracted_fields, confidence, provenance
def _merge_extractions(
llm_fields: dict[str, Any],
llm_conf: float,
llm_prov: list[dict[str, Any]],
rules_fields: dict[str, Any],
rules_conf: float,
rules_prov: list[dict[str, Any]],
) -> tuple[dict[str, Any], float, list[dict[str, Any]]]:
"""Merge LLM and rules-based extractions"""
merged_fields = {}
merged_provenance = []
# Get all field names
all_fields = set(llm_fields.keys()) | set(rules_fields.keys())
for field in all_fields:
llm_value = llm_fields.get(field)
rules_value = rules_fields.get(field)
# Prefer rules-based extraction for structured fields
if rules_value and field in ["amount", "date", "account_number", "sort_code"]:
merged_fields[field] = rules_value
# Find provenance for this field
for prov in rules_prov:
if prov["field"] == field:
merged_provenance.append(prov)
break
elif llm_value:
merged_fields[field] = llm_value
# Find provenance for this field
for prov in llm_prov:
if prov["field"] == field:
merged_provenance.append(prov)
break
# Calculate combined confidence
combined_confidence = (llm_conf + rules_conf) / 2
return merged_fields, combined_confidence, merged_provenance
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""Handle HTTP exceptions with RFC7807 format"""
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
type=f"https://httpstatuses.com/{exc.status_code}",
title=exc.detail,
status=exc.status_code,
detail=exc.detail,
instance=str(request.url),
trace_id=getattr(request.state, "trace_id", None),
).model_dump(),
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8003, reload=True, log_config=None)