import os import sys from typing import Any, cast import structlog from fastapi import HTTPException, Request from fastapi.responses import JSONResponse from pyshacl import validate from rdflib import Graph, Literal, URIRef from rdflib.namespace import RDF sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from libs.app_factory import create_app from libs.config import BaseAppSettings, create_event_bus, create_neo4j_client from libs.events import EventBus, EventPayload, EventTopics from libs.neo import Neo4jClient from libs.observability import get_metrics, get_tracer, setup_observability from libs.schemas import ErrorResponse logger = structlog.get_logger() class KGSettings(BaseAppSettings): """Settings for KG service""" service_name: str = "svc-kg" shacl_shapes_path: str = "schemas/shapes.ttl" # Global clients neo4j_client: Neo4jClient | None = None event_bus: EventBus | None = None shapes_graph: Graph | None = None settings: KGSettings async def init_dependencies(app_settings: KGSettings) -> None: """Initialize service dependencies""" global neo4j_client, event_bus, settings, shapes_graph settings = app_settings logger.info("Starting KG service") setup_observability(settings) neo4j_driver = create_neo4j_client(settings) neo4j_client = Neo4jClient(neo4j_driver) event_bus = create_event_bus(settings) if not event_bus: raise HTTPException(status_code=500, detail="Event bus not initialized") await event_bus.start() await event_bus.subscribe(EventTopics.KG_UPSERT_READY, _handle_kg_upsert_ready) # Load SHACL shapes try: shapes_graph = Graph().parse(settings.shacl_shapes_path, format="turtle") logger.info("SHACL shapes loaded successfully") except Exception as e: logger.error("Failed to load SHACL shapes", error=str(e)) shapes_graph = None app, _settings = create_app( service_name="svc-kg", title="Tax Agent Knowledge Graph Service", description="Service for managing and validating the Knowledge Graph", settings_class=KGSettings, ) # Initialize dependencies immediately @app.on_event("startup") async def startup_event(): await init_dependencies(cast(KGSettings, _settings)) tracer = get_tracer("svc-kg") metrics = get_metrics() @app.on_event("shutdown") async def shutdown_event() -> None: """Cleanup service dependencies""" global event_bus, neo4j_client logger.info("Shutting down KG service") if event_bus: await event_bus.stop() if neo4j_client: await neo4j_client.close() logger.info("KG service shutdown complete") async def _handle_kg_upsert_ready(topic: str, payload: EventPayload) -> None: """Handle KG upsert ready events""" data = payload.data nodes = data.get("nodes", []) relationships = data.get("relationships", []) document_id = data.get("document_id") tenant_id = data.get("tenant_id") if not nodes and not relationships: logger.warning("No nodes or relationships to upsert", data=data) return with tracer.start_as_current_span("upsert_kg_data") as span: span.set_attribute("document_id", document_id) span.set_attribute("tenant_id", tenant_id) span.set_attribute("node_count", len(nodes)) span.set_attribute("relationship_count", len(relationships)) try: # 1. Validate data against SHACL schema conforms, validation_report = await _validate_with_shacl( nodes, relationships ) if not conforms: logger.error( "SHACL validation failed", document_id=document_id, validation_report=validation_report, ) metrics.counter("kg_validation_errors_total").labels( tenant_id=tenant_id ).inc() return # 2. Write data to Neo4j for node in nodes: await neo4j_client.create_node(node["type"], node["properties"]) # type: ignore for rel in relationships: await neo4j_client.create_relationship( # type: ignore rel["sourceId"], rel["targetId"], rel["type"], rel["properties"], ) # 3. Publish kg.upserted event event_payload = EventPayload( data={ "document_id": document_id, "tenant_id": tenant_id, "taxpayer_id": data.get("taxpayer_id"), "tax_year": data.get("tax_year"), "node_count": len(nodes), "relationship_count": len(relationships), }, actor=payload.actor, tenant_id=tenant_id, trace_id=str(span.get_span_context().trace_id), ) await event_bus.publish(EventTopics.KG_UPSERTED, event_payload) # type: ignore metrics.counter("kg_upserts_total").labels(tenant_id=tenant_id).inc() logger.info( "KG upsert completed", document_id=document_id, tenant_id=tenant_id ) except Exception as e: logger.error( "Failed to upsert KG data", document_id=document_id, error=str(e) ) metrics.counter("kg_upsert_errors_total").labels( tenant_id=tenant_id, error_type=type(e).__name__ ).inc() async def _validate_with_shacl( nodes: list[dict[str, Any]], relationships: list[dict[str, Any]] ) -> tuple[bool, str]: """Validate data against SHACL shapes.""" if not shapes_graph: logger.warning("SHACL shapes not loaded, skipping validation.") return True, "SHACL shapes not loaded" data_graph = Graph() namespace = "http://ai-tax-agent.com/ontology/" for node in nodes: node_uri = URIRef(f"{namespace}{node['id']}") data_graph.add((node_uri, RDF.type, URIRef(f"{namespace}{node['type']}"))) for key, value in node["properties"].items(): if value is not None: data_graph.add((node_uri, URIRef(f"{namespace}{key}"), Literal(value))) for rel in relationships: source_uri = URIRef(f"{namespace}{rel['sourceId']}") target_uri = URIRef(f"{namespace}{rel['targetId']}") rel_uri = URIRef(f"{namespace}{rel['type']}") data_graph.add((source_uri, rel_uri, target_uri)) try: conforms, results_graph, results_text = validate( data_graph, shacl_graph=shapes_graph, ont_graph=None, # No ontology graph inference="rdfs", abort_on_first=False, allow_infos=False, meta_shacl=False, advanced=False, js=False, debug=False, ) return conforms, results_text except Exception as e: logger.error("Error during SHACL validation", error=str(e)) return False, str(e) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: """Handle HTTP exceptions with RFC7807 format""" return JSONResponse( status_code=exc.status_code, content=ErrorResponse( type=f"https://httpstatuses.com/{exc.status_code}", title=exc.detail, status=exc.status_code, detail=exc.detail, instance=str(request.url), trace_id=getattr(request.state, "trace_id", None), ).model_dump(), ) if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8005, reload=True, log_config=None)