1 | #include "duckdb/common/types/hugeint.hpp" |
2 | #include "duckdb/optimizer/statistics_propagator.hpp" |
3 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
4 | #include "duckdb/planner/operator/logical_any_join.hpp" |
5 | #include "duckdb/planner/operator/logical_comparison_join.hpp" |
6 | #include "duckdb/planner/operator/logical_cross_product.hpp" |
7 | #include "duckdb/planner/operator/logical_join.hpp" |
8 | #include "duckdb/planner/operator/logical_limit.hpp" |
9 | #include "duckdb/planner/operator/logical_positional_join.hpp" |
10 | |
11 | namespace duckdb { |
12 | |
13 | void StatisticsPropagator::PropagateStatistics(LogicalComparisonJoin &join, unique_ptr<LogicalOperator> *node_ptr) { |
14 | for (idx_t i = 0; i < join.conditions.size(); i++) { |
15 | auto &condition = join.conditions[i]; |
16 | auto stats_left = PropagateExpression(expr&: condition.left); |
17 | auto stats_right = PropagateExpression(expr&: condition.right); |
18 | if (stats_left && stats_right) { |
19 | if ((condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || |
20 | condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) && |
21 | stats_left->CanHaveNull() && stats_right->CanHaveNull()) { |
22 | // null values are equal in this join, and both sides can have null values |
23 | // nothing to do here |
24 | continue; |
25 | } |
26 | auto prune_result = PropagateComparison(left&: *stats_left, right&: *stats_right, comparison: condition.comparison); |
27 | // Add stats to logical_join for perfect hash join |
28 | join.join_stats.push_back(x: std::move(stats_left)); |
29 | join.join_stats.push_back(x: std::move(stats_right)); |
30 | switch (prune_result) { |
31 | case FilterPropagateResult::FILTER_FALSE_OR_NULL: |
32 | case FilterPropagateResult::FILTER_ALWAYS_FALSE: |
33 | // filter is always false or null, none of the join conditions matter |
34 | switch (join.join_type) { |
35 | case JoinType::SEMI: |
36 | case JoinType::INNER: |
37 | // semi or inner join on false; entire node can be pruned |
38 | ReplaceWithEmptyResult(node&: *node_ptr); |
39 | return; |
40 | case JoinType::ANTI: { |
41 | // when the right child has data, return the left child |
42 | // when the right child has no data, return an empty set |
43 | auto limit = make_uniq<LogicalLimit>(args: 1, args: 0, args: nullptr, args: nullptr); |
44 | limit->AddChild(child: std::move(join.children[1])); |
45 | auto cross_product = LogicalCrossProduct::Create(left: std::move(join.children[0]), right: std::move(limit)); |
46 | *node_ptr = std::move(cross_product); |
47 | return; |
48 | } |
49 | case JoinType::LEFT: |
50 | // anti/left outer join: replace right side with empty node |
51 | ReplaceWithEmptyResult(node&: join.children[1]); |
52 | return; |
53 | case JoinType::RIGHT: |
54 | // right outer join: replace left side with empty node |
55 | ReplaceWithEmptyResult(node&: join.children[0]); |
56 | return; |
57 | default: |
58 | // other join types: can't do much meaningful with this information |
59 | // full outer join requires both sides anyway; we can skip the execution of the actual join, but eh |
60 | // mark/single join requires knowing if the rhs has null values or not |
61 | break; |
62 | } |
63 | break; |
64 | case FilterPropagateResult::FILTER_ALWAYS_TRUE: |
65 | // filter is always true |
66 | if (join.conditions.size() > 1) { |
67 | // there are multiple conditions: erase this condition |
68 | join.conditions.erase(position: join.conditions.begin() + i); |
69 | // remove the corresponding statistics |
70 | join.join_stats.clear(); |
71 | i--; |
72 | continue; |
73 | } else { |
74 | // this is the only condition and it is always true: all conditions are true |
75 | switch (join.join_type) { |
76 | case JoinType::SEMI: { |
77 | // when the right child has data, return the left child |
78 | // when the right child has no data, return an empty set |
79 | auto limit = make_uniq<LogicalLimit>(args: 1, args: 0, args: nullptr, args: nullptr); |
80 | limit->AddChild(child: std::move(join.children[1])); |
81 | auto cross_product = LogicalCrossProduct::Create(left: std::move(join.children[0]), right: std::move(limit)); |
82 | *node_ptr = std::move(cross_product); |
83 | return; |
84 | } |
85 | case JoinType::INNER: { |
86 | // inner, replace with cross product |
87 | auto cross_product = |
88 | LogicalCrossProduct::Create(left: std::move(join.children[0]), right: std::move(join.children[1])); |
89 | *node_ptr = std::move(cross_product); |
90 | return; |
91 | } |
92 | case JoinType::ANTI: |
93 | // anti join on true: empty result |
94 | ReplaceWithEmptyResult(node&: *node_ptr); |
95 | return; |
96 | default: |
97 | // we don't handle mark/single join here yet |
98 | break; |
99 | } |
100 | } |
101 | break; |
102 | default: |
103 | break; |
104 | } |
105 | } |
106 | // after we have propagated, we can update the statistics on both sides |
107 | // note that it is fine to do this now, even if the same column is used again later |
108 | // e.g. if we have i=j AND i=k, and the stats for j and k are disjoint, we know there are no results |
109 | // so if we have e.g. i: [0, 100], j: [0, 25], k: [75, 100] |
110 | // we can set i: [0, 25] after the first comparison, and statically determine that the second comparison is fals |
111 | |
112 | // note that we can't update statistics the same for all join types |
113 | // mark and single joins don't filter any tuples -> so there is no propagation possible |
114 | // anti joins have inverse statistics propagation |
115 | // (i.e. if we have an anti join on i: [0, 100] and j: [0, 25], the resulting stats are i:[25,100]) |
116 | // for now we don't handle anti joins |
117 | if (condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || |
118 | condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { |
119 | // skip update when null values are equal (for now?) |
120 | continue; |
121 | } |
122 | switch (join.join_type) { |
123 | case JoinType::INNER: |
124 | case JoinType::SEMI: { |
125 | UpdateFilterStatistics(left&: *condition.left, right&: *condition.right, comparison_type: condition.comparison); |
126 | auto stats_left = PropagateExpression(expr&: condition.left); |
127 | auto stats_right = PropagateExpression(expr&: condition.right); |
128 | // Update join_stats when is already part of the join |
129 | if (join.join_stats.size() == 2) { |
130 | join.join_stats[0] = std::move(stats_left); |
131 | join.join_stats[1] = std::move(stats_right); |
132 | } |
133 | break; |
134 | } |
135 | default: |
136 | break; |
137 | } |
138 | } |
139 | } |
140 | |
141 | void StatisticsPropagator::PropagateStatistics(LogicalAnyJoin &join, unique_ptr<LogicalOperator> *node_ptr) { |
142 | // propagate the expression into the join condition |
143 | PropagateExpression(expr&: join.condition); |
144 | } |
145 | |
146 | void StatisticsPropagator::MultiplyCardinalities(unique_ptr<NodeStatistics> &stats, NodeStatistics &new_stats) { |
147 | if (!stats->has_estimated_cardinality || !new_stats.has_estimated_cardinality || !stats->has_max_cardinality || |
148 | !new_stats.has_max_cardinality) { |
149 | stats = nullptr; |
150 | return; |
151 | } |
152 | stats->estimated_cardinality = MaxValue<idx_t>(a: stats->estimated_cardinality, b: new_stats.estimated_cardinality); |
153 | auto new_max = Hugeint::Multiply(lhs: stats->max_cardinality, rhs: new_stats.max_cardinality); |
154 | if (new_max < NumericLimits<int64_t>::Maximum()) { |
155 | int64_t result; |
156 | if (!Hugeint::TryCast<int64_t>(input: new_max, result)) { |
157 | throw InternalException("Overflow in cast in statistics propagation" ); |
158 | } |
159 | D_ASSERT(result >= 0); |
160 | stats->max_cardinality = idx_t(result); |
161 | } else { |
162 | stats = nullptr; |
163 | } |
164 | } |
165 | |
166 | unique_ptr<NodeStatistics> StatisticsPropagator::PropagateStatistics(LogicalJoin &join, |
167 | unique_ptr<LogicalOperator> *node_ptr) { |
168 | // first propagate through the children of the join |
169 | node_stats = PropagateStatistics(node_ptr&: join.children[0]); |
170 | for (idx_t child_idx = 1; child_idx < join.children.size(); child_idx++) { |
171 | auto child_stats = PropagateStatistics(node_ptr&: join.children[child_idx]); |
172 | if (!child_stats) { |
173 | node_stats = nullptr; |
174 | } else if (node_stats) { |
175 | MultiplyCardinalities(stats&: node_stats, new_stats&: *child_stats); |
176 | } |
177 | } |
178 | |
179 | auto join_type = join.join_type; |
180 | // depending on the join type, we might need to alter the statistics |
181 | // LEFT, FULL, RIGHT OUTER and SINGLE joins can introduce null values |
182 | // this requires us to alter the statistics after this point in the query plan |
183 | bool adds_null_on_left = IsRightOuterJoin(type: join_type); |
184 | bool adds_null_on_right = IsLeftOuterJoin(type: join_type) || join_type == JoinType::SINGLE; |
185 | |
186 | vector<ColumnBinding> left_bindings, right_bindings; |
187 | if (adds_null_on_left) { |
188 | left_bindings = join.children[0]->GetColumnBindings(); |
189 | } |
190 | if (adds_null_on_right) { |
191 | right_bindings = join.children[1]->GetColumnBindings(); |
192 | } |
193 | |
194 | // then propagate into the join conditions |
195 | switch (join.type) { |
196 | case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: |
197 | case LogicalOperatorType::LOGICAL_DELIM_JOIN: |
198 | case LogicalOperatorType::LOGICAL_ASOF_JOIN: |
199 | PropagateStatistics(join&: join.Cast<LogicalComparisonJoin>(), node_ptr); |
200 | break; |
201 | case LogicalOperatorType::LOGICAL_ANY_JOIN: |
202 | PropagateStatistics(join&: join.Cast<LogicalAnyJoin>(), node_ptr); |
203 | break; |
204 | default: |
205 | break; |
206 | } |
207 | |
208 | if (adds_null_on_right) { |
209 | // left or full outer join: set IsNull() to true for all rhs statistics |
210 | for (auto &binding : right_bindings) { |
211 | auto stats = statistics_map.find(x: binding); |
212 | if (stats != statistics_map.end()) { |
213 | stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); |
214 | } |
215 | } |
216 | } |
217 | if (adds_null_on_left) { |
218 | // right or full outer join: set IsNull() to true for all lhs statistics |
219 | for (auto &binding : left_bindings) { |
220 | auto stats = statistics_map.find(x: binding); |
221 | if (stats != statistics_map.end()) { |
222 | stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); |
223 | } |
224 | } |
225 | } |
226 | return std::move(node_stats); |
227 | } |
228 | |
229 | static void MaxCardinalities(unique_ptr<NodeStatistics> &stats, NodeStatistics &new_stats) { |
230 | if (!stats->has_estimated_cardinality || !new_stats.has_estimated_cardinality || !stats->has_max_cardinality || |
231 | !new_stats.has_max_cardinality) { |
232 | stats = nullptr; |
233 | return; |
234 | } |
235 | stats->estimated_cardinality = MaxValue<idx_t>(a: stats->estimated_cardinality, b: new_stats.estimated_cardinality); |
236 | stats->max_cardinality = MaxValue<idx_t>(a: stats->max_cardinality, b: new_stats.max_cardinality); |
237 | } |
238 | |
239 | unique_ptr<NodeStatistics> StatisticsPropagator::PropagateStatistics(LogicalPositionalJoin &join, |
240 | unique_ptr<LogicalOperator> *node_ptr) { |
241 | D_ASSERT(join.type == LogicalOperatorType::LOGICAL_POSITIONAL_JOIN); |
242 | |
243 | // first propagate through the children of the join |
244 | node_stats = PropagateStatistics(node_ptr&: join.children[0]); |
245 | for (idx_t child_idx = 1; child_idx < join.children.size(); child_idx++) { |
246 | auto child_stats = PropagateStatistics(node_ptr&: join.children[child_idx]); |
247 | if (!child_stats) { |
248 | node_stats = nullptr; |
249 | } else if (node_stats) { |
250 | if (!node_stats->has_estimated_cardinality || !child_stats->has_estimated_cardinality || |
251 | !node_stats->has_max_cardinality || !child_stats->has_max_cardinality) { |
252 | node_stats = nullptr; |
253 | } else { |
254 | MaxCardinalities(stats&: node_stats, new_stats&: *child_stats); |
255 | } |
256 | } |
257 | } |
258 | |
259 | // No conditions. |
260 | |
261 | // Positional Joins are always FULL OUTER |
262 | |
263 | // set IsNull() to true for all lhs statistics |
264 | auto left_bindings = join.children[0]->GetColumnBindings(); |
265 | for (auto &binding : left_bindings) { |
266 | auto stats = statistics_map.find(x: binding); |
267 | if (stats != statistics_map.end()) { |
268 | stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); |
269 | } |
270 | } |
271 | |
272 | // set IsNull() to true for all rhs statistics |
273 | auto right_bindings = join.children[1]->GetColumnBindings(); |
274 | for (auto &binding : right_bindings) { |
275 | auto stats = statistics_map.find(x: binding); |
276 | if (stats != statistics_map.end()) { |
277 | stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES); |
278 | } |
279 | } |
280 | |
281 | return std::move(node_stats); |
282 | } |
283 | |
284 | } // namespace duckdb |
285 | |