Skip to content

WIP use execute query #868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions graphdatascience/graph/graph_cypher_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def project(

GraphCypherRunner._verify_query_ends_with_return_clause(self._namespace, query)

# TODO use execute_query here
result: Optional[dict[str, Any]] = self._query_runner.run_cypher(query, params, database, False).squeeze()

if not result:
Expand Down
1 change: 1 addition & 0 deletions graphdatascience/graph/graph_entity_ops_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _process_result(
)

unique_node_ids = result["nodeId"].drop_duplicates().tolist()
# TODO use execute_query
db_properties_df = query_runner.run_cypher(
GraphNodePropertiesRunner._build_query(db_node_properties), {"ids": unique_node_ids}
)
Expand Down
1 change: 1 addition & 0 deletions graphdatascience/graph_data_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def run_cypher(
if isinstance(self._query_runner, ArrowQueryRunner):
qr = self._query_runner.fallback_query_runner()

# not using qr.execute_query as we dont know if it can be retried
return qr.run_cypher(query, params, database, False)

def driver_config(self) -> dict[str, Any]:
Expand Down
2 changes: 2 additions & 0 deletions graphdatascience/query_runner/cypher_graph_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def graph_construct_error_multidf(element: str) -> str:

def _should_warn_about_arrow_missing(self) -> bool:
try:
# TOOD use execute_query
license: str = self._query_runner.run_cypher(
"CALL gds.debug.sysInfo() YIELD key, value WHERE key = 'gdsEdition' RETURN value", custom_error=False
).squeeze()
Expand Down Expand Up @@ -210,6 +211,7 @@ def run(self, node_dfs: list[DataFrame], relationship_dfs: list[DataFrame]) -> N
"undirectedRelationshipTypes": self._undirected_relationship_types,
}

