"""Neo4j session helpers, Cypher runner with retry, SHACL validator invoker.""" import asyncio from datetime import datetime from typing import Any import structlog from neo4j import Transaction from neo4j.exceptions import ServiceUnavailable, TransientError logger = structlog.get_logger() class Neo4jClient: """Neo4j client with session management and retry logic""" def __init__(self, driver: Any) -> None: self.driver = driver async def __aenter__(self) -> "Neo4jClient": """Async context manager entry""" return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Async context manager exit""" await self.close() async def close(self) -> None: """Close the driver""" await asyncio.get_event_loop().run_in_executor(None, self.driver.close) async def run_query( self, query: str, parameters: dict[str, Any] | None = None, database: str = "neo4j", max_retries: int = 3, ) -> list[dict[str, Any]]: """Run Cypher query with retry logic""" def _run_query() -> list[dict[str, Any]]: with self.driver.session(database=database) as session: result = session.run(query, parameters or {}) return [record.data() for record in result] for attempt in range(max_retries): try: return await asyncio.get_event_loop().run_in_executor(None, _run_query) except (TransientError, ServiceUnavailable) as e: if attempt == max_retries - 1: logger.error( "Query failed after retries", query=query[:100], attempt=attempt + 1, error=str(e), ) raise wait_time = 2**attempt # Exponential backoff logger.warning( "Query failed, retrying", query=query[:100], attempt=attempt + 1, wait_time=wait_time, error=str(e), ) await asyncio.sleep(wait_time) except Exception as e: logger.error( "Query failed with non-retryable error", query=query[:100], error=str(e), ) raise # This should never be reached due to the raise statements above return [] async def run_transaction( self, transaction_func: Any, database: str = "neo4j", max_retries: int = 3 ) -> Any: """Run transaction with retry logic""" def _run_transaction() -> Any: with self.driver.session(database=database) as session: return session.execute_write(transaction_func) for attempt in range(max_retries): try: return await asyncio.get_event_loop().run_in_executor( None, _run_transaction ) except (TransientError, ServiceUnavailable) as e: if attempt == max_retries - 1: logger.error( "Transaction failed after retries", attempt=attempt + 1, error=str(e), ) raise wait_time = 2**attempt logger.warning( "Transaction failed, retrying", attempt=attempt + 1, wait_time=wait_time, error=str(e), ) await asyncio.sleep(wait_time) except Exception as e: logger.error( "Transaction failed with non-retryable error", error=str(e) ) raise async def create_node( self, label: str, properties: dict[str, Any], database: str = "neo4j" ) -> dict[str, Any]: """Create a node with temporal properties""" # Add temporal properties if not present if "asserted_at" not in properties: properties["asserted_at"] = datetime.utcnow() query = f""" CREATE (n:{label} $properties) RETURN n """ result = await self.run_query(query, {"properties": properties}, database) node = result[0]["n"] if result else {} # Return node ID if available, otherwise return the full node return node.get("id", node) # type: ignore async def update_node( self, label: str, node_id: str, properties: dict[str, Any], database: str = "neo4j", ) -> dict[str, Any]: """Update node with bitemporal versioning""" def _update_transaction(tx: Transaction) -> Any: # First, retract the current version retract_query = f""" MATCH (n:{label} {{id: $node_id}}) WHERE n.retracted_at IS NULL SET n.retracted_at = datetime() RETURN n """ tx.run(retract_query, {"node_id": node_id}) # fmt: skip # pyright: ignore[reportArgumentType] # Create new version new_properties = properties.copy() new_properties["id"] = node_id new_properties["asserted_at"] = datetime.utcnow() create_query = f""" CREATE (n:{label} $properties) RETURN n """ result = tx.run(create_query, {"properties": new_properties}) # fmt: skip # pyright: ignore[reportArgumentType] record = result.single() return record["n"] if record else None result = await self.run_transaction(_update_transaction, database) return result if isinstance(result, dict) else {} async def create_relationship( # pylint: disable=too-many-arguments,too-many-positional-arguments self, from_label: str | None = None, from_id: str | None = None, to_label: str | None = None, to_id: str | None = None, relationship_type: str | None = None, properties: dict[str, Any] | None = None, database: str = "neo4j", # Alternative signature for tests from_node_id: int | None = None, to_node_id: int | None = None, ) -> dict[str, Any]: """Create relationship between nodes""" # Handle alternative signature for tests (using node IDs) if from_node_id is not None and to_node_id is not None: rel_properties = properties or {} if "asserted_at" not in rel_properties: rel_properties["asserted_at"] = datetime.utcnow() query = f""" MATCH (from) WHERE id(from) = $from_id MATCH (to) WHERE id(to) = $to_id CREATE (from)-[r:{relationship_type} $properties]->(to) RETURN r """ result = await self.run_query( query, { "from_id": from_node_id, "to_id": to_node_id, "properties": rel_properties, }, database, ) rel = result[0]["r"] if result else {} return rel.get("id", rel) # type: ignore # Original signature (using labels and IDs) rel_properties = properties or {} if "asserted_at" not in rel_properties: rel_properties["asserted_at"] = datetime.utcnow() query = f""" MATCH (from:{from_label} {{id: $from_id}}) MATCH (to:{to_label} {{id: $to_id}}) WHERE from.retracted_at IS NULL AND to.retracted_at IS NULL CREATE (from)-[r:{relationship_type} $properties]->(to) RETURN r """ result = await self.run_query( query, {"from_id": from_id, "to_id": to_id, "properties": rel_properties}, database, ) rel = result[0]["r"] if result else {} # Return relationship ID if available, otherwise return the full relationship return rel.get("id", rel) # type: ignore async def get_node_lineage( self, node_id: str, max_depth: int = 10, database: str = "neo4j" ) -> list[dict[str, Any]]: """Get complete lineage for a node""" query = """ MATCH path = (n {id: $node_id})-[:DERIVED_FROM*1..10]->(evidence:Evidence) WHERE n.retracted_at IS NULL RETURN path, evidence ORDER BY length(path) DESC LIMIT 100 """ return await self.run_query( query, {"node_id": node_id, "max_depth": max_depth}, database ) async def export_to_rdf( # pylint: disable=redefined-builtin self, format: str = "turtle", database: str = "neo4j", ) -> dict[str, Any]: """Export graph data to RDF format""" query = """ CALL n10s.rdf.export.cypher( 'MATCH (n) WHERE n.retracted_at IS NULL RETURN n', $format, {} ) YIELD triplesCount, format RETURN triplesCount, format """ try: result = await self.run_query(query, {"format": format}, database) return result[0] if result else {} except Exception as e: # pylint: disable=broad-exception-caught logger.warning("RDF export failed, using fallback", error=str(e)) fallback_result = await self._export_rdf_fallback(database) return {"rdf_data": fallback_result, "format": format} async def _export_rdf_fallback(self, database: str = "neo4j") -> str: """Fallback RDF export without n10s plugin""" # Get all nodes and relationships nodes_query = """ MATCH (n) WHERE n.retracted_at IS NULL RETURN labels(n) as labels, properties(n) as props, id(n) as neo_id """ rels_query = """ MATCH (a)-[r]->(b) WHERE a.retracted_at IS NULL AND b.retracted_at IS NULL RETURN type(r) as type, properties(r) as props, id(a) as from_id, id(b) as to_id """ nodes = await self.run_query(nodes_query, database=database) relationships = await self.run_query(rels_query, database=database) # Convert to simple Turtle format rdf_lines = ["@prefix tax: ."] for node in nodes: node_uri = f"tax:node_{node['neo_id']}" for label in node["labels"]: rdf_lines.append(f"{node_uri} a tax:{label} .") for prop, value in node["props"].items(): if isinstance(value, str): rdf_lines.append(f'{node_uri} tax:{prop} "{value}" .') else: rdf_lines.append(f"{node_uri} tax:{prop} {value} .") for rel in relationships: from_uri = f"tax:node_{rel['from_id']}" to_uri = f"tax:node_{rel['to_id']}" rdf_lines.append(f"{from_uri} tax:{rel['type']} {to_uri} .") return "\n".join(rdf_lines) async def find_nodes( self, label: str, properties: dict[str, Any], database: str = "neo4j" ) -> list[dict[str, Any]]: """Find nodes matching label and properties""" where_clause, params = self._build_properties_clause(properties) query = f"MATCH (n:{label}) WHERE {where_clause} RETURN n" result = await self.run_query(query, params, database) return [record["n"] for record in result] async def execute_query( self, query: str, parameters: dict[str, Any] | None = None, database: str = "neo4j", ) -> list[dict[str, Any]]: """Execute a custom Cypher query""" return await self.run_query(query, parameters, database) def _build_properties_clause( self, properties: dict[str, Any] ) -> tuple[str, dict[str, Any]]: """Build WHERE clause and parameters for properties""" if not properties: return "true", {} clauses = [] params = {} for i, (key, value) in enumerate(properties.items()): param_name = f"prop_{i}" clauses.append(f"n.{key} = ${param_name}") params[param_name] = value return " AND ".join(clauses), params