1#include "duckdb/optimizer/statistics_propagator.hpp"
2
3#include "duckdb/main/client_context.hpp"
4#include "duckdb/planner/expression_iterator.hpp"
5#include "duckdb/planner/logical_operator.hpp"
6#include "duckdb/planner/operator/logical_aggregate.hpp"
7#include "duckdb/planner/operator/logical_empty_result.hpp"
8#include "duckdb/planner/operator/logical_cross_product.hpp"
9#include "duckdb/planner/operator/logical_filter.hpp"
10#include "duckdb/planner/operator/logical_get.hpp"
11#include "duckdb/planner/operator/logical_join.hpp"
12#include "duckdb/planner/operator/logical_projection.hpp"
13#include "duckdb/planner/operator/logical_positional_join.hpp"
14#include "duckdb/planner/operator/logical_set_operation.hpp"
15#include "duckdb/planner/operator/logical_order.hpp"
16#include "duckdb/planner/operator/logical_window.hpp"
17#include "duckdb/planner/expression/list.hpp"
18
19namespace duckdb {
20
21StatisticsPropagator::StatisticsPropagator(ClientContext &context) : context(context) {
22}
23
24void StatisticsPropagator::ReplaceWithEmptyResult(unique_ptr<LogicalOperator> &node) {
25 node = make_uniq<LogicalEmptyResult>(args: std::move(node));
26}
27
28unique_ptr<NodeStatistics> StatisticsPropagator::PropagateChildren(LogicalOperator &node,
29 unique_ptr<LogicalOperator> *node_ptr) {
30 for (idx_t child_idx = 0; child_idx < node.children.size(); child_idx++) {
31 PropagateStatistics(node_ptr&: node.children[child_idx]);
32 }
33 return nullptr;
34}
35
36unique_ptr<NodeStatistics> StatisticsPropagator::PropagateStatistics(LogicalOperator &node,
37 unique_ptr<LogicalOperator> *node_ptr) {
38 switch (node.type) {
39 case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY:
40 return PropagateStatistics(op&: node.Cast<LogicalAggregate>(), node_ptr);
41 case LogicalOperatorType::LOGICAL_CROSS_PRODUCT:
42 return PropagateStatistics(op&: node.Cast<LogicalCrossProduct>(), node_ptr);
43 case LogicalOperatorType::LOGICAL_FILTER:
44 return PropagateStatistics(op&: node.Cast<LogicalFilter>(), node_ptr);
45 case LogicalOperatorType::LOGICAL_GET:
46 return PropagateStatistics(op&: node.Cast<LogicalGet>(), node_ptr);
47 case LogicalOperatorType::LOGICAL_PROJECTION:
48 return PropagateStatistics(op&: node.Cast<LogicalProjection>(), node_ptr);
49 case LogicalOperatorType::LOGICAL_ANY_JOIN:
50 case LogicalOperatorType::LOGICAL_ASOF_JOIN:
51 case LogicalOperatorType::LOGICAL_COMPARISON_JOIN:
52 case LogicalOperatorType::LOGICAL_JOIN:
53 case LogicalOperatorType::LOGICAL_DELIM_JOIN:
54 return PropagateStatistics(op&: node.Cast<LogicalJoin>(), node_ptr);
55 case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN:
56 return PropagateStatistics(op&: node.Cast<LogicalPositionalJoin>(), node_ptr);
57 case LogicalOperatorType::LOGICAL_UNION:
58 case LogicalOperatorType::LOGICAL_EXCEPT:
59 case LogicalOperatorType::LOGICAL_INTERSECT:
60 return PropagateStatistics(op&: node.Cast<LogicalSetOperation>(), node_ptr);
61 case LogicalOperatorType::LOGICAL_ORDER_BY:
62 return PropagateStatistics(op&: node.Cast<LogicalOrder>(), node_ptr);
63 case LogicalOperatorType::LOGICAL_WINDOW:
64 return PropagateStatistics(op&: node.Cast<LogicalWindow>(), node_ptr);
65 default:
66 return PropagateChildren(node, node_ptr);
67 }
68}
69
70unique_ptr<NodeStatistics> StatisticsPropagator::PropagateStatistics(unique_ptr<LogicalOperator> &node_ptr) {
71 return PropagateStatistics(node&: *node_ptr, node_ptr: &node_ptr);
72}
73
74unique_ptr<BaseStatistics> StatisticsPropagator::PropagateExpression(Expression &expr,
75 unique_ptr<Expression> *expr_ptr) {
76 switch (expr.GetExpressionClass()) {
77 case ExpressionClass::BOUND_AGGREGATE:
78 return PropagateExpression(expr&: expr.Cast<BoundAggregateExpression>(), expr_ptr);
79 case ExpressionClass::BOUND_BETWEEN:
80 return PropagateExpression(expr&: expr.Cast<BoundBetweenExpression>(), expr_ptr);
81 case ExpressionClass::BOUND_CASE:
82 return PropagateExpression(expr&: expr.Cast<BoundCaseExpression>(), expr_ptr);
83 case ExpressionClass::BOUND_CONJUNCTION:
84 return PropagateExpression(expr&: expr.Cast<BoundConjunctionExpression>(), expr_ptr);
85 case ExpressionClass::BOUND_FUNCTION:
86 return PropagateExpression(expr&: expr.Cast<BoundFunctionExpression>(), expr_ptr);
87 case ExpressionClass::BOUND_CAST:
88 return PropagateExpression(expr&: expr.Cast<BoundCastExpression>(), expr_ptr);
89 case ExpressionClass::BOUND_COMPARISON:
90 return PropagateExpression(expr&: expr.Cast<BoundComparisonExpression>(), expr_ptr);
91 case ExpressionClass::BOUND_CONSTANT:
92 return PropagateExpression(expr&: expr.Cast<BoundConstantExpression>(), expr_ptr);
93 case ExpressionClass::BOUND_COLUMN_REF:
94 return PropagateExpression(expr&: expr.Cast<BoundColumnRefExpression>(), expr_ptr);
95 case ExpressionClass::BOUND_OPERATOR:
96 return PropagateExpression(expr&: expr.Cast<BoundOperatorExpression>(), expr_ptr);
97 default:
98 break;
99 }
100 ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](unique_ptr<Expression> &child) { PropagateExpression(expr&: child); });
101 return nullptr;
102}
103
104unique_ptr<BaseStatistics> StatisticsPropagator::PropagateExpression(unique_ptr<Expression> &expr) {
105 auto stats = PropagateExpression(expr&: *expr, expr_ptr: &expr);
106 if (ClientConfig::GetConfig(context).query_verification_enabled && stats) {
107 expr->verification_stats = stats->ToUnique();
108 }
109 return stats;
110}
111
112} // namespace duckdb
113