Skip to content

Commit 940c30f

Browse files
authored
fix(cubesql): Disable filter pushdown over Filter(CrossJoin) (#9474)
This should help with rewrites filters on top of complex ungrouped-grouped joins as a subquery joins
1 parent 85c27a9 commit 940c30f

8 files changed

+139
-69
lines changed

rust/cubesql/cubesql/src/compile/engine/df/optimizers/filter_push_down.rs

+17-69
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ fn filter_push_down(
111111
)
112112
}
113113
LogicalPlan::Filter(Filter { predicate, input }) => {
114+
// Current DataFusion version plans complex joins as Filter(CrossJoin)
115+
// So for query like `SELECT ... FROM ... JOIN ... ON complex_condition WHERE predicate`
116+
// Plan can look like Filter(predicate, Filter(join_condition, CrossJoin))
117+
// This optimizer can mess with filter predicates, and break join detection later in rewrites
118+
// So, for now, it just completely pessimizes plans like Filter(CrossJoin)
119+
if let LogicalPlan::CrossJoin(_) = input.as_ref() {
120+
return issue_filter(predicates, plan.clone());
121+
}
122+
114123
// When encountering a filter, collect it to our list of predicates,
115124
// remove the filter from the plan and continue down the plan.
116125

@@ -692,15 +701,10 @@ mod tests {
692701
};
693702
use datafusion::logical_plan::{binary_expr, col, count, lit, sum, LogicalPlanBuilder};
694703

695-
fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
704+
fn optimize(plan: &LogicalPlan) -> LogicalPlan {
696705
let rule = FilterPushDown::new();
697706
rule.optimize(plan, &OptimizerConfig::new())
698-
}
699-
700-
fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) {
701-
let optimized_plan = optimize(&plan).expect("failed to optimize plan");
702-
let formatted_plan = format!("{:?}", optimized_plan);
703-
assert_eq!(formatted_plan, expected);
707+
.expect("failed to optimize plan")
704708
}
705709

706710
#[test]
@@ -714,14 +718,7 @@ mod tests {
714718
.filter(col("t2.n2").gt(lit(5i32)))?
715719
.build()?;
716720

717-
let expected = "\
718-
Projection: #t1.c1 AS n1, #t1.c3 AS n2, alias=t2\
719-
\n Filter: #t1.c3 > Int32(5)\
720-
\n Projection: #t1.c1, #t1.c3\
721-
\n TableScan: t1 projection=None\
722-
";
723-
724-
assert_optimized_plan_eq(plan, expected);
721+
insta::assert_debug_snapshot!(optimize(&plan));
725722
Ok(())
726723
}
727724

@@ -752,17 +749,7 @@ mod tests {
752749
.project(vec![col("c7"), col("c5"), col("c9")])?
753750
.build()?;
754751

755-
let expected = "\
756-
Projection: #t3.c7, #t3.c5, #c9\
757-
\n Projection: #t3.c7, #t3.c5, #t3.c8 AS c9\
758-
\n Projection: #t2.c4 AS c7, #t2.c5, #t2.c6 AS c8, alias=t3\
759-
\n Projection: #t1.c1 AS c4, #t1.c2 AS c5, #t1.c3 AS c6, alias=t2\
760-
\n Filter: #t1.c2 > Int32(5) AND #t1.c2 <= Int32(10) AND #t1.c3 = Int32(0) AND NOT #t1.c1 < Int32(0)\
761-
\n Projection: #t1.c1, #t1.c2, #t1.c3\
762-
\n TableScan: t1 projection=None\
763-
";
764-
765-
assert_optimized_plan_eq(plan, expected);
752+
insta::assert_debug_snapshot!(optimize(&plan));
766753
Ok(())
767754
}
768755

