Files
ai-tax-agent/apps/svc_kg/main.py
harkon a99754b86c
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
e2e backend test
2025-12-01 13:58:38 +02:00

236 lines
7.7 KiB
Python

import os
import sys
from typing import Any, cast
import structlog
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from pyshacl import validate
from rdflib import Graph, Literal, URIRef
from rdflib.namespace import RDF
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from libs.app_factory import create_app
from libs.config import BaseAppSettings, create_event_bus, create_neo4j_client
from libs.events import EventBus, EventPayload, EventTopics
from libs.neo import Neo4jClient
from libs.observability import get_metrics, get_tracer, setup_observability
from libs.schemas import ErrorResponse
logger = structlog.get_logger()
class KGSettings(BaseAppSettings):
"""Settings for KG service"""
service_name: str = "svc-kg"
shacl_shapes_path: str = "schemas/shapes.ttl"
# Global clients
neo4j_client: Neo4jClient | None = None
event_bus: EventBus | None = None
shapes_graph: Graph | None = None
settings: KGSettings
async def init_dependencies(app_settings: KGSettings) -> None:
"""Initialize service dependencies"""
global neo4j_client, event_bus, settings, shapes_graph
settings = app_settings
logger.info("Starting KG service")
setup_observability(settings)
neo4j_driver = create_neo4j_client(settings)
neo4j_client = Neo4jClient(neo4j_driver)
event_bus = create_event_bus(settings)
if not event_bus:
raise HTTPException(status_code=500, detail="Event bus not initialized")
await event_bus.start()
await event_bus.subscribe(EventTopics.KG_UPSERT_READY, _handle_kg_upsert_ready)
# Load SHACL shapes
try:
shapes_graph = Graph().parse(settings.shacl_shapes_path, format="turtle")
logger.info("SHACL shapes loaded successfully")
except Exception as e:
logger.error("Failed to load SHACL shapes", error=str(e))
shapes_graph = None
async def startup_event() -> None:
"""Initialize service dependencies"""
await init_dependencies(cast(KGSettings, _settings))
app, _settings = create_app(
service_name="svc-kg",
title="Tax Agent Knowledge Graph Service",
description="Service for managing and validating the Knowledge Graph",
settings_class=KGSettings,
startup_hooks=[startup_event],
)
tracer = get_tracer("svc-kg")
metrics = get_metrics()
@app.on_event("shutdown")
async def shutdown_event() -> None:
"""Cleanup service dependencies"""
global event_bus, neo4j_client
logger.info("Shutting down KG service")
if event_bus:
await event_bus.stop()
if neo4j_client:
await neo4j_client.close()
logger.info("KG service shutdown complete")
async def _handle_kg_upsert_ready(topic: str, payload: EventPayload) -> None:
"""Handle KG upsert ready events"""
data = payload.data
nodes = data.get("nodes", [])
relationships = data.get("relationships", [])
doc_id = data.get("doc_id")
tenant_id = data.get("tenant_id")
if not nodes and not relationships:
logger.warning("No nodes or relationships to upsert", data=data)
return
with tracer.start_as_current_span("upsert_kg_data") as span:
span.set_attribute("doc_id", doc_id)
span.set_attribute("tenant_id", tenant_id)
span.set_attribute("node_count", len(nodes))
span.set_attribute("relationship_count", len(relationships))
try:
# 1. Validate data against SHACL schema
conforms, validation_report = await _validate_with_shacl(
nodes, relationships
)
if not conforms:
logger.error(
"SHACL validation failed",
doc_id=doc_id,
validation_report=validation_report,
)
metrics.counter(
"kg_validation_errors_total", labelnames=["tenant_id"]
).labels(tenant_id=tenant_id).inc()
return
# 2. Write data to Neo4j
for node in nodes:
await neo4j_client.create_node(node["type"], node["properties"]) # type: ignore
for rel in relationships:
await neo4j_client.create_relationship( # type: ignore
rel["sourceId"],
rel["targetId"],
rel["type"],
rel["properties"],
)
# 3. Publish kg.upserted event
event_payload = EventPayload(
data={
"doc_id": doc_id,
"tenant_id": tenant_id,
"taxpayer_id": data.get("taxpayer_id"),
"tax_year": data.get("tax_year"),
"node_count": len(nodes),
"relationship_count": len(relationships),
"success": True,
},
actor=payload.actor,
tenant_id=str(tenant_id),
trace_id=str(span.get_span_context().trace_id),
)
await event_bus.publish(EventTopics.KG_UPSERTED, event_payload) # type: ignore
metrics.counter("kg_upserts_total", labelnames=["tenant_id"]).labels(
tenant_id=tenant_id
).inc()
logger.info("KG upsert completed", doc_id=doc_id, tenant_id=tenant_id)
except Exception as e:
logger.error("Failed to upsert KG data", doc_id=doc_id, error=str(e))
metrics.counter(
"kg_upsert_errors_total", labelnames=["tenant_id", "error_type"]
).labels(tenant_id=tenant_id, error_type=type(e).__name__).inc()
async def _validate_with_shacl(
nodes: list[dict[str, Any]], relationships: list[dict[str, Any]]
) -> tuple[bool, str]:
"""Validate data against SHACL shapes."""
if not shapes_graph:
logger.warning("SHACL shapes not loaded, skipping validation.")
return True, "SHACL shapes not loaded"
data_graph = Graph()
namespace = "http://ai-tax-agent.com/ontology/"
for node in nodes:
node_uri = URIRef(f"{namespace}{node['id']}")
data_graph.add((node_uri, RDF.type, URIRef(f"{namespace}{node['type']}")))
for key, value in node["properties"].items():
if value is not None:
data_graph.add((node_uri, URIRef(f"{namespace}{key}"), Literal(value)))
for rel in relationships:
source_uri = URIRef(f"{namespace}{rel['sourceId']}")
target_uri = URIRef(f"{namespace}{rel['targetId']}")
rel_uri = URIRef(f"{namespace}{rel['type']}")
data_graph.add((source_uri, rel_uri, target_uri))
try:
conforms, results_graph, results_text = validate(
data_graph,
shacl_graph=shapes_graph,
ont_graph=None, # No ontology graph
inference="rdfs",
abort_on_first=False,
allow_infos=False,
meta_shacl=False,
advanced=False,
js=False,
debug=False,
)
return conforms, results_text
except Exception as e:
logger.error("Error during SHACL validation", error=str(e))
return False, str(e)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""Handle HTTP exceptions with RFC7807 format"""
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
type=f"https://httpstatuses.com/{exc.status_code}",
title=exc.detail,
status=exc.status_code,
detail=exc.detail,
instance=str(request.url),
trace_id=getattr(request.state, "trace_id", None),
).model_dump(),
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8005, reload=True, log_config=None)