From 435274cc50c523b47e1e6b7919a6f40e78b25bf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Wed, 16 Apr 2025 10:54:20 +0200 Subject: [PATCH] WIP use execute query --- graphdatascience/graph/graph_cypher_runner.py | 1 + .../graph/graph_entity_ops_runner.py | 1 + graphdatascience/graph_data_science.py | 1 + .../query_runner/cypher_graph_constructor.py | 2 + .../query_runner/neo4j_query_runner.py | 41 +++++++++++++++---- .../topological_lp_alpha_runner.py | 4 +- .../utils/direct_util_endpoints.py | 1 + graphdatascience/utils/util_proc_runner.py | 2 + .../utils/util_remote_proc_runner.py | 2 + 9 files changed, 46 insertions(+), 9 deletions(-) diff --git a/graphdatascience/graph/graph_cypher_runner.py b/graphdatascience/graph/graph_cypher_runner.py index ef55a64ef..19d736f23 100644 --- a/graphdatascience/graph/graph_cypher_runner.py +++ b/graphdatascience/graph/graph_cypher_runner.py @@ -45,6 +45,7 @@ def project( GraphCypherRunner._verify_query_ends_with_return_clause(self._namespace, query) + # TODO use execute_query here result: Optional[dict[str, Any]] = self._query_runner.run_cypher(query, params, database, False).squeeze() if not result: diff --git a/graphdatascience/graph/graph_entity_ops_runner.py b/graphdatascience/graph/graph_entity_ops_runner.py index d3f688489..d8f1e6aa2 100644 --- a/graphdatascience/graph/graph_entity_ops_runner.py +++ b/graphdatascience/graph/graph_entity_ops_runner.py @@ -162,6 +162,7 @@ def _process_result( ) unique_node_ids = result["nodeId"].drop_duplicates().tolist() + # TODO use execute_query db_properties_df = query_runner.run_cypher( GraphNodePropertiesRunner._build_query(db_node_properties), {"ids": unique_node_ids} ) diff --git a/graphdatascience/graph_data_science.py b/graphdatascience/graph_data_science.py index 1052f57af..fc92a892a 100644 --- a/graphdatascience/graph_data_science.py +++ b/graphdatascience/graph_data_science.py @@ -218,6 +218,7 @@ def run_cypher( if isinstance(self._query_runner, ArrowQueryRunner): qr = self._query_runner.fallback_query_runner() + # not using qr.execute_query as we dont know if it can be retried return qr.run_cypher(query, params, database, False) def driver_config(self) -> dict[str, Any]: diff --git a/graphdatascience/query_runner/cypher_graph_constructor.py b/graphdatascience/query_runner/cypher_graph_constructor.py index 14b62ff55..5e55c22e9 100644 --- a/graphdatascience/query_runner/cypher_graph_constructor.py +++ b/graphdatascience/query_runner/cypher_graph_constructor.py @@ -104,6 +104,7 @@ def graph_construct_error_multidf(element: str) -> str: def _should_warn_about_arrow_missing(self) -> bool: try: + # TOOD use execute_query license: str = self._query_runner.run_cypher( "CALL gds.debug.sysInfo() YIELD key, value WHERE key = 'gdsEdition' RETURN value", custom_error=False ).squeeze() @@ -210,6 +211,7 @@ def run(self, node_dfs: list[DataFrame], relationship_dfs: list[DataFrame]) -> N "undirectedRelationshipTypes": self._undirected_relationship_types, } + # TODO use execute_query here self._query_runner.run_cypher( query, { diff --git a/graphdatascience/query_runner/neo4j_query_runner.py b/graphdatascience/query_runner/neo4j_query_runner.py index 4a8f552d5..f0e49430a 100644 --- a/graphdatascience/query_runner/neo4j_query_runner.py +++ b/graphdatascience/query_runner/neo4j_query_runner.py @@ -121,8 +121,10 @@ def __init__( def __run_cypher_simplified_for_query_progress_logger(self, query: str, database: Optional[str]) -> DataFrame: # progress logging should not retry a lot as it perodically fetches the latest progress anyway connectivity_retry_config = Neo4jQueryRunner.ConnectivityRetriesConfig(max_retries=2) + # not using execute_query as failing is okay return self.run_cypher(query=query, database=database, connectivity_retry_config=connectivity_retry_config) + # only use for user defined queries def run_cypher( self, query: str, @@ -166,12 +168,33 @@ def run_cypher( return df - def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any: + # better retry mechanism than run_cypher + def execute_query(self, + query: str, + params: Optional[dict[str, Any]] = None, + database: Optional[str] = None, + custom_error: bool = True, + connectivity_retry_config: Optional[ConnectivityRetriesConfig] = None,) -> DataFrame: + if not database: + database = self._database + + if self._NEO4J_DRIVER_VERSION < SemanticVersion(5, 5, 0): + return self.run_cypher(query, params, database, custom_error, connectivity_retry_config) + + try: + return self._driver.execute_query(query_=query,parameters_=params, database=database, result_transformer_=neo4j.Result.to_df, routing_=neo4j.RoutingControl.READ) + except Exception as e: + if custom_error: + Neo4jQueryRunner.handle_driver_exception(self._driver, e) + else: + raise e + + def call_function(self, endpoint: str, params: Optional[CallParameters] = None, custom_error: bool = True) -> Any: if params is None: params = CallParameters() query = f"RETURN {endpoint}({params.placeholder_str()})" - return self.run_cypher(query, params).squeeze() + return self.execute_query(query, params, custom_error=custom_error).squeeze() def call_procedure( self, @@ -189,7 +212,7 @@ def call_procedure( query = f"CALL {endpoint}({params.placeholder_str()}){yields_clause}" def run_cypher_query() -> DataFrame: - return self.run_cypher(query, params, database, custom_error) + return self.execute_query(query, params, database, custom_error) job_id = None if not params else params.get_job_id() if self._resolve_show_progress(logging) and job_id: @@ -205,7 +228,7 @@ def server_version(self) -> ServerVersion: return self._server_version try: - server_version_string = self.run_cypher("RETURN gds.version()", custom_error=False).squeeze() + server_version_string = self.call_function("gds.version", custom_error=False).squeeze() server_version = ServerVersion.from_string(server_version_string) self._server_version = server_version return server_version @@ -280,7 +303,7 @@ def set_show_progress(self, show_progress: bool) -> None: self._show_progress = show_progress @staticmethod - def handle_driver_exception(session: neo4j.Session, e: Exception) -> None: + def handle_driver_exception(cypher_executor: Union[neo4j.Session, neo4j.Driver], e: Exception) -> None: reg_gds_hit = re.search( r"There is no procedure with the name `(gds(?:\.\w+)+)` registered for this database instance", str(e), @@ -290,8 +313,12 @@ def handle_driver_exception(session: neo4j.Session, e: Exception) -> None: requested_endpoint = reg_gds_hit.group(1) - list_result = session.run("CALL gds.list() YIELD name") - all_endpoints = list_result.to_df()["name"].tolist() + if isinstance(cypher_executor, neo4j.Session): + list_result = cypher_executor.run("CALL gds.list() YIELD name") + all_endpoints = list_result.to_df()["name"].tolist() + else: + result = cypher_executor.execute_query("CALL gds.list() YIELD name"). + raise SyntaxError(generate_suggestive_error_message(requested_endpoint, all_endpoints)) from e diff --git a/graphdatascience/topological_lp/topological_lp_alpha_runner.py b/graphdatascience/topological_lp/topological_lp_alpha_runner.py index a21557f04..1e0b65f21 100644 --- a/graphdatascience/topological_lp/topological_lp_alpha_runner.py +++ b/graphdatascience/topological_lp/topological_lp_alpha_runner.py @@ -16,7 +16,7 @@ def _run_standard_function(self, node1: int, node2: int, config: dict[str, Any]) RETURN {self._namespace}(n1, n2, $config) AS score """ params = {"config": config} - + # TODO use execute query return self._query_runner.run_cypher(query, params)["score"].squeeze() # type: ignore def adamicAdar(self, node1: int, node2: int, **config: Any) -> float: @@ -45,7 +45,7 @@ def sameCommunity(self, node1: int, node2: int, communityProperty: Optional[str] MATCH (n2) WHERE id(n2) = {node2} RETURN {self._namespace}(n1, n2{community_property}) AS score """ - + # TODO use execute query return self._query_runner.run_cypher(query)["score"].squeeze() # type: ignore def totalNeighbors(self, node1: int, node2: int, **config: Any) -> float: diff --git a/graphdatascience/utils/direct_util_endpoints.py b/graphdatascience/utils/direct_util_endpoints.py index 581a17038..04cd4c42c 100644 --- a/graphdatascience/utils/direct_util_endpoints.py +++ b/graphdatascience/utils/direct_util_endpoints.py @@ -48,6 +48,7 @@ def find_node_id(self, labels: list[str] = [], properties: dict[str, Any] = {}) else: query = "MATCH (n) RETURN id(n) AS id" + # TODO use execute query node_match = self._query_runner.run_cypher(query, custom_error=False) if len(node_match) != 1: diff --git a/graphdatascience/utils/util_proc_runner.py b/graphdatascience/utils/util_proc_runner.py index a435e2265..faf10338f 100644 --- a/graphdatascience/utils/util_proc_runner.py +++ b/graphdatascience/utils/util_proc_runner.py @@ -18,6 +18,7 @@ def asNode(self, node_id: int) -> Any: """ self._namespace += ".asNode" + # TODO use call_function? result = self._query_runner.run_cypher(f"RETURN {self._namespace}({node_id}) AS node") return result.iat[0, 0] @@ -34,6 +35,7 @@ def asNodes(self, node_ids: list[int]) -> list[Any]: """ self._namespace += ".asNodes" + # TODO use call_function? result = self._query_runner.run_cypher(f"RETURN {self._namespace}({node_ids}) AS nodes") return result.iat[0, 0] # type: ignore diff --git a/graphdatascience/utils/util_remote_proc_runner.py b/graphdatascience/utils/util_remote_proc_runner.py index b51438dcc..cabdda772 100644 --- a/graphdatascience/utils/util_remote_proc_runner.py +++ b/graphdatascience/utils/util_remote_proc_runner.py @@ -24,6 +24,7 @@ def asNode(self, node_id: int) -> Any: query = "MATCH (n) WHERE id(n) = $nodeId RETURN n" params = {"nodeId": node_id} + # TODO use execute_query return self._query_runner.run_cypher(query=query, params=params).squeeze() @filter_id_func_deprecation_warning() @@ -41,6 +42,7 @@ def asNodes(self, node_ids: list[int]) -> list[Any]: query = "MATCH (n) WHERE id(n) IN $nodeIds RETURN collect(n)" params = {"nodeIds": node_ids} + # TODO use execute_query return self._query_runner.run_cypher(query=query, params=params).squeeze() # type: ignore @property