1 | #include "duckdb/common/pair.hpp" |
2 | #include "duckdb/common/string_util.hpp" |
3 | #include "duckdb/common/types/chunk_collection.hpp" |
4 | #include "duckdb/common/types/data_chunk.hpp" |
5 | #include "duckdb/common/vector_operations/binary_executor.hpp" |
6 | #include "duckdb/function/scalar/nested_functions.hpp" |
7 | #include "duckdb/function/scalar/string_functions.hpp" |
8 | #include "duckdb/parser/expression/bound_expression.hpp" |
9 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
10 | #include "duckdb/storage/statistics/list_stats.hpp" |
11 | |
12 | namespace duckdb { |
13 | |
14 | template <class T, bool HEAP_REF = false, bool VALIDITY_ONLY = false> |
15 | void (idx_t count, UnifiedVectorFormat &list_data, UnifiedVectorFormat &offsets_data, |
16 | Vector &child_vector, idx_t list_size, Vector &result) { |
17 | UnifiedVectorFormat child_format; |
18 | child_vector.ToUnifiedFormat(count: list_size, data&: child_format); |
19 | |
20 | T *result_data; |
21 | |
22 | result.SetVectorType(VectorType::FLAT_VECTOR); |
23 | if (!VALIDITY_ONLY) { |
24 | result_data = FlatVector::GetData<T>(result); |
25 | } |
26 | auto &result_mask = FlatVector::Validity(vector&: result); |
27 | |
28 | // heap-ref once |
29 | if (HEAP_REF) { |
30 | StringVector::AddHeapReference(vector&: result, other&: child_vector); |
31 | } |
32 | |
33 | // this is lifted from ExecuteGenericLoop because we can't push the list child data into this otherwise |
34 | // should have gone with GetValue perhaps |
35 | auto child_data = UnifiedVectorFormat::GetData<T>(child_format); |
36 | for (idx_t i = 0; i < count; i++) { |
37 | auto list_index = list_data.sel->get_index(idx: i); |
38 | auto offsets_index = offsets_data.sel->get_index(idx: i); |
39 | if (!list_data.validity.RowIsValid(row_idx: list_index)) { |
40 | result_mask.SetInvalid(i); |
41 | continue; |
42 | } |
43 | if (!offsets_data.validity.RowIsValid(row_idx: offsets_index)) { |
44 | result_mask.SetInvalid(i); |
45 | continue; |
46 | } |
47 | auto list_entry = (UnifiedVectorFormat::GetData<list_entry_t>(format: list_data))[list_index]; |
48 | auto offsets_entry = (UnifiedVectorFormat::GetData<int64_t>(format: offsets_data))[offsets_index]; |
49 | |
50 | // 1-based indexing |
51 | if (offsets_entry == 0) { |
52 | result_mask.SetInvalid(i); |
53 | continue; |
54 | } |
55 | offsets_entry = (offsets_entry > 0) ? offsets_entry - 1 : offsets_entry; |
56 | |
57 | idx_t child_offset; |
58 | if (offsets_entry < 0) { |
59 | if (offsets_entry < -int64_t(list_entry.length)) { |
60 | result_mask.SetInvalid(i); |
61 | continue; |
62 | } |
63 | child_offset = list_entry.offset + list_entry.length + offsets_entry; |
64 | } else { |
65 | if ((idx_t)offsets_entry >= list_entry.length) { |
66 | result_mask.SetInvalid(i); |
67 | continue; |
68 | } |
69 | child_offset = list_entry.offset + offsets_entry; |
70 | } |
71 | auto child_index = child_format.sel->get_index(idx: child_offset); |
72 | if (child_format.validity.RowIsValid(row_idx: child_index)) { |
73 | if (!VALIDITY_ONLY) { |
74 | result_data[i] = child_data[child_index]; |
75 | } |
76 | } else { |
77 | result_mask.SetInvalid(i); |
78 | } |
79 | } |
80 | if (count == 1) { |
81 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
82 | } |
83 | } |
84 | static void (const idx_t count, UnifiedVectorFormat &list, UnifiedVectorFormat &offsets, |
85 | Vector &child_vector, idx_t list_size, Vector &result) { |
86 | D_ASSERT(child_vector.GetType() == result.GetType()); |
87 | switch (result.GetType().InternalType()) { |
88 | case PhysicalType::BOOL: |
89 | case PhysicalType::INT8: |
90 | ListExtractTemplate<int8_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
91 | break; |
92 | case PhysicalType::INT16: |
93 | ListExtractTemplate<int16_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
94 | break; |
95 | case PhysicalType::INT32: |
96 | ListExtractTemplate<int32_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
97 | break; |
98 | case PhysicalType::INT64: |
99 | ListExtractTemplate<int64_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
100 | break; |
101 | case PhysicalType::INT128: |
102 | ListExtractTemplate<hugeint_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
103 | break; |
104 | case PhysicalType::UINT8: |
105 | ListExtractTemplate<uint8_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
106 | break; |
107 | case PhysicalType::UINT16: |
108 | ListExtractTemplate<uint16_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
109 | break; |
110 | case PhysicalType::UINT32: |
111 | ListExtractTemplate<uint32_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
112 | break; |
113 | case PhysicalType::UINT64: |
114 | ListExtractTemplate<uint64_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
115 | break; |
116 | case PhysicalType::FLOAT: |
117 | ListExtractTemplate<float>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
118 | break; |
119 | case PhysicalType::DOUBLE: |
120 | ListExtractTemplate<double>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
121 | break; |
122 | case PhysicalType::VARCHAR: |
123 | ListExtractTemplate<string_t, true>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
124 | break; |
125 | case PhysicalType::INTERVAL: |
126 | ListExtractTemplate<interval_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
127 | break; |
128 | case PhysicalType::STRUCT: { |
129 | auto &entries = StructVector::GetEntries(vector&: child_vector); |
130 | auto &result_entries = StructVector::GetEntries(vector&: result); |
131 | D_ASSERT(entries.size() == result_entries.size()); |
132 | // extract the child entries of the struct |
133 | for (idx_t i = 0; i < entries.size(); i++) { |
134 | ExecuteListExtractInternal(count, list, offsets, child_vector&: *entries[i], list_size, result&: *result_entries[i]); |
135 | } |
136 | // extract the validity mask |
137 | ListExtractTemplate<bool, false, true>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
138 | break; |
139 | } |
140 | case PhysicalType::LIST: { |
141 | // nested list: we have to reference the child |
142 | auto &child_child_list = ListVector::GetEntry(vector&: child_vector); |
143 | |
144 | ListVector::GetEntry(vector&: result).Reference(other&: child_child_list); |
145 | ListVector::SetListSize(vec&: result, size: ListVector::GetListSize(vector: child_vector)); |
146 | ListExtractTemplate<list_entry_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result); |
147 | break; |
148 | } |
149 | default: |
150 | throw NotImplementedException("Unimplemented type for LIST_EXTRACT" ); |
151 | } |
152 | } |
153 | |
154 | static void (Vector &result, Vector &list, Vector &offsets, const idx_t count) { |
155 | D_ASSERT(list.GetType().id() == LogicalTypeId::LIST); |
156 | UnifiedVectorFormat list_data; |
157 | UnifiedVectorFormat offsets_data; |
158 | |
159 | list.ToUnifiedFormat(count, data&: list_data); |
160 | offsets.ToUnifiedFormat(count, data&: offsets_data); |
161 | ExecuteListExtractInternal(count, list&: list_data, offsets&: offsets_data, child_vector&: ListVector::GetEntry(vector&: list), |
162 | list_size: ListVector::GetListSize(vector: list), result); |
163 | result.Verify(count); |
164 | } |
165 | |
166 | static void (Vector &result, Vector &input_vector, Vector &subscript_vector, const idx_t count) { |
167 | BinaryExecutor::Execute<string_t, int64_t, string_t>( |
168 | left&: input_vector, right&: subscript_vector, result, count, fun: [&](string_t input_string, int64_t subscript) { |
169 | return SubstringFun::SubstringUnicode(result, input: input_string, offset: subscript, length: 1); |
170 | }); |
171 | } |
172 | |
173 | static void (DataChunk &args, ExpressionState &state, Vector &result) { |
174 | D_ASSERT(args.ColumnCount() == 2); |
175 | auto count = args.size(); |
176 | |
177 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
178 | for (idx_t i = 0; i < args.ColumnCount(); i++) { |
179 | if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { |
180 | result.SetVectorType(VectorType::FLAT_VECTOR); |
181 | } |
182 | } |
183 | |
184 | Vector &base = args.data[0]; |
185 | Vector &subscript = args.data[1]; |
186 | |
187 | switch (base.GetType().id()) { |
188 | case LogicalTypeId::LIST: |
189 | ExecuteListExtract(result, list&: base, offsets&: subscript, count); |
190 | break; |
191 | case LogicalTypeId::VARCHAR: |
192 | ExecuteStringExtract(result, input_vector&: base, subscript_vector&: subscript, count); |
193 | break; |
194 | case LogicalTypeId::SQLNULL: |
195 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
196 | ConstantVector::SetNull(vector&: result, is_null: true); |
197 | break; |
198 | default: |
199 | throw NotImplementedException("Specifier type not implemented" ); |
200 | } |
201 | } |
202 | |
203 | static unique_ptr<FunctionData> (ClientContext &context, ScalarFunction &bound_function, |
204 | vector<unique_ptr<Expression>> &arguments) { |
205 | D_ASSERT(bound_function.arguments.size() == 2); |
206 | D_ASSERT(LogicalTypeId::LIST == arguments[0]->return_type.id()); |
207 | // list extract returns the child type of the list as return type |
208 | bound_function.return_type = ListType::GetChildType(type: arguments[0]->return_type); |
209 | return make_uniq<VariableReturnBindData>(args&: bound_function.return_type); |
210 | } |
211 | |
212 | static unique_ptr<BaseStatistics> (ClientContext &context, FunctionStatisticsInput &input) { |
213 | auto &child_stats = input.child_stats; |
214 | auto &list_child_stats = ListStats::GetChildStats(stats&: child_stats[0]); |
215 | auto child_copy = list_child_stats.Copy(); |
216 | // list_extract always pushes a NULL, since if the offset is out of range for a list it inserts a null |
217 | child_copy.Set(StatsInfo::CAN_HAVE_NULL_VALUES); |
218 | return child_copy.ToUnique(); |
219 | } |
220 | |
221 | void ListExtractFun::(BuiltinFunctions &set) { |
222 | // the arguments and return types are actually set in the binder function |
223 | ScalarFunction lfun({LogicalType::LIST(child: LogicalType::ANY), LogicalType::BIGINT}, LogicalType::ANY, |
224 | ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); |
225 | |
226 | ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); |
227 | |
228 | ScalarFunctionSet ("list_extract" ); |
229 | list_extract.AddFunction(function: lfun); |
230 | list_extract.AddFunction(function: sfun); |
231 | set.AddFunction(set: list_extract); |
232 | |
233 | ScalarFunctionSet list_element("list_element" ); |
234 | list_element.AddFunction(function: lfun); |
235 | list_element.AddFunction(function: sfun); |
236 | set.AddFunction(set: list_element); |
237 | |
238 | ScalarFunctionSet ("array_extract" ); |
239 | array_extract.AddFunction(function: lfun); |
240 | array_extract.AddFunction(function: sfun); |
241 | array_extract.AddFunction(function: StructExtractFun::GetFunction()); |
242 | set.AddFunction(set: array_extract); |
243 | } |
244 | |
245 | } // namespace duckdb |
246 | |