1 | #include "duckdb/optimizer/cse_optimizer.hpp" |
2 | |
3 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
4 | #include "duckdb/planner/expression_iterator.hpp" |
5 | #include "duckdb/planner/operator/logical_filter.hpp" |
6 | #include "duckdb/planner/operator/logical_projection.hpp" |
7 | #include "duckdb/planner/column_binding_map.hpp" |
8 | #include "duckdb/planner/binder.hpp" |
9 | |
10 | namespace duckdb { |
11 | |
12 | //! The CSENode contains information about a common subexpression; how many times it occurs, and the column index in the |
13 | //! underlying projection |
14 | struct CSENode { |
15 | idx_t count; |
16 | idx_t column_index; |
17 | |
18 | CSENode() : count(1), column_index(DConstants::INVALID_INDEX) { |
19 | } |
20 | }; |
21 | |
22 | //! The CSEReplacementState |
23 | struct CSEReplacementState { |
24 | //! The projection index of the new projection |
25 | idx_t projection_index; |
26 | //! Map of expression -> CSENode |
27 | expression_map_t<CSENode> expression_count; |
28 | //! Map of column bindings to column indexes in the projection expression list |
29 | column_binding_map_t<idx_t> column_map; |
30 | //! The set of expressions of the resulting projection |
31 | vector<unique_ptr<Expression>> expressions; |
32 | //! Cached expressions that are kept around so the expression_map always contains valid expressions |
33 | vector<unique_ptr<Expression>> cached_expressions; |
34 | }; |
35 | |
36 | void CommonSubExpressionOptimizer::VisitOperator(LogicalOperator &op) { |
37 | switch (op.type) { |
38 | case LogicalOperatorType::LOGICAL_PROJECTION: |
39 | case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: |
40 | ExtractCommonSubExpresions(op); |
41 | break; |
42 | default: |
43 | break; |
44 | } |
45 | LogicalOperatorVisitor::VisitOperator(op); |
46 | } |
47 | |
48 | void CommonSubExpressionOptimizer::CountExpressions(Expression &expr, CSEReplacementState &state) { |
49 | // we only consider expressions with children for CSE elimination |
50 | switch (expr.expression_class) { |
51 | case ExpressionClass::BOUND_COLUMN_REF: |
52 | case ExpressionClass::BOUND_CONSTANT: |
53 | case ExpressionClass::BOUND_PARAMETER: |
54 | // skip conjunctions and case, since short-circuiting might be incorrectly disabled otherwise |
55 | case ExpressionClass::BOUND_CONJUNCTION: |
56 | case ExpressionClass::BOUND_CASE: |
57 | return; |
58 | default: |
59 | break; |
60 | } |
61 | if (expr.expression_class != ExpressionClass::BOUND_AGGREGATE && !expr.HasSideEffects()) { |
62 | // we can't move aggregates to a projection, so we only consider the children of the aggregate |
63 | auto node = state.expression_count.find(x: expr); |
64 | if (node == state.expression_count.end()) { |
65 | // first time we encounter this expression, insert this node with [count = 1] |
66 | state.expression_count[expr] = CSENode(); |
67 | } else { |
68 | // we encountered this expression before, increment the occurrence count |
69 | node->second.count++; |
70 | } |
71 | } |
72 | // recursively count the children |
73 | ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](Expression &child) { CountExpressions(expr&: child, state); }); |
74 | } |
75 | |
76 | void CommonSubExpressionOptimizer::PerformCSEReplacement(unique_ptr<Expression> &expr_ptr, CSEReplacementState &state) { |
77 | Expression &expr = *expr_ptr; |
78 | if (expr.expression_class == ExpressionClass::BOUND_COLUMN_REF) { |
79 | auto &bound_column_ref = expr.Cast<BoundColumnRefExpression>(); |
80 | // bound column ref, check if this one has already been recorded in the expression list |
81 | auto column_entry = state.column_map.find(x: bound_column_ref.binding); |
82 | if (column_entry == state.column_map.end()) { |
83 | // not there yet: push the expression |
84 | idx_t new_column_index = state.expressions.size(); |
85 | state.column_map[bound_column_ref.binding] = new_column_index; |
86 | state.expressions.push_back(x: make_uniq<BoundColumnRefExpression>( |
87 | args&: bound_column_ref.alias, args&: bound_column_ref.return_type, args&: bound_column_ref.binding)); |
88 | bound_column_ref.binding = ColumnBinding(state.projection_index, new_column_index); |
89 | } else { |
90 | // else: just update the column binding! |
91 | bound_column_ref.binding = ColumnBinding(state.projection_index, column_entry->second); |
92 | } |
93 | return; |
94 | } |
95 | // check if this child is eligible for CSE elimination |
96 | bool can_cse = expr.expression_class != ExpressionClass::BOUND_CONJUNCTION && |
97 | expr.expression_class != ExpressionClass::BOUND_CASE; |
98 | if (can_cse && state.expression_count.find(x: expr) != state.expression_count.end()) { |
99 | auto &node = state.expression_count[expr]; |
100 | if (node.count > 1) { |
101 | // this expression occurs more than once! push it into the projection |
102 | // check if it has already been pushed into the projection |
103 | auto alias = expr.alias; |
104 | auto type = expr.return_type; |
105 | if (node.column_index == DConstants::INVALID_INDEX) { |
106 | // has not been pushed yet: push it |
107 | node.column_index = state.expressions.size(); |
108 | state.expressions.push_back(x: std::move(expr_ptr)); |
109 | } else { |
110 | state.cached_expressions.push_back(x: std::move(expr_ptr)); |
111 | } |
112 | // replace the original expression with a bound column ref |
113 | expr_ptr = make_uniq<BoundColumnRefExpression>(args&: alias, args&: type, |
114 | args: ColumnBinding(state.projection_index, node.column_index)); |
115 | return; |
116 | } |
117 | } |
118 | // this expression only occurs once, we can't perform CSE elimination |
119 | // look into the children to see if we can replace them |
120 | ExpressionIterator::EnumerateChildren(expression&: expr, |
121 | callback: [&](unique_ptr<Expression> &child) { PerformCSEReplacement(expr_ptr&: child, state); }); |
122 | } |
123 | |
124 | void CommonSubExpressionOptimizer::(LogicalOperator &op) { |
125 | D_ASSERT(op.children.size() == 1); |
126 | |
127 | // first we count for each expression with children how many types it occurs |
128 | CSEReplacementState state; |
129 | LogicalOperatorVisitor::EnumerateExpressions( |
130 | op, callback: [&](unique_ptr<Expression> *child) { CountExpressions(expr&: **child, state); }); |
131 | // check if there are any expressions to extract |
132 | bool perform_replacement = false; |
133 | for (auto &expr : state.expression_count) { |
134 | if (expr.second.count > 1) { |
135 | perform_replacement = true; |
136 | break; |
137 | } |
138 | } |
139 | if (!perform_replacement) { |
140 | // no CSEs to extract |
141 | return; |
142 | } |
143 | state.projection_index = binder.GenerateTableIndex(); |
144 | // we found common subexpressions to extract |
145 | // now we iterate over all the expressions and perform the actual CSE elimination |
146 | |
147 | LogicalOperatorVisitor::EnumerateExpressions( |
148 | op, callback: [&](unique_ptr<Expression> *child) { PerformCSEReplacement(expr_ptr&: *child, state); }); |
149 | D_ASSERT(state.expressions.size() > 0); |
150 | // create a projection node as the child of this node |
151 | auto projection = make_uniq<LogicalProjection>(args&: state.projection_index, args: std::move(state.expressions)); |
152 | projection->children.push_back(x: std::move(op.children[0])); |
153 | op.children[0] = std::move(projection); |
154 | } |
155 | |
156 | } // namespace duckdb |
157 | |