Skip to content

Commit 82fd7bb

Browse files
kelvinjian-dbcloud-fan
authored andcommitted
[SPARK-48771][SQL] Speed up LogicalPlanIntegrity.validateExprIdUniqueness for large query plans
### What changes were proposed in this pull request? This PR rewrites `LogicalPlanIntegrity.hasUniqueExprIdsForOutput` to only traverse the query plan once and avoids expensive Scala collections operations like `.flatten`, `.groupBy`, and `.distinct`. ### Why are the changes needed? Speeds up query compilation when plan validation is enabled. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Made sure existing UTs pass. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47170 from kelvinjian-db/SPARK-48771-speed-up. Authored-by: Kelvin Jiang <kelvin.jiang@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent aa3c369 commit 82fd7bb

File tree

1 file changed

+25
-19
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical

1 file changed

+25
-19
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.SparkUnsupportedOperationException
2123
import org.apache.spark.internal.Logging
2224
import org.apache.spark.sql.AnalysisException
@@ -29,7 +31,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LOGICAL_QUERY_STAGE, Tre
2931
import org.apache.spark.sql.catalyst.types.DataTypeUtils
3032
import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
3133
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
32-
import org.apache.spark.sql.types.StructType
34+
import org.apache.spark.sql.types.{DataType, StructType}
3335

3436

3537
abstract class LogicalPlan
@@ -314,31 +316,35 @@ object LogicalPlanIntegrity {
314316
* in plan output. Returns the error message if the check does not pass.
315317
*/
316318
def hasUniqueExprIdsForOutput(plan: LogicalPlan): Option[String] = {
317-
val exprIds = plan.collect { case p if canGetOutputAttrs(p) =>
318-
// NOTE: we still need to filter resolved expressions here because the output of
319-
// some resolved logical plans can have unresolved references,
320-
// e.g., outer references in `ExistenceJoin`.
321-
p.output.filter(_.resolved).map { a => (a.exprId, a.dataType.asNullable) }
322-
}.flatten
319+
// SPARK-48771: rewritten using mutable collections to improve this function's performance and
320+
// avoid unnecessary traversals of the query plan.
321+
val exprIds = mutable.HashMap.empty[ExprId, mutable.HashSet[DataType]]
322+
val ignoredExprIds = mutable.HashSet.empty[ExprId]
323323

324-
val ignoredExprIds = plan.collect {
324+
plan.foreach {
325325
// NOTE: `Union` currently reuses input `ExprId`s for output references, but we cannot
326326
// simply modify the code for assigning new `ExprId`s in `Union#output` because
327327
// the modification will make breaking changes (See SPARK-32741(#29585)).
328328
// So, this check just ignores the `exprId`s of `Union` output.
329-
case u: Union if u.resolved => u.output.map(_.exprId)
330-
}.flatten.toSet
331-
332-
val groupedDataTypesByExprId = exprIds.filterNot { case (exprId, _) =>
333-
ignoredExprIds.contains(exprId)
334-
}.groupBy(_._1).values.map(_.distinct)
329+
case u: Union if u.resolved =>
330+
u.output.foreach(ignoredExprIds += _.exprId)
331+
case p if canGetOutputAttrs(p) =>
332+
p.output.foreach { a =>
333+
// NOTE: we still need to filter resolved expressions here because the output of
334+
// some resolved logical plans can have unresolved references,
335+
// e.g., outer references in `ExistenceJoin`.
336+
if (a.resolved) {
337+
val prevTypes = exprIds.getOrElseUpdate(a.exprId, mutable.HashSet.empty[DataType])
338+
prevTypes += a.dataType.asNullable
339+
}
340+
}
341+
case _ =>
342+
}
335343

336-
groupedDataTypesByExprId.collectFirst {
337-
case group if group.length > 1 =>
338-
val exprId = group.head._1
339-
val types = group.map(_._2.sql)
344+
exprIds.collectFirst {
345+
case (exprId, types) if types.size > 1 && !ignoredExprIds.contains(exprId) =>
340346
s"Multiple attributes have the same expression ID ${exprId.id} but different data types: " +
341-
types.mkString(", ") + ". The plan tree:\n" + plan.treeString
347+
types.map(_.sql).mkString(", ") + ". The plan tree:\n" + plan.treeString
342348
}
343349
}
344350

0 commit comments

Comments
 (0)