1 | #include "duckdb/optimizer/expression_heuristics.hpp" |
2 | #include "duckdb/planner/expression/list.hpp" |
3 | |
4 | namespace duckdb { |
5 | |
6 | unique_ptr<LogicalOperator> ExpressionHeuristics::Rewrite(unique_ptr<LogicalOperator> op) { |
7 | VisitOperator(op&: *op); |
8 | return op; |
9 | } |
10 | |
11 | void ExpressionHeuristics::VisitOperator(LogicalOperator &op) { |
12 | if (op.type == LogicalOperatorType::LOGICAL_FILTER) { |
13 | // reorder all filter expressions |
14 | if (op.expressions.size() > 1) { |
15 | ReorderExpressions(expressions&: op.expressions); |
16 | } |
17 | } |
18 | |
19 | // traverse recursively through the operator tree |
20 | VisitOperatorChildren(op); |
21 | VisitOperatorExpressions(op); |
22 | } |
23 | |
24 | unique_ptr<Expression> ExpressionHeuristics::VisitReplace(BoundConjunctionExpression &expr, |
25 | unique_ptr<Expression> *expr_ptr) { |
26 | ReorderExpressions(expressions&: expr.children); |
27 | return nullptr; |
28 | } |
29 | |
30 | void ExpressionHeuristics::ReorderExpressions(vector<unique_ptr<Expression>> &expressions) { |
31 | |
32 | struct ExpressionCosts { |
33 | unique_ptr<Expression> expr; |
34 | idx_t cost; |
35 | |
36 | bool operator==(const ExpressionCosts &p) const { |
37 | return cost == p.cost; |
38 | } |
39 | bool operator<(const ExpressionCosts &p) const { |
40 | return cost < p.cost; |
41 | } |
42 | }; |
43 | |
44 | vector<ExpressionCosts> expression_costs; |
45 | expression_costs.reserve(n: expressions.size()); |
46 | // iterate expressions, get cost for each one |
47 | for (idx_t i = 0; i < expressions.size(); i++) { |
48 | idx_t cost = Cost(expr&: *expressions[i]); |
49 | expression_costs.push_back(x: {.expr: std::move(expressions[i]), .cost: cost}); |
50 | } |
51 | |
52 | // sort by cost and put back in place |
53 | sort(first: expression_costs.begin(), last: expression_costs.end()); |
54 | for (idx_t i = 0; i < expression_costs.size(); i++) { |
55 | expressions[i] = std::move(expression_costs[i].expr); |
56 | } |
57 | } |
58 | |
59 | idx_t ExpressionHeuristics::ExpressionCost(BoundBetweenExpression &expr) { |
60 | return Cost(expr&: *expr.input) + Cost(expr&: *expr.lower) + Cost(expr&: *expr.upper) + 10; |
61 | } |
62 | |
63 | idx_t ExpressionHeuristics::ExpressionCost(BoundCaseExpression &expr) { |
64 | // CASE WHEN check THEN result_if_true ELSE result_if_false END |
65 | idx_t case_cost = 0; |
66 | for (auto &case_check : expr.case_checks) { |
67 | case_cost += Cost(expr&: *case_check.then_expr); |
68 | case_cost += Cost(expr&: *case_check.when_expr); |
69 | } |
70 | case_cost += Cost(expr&: *expr.else_expr); |
71 | return case_cost; |
72 | } |
73 | |
74 | idx_t ExpressionHeuristics::ExpressionCost(BoundCastExpression &expr) { |
75 | // OPERATOR_CAST |
76 | // determine cast cost by comparing cast_expr.source_type and cast_expr_target_type |
77 | idx_t cast_cost = 0; |
78 | if (expr.return_type != expr.source_type()) { |
79 | // if cast from or to varchar |
80 | // TODO: we might want to add more cases |
81 | if (expr.return_type.id() == LogicalTypeId::VARCHAR || expr.source_type().id() == LogicalTypeId::VARCHAR || |
82 | expr.return_type.id() == LogicalTypeId::BLOB || expr.source_type().id() == LogicalTypeId::BLOB) { |
83 | cast_cost = 200; |
84 | } else { |
85 | cast_cost = 5; |
86 | } |
87 | } |
88 | return Cost(expr&: *expr.child) + cast_cost; |
89 | } |
90 | |
91 | idx_t ExpressionHeuristics::ExpressionCost(BoundComparisonExpression &expr) { |
92 | // COMPARE_EQUAL, COMPARE_NOTEQUAL, COMPARE_GREATERTHAN, COMPARE_GREATERTHANOREQUALTO, COMPARE_LESSTHAN, |
93 | // COMPARE_LESSTHANOREQUALTO |
94 | return Cost(expr&: *expr.left) + 5 + Cost(expr&: *expr.right); |
95 | } |
96 | |
97 | idx_t ExpressionHeuristics::ExpressionCost(BoundConjunctionExpression &expr) { |
98 | // CONJUNCTION_AND, CONJUNCTION_OR |
99 | idx_t cost = 5; |
100 | for (auto &child : expr.children) { |
101 | cost += Cost(expr&: *child); |
102 | } |
103 | return cost; |
104 | } |
105 | |
106 | idx_t ExpressionHeuristics::ExpressionCost(BoundFunctionExpression &expr) { |
107 | idx_t cost_children = 0; |
108 | for (auto &child : expr.children) { |
109 | cost_children += Cost(expr&: *child); |
110 | } |
111 | |
112 | auto cost_function = function_costs.find(x: expr.function.name); |
113 | if (cost_function != function_costs.end()) { |
114 | return cost_children + cost_function->second; |
115 | } else { |
116 | return cost_children + 1000; |
117 | } |
118 | } |
119 | |
120 | idx_t ExpressionHeuristics::ExpressionCost(BoundOperatorExpression &expr, ExpressionType &expr_type) { |
121 | idx_t sum = 0; |
122 | for (auto &child : expr.children) { |
123 | sum += Cost(expr&: *child); |
124 | } |
125 | |
126 | // OPERATOR_IS_NULL, OPERATOR_IS_NOT_NULL |
127 | if (expr_type == ExpressionType::OPERATOR_IS_NULL || expr_type == ExpressionType::OPERATOR_IS_NOT_NULL) { |
128 | return sum + 5; |
129 | } else if (expr_type == ExpressionType::COMPARE_IN || expr_type == ExpressionType::COMPARE_NOT_IN) { |
130 | // COMPARE_IN, COMPARE_NOT_IN |
131 | return sum + (expr.children.size() - 1) * 100; |
132 | } else if (expr_type == ExpressionType::OPERATOR_NOT) { |
133 | // OPERATOR_NOT |
134 | return sum + 10; // TODO: evaluate via measured runtimes |
135 | } else { |
136 | return sum + 1000; |
137 | } |
138 | } |
139 | |
140 | idx_t ExpressionHeuristics::ExpressionCost(PhysicalType return_type, idx_t multiplier) { |
141 | // TODO: ajust values according to benchmark results |
142 | switch (return_type) { |
143 | case PhysicalType::VARCHAR: |
144 | return 5 * multiplier; |
145 | case PhysicalType::FLOAT: |
146 | case PhysicalType::DOUBLE: |
147 | return 2 * multiplier; |
148 | default: |
149 | return 1 * multiplier; |
150 | } |
151 | } |
152 | |
153 | idx_t ExpressionHeuristics::Cost(Expression &expr) { |
154 | switch (expr.expression_class) { |
155 | case ExpressionClass::BOUND_CASE: { |
156 | auto &case_expr = expr.Cast<BoundCaseExpression>(); |
157 | return ExpressionCost(expr&: case_expr); |
158 | } |
159 | case ExpressionClass::BOUND_BETWEEN: { |
160 | auto &between_expr = expr.Cast<BoundBetweenExpression>(); |
161 | return ExpressionCost(expr&: between_expr); |
162 | } |
163 | case ExpressionClass::BOUND_CAST: { |
164 | auto &cast_expr = expr.Cast<BoundCastExpression>(); |
165 | return ExpressionCost(expr&: cast_expr); |
166 | } |
167 | case ExpressionClass::BOUND_COMPARISON: { |
168 | auto &comp_expr = expr.Cast<BoundComparisonExpression>(); |
169 | return ExpressionCost(expr&: comp_expr); |
170 | } |
171 | case ExpressionClass::BOUND_CONJUNCTION: { |
172 | auto &conj_expr = expr.Cast<BoundConjunctionExpression>(); |
173 | return ExpressionCost(expr&: conj_expr); |
174 | } |
175 | case ExpressionClass::BOUND_FUNCTION: { |
176 | auto &func_expr = expr.Cast<BoundFunctionExpression>(); |
177 | return ExpressionCost(expr&: func_expr); |
178 | } |
179 | case ExpressionClass::BOUND_OPERATOR: { |
180 | auto &op_expr = expr.Cast<BoundOperatorExpression>(); |
181 | return ExpressionCost(expr&: op_expr, expr_type&: expr.type); |
182 | } |
183 | case ExpressionClass::BOUND_COLUMN_REF: { |
184 | auto &col_expr = expr.Cast<BoundColumnRefExpression>(); |
185 | return ExpressionCost(return_type: col_expr.return_type.InternalType(), multiplier: 8); |
186 | } |
187 | case ExpressionClass::BOUND_CONSTANT: { |
188 | auto &const_expr = expr.Cast<BoundConstantExpression>(); |
189 | return ExpressionCost(return_type: const_expr.return_type.InternalType(), multiplier: 1); |
190 | } |
191 | case ExpressionClass::BOUND_PARAMETER: { |
192 | auto &const_expr = expr.Cast<BoundParameterExpression>(); |
193 | return ExpressionCost(return_type: const_expr.return_type.InternalType(), multiplier: 1); |
194 | } |
195 | case ExpressionClass::BOUND_REF: { |
196 | auto &col_expr = expr.Cast<BoundColumnRefExpression>(); |
197 | return ExpressionCost(return_type: col_expr.return_type.InternalType(), multiplier: 8); |
198 | } |
199 | default: { |
200 | break; |
201 | } |
202 | } |
203 | |
204 | // return a very high value if nothing matches |
205 | return 1000; |
206 | } |
207 | |
208 | } // namespace duckdb |
209 | |