1#include "duckdb/optimizer/expression_heuristics.hpp"
2#include "duckdb/planner/expression/list.hpp"
3
4namespace duckdb {
5
6unique_ptr<LogicalOperator> ExpressionHeuristics::Rewrite(unique_ptr<LogicalOperator> op) {
7 VisitOperator(op&: *op);
8 return op;
9}
10
11void ExpressionHeuristics::VisitOperator(LogicalOperator &op) {
12 if (op.type == LogicalOperatorType::LOGICAL_FILTER) {
13 // reorder all filter expressions
14 if (op.expressions.size() > 1) {
15 ReorderExpressions(expressions&: op.expressions);
16 }
17 }
18
19 // traverse recursively through the operator tree
20 VisitOperatorChildren(op);
21 VisitOperatorExpressions(op);
22}
23
24unique_ptr<Expression> ExpressionHeuristics::VisitReplace(BoundConjunctionExpression &expr,
25 unique_ptr<Expression> *expr_ptr) {
26 ReorderExpressions(expressions&: expr.children);
27 return nullptr;
28}
29
30void ExpressionHeuristics::ReorderExpressions(vector<unique_ptr<Expression>> &expressions) {
31
32 struct ExpressionCosts {
33 unique_ptr<Expression> expr;
34 idx_t cost;
35
36 bool operator==(const ExpressionCosts &p) const {
37 return cost == p.cost;
38 }
39 bool operator<(const ExpressionCosts &p) const {
40 return cost < p.cost;
41 }
42 };
43
44 vector<ExpressionCosts> expression_costs;
45 expression_costs.reserve(n: expressions.size());
46 // iterate expressions, get cost for each one
47 for (idx_t i = 0; i < expressions.size(); i++) {
48 idx_t cost = Cost(expr&: *expressions[i]);
49 expression_costs.push_back(x: {.expr: std::move(expressions[i]), .cost: cost});
50 }
51
52 // sort by cost and put back in place
53 sort(first: expression_costs.begin(), last: expression_costs.end());
54 for (idx_t i = 0; i < expression_costs.size(); i++) {
55 expressions[i] = std::move(expression_costs[i].expr);
56 }
57}
58
59idx_t ExpressionHeuristics::ExpressionCost(BoundBetweenExpression &expr) {
60 return Cost(expr&: *expr.input) + Cost(expr&: *expr.lower) + Cost(expr&: *expr.upper) + 10;
61}
62
63idx_t ExpressionHeuristics::ExpressionCost(BoundCaseExpression &expr) {
64 // CASE WHEN check THEN result_if_true ELSE result_if_false END
65 idx_t case_cost = 0;
66 for (auto &case_check : expr.case_checks) {
67 case_cost += Cost(expr&: *case_check.then_expr);
68 case_cost += Cost(expr&: *case_check.when_expr);
69 }
70 case_cost += Cost(expr&: *expr.else_expr);
71 return case_cost;
72}
73
74idx_t ExpressionHeuristics::ExpressionCost(BoundCastExpression &expr) {
75 // OPERATOR_CAST
76 // determine cast cost by comparing cast_expr.source_type and cast_expr_target_type
77 idx_t cast_cost = 0;
78 if (expr.return_type != expr.source_type()) {
79 // if cast from or to varchar
80 // TODO: we might want to add more cases
81 if (expr.return_type.id() == LogicalTypeId::VARCHAR || expr.source_type().id() == LogicalTypeId::VARCHAR ||
82 expr.return_type.id() == LogicalTypeId::BLOB || expr.source_type().id() == LogicalTypeId::BLOB) {
83 cast_cost = 200;
84 } else {
85 cast_cost = 5;
86 }
87 }
88 return Cost(expr&: *expr.child) + cast_cost;
89}
90
91idx_t ExpressionHeuristics::ExpressionCost(BoundComparisonExpression &expr) {
92 // COMPARE_EQUAL, COMPARE_NOTEQUAL, COMPARE_GREATERTHAN, COMPARE_GREATERTHANOREQUALTO, COMPARE_LESSTHAN,
93 // COMPARE_LESSTHANOREQUALTO
94 return Cost(expr&: *expr.left) + 5 + Cost(expr&: *expr.right);
95}
96
97idx_t ExpressionHeuristics::ExpressionCost(BoundConjunctionExpression &expr) {
98 // CONJUNCTION_AND, CONJUNCTION_OR
99 idx_t cost = 5;
100 for (auto &child : expr.children) {
101 cost += Cost(expr&: *child);
102 }
103 return cost;
104}
105
106idx_t ExpressionHeuristics::ExpressionCost(BoundFunctionExpression &expr) {
107 idx_t cost_children = 0;
108 for (auto &child : expr.children) {
109 cost_children += Cost(expr&: *child);
110 }
111
112 auto cost_function = function_costs.find(x: expr.function.name);
113 if (cost_function != function_costs.end()) {
114 return cost_children + cost_function->second;
115 } else {
116 return cost_children + 1000;
117 }
118}
119
120idx_t ExpressionHeuristics::ExpressionCost(BoundOperatorExpression &expr, ExpressionType &expr_type) {
121 idx_t sum = 0;
122 for (auto &child : expr.children) {
123 sum += Cost(expr&: *child);
124 }
125
126 // OPERATOR_IS_NULL, OPERATOR_IS_NOT_NULL
127 if (expr_type == ExpressionType::OPERATOR_IS_NULL || expr_type == ExpressionType::OPERATOR_IS_NOT_NULL) {
128 return sum + 5;
129 } else if (expr_type == ExpressionType::COMPARE_IN || expr_type == ExpressionType::COMPARE_NOT_IN) {
130 // COMPARE_IN, COMPARE_NOT_IN
131 return sum + (expr.children.size() - 1) * 100;
132 } else if (expr_type == ExpressionType::OPERATOR_NOT) {
133 // OPERATOR_NOT
134 return sum + 10; // TODO: evaluate via measured runtimes
135 } else {
136 return sum + 1000;
137 }
138}
139
140idx_t ExpressionHeuristics::ExpressionCost(PhysicalType return_type, idx_t multiplier) {
141 // TODO: ajust values according to benchmark results
142 switch (return_type) {
143 case PhysicalType::VARCHAR:
144 return 5 * multiplier;
145 case PhysicalType::FLOAT:
146 case PhysicalType::DOUBLE:
147 return 2 * multiplier;
148 default:
149 return 1 * multiplier;
150 }
151}
152
153idx_t ExpressionHeuristics::Cost(Expression &expr) {
154 switch (expr.expression_class) {
155 case ExpressionClass::BOUND_CASE: {
156 auto &case_expr = expr.Cast<BoundCaseExpression>();
157 return ExpressionCost(expr&: case_expr);
158 }
159 case ExpressionClass::BOUND_BETWEEN: {
160 auto &between_expr = expr.Cast<BoundBetweenExpression>();
161 return ExpressionCost(expr&: between_expr);
162 }
163 case ExpressionClass::BOUND_CAST: {
164 auto &cast_expr = expr.Cast<BoundCastExpression>();
165 return ExpressionCost(expr&: cast_expr);
166 }
167 case ExpressionClass::BOUND_COMPARISON: {
168 auto &comp_expr = expr.Cast<BoundComparisonExpression>();
169 return ExpressionCost(expr&: comp_expr);
170 }
171 case ExpressionClass::BOUND_CONJUNCTION: {
172 auto &conj_expr = expr.Cast<BoundConjunctionExpression>();
173 return ExpressionCost(expr&: conj_expr);
174 }
175 case ExpressionClass::BOUND_FUNCTION: {
176 auto &func_expr = expr.Cast<BoundFunctionExpression>();
177 return ExpressionCost(expr&: func_expr);
178 }
179 case ExpressionClass::BOUND_OPERATOR: {
180 auto &op_expr = expr.Cast<BoundOperatorExpression>();
181 return ExpressionCost(expr&: op_expr, expr_type&: expr.type);
182 }
183 case ExpressionClass::BOUND_COLUMN_REF: {
184 auto &col_expr = expr.Cast<BoundColumnRefExpression>();
185 return ExpressionCost(return_type: col_expr.return_type.InternalType(), multiplier: 8);
186 }
187 case ExpressionClass::BOUND_CONSTANT: {
188 auto &const_expr = expr.Cast<BoundConstantExpression>();
189 return ExpressionCost(return_type: const_expr.return_type.InternalType(), multiplier: 1);
190 }
191 case ExpressionClass::BOUND_PARAMETER: {
192 auto &const_expr = expr.Cast<BoundParameterExpression>();
193 return ExpressionCost(return_type: const_expr.return_type.InternalType(), multiplier: 1);
194 }
195 case ExpressionClass::BOUND_REF: {
196 auto &col_expr = expr.Cast<BoundColumnRefExpression>();
197 return ExpressionCost(return_type: col_expr.return_type.InternalType(), multiplier: 8);
198 }
199 default: {
200 break;
201 }
202 }
203
204 // return a very high value if nothing matches
205 return 1000;
206}
207
208} // namespace duckdb
209