1#include "duckdb/optimizer/cse_optimizer.hpp"
2
3#include "duckdb/planner/expression/bound_columnref_expression.hpp"
4#include "duckdb/planner/expression_iterator.hpp"
5#include "duckdb/planner/operator/logical_filter.hpp"
6#include "duckdb/planner/operator/logical_projection.hpp"
7#include "duckdb/planner/column_binding_map.hpp"
8#include "duckdb/planner/binder.hpp"
9
10namespace duckdb {
11
12//! The CSENode contains information about a common subexpression; how many times it occurs, and the column index in the
13//! underlying projection
14struct CSENode {
15 idx_t count;
16 idx_t column_index;
17
18 CSENode() : count(1), column_index(DConstants::INVALID_INDEX) {
19 }
20};
21
22//! The CSEReplacementState
23struct CSEReplacementState {
24 //! The projection index of the new projection
25 idx_t projection_index;
26 //! Map of expression -> CSENode
27 expression_map_t<CSENode> expression_count;
28 //! Map of column bindings to column indexes in the projection expression list
29 column_binding_map_t<idx_t> column_map;
30 //! The set of expressions of the resulting projection
31 vector<unique_ptr<Expression>> expressions;
32 //! Cached expressions that are kept around so the expression_map always contains valid expressions
33 vector<unique_ptr<Expression>> cached_expressions;
34};
35
36void CommonSubExpressionOptimizer::VisitOperator(LogicalOperator &op) {
37 switch (op.type) {
38 case LogicalOperatorType::LOGICAL_PROJECTION:
39 case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY:
40 ExtractCommonSubExpresions(op);
41 break;
42 default:
43 break;
44 }
45 LogicalOperatorVisitor::VisitOperator(op);
46}
47
48void CommonSubExpressionOptimizer::CountExpressions(Expression &expr, CSEReplacementState &state) {
49 // we only consider expressions with children for CSE elimination
50 switch (expr.expression_class) {
51 case ExpressionClass::BOUND_COLUMN_REF:
52 case ExpressionClass::BOUND_CONSTANT:
53 case ExpressionClass::BOUND_PARAMETER:
54 // skip conjunctions and case, since short-circuiting might be incorrectly disabled otherwise
55 case ExpressionClass::BOUND_CONJUNCTION:
56 case ExpressionClass::BOUND_CASE:
57 return;
58 default:
59 break;
60 }
61 if (expr.expression_class != ExpressionClass::BOUND_AGGREGATE && !expr.HasSideEffects()) {
62 // we can't move aggregates to a projection, so we only consider the children of the aggregate
63 auto node = state.expression_count.find(x: expr);
64 if (node == state.expression_count.end()) {
65 // first time we encounter this expression, insert this node with [count = 1]
66 state.expression_count[expr] = CSENode();
67 } else {
68 // we encountered this expression before, increment the occurrence count
69 node->second.count++;
70 }
71 }
72 // recursively count the children
73 ExpressionIterator::EnumerateChildren(expression&: expr, callback: [&](Expression &child) { CountExpressions(expr&: child, state); });
74}
75
76void CommonSubExpressionOptimizer::PerformCSEReplacement(unique_ptr<Expression> &expr_ptr, CSEReplacementState &state) {
77 Expression &expr = *expr_ptr;
78 if (expr.expression_class == ExpressionClass::BOUND_COLUMN_REF) {
79 auto &bound_column_ref = expr.Cast<BoundColumnRefExpression>();
80 // bound column ref, check if this one has already been recorded in the expression list
81 auto column_entry = state.column_map.find(x: bound_column_ref.binding);
82 if (column_entry == state.column_map.end()) {
83 // not there yet: push the expression
84 idx_t new_column_index = state.expressions.size();
85 state.column_map[bound_column_ref.binding] = new_column_index;
86 state.expressions.push_back(x: make_uniq<BoundColumnRefExpression>(
87 args&: bound_column_ref.alias, args&: bound_column_ref.return_type, args&: bound_column_ref.binding));
88 bound_column_ref.binding = ColumnBinding(state.projection_index, new_column_index);
89 } else {
90 // else: just update the column binding!
91 bound_column_ref.binding = ColumnBinding(state.projection_index, column_entry->second);
92 }
93 return;
94 }
95 // check if this child is eligible for CSE elimination
96 bool can_cse = expr.expression_class != ExpressionClass::BOUND_CONJUNCTION &&
97 expr.expression_class != ExpressionClass::BOUND_CASE;
98 if (can_cse && state.expression_count.find(x: expr) != state.expression_count.end()) {
99 auto &node = state.expression_count[expr];
100 if (node.count > 1) {
101 // this expression occurs more than once! push it into the projection
102 // check if it has already been pushed into the projection
103 auto alias = expr.alias;
104 auto type = expr.return_type;
105 if (node.column_index == DConstants::INVALID_INDEX) {
106 // has not been pushed yet: push it
107 node.column_index = state.expressions.size();
108 state.expressions.push_back(x: std::move(expr_ptr));
109 } else {
110 state.cached_expressions.push_back(x: std::move(expr_ptr));
111 }
112 // replace the original expression with a bound column ref
113 expr_ptr = make_uniq<BoundColumnRefExpression>(args&: alias, args&: type,
114 args: ColumnBinding(state.projection_index, node.column_index));
115 return;
116 }
117 }
118 // this expression only occurs once, we can't perform CSE elimination
119 // look into the children to see if we can replace them
120 ExpressionIterator::EnumerateChildren(expression&: expr,
121 callback: [&](unique_ptr<Expression> &child) { PerformCSEReplacement(expr_ptr&: child, state); });
122}
123
124void CommonSubExpressionOptimizer::ExtractCommonSubExpresions(LogicalOperator &op) {
125 D_ASSERT(op.children.size() == 1);
126
127 // first we count for each expression with children how many types it occurs
128 CSEReplacementState state;
129 LogicalOperatorVisitor::EnumerateExpressions(
130 op, callback: [&](unique_ptr<Expression> *child) { CountExpressions(expr&: **child, state); });
131 // check if there are any expressions to extract
132 bool perform_replacement = false;
133 for (auto &expr : state.expression_count) {
134 if (expr.second.count > 1) {
135 perform_replacement = true;
136 break;
137 }
138 }
139 if (!perform_replacement) {
140 // no CSEs to extract
141 return;
142 }
143 state.projection_index = binder.GenerateTableIndex();
144 // we found common subexpressions to extract
145 // now we iterate over all the expressions and perform the actual CSE elimination
146
147 LogicalOperatorVisitor::EnumerateExpressions(
148 op, callback: [&](unique_ptr<Expression> *child) { PerformCSEReplacement(expr_ptr&: *child, state); });
149 D_ASSERT(state.expressions.size() > 0);
150 // create a projection node as the child of this node
151 auto projection = make_uniq<LogicalProjection>(args&: state.projection_index, args: std::move(state.expressions));
152 projection->children.push_back(x: std::move(op.children[0]));
153 op.children[0] = std::move(projection);
154}
155
156} // namespace duckdb
157