Skip to content

Commit fb83fac

Browse files
authored
Refactor(optimizer)!: remove recursion from simplify (#4988)
* Refactor: remove recursion from optimizer.simplify * Test Oracle edge case * Update comment * Refactor while_changing to avoid recursive hash computation * Add stress test * Reduce number of AST walks in while_changing * Refactor expression.dfs into manual stack logic
1 parent aae9aa8 commit fb83fac

File tree

5 files changed

+113
-74
lines changed

5 files changed

+113
-74
lines changed

sqlglot/executor/python.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sqlglot.executor.context import Context
1010
from sqlglot.executor.env import ENV
1111
from sqlglot.executor.table import RowReader, Table
12-
from sqlglot.helper import csv_reader, ensure_list, subclasses
12+
from sqlglot.helper import csv_reader, subclasses
1313

1414

1515
class PythonExecutor:
@@ -370,8 +370,8 @@ def _rename(self, e):
370370
return self.func(e.key, *values)
371371

372372
if isinstance(e, exp.Func) and e.is_var_len_args:
373-
*head, tail = values
374-
return self.func(e.key, *head, *ensure_list(tail))
373+
args = itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in values)
374+
return self.func(e.key, *args)
375375

376376
return self.func(e.key, *values)
377377
except Exception as ex:

