1 | #include "duckdb/execution/expression_executor.hpp" |
2 | |
3 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
4 | #include "duckdb/execution/execution_context.hpp" |
5 | #include "duckdb/storage/statistics/base_statistics.hpp" |
6 | #include "duckdb/planner/expression/list.hpp" |
7 | |
8 | namespace duckdb { |
9 | |
10 | ExpressionExecutor::ExpressionExecutor(ClientContext &context) : context(&context) { |
11 | } |
12 | |
13 | ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression *expression) |
14 | : ExpressionExecutor(context) { |
15 | D_ASSERT(expression); |
16 | AddExpression(expr: *expression); |
17 | } |
18 | |
19 | ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression &expression) |
20 | : ExpressionExecutor(context) { |
21 | AddExpression(expr: expression); |
22 | } |
23 | |
24 | ExpressionExecutor::ExpressionExecutor(ClientContext &context, const vector<unique_ptr<Expression>> &exprs) |
25 | : ExpressionExecutor(context) { |
26 | D_ASSERT(exprs.size() > 0); |
27 | for (auto &expr : exprs) { |
28 | AddExpression(expr: *expr); |
29 | } |
30 | } |
31 | |
32 | ExpressionExecutor::ExpressionExecutor(const vector<unique_ptr<Expression>> &exprs) : context(nullptr) { |
33 | D_ASSERT(exprs.size() > 0); |
34 | for (auto &expr : exprs) { |
35 | AddExpression(expr: *expr); |
36 | } |
37 | } |
38 | |
39 | ExpressionExecutor::ExpressionExecutor() : context(nullptr) { |
40 | } |
41 | |
42 | bool ExpressionExecutor::HasContext() { |
43 | return context; |
44 | } |
45 | |
46 | ClientContext &ExpressionExecutor::GetContext() { |
47 | if (!context) { |
48 | throw InternalException("Calling ExpressionExecutor::GetContext on an expression executor without a context" ); |
49 | } |
50 | return *context; |
51 | } |
52 | |
53 | Allocator &ExpressionExecutor::GetAllocator() { |
54 | return context ? Allocator::Get(context&: *context) : Allocator::DefaultAllocator(); |
55 | } |
56 | |
57 | void ExpressionExecutor::AddExpression(const Expression &expr) { |
58 | expressions.push_back(x: &expr); |
59 | auto state = make_uniq<ExpressionExecutorState>(); |
60 | Initialize(expr, state&: *state); |
61 | state->Verify(); |
62 | states.push_back(x: std::move(state)); |
63 | } |
64 | |
65 | void ExpressionExecutor::Initialize(const Expression &expression, ExpressionExecutorState &state) { |
66 | state.executor = this; |
67 | state.root_state = InitializeState(expr: expression, state); |
68 | } |
69 | |
70 | void ExpressionExecutor::Execute(DataChunk *input, DataChunk &result) { |
71 | SetChunk(input); |
72 | D_ASSERT(expressions.size() == result.ColumnCount()); |
73 | D_ASSERT(!expressions.empty()); |
74 | |
75 | for (idx_t i = 0; i < expressions.size(); i++) { |
76 | ExecuteExpression(expr_idx: i, result&: result.data[i]); |
77 | } |
78 | result.SetCardinality(input ? input->size() : 1); |
79 | result.Verify(); |
80 | } |
81 | |
82 | void ExpressionExecutor::ExecuteExpression(DataChunk &input, Vector &result) { |
83 | SetChunk(&input); |
84 | ExecuteExpression(result); |
85 | } |
86 | |
87 | idx_t ExpressionExecutor::SelectExpression(DataChunk &input, SelectionVector &sel) { |
88 | D_ASSERT(expressions.size() == 1); |
89 | SetChunk(&input); |
90 | states[0]->profiler.BeginSample(); |
91 | idx_t selected_tuples = Select(expr: *expressions[0], state: states[0]->root_state.get(), sel: nullptr, count: input.size(), true_sel: &sel, false_sel: nullptr); |
92 | states[0]->profiler.EndSample(chunk_size: chunk ? chunk->size() : 0); |
93 | return selected_tuples; |
94 | } |
95 | |
96 | void ExpressionExecutor::ExecuteExpression(Vector &result) { |
97 | D_ASSERT(expressions.size() == 1); |
98 | ExecuteExpression(expr_idx: 0, result); |
99 | } |
100 | |
101 | void ExpressionExecutor::ExecuteExpression(idx_t expr_idx, Vector &result) { |
102 | D_ASSERT(expr_idx < expressions.size()); |
103 | D_ASSERT(result.GetType().id() == expressions[expr_idx]->return_type.id()); |
104 | states[expr_idx]->profiler.BeginSample(); |
105 | Execute(expr: *expressions[expr_idx], state: states[expr_idx]->root_state.get(), sel: nullptr, count: chunk ? chunk->size() : 1, result); |
106 | states[expr_idx]->profiler.EndSample(chunk_size: chunk ? chunk->size() : 0); |
107 | } |
108 | |
109 | Value ExpressionExecutor::EvaluateScalar(ClientContext &context, const Expression &expr, bool allow_unfoldable) { |
110 | D_ASSERT(allow_unfoldable || expr.IsFoldable()); |
111 | D_ASSERT(expr.IsScalar()); |
112 | // use an ExpressionExecutor to execute the expression |
113 | ExpressionExecutor executor(context, expr); |
114 | |
115 | Vector result(expr.return_type); |
116 | executor.ExecuteExpression(result); |
117 | |
118 | D_ASSERT(allow_unfoldable || result.GetVectorType() == VectorType::CONSTANT_VECTOR); |
119 | auto result_value = result.GetValue(index: 0); |
120 | D_ASSERT(result_value.type().InternalType() == expr.return_type.InternalType()); |
121 | return result_value; |
122 | } |
123 | |
124 | bool ExpressionExecutor::TryEvaluateScalar(ClientContext &context, const Expression &expr, Value &result) { |
125 | try { |
126 | result = EvaluateScalar(context, expr); |
127 | return true; |
128 | } catch (InternalException &ex) { |
129 | throw ex; |
130 | } catch (...) { |
131 | return false; |
132 | } |
133 | } |
134 | |
135 | void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t count) { |
136 | D_ASSERT(expr.return_type.id() == vector.GetType().id()); |
137 | vector.Verify(count); |
138 | if (expr.verification_stats) { |
139 | expr.verification_stats->Verify(vector, count); |
140 | } |
141 | } |
142 | |
143 | unique_ptr<ExpressionState> ExpressionExecutor::InitializeState(const Expression &expr, |
144 | ExpressionExecutorState &state) { |
145 | switch (expr.expression_class) { |
146 | case ExpressionClass::BOUND_REF: |
147 | return InitializeState(expr: expr.Cast<BoundReferenceExpression>(), state); |
148 | case ExpressionClass::BOUND_BETWEEN: |
149 | return InitializeState(expr: expr.Cast<BoundBetweenExpression>(), state); |
150 | case ExpressionClass::BOUND_CASE: |
151 | return InitializeState(expr: expr.Cast<BoundCaseExpression>(), state); |
152 | case ExpressionClass::BOUND_CAST: |
153 | return InitializeState(expr: expr.Cast<BoundCastExpression>(), state); |
154 | case ExpressionClass::BOUND_COMPARISON: |
155 | return InitializeState(expr: expr.Cast<BoundComparisonExpression>(), state); |
156 | case ExpressionClass::BOUND_CONJUNCTION: |
157 | return InitializeState(expr: expr.Cast<BoundConjunctionExpression>(), state); |
158 | case ExpressionClass::BOUND_CONSTANT: |
159 | return InitializeState(expr: expr.Cast<BoundConstantExpression>(), state); |
160 | case ExpressionClass::BOUND_FUNCTION: |
161 | return InitializeState(expr: expr.Cast<BoundFunctionExpression>(), state); |
162 | case ExpressionClass::BOUND_OPERATOR: |
163 | return InitializeState(expr: expr.Cast<BoundOperatorExpression>(), state); |
164 | case ExpressionClass::BOUND_PARAMETER: |
165 | return InitializeState(expr: expr.Cast<BoundParameterExpression>(), state); |
166 | default: |
167 | throw InternalException("Attempting to initialize state of expression of unknown type!" ); |
168 | } |
169 | } |
170 | |
171 | void ExpressionExecutor::Execute(const Expression &expr, ExpressionState *state, const SelectionVector *sel, |
172 | idx_t count, Vector &result) { |
173 | #ifdef DEBUG |
174 | //! The result Vector must be "clean" |
175 | if (result.GetVectorType() == VectorType::FLAT_VECTOR) { |
176 | D_ASSERT(FlatVector::Validity(result).CheckAllValid(count)); |
177 | } |
178 | #endif |
179 | |
180 | if (count == 0) { |
181 | return; |
182 | } |
183 | if (result.GetType().id() != expr.return_type.id()) { |
184 | throw InternalException( |
185 | "ExpressionExecutor::Execute called with a result vector of type %s that does not match expression type %s" , |
186 | result.GetType(), expr.return_type); |
187 | } |
188 | switch (expr.expression_class) { |
189 | case ExpressionClass::BOUND_BETWEEN: |
190 | Execute(expr: expr.Cast<BoundBetweenExpression>(), state, sel, count, result); |
191 | break; |
192 | case ExpressionClass::BOUND_REF: |
193 | Execute(expr: expr.Cast<BoundReferenceExpression>(), state, sel, count, result); |
194 | break; |
195 | case ExpressionClass::BOUND_CASE: |
196 | Execute(expr: expr.Cast<BoundCaseExpression>(), state, sel, count, result); |
197 | break; |
198 | case ExpressionClass::BOUND_CAST: |
199 | Execute(expr: expr.Cast<BoundCastExpression>(), state, sel, count, result); |
200 | break; |
201 | case ExpressionClass::BOUND_COMPARISON: |
202 | Execute(expr: expr.Cast<BoundComparisonExpression>(), state, sel, count, result); |
203 | break; |
204 | case ExpressionClass::BOUND_CONJUNCTION: |
205 | Execute(expr: expr.Cast<BoundConjunctionExpression>(), state, sel, count, result); |
206 | break; |
207 | case ExpressionClass::BOUND_CONSTANT: |
208 | Execute(expr: expr.Cast<BoundConstantExpression>(), state, sel, count, result); |
209 | break; |
210 | case ExpressionClass::BOUND_FUNCTION: |
211 | Execute(expr: expr.Cast<BoundFunctionExpression>(), state, sel, count, result); |
212 | break; |
213 | case ExpressionClass::BOUND_OPERATOR: |
214 | Execute(expr: expr.Cast<BoundOperatorExpression>(), state, sel, count, result); |
215 | break; |
216 | case ExpressionClass::BOUND_PARAMETER: |
217 | Execute(expr: expr.Cast<BoundParameterExpression>(), state, sel, count, result); |
218 | break; |
219 | default: |
220 | throw InternalException("Attempting to execute expression of unknown type!" ); |
221 | } |
222 | Verify(expr, vector&: result, count); |
223 | } |
224 | |
225 | idx_t ExpressionExecutor::Select(const Expression &expr, ExpressionState *state, const SelectionVector *sel, |
226 | idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { |
227 | if (count == 0) { |
228 | return 0; |
229 | } |
230 | D_ASSERT(true_sel || false_sel); |
231 | D_ASSERT(expr.return_type.id() == LogicalTypeId::BOOLEAN); |
232 | switch (expr.expression_class) { |
233 | case ExpressionClass::BOUND_BETWEEN: |
234 | return Select(expr: expr.Cast<BoundBetweenExpression>(), state, sel, count, true_sel, false_sel); |
235 | case ExpressionClass::BOUND_COMPARISON: |
236 | return Select(expr: expr.Cast<BoundComparisonExpression>(), state, sel, count, true_sel, false_sel); |
237 | case ExpressionClass::BOUND_CONJUNCTION: |
238 | return Select(expr: expr.Cast<BoundConjunctionExpression>(), state, sel, count, true_sel, false_sel); |
239 | default: |
240 | return DefaultSelect(expr, state, sel, count, true_sel, false_sel); |
241 | } |
242 | } |
243 | |
244 | template <bool NO_NULL, bool HAS_TRUE_SEL, bool HAS_FALSE_SEL> |
245 | static inline idx_t DefaultSelectLoop(const SelectionVector *bsel, const uint8_t *__restrict bdata, ValidityMask &mask, |
246 | const SelectionVector *sel, idx_t count, SelectionVector *true_sel, |
247 | SelectionVector *false_sel) { |
248 | idx_t true_count = 0, false_count = 0; |
249 | for (idx_t i = 0; i < count; i++) { |
250 | auto bidx = bsel->get_index(idx: i); |
251 | auto result_idx = sel->get_index(idx: i); |
252 | if (bdata[bidx] > 0 && (NO_NULL || mask.RowIsValid(row_idx: bidx))) { |
253 | if (HAS_TRUE_SEL) { |
254 | true_sel->set_index(idx: true_count++, loc: result_idx); |
255 | } |
256 | } else { |
257 | if (HAS_FALSE_SEL) { |
258 | false_sel->set_index(idx: false_count++, loc: result_idx); |
259 | } |
260 | } |
261 | } |
262 | if (HAS_TRUE_SEL) { |
263 | return true_count; |
264 | } else { |
265 | return count - false_count; |
266 | } |
267 | } |
268 | |
269 | template <bool NO_NULL> |
270 | static inline idx_t DefaultSelectSwitch(UnifiedVectorFormat &idata, const SelectionVector *sel, idx_t count, |
271 | SelectionVector *true_sel, SelectionVector *false_sel) { |
272 | if (true_sel && false_sel) { |
273 | return DefaultSelectLoop<NO_NULL, true, true>(idata.sel, UnifiedVectorFormat::GetData<uint8_t>(format: idata), |
274 | idata.validity, sel, count, true_sel, false_sel); |
275 | } else if (true_sel) { |
276 | return DefaultSelectLoop<NO_NULL, true, false>(idata.sel, UnifiedVectorFormat::GetData<uint8_t>(format: idata), |
277 | idata.validity, sel, count, true_sel, false_sel); |
278 | } else { |
279 | D_ASSERT(false_sel); |
280 | return DefaultSelectLoop<NO_NULL, false, true>(idata.sel, UnifiedVectorFormat::GetData<uint8_t>(format: idata), |
281 | idata.validity, sel, count, true_sel, false_sel); |
282 | } |
283 | } |
284 | |
285 | idx_t ExpressionExecutor::DefaultSelect(const Expression &expr, ExpressionState *state, const SelectionVector *sel, |
286 | idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { |
287 | // generic selection of boolean expression: |
288 | // resolve the true/false expression first |
289 | // then use that to generate the selection vector |
290 | bool intermediate_bools[STANDARD_VECTOR_SIZE]; |
291 | Vector intermediate(LogicalType::BOOLEAN, data_ptr_cast(src: intermediate_bools)); |
292 | Execute(expr, state, sel, count, result&: intermediate); |
293 | |
294 | UnifiedVectorFormat idata; |
295 | intermediate.ToUnifiedFormat(count, data&: idata); |
296 | |
297 | if (!sel) { |
298 | sel = FlatVector::IncrementalSelectionVector(); |
299 | } |
300 | if (!idata.validity.AllValid()) { |
301 | return DefaultSelectSwitch<false>(idata, sel, count, true_sel, false_sel); |
302 | } else { |
303 | return DefaultSelectSwitch<true>(idata, sel, count, true_sel, false_sel); |
304 | } |
305 | } |
306 | |
307 | vector<unique_ptr<ExpressionExecutorState>> &ExpressionExecutor::GetStates() { |
308 | return states; |
309 | } |
310 | |
311 | } // namespace duckdb |
312 | |