Skip to content

Refactor(optimizer)!: remove recursion from simplify #4988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sqlglot/executor/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
from sqlglot.executor.table import RowReader, Table
from sqlglot.helper import csv_reader, ensure_list, subclasses
from sqlglot.helper import csv_reader, subclasses


class PythonExecutor:
Expand Down Expand Up @@ -370,8 +370,8 @@ def _rename(self, e):
return self.func(e.key, *values)

if isinstance(e, exp.Func) and e.is_var_len_args:
*head, tail = values
return self.func(e.key, *head, *ensure_list(tail))
args = itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in values)
return self.func(e.key, *args)

return self.func(e.key, *values)
except Exception as ex:
Expand Down
25 changes: 20 additions & 5 deletions sqlglot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,31 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
Returns:
The transformed expression.
"""
end_hash: t.Optional[int] = None

while True:
for n in reversed(tuple(expression.walk())):
n._hash = hash(n)
# No need to walk the AST– we've already cached the hashes in the previous iteration
if end_hash is None:
for n in reversed(tuple(expression.walk())):
n._hash = hash(n)

start = hash(expression)
start_hash = hash(expression)
expression = func(expression)

for n in expression.walk():
expression_nodes = tuple(expression.walk())

# Uncache previous caches so we can recompute them
for n in reversed(expression_nodes):
n._hash = None
if start == hash(expression):
n._hash = hash(n)

end_hash = hash(expression)

if start_hash == end_hash:
# ... and reset the hash so we don't risk it becoming out of date if a mutation happens
for n in expression_nodes:
n._hash = None

break

return expression
Expand Down
145 changes: 79 additions & 66 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,74 +61,87 @@ def simplify(

dialect = Dialect.get_or_raise(dialect)

def _simplify(expression, root=True):
if (
max_depth
and isinstance(expression, exp.Connector)
and not isinstance(expression.parent, exp.Connector)
):
depth = connector_depth(expression)
if depth > max_depth:
logger.info(
f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
)
return expression
def _simplify(expression):
pre_transformation_stack = [expression]
post_transformation_stack = []

if expression.meta.get(FINAL):
return expression
while pre_transformation_stack:
node = pre_transformation_stack.pop()

if node.meta.get(FINAL):
continue

# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
group = node.args.get("group")

if group and hasattr(node, "selects"):
groups = set(group.expressions)
group.meta[FINAL] = True

for s in node.selects:
for n in s.walk():
if n in groups:
s.meta[FINAL] = True
break

having = node.args.get("having")
if having:
for n in having.walk():
if n in groups:
having.meta[FINAL] = True
break

parent = node.parent
root = node is expression

new_node = rewrite_between(node)
new_node = uniq_sort(new_node, root)
new_node = absorb_and_eliminate(new_node, root)
new_node = simplify_concat(new_node)
new_node = simplify_conditionals(new_node)

if constant_propagation:
new_node = propagate_constants(new_node, root)

if new_node is not node:
node.replace(new_node)

pre_transformation_stack.extend(
n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
)
post_transformation_stack.append((new_node, parent))

while post_transformation_stack:
node, parent = post_transformation_stack.pop()
root = node is expression

# Resets parent, arg_key, index pointers– this is needed because some of the
# previous transformations mutate the AST, leading to an inconsistent state
for k, v in tuple(node.args.items()):
node.set(k, v)

# Post-order transformations
new_node = simplify_not(node)
new_node = flatten(new_node)
new_node = simplify_connectors(new_node, root)
new_node = remove_complements(new_node, root)
new_node = simplify_coalesce(new_node, dialect)

new_node.parent = parent

new_node = simplify_literals(new_node, root)
new_node = simplify_equality(new_node)
new_node = simplify_parens(new_node)
new_node = simplify_datetrunc(new_node, dialect)
new_node = sort_comparison(new_node)
new_node = simplify_startswith(new_node)

if new_node is not node:
node.replace(new_node)

# group by expressions cannot be simplified, for example
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
# the projection must exactly match the group by key
group = expression.args.get("group")

if group and hasattr(expression, "selects"):
groups = set(group.expressions)
group.meta[FINAL] = True

for e in expression.selects:
for node in e.walk():
if node in groups:
e.meta[FINAL] = True
break

having = expression.args.get("having")
if having:
for node in having.walk():
if node in groups:
having.meta[FINAL] = True
break

# Pre-order transformations
node = expression
node = rewrite_between(node)
node = uniq_sort(node, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
node = simplify_conditionals(node)

if constant_propagation:
node = propagate_constants(node, root)

exp.replace_children(node, lambda e: _simplify(e, False))

# Post-order transformations
node = simplify_not(node)
node = flatten(node)
node = simplify_connectors(node, root)
node = remove_complements(node, root)
node = simplify_coalesce(node, dialect)
node.parent = expression.parent
node = simplify_literals(node, root)
node = simplify_equality(node)
node = simplify_parens(node)
node = simplify_datetrunc(node, dialect)
node = sort_comparison(node)
node = simplify_startswith(node)

if root:
expression.replace(node)
return node
return new_node

expression = while_changing(expression, _simplify)
remove_where_true(expression)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,12 @@ def test_scalar_functions(self):
result = execute(f"SELECT {sql}")
self.assertEqual(result.rows, [(expected,)])

result = execute(
"WITH t AS (SELECT 'a' AS c1, 'b' AS c2) SELECT NVL(c1, c2) FROM t",
dialect="oracle",
)
self.assertEqual(result.rows, [("a",)])

def test_case_sensitivity(self):
result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]})
self.assertEqual(result.columns, ("a",))
Expand Down
5 changes: 5 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,11 @@ def test_pushdown_projection(self):
def test_simplify(self):
self.check_file("simplify", simplify)

# Stress test with huge unios
union_sql = "SELECT 1 UNION ALL " * 1000 + "SELECT 1"
expression = parse_one(union_sql)
self.assertEqual(simplify(expression).sql(), union_sql)

# Ensure simplify mutates the AST properly
expression = parse_one("SELECT 1 + 2")
simplify(expression.selects[0])
Expand Down