@@ -111,6 +111,15 @@ fn filter_push_down(
111
111
)
112
112
}
113
113
LogicalPlan :: Filter ( Filter { predicate, input } ) => {
114
+ // Current DataFusion version plans complex joins as Filter(CrossJoin)
115
+ // So for query like `SELECT ... FROM ... JOIN ... ON complex_condition WHERE predicate`
116
+ // Plan can look like Filter(predicate, Filter(join_condition, CrossJoin))
117
+ // This optimizer can mess with filter predicates, and break join detection later in rewrites
118
+ // So, for now, it just completely pessimizes plans like Filter(CrossJoin)
119
+ if let LogicalPlan :: CrossJoin ( _) = input. as_ref ( ) {
120
+ return issue_filter ( predicates, plan. clone ( ) ) ;
121
+ }
122
+
114
123
// When encountering a filter, collect it to our list of predicates,
115
124
// remove the filter from the plan and continue down the plan.
116
125
@@ -692,15 +701,10 @@ mod tests {
692
701
} ;
693
702
use datafusion:: logical_plan:: { binary_expr, col, count, lit, sum, LogicalPlanBuilder } ;
694
703
695
- fn optimize ( plan : & LogicalPlan ) -> Result < LogicalPlan > {
704
+ fn optimize ( plan : & LogicalPlan ) -> LogicalPlan {
696
705
let rule = FilterPushDown :: new ( ) ;
697
706
rule. optimize ( plan, & OptimizerConfig :: new ( ) )
698
- }
699
-
700
- fn assert_optimized_plan_eq ( plan : LogicalPlan , expected : & str ) {
701
- let optimized_plan = optimize ( & plan) . expect ( "failed to optimize plan" ) ;
702
- let formatted_plan = format ! ( "{:?}" , optimized_plan) ;
703
- assert_eq ! ( formatted_plan, expected) ;
707
+ . expect ( "failed to optimize plan" )
704
708
}
705
709
706
710
#[ test]
@@ -714,14 +718,7 @@ mod tests {
714
718
. filter ( col ( "t2.n2" ) . gt ( lit ( 5i32 ) ) ) ?
715
719
. build ( ) ?;
716
720
717
- let expected = "\
718
- Projection: #t1.c1 AS n1, #t1.c3 AS n2, alias=t2\
719
- \n Filter: #t1.c3 > Int32(5)\
720
- \n Projection: #t1.c1, #t1.c3\
721
- \n TableScan: t1 projection=None\
722
- ";
723
-
724
- assert_optimized_plan_eq ( plan, expected) ;
721
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
725
722
Ok ( ( ) )
726
723
}
727
724
@@ -752,17 +749,7 @@ mod tests {
752
749
. project ( vec ! [ col( "c7" ) , col( "c5" ) , col( "c9" ) ] ) ?
753
750
. build ( ) ?;
754
751
755
- let expected = "\
756
- Projection: #t3.c7, #t3.c5, #c9\
757
- \n Projection: #t3.c7, #t3.c5, #t3.c8 AS c9\
758
- \n Projection: #t2.c4 AS c7, #t2.c5, #t2.c6 AS c8, alias=t3\
759
- \n Projection: #t1.c1 AS c4, #t1.c2 AS c5, #t1.c3 AS c6, alias=t2\
760
- \n Filter: #t1.c2 > Int32(5) AND #t1.c2 <= Int32(10) AND #t1.c3 = Int32(0) AND NOT #t1.c1 < Int32(0)\
761
- \n Projection: #t1.c1, #t1.c2, #t1.c3\
762
- \n TableScan: t1 projection=None\
763
- ";
764
-
765
- assert_optimized_plan_eq ( plan, expected) ;
752
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
766
753
Ok ( ( ) )
767
754
}
768
755
@@ -782,18 +769,7 @@ mod tests {
782
769
. project ( vec ! [ col( "c1" ) , col( "c2" ) , col( "c3" ) ] ) ?
783
770
. build ( ) ?;
784
771
785
- let expected = "\
786
- Projection: #t1.c1, #c2, #t1.c3\
787
- \n Filter: #t1.c1 > #t1.c3\
788
- \n Projection: #t1.c1, #c2, #t1.c3\
789
- \n Filter: #c2 = Int32(5)\
790
- \n Projection: #t1.c1, #t1.c2 + Int32(5) AS c2, #t1.c3\
791
- \n Filter: #t1.c3 < Int32(5)\
792
- \n Projection: #t1.c1, #t1.c2, #t1.c3\
793
- \n TableScan: t1 projection=None\
794
- ";
795
-
796
- assert_optimized_plan_eq ( plan, expected) ;
772
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
797
773
Ok ( ( ) )
798
774
}
799
775
@@ -847,16 +823,7 @@ mod tests {
847
823
. filter ( col ( "c3" ) . eq ( lit ( 0i32 ) ) ) ?
848
824
. build ( ) ?;
849
825
850
- let expected = "\
851
- Projection: #t1.c1, #SUM(t1.c2) AS c2_sum, #t1.c3\
852
- \n Filter: #SUM(t1.c2) > Int32(10)\
853
- \n Aggregate: groupBy=[[#t1.c1, #t1.c3]], aggr=[[SUM(#t1.c2)]]\
854
- \n Filter: #t1.c3 = Int32(0)\
855
- \n Projection: #t1.c1, #t1.c2, #t1.c3\
856
- \n TableScan: t1 projection=None\
857
- ";
858
-
859
- assert_optimized_plan_eq ( plan, expected) ;
826
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
860
827
Ok ( ( ) )
861
828
}
862
829
@@ -897,14 +864,7 @@ mod tests {
897
864
. filter ( col ( "c3" ) . eq ( lit ( 5i32 ) ) ) ?
898
865
. build ( ) ?;
899
866
900
- let expected = "\
901
- Sort: #t1.c2\
902
- \n Filter: #t1.c3 = Int32(5)\
903
- \n Projection: #t1.c1, #t1.c2, #t1.c3\
904
- \n TableScan: t1 projection=None\
905
- ";
906
-
907
- assert_optimized_plan_eq ( plan, expected) ;
867
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
908
868
Ok ( ( ) )
909
869
}
910
870
@@ -998,19 +958,7 @@ mod tests {
998
958
. filter ( col ( "c2" ) . eq ( lit ( 10i32 ) ) ) ?
999
959
. build ( ) ?;
1000
960
1001
- let expected = "\
1002
- Filter: #j2.c2 = Int32(10)\
1003
- \n CrossJoin:\
1004
- \n Filter: #j1.c1 = Int32(5)\
1005
- \n Projection: #j1.c1\
1006
- \n TableScan: j1 projection=None\
1007
- \n Projection: #COUNT(UInt8(1)) AS c2, alias=j2\
1008
- \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
1009
- \n Projection: #j2.c2\
1010
- \n TableScan: j2 projection=None\
1011
- ";
1012
-
1013
- assert_optimized_plan_eq ( plan, expected) ;
961
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
1014
962
Ok ( ( ) )
1015
963
}
1016
964
0 commit comments