1 | #include "duckdb/optimizer/statistics_propagator.hpp" |
2 | |
3 | #include "duckdb/main/client_context.hpp" |
4 | #include "duckdb/planner/expression_iterator.hpp" |
5 | #include "duckdb/planner/logical_operator.hpp" |
6 | #include "duckdb/planner/operator/logical_aggregate.hpp" |
7 | #include "duckdb/planner/operator/logical_empty_result.hpp" |
8 | #include "duckdb/planner/operator/logical_cross_product.hpp" |
9 | #include "duckdb/planner/operator/logical_filter.hpp" |
10 | #include "duckdb/planner/operator/logical_get.hpp" |
11 | #include "duckdb/planner/operator/logical_join.hpp" |
12 | #include "duckdb/planner/operator/logical_projection.hpp" |
13 | #include "duckdb/planner/operator/logical_positional_join.hpp" |
14 | #include "duckdb/planner/operator/logical_set_operation.hpp" |
15 | #include "duckdb/planner/operator/logical_order.hpp" |
16 | #include "duckdb/planner/operator/logical_window.hpp" |
17 | #include "duckdb/planner/expression/list.hpp" |
18 | |
19 | namespace duckdb { |
20 | |
21 | StatisticsPropagator::StatisticsPropagator(ClientContext &context) : context(context) { |
22 | } |
23 | |
24 | void StatisticsPropagator::ReplaceWithEmptyResult(unique_ptr<LogicalOperator> &node) { |
25 | node = make_uniq<LogicalEmptyResult>(args: std::move(node)); |
26 | } |
27 | |
28 | unique_ptr<NodeStatistics> StatisticsPropagator::PropagateChildren(LogicalOperator &node, |
29 | unique_ptr<LogicalOperator> *node_ptr) { |
30 | for (idx_t child_idx = 0; child_idx < node.children.size(); child_idx++) { |
31 | PropagateStatistics(node_ptr&: node.children[child_idx]); |
32 | } |
33 | return nullptr; |
34 | } |
35 | |
36 | unique_ptr<NodeStatistics> StatisticsPropagator::PropagateStatistics(LogicalOperator &node, |
37 | unique_ptr<LogicalOperator> *node_ptr) { |
38 | switch (node.type) { |
39 | case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: |
40 | return PropagateStatistics(op&: node.Cast<LogicalAggregate>(), node_ptr); |
41 | case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: |
42 | return PropagateStatistics(op&: node.Cast<LogicalCrossProduct>(), node_ptr); |
43 | case LogicalOperatorType::LOGICAL_FILTER: |
44 | return PropagateStatistics(op&: node.Cast<LogicalFilter>(), node_ptr); |
45 | case LogicalOperatorType::LOGICAL_GET: |
46 | return PropagateStatistics(op&: node.Cast<LogicalGet>(), node_ptr); |
47 | case LogicalOperatorType::LOGICAL_PROJECTION: |
48 | return PropagateStatistics(op&: node.Cast<LogicalProjection>(), node_ptr); |
49 | case LogicalOperatorType::LOGICAL_ANY_JOIN: |
50 | case LogicalOperatorType::LOGICAL_ASOF_JOIN: |
51 | case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: |
52 | case LogicalOperatorType::LOGICAL_JOIN: |
53 | case LogicalOperatorType::LOGICAL_DELIM_JOIN: |
54 | return PropagateStatistics(op&: node.Cast<LogicalJoin>(), node_ptr); |
55 | case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: |
56 | return PropagateStatistics(op&: node.Cast<LogicalPositionalJoin>(), node_ptr); |
57 | case LogicalOperatorType::LOGICAL_UNION: |
58 | case LogicalOperatorType::LOGICAL_EXCEPT: |
59 | case LogicalOperatorType::LOGICAL_INTERSECT: |
60 | return PropagateStatistics(op&: node.Cast<LogicalSetOperation>(), node_ptr); |
61 | case LogicalOperatorType::LOGICAL_ORDER_BY: |
62 | return PropagateStatistics(op&: node.Cast<LogicalOrder>(), node_ptr); |
63 | case LogicalOperatorType::LOGICAL_WINDOW: |
64 | return PropagateStatistics(op&: node.Cast<LogicalWindow>(), node_ptr); |
65 | default: |
66 | return PropagateChildren(node, node_ptr); |
67 | } |
68 | } |
69 | |
70 | unique_ptr<NodeStatistics> StatisticsPropagator::PropagateStatistics(unique_ptr<LogicalOperator> &node_ptr) { |
71 | return PropagateStatistics(node&: *node_ptr, node_ptr: &node_ptr); |
72 | } |
73 | |
74 | unique_ptr<BaseStatistics> StatisticsPropagator::PropagateExpression(Expression &expr, |
75 | unique_ptr<Expression> *expr_ptr) { |
76 | switch (expr.GetExpressionClass()) { |
77 | case ExpressionClass::BOUND_AGGREGATE: |
78 | return PropagateExpression(expr&: expr.Cast<BoundAggregateExpression>(), expr_ptr); |
79 | case ExpressionClass::BOUND_BETWEEN: |
80 | return PropagateExpression(expr&: expr.Cast<BoundBetweenExpression>(), expr_ptr); |
81 | case ExpressionClass::BOUND_CASE: |
82 | return PropagateExpression(expr&: expr.Cast<BoundCaseExpression>(), expr_ptr); |
83 | case ExpressionClass::BOUND_CONJUNCTION: |
84 | return PropagateExpression(expr&: expr.Cast<BoundConjunctionExpression>(), expr_ptr); |
85 | case ExpressionClass::BOUND_FUNCTION: |
86 | return PropagateExpression(expr&: expr.Cast<BoundFunctionExpression>(), expr_ptr); |
87 | case ExpressionClass::BOUND_CAST: |
88 | return PropagateExpression(expr&: expr.Cast<BoundCastExpression>(), expr_ptr); |
89 | case ExpressionClass::BOUND_COMPARISON: |
90 | return PropagateExpression(expr&: expr.Cast<BoundComparisonExpression>(), expr_ptr); |
91 | case ExpressionClass::BOUND_CONSTANT: |
92 | return PropagateExpression(expr&: expr.Cast<BoundConstantExpression>(), expr_ptr); |
93 | case ExpressionClass::BOUND_COLUMN_REF: |
94 | return PropagateExpression(expr&: expr.Cast<BoundColumnRefExpression>(), expr_ptr); |
95 | case ExpressionClass::BOUND_OPERATOR: |
96 | return PropagateExpression(expr&: expr.Cast<BoundOperatorExpression>(), expr_ptr); |
97 | default: |
98 | break; |
99 | } |
100 | ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](unique_ptr<Expression> &child) { PropagateExpression(expr&: child); }); |
101 | return nullptr; |
102 | } |
103 | |
104 | unique_ptr<BaseStatistics> StatisticsPropagator::PropagateExpression(unique_ptr<Expression> &expr) { |
105 | auto stats = PropagateExpression(expr&: *expr, expr_ptr: &expr); |
106 | if (ClientConfig::GetConfig(context).query_verification_enabled && stats) { |
107 | expr->verification_stats = stats->ToUnique(); |
108 | } |
109 | return stats; |
110 | } |
111 | |
112 | } // namespace duckdb |
113 | |