# TODO use execute_query here
self._query_runner.run_cypher(
query,
{
Expand Down
41 changes: 34 additions & 7 deletions graphdatascience/query_runner/neo4j_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def __init__(
def __run_cypher_simplified_for_query_progress_logger(self, query: str, database: Optional[str]) -> DataFrame:
# progress logging should not retry a lot as it perodically fetches the latest progress anyway
connectivity_retry_config = Neo4jQueryRunner.ConnectivityRetriesConfig(max_retries=2)
# not using execute_query as failing is okay
return self.run_cypher(query=query, database=database, connectivity_retry_config=connectivity_retry_config)

# only use for user defined queries
def run_cypher(
self,
query: str,
Expand Down Expand Up @@ -166,12 +168,33 @@ def run_cypher(

return df

def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
# better retry mechanism than run_cypher
def execute_query(self,
query: str,
params: Optional[dict[str, Any]] = None,
database: Optional[str] = None,
custom_error: bool = True,
connectivity_retry_config: Optional[ConnectivityRetriesConfig] = None,) -> DataFrame:
if not database:
database = self._database

if self._NEO4J_DRIVER_VERSION < SemanticVersion(5, 5, 0):
return self.run_cypher(query, params, database, custom_error, connectivity_retry_config)

try:
return self._driver.execute_query(query_=query,parameters_=params, database=database, result_transformer_=neo4j.Result.to_df, routing_=neo4j.RoutingControl.READ)
except Exception as e:
if custom_error:
Neo4jQueryRunner.handle_driver_exception(self._driver, e)
else:
raise e

def call_function(self, endpoint: str, params: Optional[CallParameters] = None, custom_error: bool = True) -> Any:
if params is None:
params = CallParameters()
query = f"RETURN {endpoint}({params.placeholder_str()})"

return self.run_cypher(query, params).squeeze()
return self.execute_query(query, params, custom_error=custom_error).squeeze()

def call_procedure(
self,
Expand All @@ -189,7 +212,7 @@ def call_procedure(
query = f"CALL {endpoint}({params.placeholder_str()}){yields_clause}"

def run_cypher_query() -> DataFrame:
return self.run_cypher(query, params, database, custom_error)
return self.execute_query(query, params, database, custom_error)

job_id = None if not params else params.get_job_id()
if self._resolve_show_progress(logging) and job_id:
Expand All @@ -205,7 +228,7 @@ def server_version(self) -> ServerVersion:
return self._server_version

try:
server_version_string = self.run_cypher("RETURN gds.version()", custom_error=False).squeeze()
server_version_string = self.call_function("gds.version", custom_error=False).squeeze()
server_version = ServerVersion.from_string(server_version_string)
self._server_version = server_version
return server_version
Expand Down Expand Up @@ -280,7 +303,7 @@ def set_show_progress(self, show_progress: bool) -> None:
self._show_progress = show_progress

@staticmethod
def handle_driver_exception(session: neo4j.Session, e: Exception) -> None:
def handle_driver_exception(cypher_executor: Union[neo4j.Session, neo4j.Driver], e: Exception) -> None:
reg_gds_hit = re.search(
r"There is no procedure with the name `(gds(?:\.\w+)+)` registered for this database instance",
str(e),
Expand All @@ -290,8 +313,12 @@ def handle_driver_exception(session: neo4j.Session, e: Exception) -> None:

requested_endpoint = reg_gds_hit.group(1)

list_result = session.run("CALL gds.list() YIELD name")
all_endpoints = list_result.to_df()["name"].tolist()
if isinstance(cypher_executor, neo4j.Session):
list_result = cypher_executor.run("CALL gds.list() YIELD name")
all_endpoints = list_result.to_df()["name"].tolist()
else:
result = cypher_executor.execute_query("CALL gds.list() YIELD name").


raise SyntaxError(generate_suggestive_error_message(requested_endpoint, all_endpoints)) from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _run_standard_function(self, node1: int, node2: int, config: dict[str, Any])
RETURN {self._namespace}(n1, n2, $config) AS score
"""
params = {"config": config}

# TODO use execute query
return self._query_runner.run_cypher(query, params)["score"].squeeze() # type: ignore

def adamicAdar(self, node1: int, node2: int, **config: Any) -> float:
Expand Down Expand Up @@ -45,7 +45,7 @@ def sameCommunity(self, node1: int, node2: int, communityProperty: Optional[str]
MATCH (n2) WHERE id(n2) = {node2}
RETURN {self._namespace}(n1, n2{community_property}) AS score
"""

# TODO use execute query
return self._query_runner.run_cypher(query)["score"].squeeze() # type: ignore

def totalNeighbors(self, node1: int, node2: int, **config: Any) -> float:
Expand Down
1 change: 1 addition & 0 deletions graphdatascience/utils/direct_util_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def find_node_id(self, labels: list[str] = [], properties: dict[str, Any] = {})
else:
query = "MATCH (n) RETURN id(n) AS id"

# TODO use execute query
node_match = self._query_runner.run_cypher(query, custom_error=False)

if len(node_match) != 1:
Expand Down
2 changes: 2 additions & 0 deletions graphdatascience/utils/util_proc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def asNode(self, node_id: int) -> Any:

"""
self._namespace += ".asNode"
# TODO use call_function?
result = self._query_runner.run_cypher(f"RETURN {self._namespace}({node_id}) AS node")

return result.iat[0, 0]
Expand All @@ -34,6 +35,7 @@ def asNodes(self, node_ids: list[int]) -> list[Any]:

"""
self._namespace += ".asNodes"
# TODO use call_function?
result = self._query_runner.run_cypher(f"RETURN {self._namespace}({node_ids}) AS nodes")

return result.iat[0, 0] # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions graphdatascience/utils/util_remote_proc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def asNode(self, node_id: int) -> Any:
query = "MATCH (n) WHERE id(n) = $nodeId RETURN n"
params = {"nodeId": node_id}

# TODO use execute_query
return self._query_runner.run_cypher(query=query, params=params).squeeze()

@filter_id_func_deprecation_warning()
Expand All @@ -41,6 +42,7 @@ def asNodes(self, node_ids: list[int]) -> list[Any]:
query = "MATCH (n) WHERE id(n) IN $nodeIds RETURN collect(n)"
params = {"nodeIds": node_ids}

# TODO use execute_query
return self._query_runner.run_cypher(query=query, params=params).squeeze() # type: ignore

@property
Expand Down