1#include "duckdb/common/pair.hpp"
2#include "duckdb/common/string_util.hpp"
3#include "duckdb/common/types/chunk_collection.hpp"
4#include "duckdb/common/types/data_chunk.hpp"
5#include "duckdb/common/vector_operations/binary_executor.hpp"
6#include "duckdb/function/scalar/nested_functions.hpp"
7#include "duckdb/function/scalar/string_functions.hpp"
8#include "duckdb/parser/expression/bound_expression.hpp"
9#include "duckdb/planner/expression/bound_function_expression.hpp"
10#include "duckdb/storage/statistics/list_stats.hpp"
11
12namespace duckdb {
13
14template <class T, bool HEAP_REF = false, bool VALIDITY_ONLY = false>
15void ListExtractTemplate(idx_t count, UnifiedVectorFormat &list_data, UnifiedVectorFormat &offsets_data,
16 Vector &child_vector, idx_t list_size, Vector &result) {
17 UnifiedVectorFormat child_format;
18 child_vector.ToUnifiedFormat(count: list_size, data&: child_format);
19
20 T *result_data;
21
22 result.SetVectorType(VectorType::FLAT_VECTOR);
23 if (!VALIDITY_ONLY) {
24 result_data = FlatVector::GetData<T>(result);
25 }
26 auto &result_mask = FlatVector::Validity(vector&: result);
27
28 // heap-ref once
29 if (HEAP_REF) {
30 StringVector::AddHeapReference(vector&: result, other&: child_vector);
31 }
32
33 // this is lifted from ExecuteGenericLoop because we can't push the list child data into this otherwise
34 // should have gone with GetValue perhaps
35 auto child_data = UnifiedVectorFormat::GetData<T>(child_format);
36 for (idx_t i = 0; i < count; i++) {
37 auto list_index = list_data.sel->get_index(idx: i);
38 auto offsets_index = offsets_data.sel->get_index(idx: i);
39 if (!list_data.validity.RowIsValid(row_idx: list_index)) {
40 result_mask.SetInvalid(i);
41 continue;
42 }
43 if (!offsets_data.validity.RowIsValid(row_idx: offsets_index)) {
44 result_mask.SetInvalid(i);
45 continue;
46 }
47 auto list_entry = (UnifiedVectorFormat::GetData<list_entry_t>(format: list_data))[list_index];
48 auto offsets_entry = (UnifiedVectorFormat::GetData<int64_t>(format: offsets_data))[offsets_index];
49
50 // 1-based indexing
51 if (offsets_entry == 0) {
52 result_mask.SetInvalid(i);
53 continue;
54 }
55 offsets_entry = (offsets_entry > 0) ? offsets_entry - 1 : offsets_entry;
56
57 idx_t child_offset;
58 if (offsets_entry < 0) {
59 if (offsets_entry < -int64_t(list_entry.length)) {
60 result_mask.SetInvalid(i);
61 continue;
62 }
63 child_offset = list_entry.offset + list_entry.length + offsets_entry;
64 } else {
65 if ((idx_t)offsets_entry >= list_entry.length) {
66 result_mask.SetInvalid(i);
67 continue;
68 }
69 child_offset = list_entry.offset + offsets_entry;
70 }
71 auto child_index = child_format.sel->get_index(idx: child_offset);
72 if (child_format.validity.RowIsValid(row_idx: child_index)) {
73 if (!VALIDITY_ONLY) {
74 result_data[i] = child_data[child_index];
75 }
76 } else {
77 result_mask.SetInvalid(i);
78 }
79 }
80 if (count == 1) {
81 result.SetVectorType(VectorType::CONSTANT_VECTOR);
82 }
83}
84static void ExecuteListExtractInternal(const idx_t count, UnifiedVectorFormat &list, UnifiedVectorFormat &offsets,
85 Vector &child_vector, idx_t list_size, Vector &result) {
86 D_ASSERT(child_vector.GetType() == result.GetType());
87 switch (result.GetType().InternalType()) {
88 case PhysicalType::BOOL:
89 case PhysicalType::INT8:
90 ListExtractTemplate<int8_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
91 break;
92 case PhysicalType::INT16:
93 ListExtractTemplate<int16_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
94 break;
95 case PhysicalType::INT32:
96 ListExtractTemplate<int32_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
97 break;
98 case PhysicalType::INT64:
99 ListExtractTemplate<int64_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
100 break;
101 case PhysicalType::INT128:
102 ListExtractTemplate<hugeint_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
103 break;
104 case PhysicalType::UINT8:
105 ListExtractTemplate<uint8_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
106 break;
107 case PhysicalType::UINT16:
108 ListExtractTemplate<uint16_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
109 break;
110 case PhysicalType::UINT32:
111 ListExtractTemplate<uint32_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
112 break;
113 case PhysicalType::UINT64:
114 ListExtractTemplate<uint64_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
115 break;
116 case PhysicalType::FLOAT:
117 ListExtractTemplate<float>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
118 break;
119 case PhysicalType::DOUBLE:
120 ListExtractTemplate<double>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
121 break;
122 case PhysicalType::VARCHAR:
123 ListExtractTemplate<string_t, true>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
124 break;
125 case PhysicalType::INTERVAL:
126 ListExtractTemplate<interval_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
127 break;
128 case PhysicalType::STRUCT: {
129 auto &entries = StructVector::GetEntries(vector&: child_vector);
130 auto &result_entries = StructVector::GetEntries(vector&: result);
131 D_ASSERT(entries.size() == result_entries.size());
132 // extract the child entries of the struct
133 for (idx_t i = 0; i < entries.size(); i++) {
134 ExecuteListExtractInternal(count, list, offsets, child_vector&: *entries[i], list_size, result&: *result_entries[i]);
135 }
136 // extract the validity mask
137 ListExtractTemplate<bool, false, true>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
138 break;
139 }
140 case PhysicalType::LIST: {
141 // nested list: we have to reference the child
142 auto &child_child_list = ListVector::GetEntry(vector&: child_vector);
143
144 ListVector::GetEntry(vector&: result).Reference(other&: child_child_list);
145 ListVector::SetListSize(vec&: result, size: ListVector::GetListSize(vector: child_vector));
146 ListExtractTemplate<list_entry_t>(count, list_data&: list, offsets_data&: offsets, child_vector, list_size, result);
147 break;
148 }
149 default:
150 throw NotImplementedException("Unimplemented type for LIST_EXTRACT");
151 }
152}
153
154static void ExecuteListExtract(Vector &result, Vector &list, Vector &offsets, const idx_t count) {
155 D_ASSERT(list.GetType().id() == LogicalTypeId::LIST);
156 UnifiedVectorFormat list_data;
157 UnifiedVectorFormat offsets_data;
158
159 list.ToUnifiedFormat(count, data&: list_data);
160 offsets.ToUnifiedFormat(count, data&: offsets_data);
161 ExecuteListExtractInternal(count, list&: list_data, offsets&: offsets_data, child_vector&: ListVector::GetEntry(vector&: list),
162 list_size: ListVector::GetListSize(vector: list), result);
163 result.Verify(count);
164}
165
166static void ExecuteStringExtract(Vector &result, Vector &input_vector, Vector &subscript_vector, const idx_t count) {
167 BinaryExecutor::Execute<string_t, int64_t, string_t>(
168 left&: input_vector, right&: subscript_vector, result, count, fun: [&](string_t input_string, int64_t subscript) {
169 return SubstringFun::SubstringUnicode(result, input: input_string, offset: subscript, length: 1);
170 });
171}
172
173static void ListExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) {
174 D_ASSERT(args.ColumnCount() == 2);
175 auto count = args.size();
176
177 result.SetVectorType(VectorType::CONSTANT_VECTOR);
178 for (idx_t i = 0; i < args.ColumnCount(); i++) {
179 if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) {
180 result.SetVectorType(VectorType::FLAT_VECTOR);
181 }
182 }
183
184 Vector &base = args.data[0];
185 Vector &subscript = args.data[1];
186
187 switch (base.GetType().id()) {
188 case LogicalTypeId::LIST:
189 ExecuteListExtract(result, list&: base, offsets&: subscript, count);
190 break;
191 case LogicalTypeId::VARCHAR:
192 ExecuteStringExtract(result, input_vector&: base, subscript_vector&: subscript, count);
193 break;
194 case LogicalTypeId::SQLNULL:
195 result.SetVectorType(VectorType::CONSTANT_VECTOR);
196 ConstantVector::SetNull(vector&: result, is_null: true);
197 break;
198 default:
199 throw NotImplementedException("Specifier type not implemented");
200 }
201}
202
203static unique_ptr<FunctionData> ListExtractBind(ClientContext &context, ScalarFunction &bound_function,
204 vector<unique_ptr<Expression>> &arguments) {
205 D_ASSERT(bound_function.arguments.size() == 2);
206 D_ASSERT(LogicalTypeId::LIST == arguments[0]->return_type.id());
207 // list extract returns the child type of the list as return type
208 bound_function.return_type = ListType::GetChildType(type: arguments[0]->return_type);
209 return make_uniq<VariableReturnBindData>(args&: bound_function.return_type);
210}
211
212static unique_ptr<BaseStatistics> ListExtractStats(ClientContext &context, FunctionStatisticsInput &input) {
213 auto &child_stats = input.child_stats;
214 auto &list_child_stats = ListStats::GetChildStats(stats&: child_stats[0]);
215 auto child_copy = list_child_stats.Copy();
216 // list_extract always pushes a NULL, since if the offset is out of range for a list it inserts a null
217 child_copy.Set(StatsInfo::CAN_HAVE_NULL_VALUES);
218 return child_copy.ToUnique();
219}
220
221void ListExtractFun::RegisterFunction(BuiltinFunctions &set) {
222 // the arguments and return types are actually set in the binder function
223 ScalarFunction lfun({LogicalType::LIST(child: LogicalType::ANY), LogicalType::BIGINT}, LogicalType::ANY,
224 ListExtractFunction, ListExtractBind, nullptr, ListExtractStats);
225
226 ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction);
227
228 ScalarFunctionSet list_extract("list_extract");
229 list_extract.AddFunction(function: lfun);
230 list_extract.AddFunction(function: sfun);
231 set.AddFunction(set: list_extract);
232
233 ScalarFunctionSet list_element("list_element");
234 list_element.AddFunction(function: lfun);
235 list_element.AddFunction(function: sfun);
236 set.AddFunction(set: list_element);
237
238 ScalarFunctionSet array_extract("array_extract");
239 array_extract.AddFunction(function: lfun);
240 array_extract.AddFunction(function: sfun);
241 array_extract.AddFunction(function: StructExtractFun::GetFunction());
242 set.AddFunction(set: array_extract);
243}
244
245} // namespace duckdb
246