1#include "duckdb/execution/expression_executor.hpp"
2
3#include "duckdb/common/vector_operations/vector_operations.hpp"
4#include "duckdb/execution/execution_context.hpp"
5#include "duckdb/storage/statistics/base_statistics.hpp"
6#include "duckdb/planner/expression/list.hpp"
7
8namespace duckdb {
9
10ExpressionExecutor::ExpressionExecutor(ClientContext &context) : context(&context) {
11}
12
13ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression *expression)
14 : ExpressionExecutor(context) {
15 D_ASSERT(expression);
16 AddExpression(expr: *expression);
17}
18
19ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression &expression)
20 : ExpressionExecutor(context) {
21 AddExpression(expr: expression);
22}
23
24ExpressionExecutor::ExpressionExecutor(ClientContext &context, const vector<unique_ptr<Expression>> &exprs)
25 : ExpressionExecutor(context) {
26 D_ASSERT(exprs.size() > 0);
27 for (auto &expr : exprs) {
28 AddExpression(expr: *expr);
29 }
30}
31
32ExpressionExecutor::ExpressionExecutor(const vector<unique_ptr<Expression>> &exprs) : context(nullptr) {
33 D_ASSERT(exprs.size() > 0);
34 for (auto &expr : exprs) {
35 AddExpression(expr: *expr);
36 }
37}
38
39ExpressionExecutor::ExpressionExecutor() : context(nullptr) {
40}
41
42bool ExpressionExecutor::HasContext() {
43 return context;
44}
45
46ClientContext &ExpressionExecutor::GetContext() {
47 if (!context) {
48 throw InternalException("Calling ExpressionExecutor::GetContext on an expression executor without a context");
49 }
50 return *context;
51}
52
53Allocator &ExpressionExecutor::GetAllocator() {
54 return context ? Allocator::Get(context&: *context) : Allocator::DefaultAllocator();
55}
56
57void ExpressionExecutor::AddExpression(const Expression &expr) {
58 expressions.push_back(x: &expr);
59 auto state = make_uniq<ExpressionExecutorState>();
60 Initialize(expr, state&: *state);
61 state->Verify();
62 states.push_back(x: std::move(state));
63}
64
65void ExpressionExecutor::Initialize(const Expression &expression, ExpressionExecutorState &state) {
66 state.executor = this;
67 state.root_state = InitializeState(expr: expression, state);
68}
69
70void ExpressionExecutor::Execute(DataChunk *input, DataChunk &result) {
71 SetChunk(input);
72 D_ASSERT(expressions.size() == result.ColumnCount());
73 D_ASSERT(!expressions.empty());
74
75 for (idx_t i = 0; i < expressions.size(); i++) {
76 ExecuteExpression(expr_idx: i, result&: result.data[i]);
77 }
78 result.SetCardinality(input ? input->size() : 1);
79 result.Verify();
80}
81
82void ExpressionExecutor::ExecuteExpression(DataChunk &input, Vector &result) {
83 SetChunk(&input);
84 ExecuteExpression(result);
85}
86
87idx_t ExpressionExecutor::SelectExpression(DataChunk &input, SelectionVector &sel) {
88 D_ASSERT(expressions.size() == 1);
89 SetChunk(&input);
90 states[0]->profiler.BeginSample();
91 idx_t selected_tuples = Select(expr: *expressions[0], state: states[0]->root_state.get(), sel: nullptr, count: input.size(), true_sel: &sel, false_sel: nullptr);
92 states[0]->profiler.EndSample(chunk_size: chunk ? chunk->size() : 0);
93 return selected_tuples;
94}
95
96void ExpressionExecutor::ExecuteExpression(Vector &result) {
97 D_ASSERT(expressions.size() == 1);
98 ExecuteExpression(expr_idx: 0, result);
99}
100
101void ExpressionExecutor::ExecuteExpression(idx_t expr_idx, Vector &result) {
102 D_ASSERT(expr_idx < expressions.size());
103 D_ASSERT(result.GetType().id() == expressions[expr_idx]->return_type.id());
104 states[expr_idx]->profiler.BeginSample();
105 Execute(expr: *expressions[expr_idx], state: states[expr_idx]->root_state.get(), sel: nullptr, count: chunk ? chunk->size() : 1, result);
106 states[expr_idx]->profiler.EndSample(chunk_size: chunk ? chunk->size() : 0);
107}
108
109Value ExpressionExecutor::EvaluateScalar(ClientContext &context, const Expression &expr, bool allow_unfoldable) {
110 D_ASSERT(allow_unfoldable || expr.IsFoldable());
111 D_ASSERT(expr.IsScalar());
112 // use an ExpressionExecutor to execute the expression
113 ExpressionExecutor executor(context, expr);
114
115 Vector result(expr.return_type);
116 executor.ExecuteExpression(result);
117
118 D_ASSERT(allow_unfoldable || result.GetVectorType() == VectorType::CONSTANT_VECTOR);
119 auto result_value = result.GetValue(index: 0);
120 D_ASSERT(result_value.type().InternalType() == expr.return_type.InternalType());
121 return result_value;
122}
123
124bool ExpressionExecutor::TryEvaluateScalar(ClientContext &context, const Expression &expr, Value &result) {
125 try {
126 result = EvaluateScalar(context, expr);
127 return true;
128 } catch (InternalException &ex) {
129 throw ex;
130 } catch (...) {
131 return false;
132 }
133}
134
135void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t count) {
136 D_ASSERT(expr.return_type.id() == vector.GetType().id());
137 vector.Verify(count);
138 if (expr.verification_stats) {
139 expr.verification_stats->Verify(vector, count);
140 }
141}
142
143unique_ptr<ExpressionState> ExpressionExecutor::InitializeState(const Expression &expr,
144 ExpressionExecutorState &state) {
145 switch (expr.expression_class) {
146 case ExpressionClass::BOUND_REF:
147 return InitializeState(expr: expr.Cast<BoundReferenceExpression>(), state);
148 case ExpressionClass::BOUND_BETWEEN:
149 return InitializeState(expr: expr.Cast<BoundBetweenExpression>(), state);
150 case ExpressionClass::BOUND_CASE:
151 return InitializeState(expr: expr.Cast<BoundCaseExpression>(), state);
152 case ExpressionClass::BOUND_CAST:
153 return InitializeState(expr: expr.Cast<BoundCastExpression>(), state);
154 case ExpressionClass::BOUND_COMPARISON:
155 return InitializeState(expr: expr.Cast<BoundComparisonExpression>(), state);
156 case ExpressionClass::BOUND_CONJUNCTION:
157 return InitializeState(expr: expr.Cast<BoundConjunctionExpression>(), state);
158 case ExpressionClass::BOUND_CONSTANT:
159 return InitializeState(expr: expr.Cast<BoundConstantExpression>(), state);
160 case ExpressionClass::BOUND_FUNCTION:
161 return InitializeState(expr: expr.Cast<BoundFunctionExpression>(), state);
162 case ExpressionClass::BOUND_OPERATOR:
163 return InitializeState(expr: expr.Cast<BoundOperatorExpression>(), state);
164 case ExpressionClass::BOUND_PARAMETER:
165 return InitializeState(expr: expr.Cast<BoundParameterExpression>(), state);
166 default:
167 throw InternalException("Attempting to initialize state of expression of unknown type!");
168 }
169}
170
171void ExpressionExecutor::Execute(const Expression &expr, ExpressionState *state, const SelectionVector *sel,
172 idx_t count, Vector &result) {
173#ifdef DEBUG
174 //! The result Vector must be "clean"
175 if (result.GetVectorType() == VectorType::FLAT_VECTOR) {
176 D_ASSERT(FlatVector::Validity(result).CheckAllValid(count));
177 }
178#endif
179
180 if (count == 0) {
181 return;
182 }
183 if (result.GetType().id() != expr.return_type.id()) {
184 throw InternalException(
185 "ExpressionExecutor::Execute called with a result vector of type %s that does not match expression type %s",
186 result.GetType(), expr.return_type);
187 }
188 switch (expr.expression_class) {
189 case ExpressionClass::BOUND_BETWEEN:
190 Execute(expr: expr.Cast<BoundBetweenExpression>(), state, sel, count, result);
191 break;
192 case ExpressionClass::BOUND_REF:
193 Execute(expr: expr.Cast<BoundReferenceExpression>(), state, sel, count, result);
194 break;
195 case ExpressionClass::BOUND_CASE:
196 Execute(expr: expr.Cast<BoundCaseExpression>(), state, sel, count, result);
197 break;
198 case ExpressionClass::BOUND_CAST:
199 Execute(expr: expr.Cast<BoundCastExpression>(), state, sel, count, result);
200 break;
201 case ExpressionClass::BOUND_COMPARISON:
202 Execute(expr: expr.Cast<BoundComparisonExpression>(), state, sel, count, result);
203 break;
204 case ExpressionClass::BOUND_CONJUNCTION:
205 Execute(expr: expr.Cast<BoundConjunctionExpression>(), state, sel, count, result);
206 break;
207 case ExpressionClass::BOUND_CONSTANT:
208 Execute(expr: expr.Cast<BoundConstantExpression>(), state, sel, count, result);
209 break;
210 case ExpressionClass::BOUND_FUNCTION:
211 Execute(expr: expr.Cast<BoundFunctionExpression>(), state, sel, count, result);
212 break;
213 case ExpressionClass::BOUND_OPERATOR:
214 Execute(expr: expr.Cast<BoundOperatorExpression>(), state, sel, count, result);
215 break;
216 case ExpressionClass::BOUND_PARAMETER:
217 Execute(expr: expr.Cast<BoundParameterExpression>(), state, sel, count, result);
218 break;
219 default:
220 throw InternalException("Attempting to execute expression of unknown type!");
221 }
222 Verify(expr, vector&: result, count);
223}
224
225idx_t ExpressionExecutor::Select(const Expression &expr, ExpressionState *state, const SelectionVector *sel,
226 idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) {
227 if (count == 0) {
228 return 0;
229 }
230 D_ASSERT(true_sel || false_sel);
231 D_ASSERT(expr.return_type.id() == LogicalTypeId::BOOLEAN);
232 switch (expr.expression_class) {
233 case ExpressionClass::BOUND_BETWEEN:
234 return Select(expr: expr.Cast<BoundBetweenExpression>(), state, sel, count, true_sel, false_sel);
235 case ExpressionClass::BOUND_COMPARISON:
236 return Select(expr: expr.Cast<BoundComparisonExpression>(), state, sel, count, true_sel, false_sel);
237 case ExpressionClass::BOUND_CONJUNCTION:
238 return Select(expr: expr.Cast<BoundConjunctionExpression>(), state, sel, count, true_sel, false_sel);
239 default:
240 return DefaultSelect(expr, state, sel, count, true_sel, false_sel);
241 }
242}
243
244template <bool NO_NULL, bool HAS_TRUE_SEL, bool HAS_FALSE_SEL>
245static inline idx_t DefaultSelectLoop(const SelectionVector *bsel, const uint8_t *__restrict bdata, ValidityMask &mask,
246 const SelectionVector *sel, idx_t count, SelectionVector *true_sel,
247 SelectionVector *false_sel) {
248 idx_t true_count = 0, false_count = 0;
249 for (idx_t i = 0; i < count; i++) {
250 auto bidx = bsel->get_index(idx: i);
251 auto result_idx = sel->get_index(idx: i);
252 if (bdata[bidx] > 0 && (NO_NULL || mask.RowIsValid(row_idx: bidx))) {
253 if (HAS_TRUE_SEL) {
254 true_sel->set_index(idx: true_count++, loc: result_idx);
255 }
256 } else {
257 if (HAS_FALSE_SEL) {
258 false_sel->set_index(idx: false_count++, loc: result_idx);
259 }
260 }
261 }
262 if (HAS_TRUE_SEL) {
263 return true_count;
264 } else {
265 return count - false_count;
266 }
267}
268
269template <bool NO_NULL>
270static inline idx_t DefaultSelectSwitch(UnifiedVectorFormat &idata, const SelectionVector *sel, idx_t count,
271 SelectionVector *true_sel, SelectionVector *false_sel) {
272 if (true_sel && false_sel) {
273 return DefaultSelectLoop<NO_NULL, true, true>(idata.sel, UnifiedVectorFormat::GetData<uint8_t>(format: idata),
274 idata.validity, sel, count, true_sel, false_sel);
275 } else if (true_sel) {
276 return DefaultSelectLoop<NO_NULL, true, false>(idata.sel, UnifiedVectorFormat::GetData<uint8_t>(format: idata),
277 idata.validity, sel, count, true_sel, false_sel);
278 } else {
279 D_ASSERT(false_sel);
280 return DefaultSelectLoop<NO_NULL, false, true>(idata.sel, UnifiedVectorFormat::GetData<uint8_t>(format: idata),
281 idata.validity, sel, count, true_sel, false_sel);
282 }
283}
284
285idx_t ExpressionExecutor::DefaultSelect(const Expression &expr, ExpressionState *state, const SelectionVector *sel,
286 idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) {
287 // generic selection of boolean expression:
288 // resolve the true/false expression first
289 // then use that to generate the selection vector
290 bool intermediate_bools[STANDARD_VECTOR_SIZE];
291 Vector intermediate(LogicalType::BOOLEAN, data_ptr_cast(src: intermediate_bools));
292 Execute(expr, state, sel, count, result&: intermediate);
293
294 UnifiedVectorFormat idata;
295 intermediate.ToUnifiedFormat(count, data&: idata);
296
297 if (!sel) {
298 sel = FlatVector::IncrementalSelectionVector();
299 }
300 if (!idata.validity.AllValid()) {
301 return DefaultSelectSwitch<false>(idata, sel, count, true_sel, false_sel);
302 } else {
303 return DefaultSelectSwitch<true>(idata, sel, count, true_sel, false_sel);
304 }
305}
306
307vector<unique_ptr<ExpressionExecutorState>> &ExpressionExecutor::GetStates() {
308 return states;
309}
310
311} // namespace duckdb
312