Files
ai-tax-agent/libs/neo/client.py
harkon eea46ac89c
Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
deployment, linting and infra configuration
2025-10-14 07:42:31 +01:00

351 lines
12 KiB
Python

"""Neo4j session helpers, Cypher runner with retry, SHACL validator invoker."""
import asyncio
from datetime import datetime
from typing import Any
import structlog
from neo4j import Transaction
from neo4j.exceptions import ServiceUnavailable, TransientError
logger = structlog.get_logger()
class Neo4jClient:
"""Neo4j client with session management and retry logic"""
def __init__(self, driver: Any) -> None:
self.driver = driver
async def __aenter__(self) -> "Neo4jClient":
"""Async context manager entry"""
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Async context manager exit"""
await self.close()
async def close(self) -> None:
"""Close the driver"""
await asyncio.get_event_loop().run_in_executor(None, self.driver.close)
async def run_query(
self,
query: str,
parameters: dict[str, Any] | None = None,
database: str = "neo4j",
max_retries: int = 3,
) -> list[dict[str, Any]]:
"""Run Cypher query with retry logic"""
def _run_query() -> list[dict[str, Any]]:
with self.driver.session(database=database) as session:
result = session.run(query, parameters or {})
return [record.data() for record in result]
for attempt in range(max_retries):
try:
return await asyncio.get_event_loop().run_in_executor(None, _run_query)
except (TransientError, ServiceUnavailable) as e:
if attempt == max_retries - 1:
logger.error(
"Query failed after retries",
query=query[:100],
attempt=attempt + 1,
error=str(e),
)
raise
wait_time = 2**attempt # Exponential backoff
logger.warning(
"Query failed, retrying",
query=query[:100],
attempt=attempt + 1,
wait_time=wait_time,
error=str(e),
)
await asyncio.sleep(wait_time)
except Exception as e:
logger.error(
"Query failed with non-retryable error",
query=query[:100],
error=str(e),
)
raise
# This should never be reached due to the raise statements above
return []
async def run_transaction(
self, transaction_func: Any, database: str = "neo4j", max_retries: int = 3
) -> Any:
"""Run transaction with retry logic"""
def _run_transaction() -> Any:
with self.driver.session(database=database) as session:
return session.execute_write(transaction_func)
for attempt in range(max_retries):
try:
return await asyncio.get_event_loop().run_in_executor(
None, _run_transaction
)
except (TransientError, ServiceUnavailable) as e:
if attempt == max_retries - 1:
logger.error(
"Transaction failed after retries",
attempt=attempt + 1,
error=str(e),
)
raise
wait_time = 2**attempt
logger.warning(
"Transaction failed, retrying",
attempt=attempt + 1,
wait_time=wait_time,
error=str(e),
)
await asyncio.sleep(wait_time)
except Exception as e:
logger.error(
"Transaction failed with non-retryable error", error=str(e)
)
raise
async def create_node(
self, label: str, properties: dict[str, Any], database: str = "neo4j"
) -> dict[str, Any]:
"""Create a node with temporal properties"""
# Add temporal properties if not present
if "asserted_at" not in properties:
properties["asserted_at"] = datetime.utcnow()
query = f"""
CREATE (n:{label} $properties)
RETURN n
"""
result = await self.run_query(query, {"properties": properties}, database)
node = result[0]["n"] if result else {}
# Return node ID if available, otherwise return the full node
return node.get("id", node) # type: ignore
async def update_node(
self,
label: str,
node_id: str,
properties: dict[str, Any],
database: str = "neo4j",
) -> dict[str, Any]:
"""Update node with bitemporal versioning"""
def _update_transaction(tx: Transaction) -> Any:
# First, retract the current version
retract_query = f"""
MATCH (n:{label} {{id: $node_id}})
WHERE n.retracted_at IS NULL
SET n.retracted_at = datetime()
RETURN n
"""
tx.run(retract_query, {"node_id": node_id}) # fmt: skip # pyright: ignore[reportArgumentType]
# Create new version
new_properties = properties.copy()
new_properties["id"] = node_id
new_properties["asserted_at"] = datetime.utcnow()
create_query = f"""
CREATE (n:{label} $properties)
RETURN n
"""
result = tx.run(create_query, {"properties": new_properties}) # fmt: skip # pyright: ignore[reportArgumentType]
record = result.single()
return record["n"] if record else None
result = await self.run_transaction(_update_transaction, database)
return result if isinstance(result, dict) else {}
async def create_relationship( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
from_label: str | None = None,
from_id: str | None = None,
to_label: str | None = None,
to_id: str | None = None,
relationship_type: str | None = None,
properties: dict[str, Any] | None = None,
database: str = "neo4j",
# Alternative signature for tests
from_node_id: int | None = None,
to_node_id: int | None = None,
) -> dict[str, Any]:
"""Create relationship between nodes"""
# Handle alternative signature for tests (using node IDs)
if from_node_id is not None and to_node_id is not None:
rel_properties = properties or {}
if "asserted_at" not in rel_properties:
rel_properties["asserted_at"] = datetime.utcnow()
query = f"""
MATCH (from) WHERE id(from) = $from_id
MATCH (to) WHERE id(to) = $to_id
CREATE (from)-[r:{relationship_type} $properties]->(to)
RETURN r
"""
result = await self.run_query(
query,
{
"from_id": from_node_id,
"to_id": to_node_id,
"properties": rel_properties,
},
database,
)
rel = result[0]["r"] if result else {}
return rel.get("id", rel) # type: ignore
# Original signature (using labels and IDs)
rel_properties = properties or {}
if "asserted_at" not in rel_properties:
rel_properties["asserted_at"] = datetime.utcnow()
query = f"""
MATCH (from:{from_label} {{id: $from_id}})
MATCH (to:{to_label} {{id: $to_id}})
WHERE from.retracted_at IS NULL AND to.retracted_at IS NULL
CREATE (from)-[r:{relationship_type} $properties]->(to)
RETURN r
"""
result = await self.run_query(
query,
{"from_id": from_id, "to_id": to_id, "properties": rel_properties},
database,
)
rel = result[0]["r"] if result else {}
# Return relationship ID if available, otherwise return the full relationship
return rel.get("id", rel) # type: ignore
async def get_node_lineage(
self, node_id: str, max_depth: int = 10, database: str = "neo4j"
) -> list[dict[str, Any]]:
"""Get complete lineage for a node"""
query = """
MATCH path = (n {id: $node_id})-[:DERIVED_FROM*1..10]->(evidence:Evidence)
WHERE n.retracted_at IS NULL
RETURN path, evidence
ORDER BY length(path) DESC
LIMIT 100
"""
return await self.run_query(
query, {"node_id": node_id, "max_depth": max_depth}, database
)
async def export_to_rdf( # pylint: disable=redefined-builtin
self,
format: str = "turtle",
database: str = "neo4j",
) -> dict[str, Any]:
"""Export graph data to RDF format"""
query = """
CALL n10s.rdf.export.cypher(
'MATCH (n) WHERE n.retracted_at IS NULL RETURN n',
$format,
{}
) YIELD triplesCount, format
RETURN triplesCount, format
"""
try:
result = await self.run_query(query, {"format": format}, database)
return result[0] if result else {}
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("RDF export failed, using fallback", error=str(e))
fallback_result = await self._export_rdf_fallback(database)
return {"rdf_data": fallback_result, "format": format}
async def _export_rdf_fallback(self, database: str = "neo4j") -> str:
"""Fallback RDF export without n10s plugin"""
# Get all nodes and relationships
nodes_query = """
MATCH (n) WHERE n.retracted_at IS NULL
RETURN labels(n) as labels, properties(n) as props, id(n) as neo_id
"""
rels_query = """
MATCH (a)-[r]->(b)
WHERE a.retracted_at IS NULL AND b.retracted_at IS NULL
RETURN type(r) as type, properties(r) as props,
id(a) as from_id, id(b) as to_id
"""
nodes = await self.run_query(nodes_query, database=database)
relationships = await self.run_query(rels_query, database=database)
# Convert to simple Turtle format
rdf_lines = ["@prefix tax: <https://tax-kg.example.com/> ."]
for node in nodes:
node_uri = f"tax:node_{node['neo_id']}"
for label in node["labels"]:
rdf_lines.append(f"{node_uri} a tax:{label} .")
for prop, value in node["props"].items():
if isinstance(value, str):
rdf_lines.append(f'{node_uri} tax:{prop} "{value}" .')
else:
rdf_lines.append(f"{node_uri} tax:{prop} {value} .")
for rel in relationships:
from_uri = f"tax:node_{rel['from_id']}"
to_uri = f"tax:node_{rel['to_id']}"
rdf_lines.append(f"{from_uri} tax:{rel['type']} {to_uri} .")
return "\n".join(rdf_lines)
async def find_nodes(
self, label: str, properties: dict[str, Any], database: str = "neo4j"
) -> list[dict[str, Any]]:
"""Find nodes matching label and properties"""
where_clause, params = self._build_properties_clause(properties)
query = f"MATCH (n:{label}) WHERE {where_clause} RETURN n"
result = await self.run_query(query, params, database)
return [record["n"] for record in result]
async def execute_query(
self,
query: str,
parameters: dict[str, Any] | None = None,
database: str = "neo4j",
) -> list[dict[str, Any]]:
"""Execute a custom Cypher query"""
return await self.run_query(query, parameters, database)
def _build_properties_clause(
self, properties: dict[str, Any]
) -> tuple[str, dict[str, Any]]:
"""Build WHERE clause and parameters for properties"""
if not properties:
return "true", {}
clauses = []
params = {}
for i, (key, value) in enumerate(properties.items()):
param_name = f"prop_{i}"
clauses.append(f"n.{key} = ${param_name}")
params[param_name] = value
return " AND ".join(clauses), params