| 1 | #include "duckdb/optimizer/cse_optimizer.hpp" |
| 2 | |
| 3 | #include "duckdb/planner/expression/bound_columnref_expression.hpp" |
| 4 | #include "duckdb/planner/expression_iterator.hpp" |
| 5 | #include "duckdb/planner/operator/logical_filter.hpp" |
| 6 | #include "duckdb/planner/operator/logical_projection.hpp" |
| 7 | #include "duckdb/planner/column_binding_map.hpp" |
| 8 | #include "duckdb/planner/binder.hpp" |
| 9 | |
| 10 | namespace duckdb { |
| 11 | |
| 12 | //! The CSENode contains information about a common subexpression; how many times it occurs, and the column index in the |
| 13 | //! underlying projection |
| 14 | struct CSENode { |
| 15 | idx_t count; |
| 16 | idx_t column_index; |
| 17 | |
| 18 | CSENode() : count(1), column_index(DConstants::INVALID_INDEX) { |
| 19 | } |
| 20 | }; |
| 21 | |
| 22 | //! The CSEReplacementState |
| 23 | struct CSEReplacementState { |
| 24 | //! The projection index of the new projection |
| 25 | idx_t projection_index; |
| 26 | //! Map of expression -> CSENode |
| 27 | expression_map_t<CSENode> expression_count; |
| 28 | //! Map of column bindings to column indexes in the projection expression list |
| 29 | column_binding_map_t<idx_t> column_map; |
| 30 | //! The set of expressions of the resulting projection |
| 31 | vector<unique_ptr<Expression>> expressions; |
| 32 | //! Cached expressions that are kept around so the expression_map always contains valid expressions |
| 33 | vector<unique_ptr<Expression>> cached_expressions; |
| 34 | }; |
| 35 | |
| 36 | void CommonSubExpressionOptimizer::VisitOperator(LogicalOperator &op) { |
| 37 | switch (op.type) { |
| 38 | case LogicalOperatorType::LOGICAL_PROJECTION: |
| 39 | case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: |
| 40 | ExtractCommonSubExpresions(op); |
| 41 | break; |
| 42 | default: |
| 43 | break; |
| 44 | } |
| 45 | LogicalOperatorVisitor::VisitOperator(op); |
| 46 | } |
| 47 | |
| 48 | void CommonSubExpressionOptimizer::CountExpressions(Expression &expr, CSEReplacementState &state) { |
| 49 | // we only consider expressions with children for CSE elimination |
| 50 | switch (expr.expression_class) { |
| 51 | case ExpressionClass::BOUND_COLUMN_REF: |
| 52 | case ExpressionClass::BOUND_CONSTANT: |
| 53 | case ExpressionClass::BOUND_PARAMETER: |
| 54 | // skip conjunctions and case, since short-circuiting might be incorrectly disabled otherwise |
| 55 | case ExpressionClass::BOUND_CONJUNCTION: |
| 56 | case ExpressionClass::BOUND_CASE: |
| 57 | return; |
| 58 | default: |
| 59 | break; |
| 60 | } |
| 61 | if (expr.expression_class != ExpressionClass::BOUND_AGGREGATE && !expr.HasSideEffects()) { |
| 62 | // we can't move aggregates to a projection, so we only consider the children of the aggregate |
| 63 | auto node = state.expression_count.find(x: expr); |
| 64 | if (node == state.expression_count.end()) { |
| 65 | // first time we encounter this expression, insert this node with [count = 1] |
| 66 | state.expression_count[expr] = CSENode(); |
| 67 | } else { |
| 68 | // we encountered this expression before, increment the occurrence count |
| 69 | node->second.count++; |
| 70 | } |
| 71 | } |
| 72 | // recursively count the children |
| 73 | ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](Expression &child) { CountExpressions(expr&: child, state); }); |
| 74 | } |
| 75 | |
| 76 | void CommonSubExpressionOptimizer::PerformCSEReplacement(unique_ptr<Expression> &expr_ptr, CSEReplacementState &state) { |
| 77 | Expression &expr = *expr_ptr; |
| 78 | if (expr.expression_class == ExpressionClass::BOUND_COLUMN_REF) { |
| 79 | auto &bound_column_ref = expr.Cast<BoundColumnRefExpression>(); |
| 80 | // bound column ref, check if this one has already been recorded in the expression list |
| 81 | auto column_entry = state.column_map.find(x: bound_column_ref.binding); |
| 82 | if (column_entry == state.column_map.end()) { |
| 83 | // not there yet: push the expression |
| 84 | idx_t new_column_index = state.expressions.size(); |
| 85 | state.column_map[bound_column_ref.binding] = new_column_index; |
| 86 | state.expressions.push_back(x: make_uniq<BoundColumnRefExpression>( |
| 87 | args&: bound_column_ref.alias, args&: bound_column_ref.return_type, args&: bound_column_ref.binding)); |
| 88 | bound_column_ref.binding = ColumnBinding(state.projection_index, new_column_index); |
| 89 | } else { |
| 90 | // else: just update the column binding! |
| 91 | bound_column_ref.binding = ColumnBinding(state.projection_index, column_entry->second); |
| 92 | } |
| 93 | return; |
| 94 | } |
| 95 | // check if this child is eligible for CSE elimination |
| 96 | bool can_cse = expr.expression_class != ExpressionClass::BOUND_CONJUNCTION && |
| 97 | expr.expression_class != ExpressionClass::BOUND_CASE; |
| 98 | if (can_cse && state.expression_count.find(x: expr) != state.expression_count.end()) { |
| 99 | auto &node = state.expression_count[expr]; |
| 100 | if (node.count > 1) { |
| 101 | // this expression occurs more than once! push it into the projection |
| 102 | // check if it has already been pushed into the projection |
| 103 | auto alias = expr.alias; |
| 104 | auto type = expr.return_type; |
| 105 | if (node.column_index == DConstants::INVALID_INDEX) { |
| 106 | // has not been pushed yet: push it |
| 107 | node.column_index = state.expressions.size(); |
| 108 | state.expressions.push_back(x: std::move(expr_ptr)); |
| 109 | } else { |
| 110 | state.cached_expressions.push_back(x: std::move(expr_ptr)); |
| 111 | } |
| 112 | // replace the original expression with a bound column ref |
| 113 | expr_ptr = make_uniq<BoundColumnRefExpression>(args&: alias, args&: type, |
| 114 | args: ColumnBinding(state.projection_index, node.column_index)); |
| 115 | return; |
| 116 | } |
| 117 | } |
| 118 | // this expression only occurs once, we can't perform CSE elimination |
| 119 | // look into the children to see if we can replace them |
| 120 | ExpressionIterator::EnumerateChildren(expression&: expr, |
| 121 | callback: [&](unique_ptr<Expression> &child) { PerformCSEReplacement(expr_ptr&: child, state); }); |
| 122 | } |
| 123 | |
| 124 | void CommonSubExpressionOptimizer::(LogicalOperator &op) { |
| 125 | D_ASSERT(op.children.size() == 1); |
| 126 | |
| 127 | // first we count for each expression with children how many types it occurs |
| 128 | CSEReplacementState state; |
| 129 | LogicalOperatorVisitor::EnumerateExpressions( |
| 130 | op, callback: [&](unique_ptr<Expression> *child) { CountExpressions(expr&: **child, state); }); |
| 131 | // check if there are any expressions to extract |
| 132 | bool perform_replacement = false; |
| 133 | for (auto &expr : state.expression_count) { |
| 134 | if (expr.second.count > 1) { |
| 135 | perform_replacement = true; |
| 136 | break; |
| 137 | } |
| 138 | } |
| 139 | if (!perform_replacement) { |
| 140 | // no CSEs to extract |
| 141 | return; |
| 142 | } |
| 143 | state.projection_index = binder.GenerateTableIndex(); |
| 144 | // we found common subexpressions to extract |
| 145 | // now we iterate over all the expressions and perform the actual CSE elimination |
| 146 | |
| 147 | LogicalOperatorVisitor::EnumerateExpressions( |
| 148 | op, callback: [&](unique_ptr<Expression> *child) { PerformCSEReplacement(expr_ptr&: *child, state); }); |
| 149 | D_ASSERT(state.expressions.size() > 0); |
| 150 | // create a projection node as the child of this node |
| 151 | auto projection = make_uniq<LogicalProjection>(args&: state.projection_index, args: std::move(state.expressions)); |
| 152 | projection->children.push_back(x: std::move(op.children[0])); |
| 153 | op.children[0] = std::move(projection); |
| 154 | } |
| 155 | |
| 156 | } // namespace duckdb |
| 157 | |