Some checks failed
CI/CD Pipeline / Code Quality & Linting (push) Has been cancelled
CI/CD Pipeline / Policy Validation (push) Has been cancelled
CI/CD Pipeline / Test Suite (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-firm-connectors) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-forms) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-hmrc) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ingestion) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-normalize-map) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-ocr) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-indexer) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-reason) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (svc-rpa) (push) Has been cancelled
CI/CD Pipeline / Build Docker Images (ui-review) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-coverage) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-extract) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-kg) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (svc-rag-retriever) (push) Has been cancelled
CI/CD Pipeline / Security Scanning (ui-review) (push) Has been cancelled
CI/CD Pipeline / Generate SBOM (push) Has been cancelled
CI/CD Pipeline / Deploy to Staging (push) Has been cancelled
CI/CD Pipeline / Deploy to Production (push) Has been cancelled
CI/CD Pipeline / Notifications (push) Has been cancelled
351 lines
12 KiB
Python
351 lines
12 KiB
Python
"""Neo4j session helpers, Cypher runner with retry, SHACL validator invoker."""
|
|
|
|
import asyncio
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
import structlog
|
|
from neo4j import Transaction
|
|
from neo4j.exceptions import ServiceUnavailable, TransientError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class Neo4jClient:
|
|
"""Neo4j client with session management and retry logic"""
|
|
|
|
def __init__(self, driver: Any) -> None:
|
|
self.driver = driver
|
|
|
|
async def __aenter__(self) -> "Neo4jClient":
|
|
"""Async context manager entry"""
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
"""Async context manager exit"""
|
|
await self.close()
|
|
|
|
async def close(self) -> None:
|
|
"""Close the driver"""
|
|
await asyncio.get_event_loop().run_in_executor(None, self.driver.close)
|
|
|
|
async def run_query(
|
|
self,
|
|
query: str,
|
|
parameters: dict[str, Any] | None = None,
|
|
database: str = "neo4j",
|
|
max_retries: int = 3,
|
|
) -> list[dict[str, Any]]:
|
|
"""Run Cypher query with retry logic"""
|
|
|
|
def _run_query() -> list[dict[str, Any]]:
|
|
with self.driver.session(database=database) as session:
|
|
result = session.run(query, parameters or {})
|
|
return [record.data() for record in result]
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return await asyncio.get_event_loop().run_in_executor(None, _run_query)
|
|
|
|
except (TransientError, ServiceUnavailable) as e:
|
|
if attempt == max_retries - 1:
|
|
logger.error(
|
|
"Query failed after retries",
|
|
query=query[:100],
|
|
attempt=attempt + 1,
|
|
error=str(e),
|
|
)
|
|
raise
|
|
|
|
wait_time = 2**attempt # Exponential backoff
|
|
logger.warning(
|
|
"Query failed, retrying",
|
|
query=query[:100],
|
|
attempt=attempt + 1,
|
|
wait_time=wait_time,
|
|
error=str(e),
|
|
)
|
|
await asyncio.sleep(wait_time)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Query failed with non-retryable error",
|
|
query=query[:100],
|
|
error=str(e),
|
|
)
|
|
raise
|
|
|
|
# This should never be reached due to the raise statements above
|
|
return []
|
|
|
|
async def run_transaction(
|
|
self, transaction_func: Any, database: str = "neo4j", max_retries: int = 3
|
|
) -> Any:
|
|
"""Run transaction with retry logic"""
|
|
|
|
def _run_transaction() -> Any:
|
|
with self.driver.session(database=database) as session:
|
|
return session.execute_write(transaction_func)
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
None, _run_transaction
|
|
)
|
|
|
|
except (TransientError, ServiceUnavailable) as e:
|
|
if attempt == max_retries - 1:
|
|
logger.error(
|
|
"Transaction failed after retries",
|
|
attempt=attempt + 1,
|
|
error=str(e),
|
|
)
|
|
raise
|
|
|
|
wait_time = 2**attempt
|
|
logger.warning(
|
|
"Transaction failed, retrying",
|
|
attempt=attempt + 1,
|
|
wait_time=wait_time,
|
|
error=str(e),
|
|
)
|
|
await asyncio.sleep(wait_time)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Transaction failed with non-retryable error", error=str(e)
|
|
)
|
|
raise
|
|
|
|
async def create_node(
|
|
self, label: str, properties: dict[str, Any], database: str = "neo4j"
|
|
) -> dict[str, Any]:
|
|
"""Create a node with temporal properties"""
|
|
|
|
# Add temporal properties if not present
|
|
if "asserted_at" not in properties:
|
|
properties["asserted_at"] = datetime.utcnow()
|
|
|
|
query = f"""
|
|
CREATE (n:{label} $properties)
|
|
RETURN n
|
|
"""
|
|
|
|
result = await self.run_query(query, {"properties": properties}, database)
|
|
node = result[0]["n"] if result else {}
|
|
# Return node ID if available, otherwise return the full node
|
|
return node.get("id", node) # type: ignore
|
|
|
|
async def update_node(
|
|
self,
|
|
label: str,
|
|
node_id: str,
|
|
properties: dict[str, Any],
|
|
database: str = "neo4j",
|
|
) -> dict[str, Any]:
|
|
"""Update node with bitemporal versioning"""
|
|
|
|
def _update_transaction(tx: Transaction) -> Any:
|
|
# First, retract the current version
|
|
retract_query = f"""
|
|
MATCH (n:{label} {{id: $node_id}})
|
|
WHERE n.retracted_at IS NULL
|
|
SET n.retracted_at = datetime()
|
|
RETURN n
|
|
"""
|
|
tx.run(retract_query, {"node_id": node_id}) # fmt: skip # pyright: ignore[reportArgumentType]
|
|
|
|
# Create new version
|
|
new_properties = properties.copy()
|
|
new_properties["id"] = node_id
|
|
new_properties["asserted_at"] = datetime.utcnow()
|
|
|
|
create_query = f"""
|
|
CREATE (n:{label} $properties)
|
|
RETURN n
|
|
"""
|
|
result = tx.run(create_query, {"properties": new_properties}) # fmt: skip # pyright: ignore[reportArgumentType]
|
|
record = result.single()
|
|
return record["n"] if record else None
|
|
|
|
result = await self.run_transaction(_update_transaction, database)
|
|
return result if isinstance(result, dict) else {}
|
|
|
|
async def create_relationship( # pylint: disable=too-many-arguments,too-many-positional-arguments
|
|
self,
|
|
from_label: str | None = None,
|
|
from_id: str | None = None,
|
|
to_label: str | None = None,
|
|
to_id: str | None = None,
|
|
relationship_type: str | None = None,
|
|
properties: dict[str, Any] | None = None,
|
|
database: str = "neo4j",
|
|
# Alternative signature for tests
|
|
from_node_id: int | None = None,
|
|
to_node_id: int | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Create relationship between nodes"""
|
|
|
|
# Handle alternative signature for tests (using node IDs)
|
|
if from_node_id is not None and to_node_id is not None:
|
|
rel_properties = properties or {}
|
|
if "asserted_at" not in rel_properties:
|
|
rel_properties["asserted_at"] = datetime.utcnow()
|
|
|
|
query = f"""
|
|
MATCH (from) WHERE id(from) = $from_id
|
|
MATCH (to) WHERE id(to) = $to_id
|
|
CREATE (from)-[r:{relationship_type} $properties]->(to)
|
|
RETURN r
|
|
"""
|
|
|
|
result = await self.run_query(
|
|
query,
|
|
{
|
|
"from_id": from_node_id,
|
|
"to_id": to_node_id,
|
|
"properties": rel_properties,
|
|
},
|
|
database,
|
|
)
|
|
rel = result[0]["r"] if result else {}
|
|
return rel.get("id", rel) # type: ignore
|
|
|
|
# Original signature (using labels and IDs)
|
|
rel_properties = properties or {}
|
|
if "asserted_at" not in rel_properties:
|
|
rel_properties["asserted_at"] = datetime.utcnow()
|
|
|
|
query = f"""
|
|
MATCH (from:{from_label} {{id: $from_id}})
|
|
MATCH (to:{to_label} {{id: $to_id}})
|
|
WHERE from.retracted_at IS NULL AND to.retracted_at IS NULL
|
|
CREATE (from)-[r:{relationship_type} $properties]->(to)
|
|
RETURN r
|
|
"""
|
|
|
|
result = await self.run_query(
|
|
query,
|
|
{"from_id": from_id, "to_id": to_id, "properties": rel_properties},
|
|
database,
|
|
)
|
|
rel = result[0]["r"] if result else {}
|
|
# Return relationship ID if available, otherwise return the full relationship
|
|
return rel.get("id", rel) # type: ignore
|
|
|
|
async def get_node_lineage(
|
|
self, node_id: str, max_depth: int = 10, database: str = "neo4j"
|
|
) -> list[dict[str, Any]]:
|
|
"""Get complete lineage for a node"""
|
|
|
|
query = """
|
|
MATCH path = (n {id: $node_id})-[:DERIVED_FROM*1..10]->(evidence:Evidence)
|
|
WHERE n.retracted_at IS NULL
|
|
RETURN path, evidence
|
|
ORDER BY length(path) DESC
|
|
LIMIT 100
|
|
"""
|
|
|
|
return await self.run_query(
|
|
query, {"node_id": node_id, "max_depth": max_depth}, database
|
|
)
|
|
|
|
async def export_to_rdf( # pylint: disable=redefined-builtin
|
|
self,
|
|
format: str = "turtle",
|
|
database: str = "neo4j",
|
|
) -> dict[str, Any]:
|
|
"""Export graph data to RDF format"""
|
|
|
|
query = """
|
|
CALL n10s.rdf.export.cypher(
|
|
'MATCH (n) WHERE n.retracted_at IS NULL RETURN n',
|
|
$format,
|
|
{}
|
|
) YIELD triplesCount, format
|
|
RETURN triplesCount, format
|
|
"""
|
|
|
|
try:
|
|
result = await self.run_query(query, {"format": format}, database)
|
|
return result[0] if result else {}
|
|
except Exception as e: # pylint: disable=broad-exception-caught
|
|
logger.warning("RDF export failed, using fallback", error=str(e))
|
|
fallback_result = await self._export_rdf_fallback(database)
|
|
return {"rdf_data": fallback_result, "format": format}
|
|
|
|
async def _export_rdf_fallback(self, database: str = "neo4j") -> str:
|
|
"""Fallback RDF export without n10s plugin"""
|
|
|
|
# Get all nodes and relationships
|
|
nodes_query = """
|
|
MATCH (n) WHERE n.retracted_at IS NULL
|
|
RETURN labels(n) as labels, properties(n) as props, id(n) as neo_id
|
|
"""
|
|
|
|
rels_query = """
|
|
MATCH (a)-[r]->(b)
|
|
WHERE a.retracted_at IS NULL AND b.retracted_at IS NULL
|
|
RETURN type(r) as type, properties(r) as props,
|
|
id(a) as from_id, id(b) as to_id
|
|
"""
|
|
|
|
nodes = await self.run_query(nodes_query, database=database)
|
|
relationships = await self.run_query(rels_query, database=database)
|
|
|
|
# Convert to simple Turtle format
|
|
rdf_lines = ["@prefix tax: <https://tax-kg.example.com/> ."]
|
|
|
|
for node in nodes:
|
|
node_uri = f"tax:node_{node['neo_id']}"
|
|
for label in node["labels"]:
|
|
rdf_lines.append(f"{node_uri} a tax:{label} .")
|
|
|
|
for prop, value in node["props"].items():
|
|
if isinstance(value, str):
|
|
rdf_lines.append(f'{node_uri} tax:{prop} "{value}" .')
|
|
else:
|
|
rdf_lines.append(f"{node_uri} tax:{prop} {value} .")
|
|
|
|
for rel in relationships:
|
|
from_uri = f"tax:node_{rel['from_id']}"
|
|
to_uri = f"tax:node_{rel['to_id']}"
|
|
rdf_lines.append(f"{from_uri} tax:{rel['type']} {to_uri} .")
|
|
|
|
return "\n".join(rdf_lines)
|
|
|
|
async def find_nodes(
|
|
self, label: str, properties: dict[str, Any], database: str = "neo4j"
|
|
) -> list[dict[str, Any]]:
|
|
"""Find nodes matching label and properties"""
|
|
where_clause, params = self._build_properties_clause(properties)
|
|
query = f"MATCH (n:{label}) WHERE {where_clause} RETURN n"
|
|
|
|
result = await self.run_query(query, params, database)
|
|
return [record["n"] for record in result]
|
|
|
|
async def execute_query(
|
|
self,
|
|
query: str,
|
|
parameters: dict[str, Any] | None = None,
|
|
database: str = "neo4j",
|
|
) -> list[dict[str, Any]]:
|
|
"""Execute a custom Cypher query"""
|
|
return await self.run_query(query, parameters, database)
|
|
|
|
def _build_properties_clause(
|
|
self, properties: dict[str, Any]
|
|
) -> tuple[str, dict[str, Any]]:
|
|
"""Build WHERE clause and parameters for properties"""
|
|
if not properties:
|
|
return "true", {}
|
|
|
|
clauses = []
|
|
params = {}
|
|
for i, (key, value) in enumerate(properties.items()):
|
|
param_name = f"prop_{i}"
|
|
clauses.append(f"n.{key} = ${param_name}")
|
|
params[param_name] = value
|
|
|
|
return " AND ".join(clauses), params
|