sqlglot/helper.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,16 +211,31 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
211211
Returns:
212212
The transformed expression.
213213
"""
214+
end_hash: t.Optional[int] = None
215+
214216
while True:
215-
for n in reversed(tuple(expression.walk())):
216-
n._hash = hash(n)
217+
# No need to walk the AST– we've already cached the hashes in the previous iteration
218+
if end_hash is None:
219+
for n in reversed(tuple(expression.walk())):
220+
n._hash = hash(n)
217221

218-
start = hash(expression)
222+
start_hash = hash(expression)
219223
expression = func(expression)
220224

221-
for n in expression.walk():
225+
expression_nodes = tuple(expression.walk())
226+
227+
# Uncache previous caches so we can recompute them
228+
for n in reversed(expression_nodes):
222229
n._hash = None
223-
if start == hash(expression):
230+
n._hash = hash(n)
231+
232+
end_hash = hash(expression)
233+
234+
if start_hash == end_hash:
235+
# ... and reset the hash so we don't risk it becoming out of date if a mutation happens
236+
for n in expression_nodes:
237+
n._hash = None
238+
224239
break
225240

226241
return expression

sqlglot/optimizer/simplify.py

Lines changed: 79 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -61,74 +61,87 @@ def simplify(
6161

6262
dialect = Dialect.get_or_raise(dialect)
6363

64-
def _simplify(expression, root=True):
65-
if (
66-
max_depth
67-
and isinstance(expression, exp.Connector)
68-
and not isinstance(expression.parent, exp.Connector)
69-
):
70-
depth = connector_depth(expression)
71-
if depth > max_depth:
72-
logger.info(
73-
f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
74-
)
75-
return expression
64+
def _simplify(expression):
65+
pre_transformation_stack = [expression]
66+
post_transformation_stack = []
7667

77-
if expression.meta.get(FINAL):
78-
return expression
68+
while pre_transformation_stack:
69+
node = pre_transformation_stack.pop()
70+
71+
if node.meta.get(FINAL):
72+
continue
73+
74+
# group by expressions cannot be simplified, for example
75+
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
76+
# the projection must exactly match the group by key
77+
group = node.args.get("group")
78+
79+
if group and hasattr(node, "selects"):
80+
groups = set(group.expressions)
81+
group.meta[FINAL] = True
82+
83+
for s in node.selects:
84+
for n in s.walk():
85+
if n in groups:
86+
s.meta[FINAL] = True
87+
break
88+
89+
having = node.args.get("having")
90+
if having:
91+
for n in having.walk():
92+
if n in groups:
93+
having.meta[FINAL] = True
94+
break
95+
96+
parent = node.parent
97+
root = node is expression
98+
99+
new_node = rewrite_between(node)
100+
new_node = uniq_sort(new_node, root)
101+
new_node = absorb_and_eliminate(new_node, root)
102+
new_node = simplify_concat(new_node)
103+
new_node = simplify_conditionals(new_node)
104+
105+
if constant_propagation:
106+
new_node = propagate_constants(new_node, root)
107+
108+
if new_node is not node:
109+
node.replace(new_node)
110+
111+
pre_transformation_stack.extend(
112+
n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
113+
)
114+
post_transformation_stack.append((new_node, parent))
115+
116+
while post_transformation_stack:
117+
node, parent = post_transformation_stack.pop()
118+
root = node is expression
119+
120+
# Resets parent, arg_key, index pointers– this is needed because some of the
121+
# previous transformations mutate the AST, leading to an inconsistent state
122+
for k, v in tuple(node.args.items()):
123+
node.set(k, v)
124+
125+
# Post-order transformations
126+
new_node = simplify_not(node)
127+
new_node = flatten(new_node)
128+
new_node = simplify_connectors(new_node, root)
129+
new_node = remove_complements(new_node, root)
130+
new_node = simplify_coalesce(new_node, dialect)
131+
132+
new_node.parent = parent
133+
134+
new_node = simplify_literals(new_node, root)
135+
new_node = simplify_equality(new_node)
136+
new_node = simplify_parens(new_node)
137+
new_node = simplify_datetrunc(new_node, dialect)
138+
new_node = sort_comparison(new_node)
139+
new_node = simplify_startswith(new_node)
140+
141+
if new_node is not node:
142+
node.replace(new_node)
79143

80-
# group by expressions cannot be simplified, for example
81-
# select x + 1 + 1 FROM y GROUP BY x + 1 + 1
82-
# the projection must exactly match the group by key
83-
group = expression.args.get("group")
84-
85-
if group and hasattr(expression, "selects"):
86-
groups = set(group.expressions)
87-
group.meta[FINAL] = True
88-
89-
for e in expression.selects:
90-
for node in e.walk():
91-
if node in groups:
92-
e.meta[FINAL] = True
93-
break
94-
95-
having = expression.args.get("having")
96-
if having:
97-
for node in having.walk():
98-
if node in groups:
99-
having.meta[FINAL] = True
100-
break
101-
102-
# Pre-order transformations
103-
node = expression
104-
node = rewrite_between(node)
105-
node = uniq_sort(node, root)
106-
node = absorb_and_eliminate(node, root)
107-
node = simplify_concat(node)
108-
node = simplify_conditionals(node)
109-
110-
if constant_propagation:
111-
node = propagate_constants(node, root)
112-
113-
exp.replace_children(node, lambda e: _simplify(e, False))
114-
115-
# Post-order transformations
116-
node = simplify_not(node)
117-
node = flatten(node)
118-
node = simplify_connectors(node, root)
119-
node = remove_complements(node, root)
120-
node = simplify_coalesce(node, dialect)
121-
node.parent = expression.parent
122-
node = simplify_literals(node, root)
123-
node = simplify_equality(node)
124-
node = simplify_parens(node)
125-
node = simplify_datetrunc(node, dialect)
126-
node = sort_comparison(node)
127-
node = simplify_startswith(node)
128-
129-
if root:
130-
expression.replace(node)
131-
return node
144+
return new_node
132145

133146
expression = while_changing(expression, _simplify)
134147
remove_where_true(expression)

tests/test_executor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,12 @@ def test_scalar_functions(self):
728728
result = execute(f"SELECT {sql}")
729729
self.assertEqual(result.rows, [(expected,)])
730730

731+
result = execute(
732+
"WITH t AS (SELECT 'a' AS c1, 'b' AS c2) SELECT NVL(c1, c2) FROM t",
733+
dialect="oracle",
734+
)
735+
self.assertEqual(result.rows, [("a",)])
736+
731737
def test_case_sensitivity(self):
732738
result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]})
733739
self.assertEqual(result.columns, ("a",))

tests/test_optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,11 @@ def test_pushdown_projection(self):
538538
def test_simplify(self):
539539
self.check_file("simplify", simplify)
540540

541+
# Stress test with huge unios
542+
union_sql = "SELECT 1 UNION ALL " * 1000 + "SELECT 1"
543+
expression = parse_one(union_sql)
544+
self.assertEqual(simplify(expression).sql(), union_sql)
545+
541546
# Ensure simplify mutates the AST properly
542547
expression = parse_one("SELECT 1 + 2")
543548
simplify(expression.selects[0])

0 commit comments

Comments
 (0)