@@ -782,18 +769,7 @@ mod tests {
782769
.project(vec![col("c1"), col("c2"), col("c3")])?
783770
.build()?;
784771

785-
let expected = "\
786-
Projection: #t1.c1, #c2, #t1.c3\
787-
\n Filter: #t1.c1 > #t1.c3\
788-
\n Projection: #t1.c1, #c2, #t1.c3\
789-
\n Filter: #c2 = Int32(5)\
790-
\n Projection: #t1.c1, #t1.c2 + Int32(5) AS c2, #t1.c3\
791-
\n Filter: #t1.c3 < Int32(5)\
792-
\n Projection: #t1.c1, #t1.c2, #t1.c3\
793-
\n TableScan: t1 projection=None\
794-
";
795-
796-
assert_optimized_plan_eq(plan, expected);
772+
insta::assert_debug_snapshot!(optimize(&plan));
797773
Ok(())
798774
}
799775

@@ -847,16 +823,7 @@ mod tests {
847823
.filter(col("c3").eq(lit(0i32)))?
848824
.build()?;
849825

850-
let expected = "\
851-
Projection: #t1.c1, #SUM(t1.c2) AS c2_sum, #t1.c3\
852-
\n Filter: #SUM(t1.c2) > Int32(10)\
853-
\n Aggregate: groupBy=[[#t1.c1, #t1.c3]], aggr=[[SUM(#t1.c2)]]\
854-
\n Filter: #t1.c3 = Int32(0)\
855-
\n Projection: #t1.c1, #t1.c2, #t1.c3\
856-
\n TableScan: t1 projection=None\
857-
";
858-
859-
assert_optimized_plan_eq(plan, expected);
826+
insta::assert_debug_snapshot!(optimize(&plan));
860827
Ok(())
861828
}
862829

@@ -897,14 +864,7 @@ mod tests {
897864
.filter(col("c3").eq(lit(5i32)))?
898865
.build()?;
899866

900-
let expected = "\
901-
Sort: #t1.c2\
902-
\n Filter: #t1.c3 = Int32(5)\
903-
\n Projection: #t1.c1, #t1.c2, #t1.c3\
904-
\n TableScan: t1 projection=None\
905-
";
906-
907-
assert_optimized_plan_eq(plan, expected);
867+
insta::assert_debug_snapshot!(optimize(&plan));
908868
Ok(())
909869
}
910870

@@ -998,19 +958,7 @@ mod tests {
998958
.filter(col("c2").eq(lit(10i32)))?
999959
.build()?;
1000960

1001-
let expected = "\
1002-
Filter: #j2.c2 = Int32(10)\
1003-
\n CrossJoin:\
1004-
\n Filter: #j1.c1 = Int32(5)\
1005-
\n Projection: #j1.c1\
1006-
\n TableScan: j1 projection=None\
1007-
\n Projection: #COUNT(UInt8(1)) AS c2, alias=j2\
1008-
\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
1009-
\n Projection: #j2.c2\
1010-
\n TableScan: j2 projection=None\
1011-
";
1012-
1013-
assert_optimized_plan_eq(plan, expected);
961+
insta::assert_debug_snapshot!(optimize(&plan));
1014962
Ok(())
1015963
}
1016964

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/filter_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Filter: #j2.c2 = Int32(10)
6+
Filter: #j1.c1 = Int32(5)
7+
CrossJoin:
8+
Projection: #j1.c1
9+
TableScan: j1 projection=None
10+
Projection: #COUNT(UInt8(1)) AS c2, alias=j2
11+
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]
12+
Projection: #j2.c2
13+
TableScan: j2 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/filter_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #t1.c1 AS n1, #t1.c3 AS n2, alias=t2
6+
Filter: #t1.c3 > Int32(5)
7+
Projection: #t1.c1, #t1.c3
8+
TableScan: t1 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/filter_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Sort: #t1.c2
6+
Filter: #t1.c3 = Int32(5)
7+
Projection: #t1.c1, #t1.c2, #t1.c3
8+
TableScan: t1 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/filter_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #t1.c1, #SUM(t1.c2) AS c2_sum, #t1.c3
6+
Filter: #SUM(t1.c2) > Int32(10)
7+
Aggregate: groupBy=[[#t1.c1, #t1.c3]], aggr=[[SUM(#t1.c2)]]
8+
Filter: #t1.c3 = Int32(0)
9+
Projection: #t1.c1, #t1.c2, #t1.c3
10+
TableScan: t1 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/filter_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #t3.c7, #t3.c5, #c9
6+
Projection: #t3.c7, #t3.c5, #t3.c8 AS c9
7+
Projection: #t2.c4 AS c7, #t2.c5, #t2.c6 AS c8, alias=t3
8+
Projection: #t1.c1 AS c4, #t1.c2 AS c5, #t1.c3 AS c6, alias=t2
9+
Filter: #t1.c2 > Int32(5) AND #t1.c2 <= Int32(10) AND #t1.c3 = Int32(0) AND NOT #t1.c1 < Int32(0)
10+
Projection: #t1.c1, #t1.c2, #t1.c3
11+
TableScan: t1 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/filter_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #t1.c1, #c2, #t1.c3
6+
Filter: #t1.c1 > #t1.c3
7+
Projection: #t1.c1, #c2, #t1.c3
8+
Filter: #c2 = Int32(5)
9+
Projection: #t1.c1, #t1.c2 + Int32(5) AS c2, #t1.c3
10+
Filter: #t1.c3 < Int32(5)
11+
Projection: #t1.c1, #t1.c2, #t1.c3
12+
TableScan: t1 projection=None

rust/cubesql/cubesql/src/compile/test/test_cube_join_grouped.rs

+60
Original file line numberDiff line numberDiff line change
@@ -899,3 +899,63 @@ LIMIT 1
899899
.on
900900
.contains(r#"${MultiTypeCube.dim_str0} IS NOT DISTINCT FROM \"t0\".\"dim_str0\""#));
901901
}
902+
903+
/// Filter on top of ungrouped-grouped join with complex condition should be rewritten as well
904+
#[tokio::test]
905+
async fn test_join_ungrouped_grouped_with_filter_and_measure() {
906+
if !Rewriter::sql_push_down_enabled() {
907+
return;
908+
}
909+
init_testing_logger();
910+
911+
let query_plan = convert_select_to_query_plan(
912+
// language=PostgreSQL
913+
r#"
914+
SELECT "t0"."measure"
915+
FROM
916+
MultiTypeCube
917+
INNER JOIN (
918+
SELECT
919+
dim_str0,
920+
AVG(avgPrice) AS "measure"
921+
FROM
922+
MultiTypeCube
923+
GROUP BY 1
924+
) "t0"
925+
ON (MultiTypeCube.dim_str0 IS NOT DISTINCT FROM "t0".dim_str0)
926+
WHERE ("t0"."measure" IS NULL)
927+
LIMIT 1
928+
;
929+
"#
930+
.to_string(),
931+
DatabaseProtocol::PostgreSQL,
932+
)
933+
.await;
934+
935+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
936+
println!(
937+
"Physical plan: {}",
938+
displayable(physical_plan.as_ref()).indent()
939+
);
940+
941+
let request = query_plan
942+
.as_logical_plan()
943+
.find_cube_scan_wrapped_sql()
944+
.request;
945+
946+
assert_eq!(request.ungrouped, Some(true));
947+
948+
assert_eq!(request.subquery_joins.as_ref().unwrap().len(), 1);
949+
950+
let subquery = &request.subquery_joins.unwrap()[0];
951+
952+
assert!(!subquery.sql.contains("ungrouped"));
953+
assert_eq!(subquery.join_type, "INNER");
954+
assert!(subquery
955+
.on
956+
.contains(r#"${MultiTypeCube.dim_str0} IS NOT DISTINCT FROM \"t0\".\"dim_str0\""#));
957+
958+
// Outer filter
959+
assert_eq!(request.segments.as_ref().unwrap().len(), 1);
960+
assert!(request.segments.as_ref().unwrap()[0].contains(r#"\"t0\".\"measure\" IS NULL"#));
961+
}

0 commit comments

Comments
 (0)