Source code for kgx.sink.neo_sink

from typing import List, Union, Any

from neo4jrestclient.client import GraphDatabase
from neo4jrestclient.query import CypherException

from kgx.config import get_logger
from kgx.sink.sink import Sink
from kgx.utils.kgx_utils import DEFAULT_NODE_CATEGORY

log = get_logger()


[docs]class NeoSink(Sink): """ NeoSink is responsible for writing data as records to a Neo4j instance. Parameters ---------- uri: str The URI for the Neo4j instance. For example, http://localhost:7474 username: str The username password: str The password kwargs: Any Any additional arguments """ CACHE_SIZE = 100000 node_cache = {} edge_cache = {} node_count = 0 edge_count = 0 CATEGORY_DELIMITER = "|" CYPHER_CATEGORY_DELIMITER = ":" _seen_categories = set() def __init__(self, uri: str, username: str, password: str, **kwargs: Any): super().__init__() if "cache_size" in kwargs: self.CACHE_SIZE = kwargs["cache_size"] self.http_driver: GraphDatabase = GraphDatabase( uri, username=username, password=password ) def _flush_node_cache(self): self._write_node_cache() self.node_cache.clear() self.node_count = 0
[docs] def write_node(self, record) -> None: """ Cache a node record that is to be written to Neo4j. This method writes a cache of node records when the total number of records exceeds ``CACHE_SIZE`` Parameters ---------- record: Dict A node record """ sanitized_category = self.sanitize_category(record["category"]) category = self.CATEGORY_DELIMITER.join(sanitized_category) if self.node_count >= self.CACHE_SIZE: self._flush_node_cache() if category not in self.node_cache: self.node_cache[category] = [record] else: self.node_cache[category].append(record) self.node_count += 1
def _write_node_cache(self) -> None: """ Write cached node records to Neo4j. """ batch_size = 10000 categories = self.node_cache.keys() filtered_categories = [x for x in categories if x not in self._seen_categories] self.create_constraints(filtered_categories) for category in self.node_cache.keys(): log.debug("Generating UNWIND for category: {}".format(category)) cypher_category = category.replace( self.CATEGORY_DELIMITER, self.CYPHER_CATEGORY_DELIMITER ) query = self.generate_unwind_node_query(cypher_category) log.debug(query) nodes = self.node_cache[category] for x in range(0, len(nodes), batch_size): y = min(x + batch_size, len(nodes)) log.debug(f"Batch {x} - {y}") batch = nodes[x:y] try: self.http_driver.query(query, params={"nodes": batch}) except CypherException as ce: log.error(ce) def _flush_edge_cache(self): self._flush_node_cache() self._write_edge_cache() self.edge_cache.clear() self.edge_count = 0
[docs] def write_edge(self, record) -> None: """ Cache an edge record that is to be written to Neo4j. This method writes a cache of edge records when the total number of records exceeds ``CACHE_SIZE`` Parameters ---------- record: Dict An edge record """ if self.edge_count >= self.CACHE_SIZE: self._flush_edge_cache() # self.validate_edge(data) edge_predicate = record["predicate"] if edge_predicate in self.edge_cache: self.edge_cache[edge_predicate].append(record) else: self.edge_cache[edge_predicate] = [record] self.edge_count += 1
def _write_edge_cache(self) -> None: """ Write cached edge records to Neo4j. """ batch_size = 10000 for predicate in self.edge_cache.keys(): query = self.generate_unwind_edge_query(predicate) log.debug(query) edges = self.edge_cache[predicate] for x in range(0, len(edges), batch_size): y = min(x + batch_size, len(edges)) batch = edges[x:y] log.debug(f"Batch {x} - {y}") try: self.http_driver.query( query, params={"relationship": predicate, "edges": batch} ) except CypherException as ce: log.error(ce)
[docs] def finalize(self) -> None: """ Write any remaining cached node and/or edge records. """ self._write_node_cache() self._write_edge_cache()
[docs] @staticmethod def sanitize_category(category: List) -> List: """ Sanitize category for use in UNWIND cypher clause. This method adds escape characters to each element in category list to ensure the category is processed correctly. Parameters ---------- category: List Category Returns ------- List Sanitized category list """ return [f"`{x}`" for x in category]
[docs] @staticmethod def generate_unwind_node_query(category: str) -> str: """ Generate UNWIND cypher query for saving nodes into Neo4j. There should be a CONSTRAINT in Neo4j for ``self.DEFAULT_NODE_CATEGORY``. The query uses ``self.DEFAULT_NODE_CATEGORY`` as the node label to increase speed for adding nodes. The query also sets label to ``self.DEFAULT_NODE_CATEGORY`` for any node to make sure that the CONSTRAINT applies. Parameters ---------- category: str Node category Returns ------- str The UNWIND cypher query """ query = f""" UNWIND $nodes AS node MERGE (n:`{DEFAULT_NODE_CATEGORY}` {{id: node.id}}) ON CREATE SET n += node, n:{category} ON MATCH SET n += node, n:{category} """ return query
[docs] @staticmethod def generate_unwind_edge_query(edge_predicate: str) -> str: """ Generate UNWIND cypher query for saving edges into Neo4j. Query uses ``self.DEFAULT_NODE_CATEGORY`` to quickly lookup the required subject and object node. Parameters ---------- edge_predicate: str Edge label as string Returns ------- str The UNWIND cypher query """ query = f""" UNWIND $edges AS edge MATCH (s:`{DEFAULT_NODE_CATEGORY}` {{id: edge.subject}}), (o:`{DEFAULT_NODE_CATEGORY}` {{id: edge.object}}) MERGE (s)-[r:`{edge_predicate}`]->(o) SET r += edge """ return query
[docs] def create_constraints(self, categories: Union[set, list]) -> None: """ Create a unique constraint on node 'id' for all ``categories`` in Neo4j. Parameters ---------- categories: Union[set, list] Set of categories """ categories_set = set(categories) categories_set.add(f"`{DEFAULT_NODE_CATEGORY}`") for category in categories_set: if self.CATEGORY_DELIMITER in category: subcategories = category.split(self.CATEGORY_DELIMITER) self.create_constraints(subcategories) else: query = NeoSink.create_constraint_query(category) try: self.http_driver.query(query) self._seen_categories.add(category) except CypherException as ce: log.error(ce)
[docs] @staticmethod def create_constraint_query(category: str) -> str: """ Create a Cypher CONSTRAINT query Parameters ---------- category: str The category to create a constraint on Returns ------- str The Cypher CONSTRAINT query """ query = f"CREATE CONSTRAINT ON (n:{category}) ASSERT n.id IS UNIQUE" return query