1#include "duckdb/common/vector_operations/vector_operations.hpp"
2#include "duckdb/execution/expression_executor.hpp"
3#include "duckdb/planner/expression/bound_operator_expression.hpp"
4
5namespace duckdb {
6
7unique_ptr<ExpressionState> ExpressionExecutor::InitializeState(const BoundOperatorExpression &expr,
8 ExpressionExecutorState &root) {
9 auto result = make_uniq<ExpressionState>(args: expr, args&: root);
10 for (auto &child : expr.children) {
11 result->AddChild(expr: child.get());
12 }
13 result->Finalize();
14 return result;
15}
16
17void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, ExpressionState *state,
18 const SelectionVector *sel, idx_t count, Vector &result) {
19 // special handling for special snowflake 'IN'
20 // IN has n children
21 if (expr.type == ExpressionType::COMPARE_IN || expr.type == ExpressionType::COMPARE_NOT_IN) {
22 if (expr.children.size() < 2) {
23 throw InvalidInputException("IN needs at least two children");
24 }
25
26 Vector left(expr.children[0]->return_type);
27 // eval left side
28 Execute(expr: *expr.children[0], state: state->child_states[0].get(), sel, count, result&: left);
29
30 // init result to false
31 Vector intermediate(LogicalType::BOOLEAN);
32 Value false_val = Value::BOOLEAN(value: false);
33 intermediate.Reference(value: false_val);
34
35 // in rhs is a list of constants
36 // for every child, OR the result of the comparision with the left
37 // to get the overall result.
38 for (idx_t child = 1; child < expr.children.size(); child++) {
39 Vector vector_to_check(expr.children[child]->return_type);
40 Vector comp_res(LogicalType::BOOLEAN);
41
42 Execute(expr: *expr.children[child], state: state->child_states[child].get(), sel, count, result&: vector_to_check);
43 VectorOperations::Equals(left, right&: vector_to_check, result&: comp_res, count);
44
45 if (child == 1) {
46 // first child: move to result
47 intermediate.Reference(other&: comp_res);
48 } else {
49 // otherwise OR together
50 Vector new_result(LogicalType::BOOLEAN, true, false);
51 VectorOperations::Or(left&: intermediate, right&: comp_res, result&: new_result, count);
52 intermediate.Reference(other&: new_result);
53 }
54 }
55 if (expr.type == ExpressionType::COMPARE_NOT_IN) {
56 // NOT IN: invert result
57 VectorOperations::Not(left&: intermediate, result, count);
58 } else {
59 // directly use the result
60 result.Reference(other&: intermediate);
61 }
62 } else if (expr.type == ExpressionType::OPERATOR_COALESCE) {
63 SelectionVector sel_a(count);
64 SelectionVector sel_b(count);
65 SelectionVector slice_sel(count);
66 SelectionVector result_sel(count);
67 SelectionVector *next_sel = &sel_a;
68 const SelectionVector *current_sel = sel;
69 idx_t remaining_count = count;
70 idx_t next_count;
71 for (idx_t child = 0; child < expr.children.size(); child++) {
72 Vector vector_to_check(expr.children[child]->return_type);
73 Execute(expr: *expr.children[child], state: state->child_states[child].get(), sel: current_sel, count: remaining_count,
74 result&: vector_to_check);
75
76 UnifiedVectorFormat vdata;
77 vector_to_check.ToUnifiedFormat(count: remaining_count, data&: vdata);
78
79 idx_t result_count = 0;
80 next_count = 0;
81 for (idx_t i = 0; i < remaining_count; i++) {
82 auto base_idx = current_sel ? current_sel->get_index(idx: i) : i;
83 auto idx = vdata.sel->get_index(idx: i);
84 if (vdata.validity.RowIsValid(row_idx: idx)) {
85 slice_sel.set_index(idx: result_count, loc: i);
86 result_sel.set_index(idx: result_count++, loc: base_idx);
87 } else {
88 next_sel->set_index(idx: next_count++, loc: base_idx);
89 }
90 }
91 if (result_count > 0) {
92 vector_to_check.Slice(sel: slice_sel, count: result_count);
93 FillSwitch(vector&: vector_to_check, result, sel: result_sel, count: result_count);
94 }
95 current_sel = next_sel;
96 next_sel = next_sel == &sel_a ? &sel_b : &sel_a;
97 remaining_count = next_count;
98 if (next_count == 0) {
99 break;
100 }
101 }
102 if (remaining_count > 0) {
103 for (idx_t i = 0; i < remaining_count; i++) {
104 FlatVector::SetNull(vector&: result, idx: current_sel->get_index(idx: i), is_null: true);
105 }
106 }
107 if (sel) {
108 result.Slice(sel: *sel, count);
109 } else if (count == 1) {
110 result.SetVectorType(VectorType::CONSTANT_VECTOR);
111 }
112 } else if (expr.children.size() == 1) {
113 state->intermediate_chunk.Reset();
114 auto &child = state->intermediate_chunk.data[0];
115
116 Execute(expr: *expr.children[0], state: state->child_states[0].get(), sel, count, result&: child);
117 switch (expr.type) {
118 case ExpressionType::OPERATOR_NOT: {
119 VectorOperations::Not(left&: child, result, count);
120 break;
121 }
122 case ExpressionType::OPERATOR_IS_NULL: {
123 VectorOperations::IsNull(input&: child, result, count);
124 break;
125 }
126 case ExpressionType::OPERATOR_IS_NOT_NULL: {
127 VectorOperations::IsNotNull(arg&: child, result, count);
128 break;
129 }
130 default:
131 throw NotImplementedException("Unsupported operator type with 1 child!");
132 }
133 } else {
134 throw NotImplementedException("operator");
135 }
136}
137
138} // namespace duckdb
139