1#include "duckdb/common/types/hugeint.hpp"
2#include "duckdb/optimizer/statistics_propagator.hpp"
3#include "duckdb/planner/expression/bound_columnref_expression.hpp"
4#include "duckdb/planner/operator/logical_any_join.hpp"
5#include "duckdb/planner/operator/logical_comparison_join.hpp"
6#include "duckdb/planner/operator/logical_cross_product.hpp"
7#include "duckdb/planner/operator/logical_join.hpp"
8#include "duckdb/planner/operator/logical_limit.hpp"
9#include "duckdb/planner/operator/logical_positional_join.hpp"
10
11namespace duckdb {
12
13void StatisticsPropagator::PropagateStatistics(LogicalComparisonJoin &join, unique_ptr<LogicalOperator> *node_ptr) {
14 for (idx_t i = 0; i < join.conditions.size(); i++) {
15 auto &condition = join.conditions[i];
16 auto stats_left = PropagateExpression(expr&: condition.left);
17 auto stats_right = PropagateExpression(expr&: condition.right);
18 if (stats_left && stats_right) {
19 if ((condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM ||
20 condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) &&
21 stats_left->CanHaveNull() && stats_right->CanHaveNull()) {
22 // null values are equal in this join, and both sides can have null values
23 // nothing to do here
24 continue;
25 }
26 auto prune_result = PropagateComparison(left&: *stats_left, right&: *stats_right, comparison: condition.comparison);
27 // Add stats to logical_join for perfect hash join
28 join.join_stats.push_back(x: std::move(stats_left));
29 join.join_stats.push_back(x: std::move(stats_right));
30 switch (prune_result) {
31 case FilterPropagateResult::FILTER_FALSE_OR_NULL:
32 case FilterPropagateResult::FILTER_ALWAYS_FALSE:
33 // filter is always false or null, none of the join conditions matter
34 switch (join.join_type) {
35 case JoinType::SEMI:
36 case JoinType::INNER:
37 // semi or inner join on false; entire node can be pruned
38 ReplaceWithEmptyResult(node&: *node_ptr);
39 return;
40 case JoinType::ANTI: {
41 // when the right child has data, return the left child
42 // when the right child has no data, return an empty set
43 auto limit = make_uniq<LogicalLimit>(args: 1, args: 0, args: nullptr, args: nullptr);
44 limit->AddChild(child: std::move(join.children[1]));
45 auto cross_product = LogicalCrossProduct::Create(left: std::move(join.children[0]), right: std::move(limit));
46 *node_ptr = std::move(cross_product);
47 return;
48 }
49 case JoinType::LEFT:
50 // anti/left outer join: replace right side with empty node
51 ReplaceWithEmptyResult(node&: join.children[1]);
52 return;
53 case JoinType::RIGHT:
54 // right outer join: replace left side with empty node
55 ReplaceWithEmptyResult(node&: join.children[0]);
56 return;
57 default:
58 // other join types: can't do much meaningful with this information
59 // full outer join requires both sides anyway; we can skip the execution of the actual join, but eh
60 // mark/single join requires knowing if the rhs has null values or not
61 break;
62 }
63 break;
64 case FilterPropagateResult::FILTER_ALWAYS_TRUE:
65 // filter is always true
66 if (join.conditions.size() > 1) {
67 // there are multiple conditions: erase this condition
68 join.conditions.erase(position: join.conditions.begin() + i);
69 // remove the corresponding statistics
70 join.join_stats.clear();
71 i--;
72 continue;
73 } else {
74 // this is the only condition and it is always true: all conditions are true
75 switch (join.join_type) {
76 case JoinType::SEMI: {
77 // when the right child has data, return the left child
78 // when the right child has no data, return an empty set
79 auto limit = make_uniq<LogicalLimit>(args: 1, args: 0, args: nullptr, args: nullptr);
80 limit->AddChild(child: std::move(join.children[1]));
81 auto cross_product = LogicalCrossProduct::Create(left: std::move(join.children[0]), right: std::move(limit));
82 *node_ptr = std::move(cross_product);
83 return;
84 }
85 case JoinType::INNER: {
86 // inner, replace with cross product
87 auto cross_product =
88 LogicalCrossProduct::Create(left: std::move(join.children[0]), right: std::move(join.children[1]));
89 *node_ptr = std::move(cross_product);
90 return;
91 }
92 case JoinType::ANTI:
93 // anti join on true: empty result
94 ReplaceWithEmptyResult(node&: *node_ptr);
95 return;
96 default:
97 // we don't handle mark/single join here yet
98 break;
99 }
100 }
101 break;
102 default:
103 break;
104 }
105 }
106 // after we have propagated, we can update the statistics on both sides
107 // note that it is fine to do this now, even if the same column is used again later
108 // e.g. if we have i=j AND i=k, and the stats for j and k are disjoint, we know there are no results
109 // so if we have e.g. i: [0, 100], j: [0, 25], k: [75, 100]
110 // we can set i: [0, 25] after the first comparison, and statically determine that the second comparison is fals
111
112 // note that we can't update statistics the same for all join types
113 // mark and single joins don't filter any tuples -> so there is no propagation possible
114 // anti joins have inverse statistics propagation
115 // (i.e. if we have an anti join on i: [0, 100] and j: [0, 25], the resulting stats are i:[25,100])
116 // for now we don't handle anti joins
117 if (condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM ||
118 condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) {
119 // skip update when null values are equal (for now?)
120 continue;
121 }
122 switch (join.join_type) {
123 case JoinType::INNER:
124 case JoinType::SEMI: {
125 UpdateFilterStatistics(left&: *condition.left, right&: *condition.right, comparison_type: condition.comparison);
126 auto stats_left = PropagateExpression(expr&: condition.left);
127 auto stats_right = PropagateExpression(expr&: condition.right);
128 // Update join_stats when is already part of the join
129 if (join.join_stats.size() == 2) {
130 join.join_stats[0] = std::move(stats_left);
131 join.join_stats[1] = std::move(stats_right);
132 }
133 break;
134 }
135 default:
136 break;
137 }
138 }
139}
140
141void StatisticsPropagator::PropagateStatistics(LogicalAnyJoin &join, unique_ptr<LogicalOperator> *node_ptr) {
142 // propagate the expression into the join condition
143 PropagateExpression(expr&: join.condition);
144}
145
146void StatisticsPropagator::MultiplyCardinalities(unique_ptr<NodeStatistics> &stats, NodeStatistics &new_stats) {
147 if (!stats->has_estimated_cardinality || !new_stats.has_estimated_cardinality || !stats->has_max_cardinality ||
148 !new_stats.has_max_cardinality) {
149 stats = nullptr;
150 return;
151 }
152 stats->estimated_cardinality = MaxValue<idx_t>(a: stats->estimated_cardinality, b: new_stats.estimated_cardinality);
153 auto new_max = Hugeint::Multiply(lhs: stats->max_cardinality, rhs: new_stats.max_cardinality);
154 if (new_max < NumericLimits<int64_t>::Maximum()) {
155 int64_t result;
156 if (!Hugeint::TryCast<int64_t>(input: new_max, result)) {
157 throw InternalException("Overflow in cast in statistics propagation");
158 }
159 D_ASSERT(result >= 0);
160 stats->max_cardinality = idx_t(result);
161 } else {
162 stats = nullptr;
163 }
164}
165
166unique_ptr<NodeStatistics> StatisticsPropagator::PropagateStatistics(LogicalJoin &join,
167 unique_ptr<LogicalOperator> *node_ptr) {
168 // first propagate through the children of the join
169 node_stats = PropagateStatistics(node_ptr&: join.children[0]);
170 for (idx_t child_idx = 1; child_idx < join.children.size(); child_idx++) {
171 auto child_stats = PropagateStatistics(node_ptr&: join.children[child_idx]);
172 if (!child_stats) {
173 node_stats = nullptr;
174 } else if (node_stats) {
175 MultiplyCardinalities(stats&: node_stats, new_stats&: *child_stats);
176 }
177 }
178
179 auto join_type = join.join_type;
180 // depending on the join type, we might need to alter the statistics
181 // LEFT, FULL, RIGHT OUTER and SINGLE joins can introduce null values
182 // this requires us to alter the statistics after this point in the query plan
183 bool adds_null_on_left = IsRightOuterJoin(type: join_type);
184 bool adds_null_on_right = IsLeftOuterJoin(type: join_type) || join_type == JoinType::SINGLE;
185
186 vector<ColumnBinding> left_bindings, right_bindings;
187 if (adds_null_on_left) {
188 left_bindings = join.children[0]->GetColumnBindings();
189 }
190 if (adds_null_on_right) {
191 right_bindings = join.children[1]->GetColumnBindings();
192 }
193
194 // then propagate into the join conditions
195 switch (join.type) {
196 case LogicalOperatorType::LOGICAL_COMPARISON_JOIN:
197 case LogicalOperatorType::LOGICAL_DELIM_JOIN:
198 case LogicalOperatorType::LOGICAL_ASOF_JOIN:
199 PropagateStatistics(join&: join.Cast<LogicalComparisonJoin>(), node_ptr);
200 break;
201 case LogicalOperatorType::LOGICAL_ANY_JOIN:
202 PropagateStatistics(join&: join.Cast<LogicalAnyJoin>(), node_ptr);
203 break;
204 default:
205 break;
206 }
207
208 if (adds_null_on_right) {
209 // left or full outer join: set IsNull() to true for all rhs statistics
210 for (auto &binding : right_bindings) {
211 auto stats = statistics_map.find(x: binding);
212 if (stats != statistics_map.end()) {
213 stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES);
214 }
215 }
216 }
217 if (adds_null_on_left) {
218 // right or full outer join: set IsNull() to true for all lhs statistics
219 for (auto &binding : left_bindings) {
220 auto stats = statistics_map.find(x: binding);
221 if (stats != statistics_map.end()) {
222 stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES);
223 }
224 }
225 }
226 return std::move(node_stats);
227}
228
229static void MaxCardinalities(unique_ptr<NodeStatistics> &stats, NodeStatistics &new_stats) {
230 if (!stats->has_estimated_cardinality || !new_stats.has_estimated_cardinality || !stats->has_max_cardinality ||
231 !new_stats.has_max_cardinality) {
232 stats = nullptr;
233 return;
234 }
235 stats->estimated_cardinality = MaxValue<idx_t>(a: stats->estimated_cardinality, b: new_stats.estimated_cardinality);
236 stats->max_cardinality = MaxValue<idx_t>(a: stats->max_cardinality, b: new_stats.max_cardinality);
237}
238
239unique_ptr<NodeStatistics> StatisticsPropagator::PropagateStatistics(LogicalPositionalJoin &join,
240 unique_ptr<LogicalOperator> *node_ptr) {
241 D_ASSERT(join.type == LogicalOperatorType::LOGICAL_POSITIONAL_JOIN);
242
243 // first propagate through the children of the join
244 node_stats = PropagateStatistics(node_ptr&: join.children[0]);
245 for (idx_t child_idx = 1; child_idx < join.children.size(); child_idx++) {
246 auto child_stats = PropagateStatistics(node_ptr&: join.children[child_idx]);
247 if (!child_stats) {
248 node_stats = nullptr;
249 } else if (node_stats) {
250 if (!node_stats->has_estimated_cardinality || !child_stats->has_estimated_cardinality ||
251 !node_stats->has_max_cardinality || !child_stats->has_max_cardinality) {
252 node_stats = nullptr;
253 } else {
254 MaxCardinalities(stats&: node_stats, new_stats&: *child_stats);
255 }
256 }
257 }
258
259 // No conditions.
260
261 // Positional Joins are always FULL OUTER
262
263 // set IsNull() to true for all lhs statistics
264 auto left_bindings = join.children[0]->GetColumnBindings();
265 for (auto &binding : left_bindings) {
266 auto stats = statistics_map.find(x: binding);
267 if (stats != statistics_map.end()) {
268 stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES);
269 }
270 }
271
272 // set IsNull() to true for all rhs statistics
273 auto right_bindings = join.children[1]->GetColumnBindings();
274 for (auto &binding : right_bindings) {
275 auto stats = statistics_map.find(x: binding);
276 if (stats != statistics_map.end()) {
277 stats->second->Set(StatsInfo::CAN_HAVE_NULL_VALUES);
278 }
279 }
280
281 return std::move(node_stats);
282}
283
284} // namespace duckdb
285