1 | #include "duckdb/common/arrow/arrow.hpp" |
2 | |
3 | #include "duckdb.hpp" |
4 | #include "duckdb/common/arrow/arrow_wrapper.hpp" |
5 | #include "duckdb/common/limits.hpp" |
6 | #include "duckdb/common/to_string.hpp" |
7 | #include "duckdb/common/types/date.hpp" |
8 | #include "duckdb/common/types/vector_buffer.hpp" |
9 | #include "duckdb/function/table/arrow.hpp" |
10 | #include "duckdb/function/table_function.hpp" |
11 | #include "duckdb/parser/parsed_data/create_table_function_info.hpp" |
12 | #include "utf8proc_wrapper.hpp" |
13 | |
14 | namespace duckdb { |
15 | |
16 | LogicalType ArrowTableFunction::GetArrowLogicalType( |
17 | ArrowSchema &schema, std::unordered_map<idx_t, unique_ptr<ArrowConvertData>> &arrow_convert_data, idx_t col_idx) { |
18 | auto format = string(schema.format); |
19 | if (arrow_convert_data.find(x: col_idx) == arrow_convert_data.end()) { |
20 | arrow_convert_data[col_idx] = make_uniq<ArrowConvertData>(); |
21 | } |
22 | auto &convert_data = *arrow_convert_data[col_idx]; |
23 | if (format == "n" ) { |
24 | return LogicalType::SQLNULL; |
25 | } else if (format == "b" ) { |
26 | return LogicalType::BOOLEAN; |
27 | } else if (format == "c" ) { |
28 | return LogicalType::TINYINT; |
29 | } else if (format == "s" ) { |
30 | return LogicalType::SMALLINT; |
31 | } else if (format == "i" ) { |
32 | return LogicalType::INTEGER; |
33 | } else if (format == "l" ) { |
34 | return LogicalType::BIGINT; |
35 | } else if (format == "C" ) { |
36 | return LogicalType::UTINYINT; |
37 | } else if (format == "S" ) { |
38 | return LogicalType::USMALLINT; |
39 | } else if (format == "I" ) { |
40 | return LogicalType::UINTEGER; |
41 | } else if (format == "L" ) { |
42 | return LogicalType::UBIGINT; |
43 | } else if (format == "f" ) { |
44 | return LogicalType::FLOAT; |
45 | } else if (format == "g" ) { |
46 | return LogicalType::DOUBLE; |
47 | } else if (format[0] == 'd') { //! this can be either decimal128 or decimal 256 (e.g., d:38,0) |
48 | std::string parameters = format.substr(pos: format.find(c: ':')); |
49 | uint8_t width = std::stoi(str: parameters.substr(pos: 1, n: parameters.find(c: ','))); |
50 | uint8_t scale = std::stoi(str: parameters.substr(pos: parameters.find(c: ',') + 1)); |
51 | if (width > 38) { |
52 | throw NotImplementedException("Unsupported Internal Arrow Type for Decimal %s" , format); |
53 | } |
54 | return LogicalType::DECIMAL(width, scale); |
55 | } else if (format == "u" ) { |
56 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0); |
57 | return LogicalType::VARCHAR; |
58 | } else if (format == "U" ) { |
59 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::SUPER_SIZE, args: 0); |
60 | return LogicalType::VARCHAR; |
61 | } else if (format == "tsn:" ) { |
62 | return LogicalTypeId::TIMESTAMP_NS; |
63 | } else if (format == "tsu:" ) { |
64 | return LogicalTypeId::TIMESTAMP; |
65 | } else if (format == "tsm:" ) { |
66 | return LogicalTypeId::TIMESTAMP_MS; |
67 | } else if (format == "tss:" ) { |
68 | return LogicalTypeId::TIMESTAMP_SEC; |
69 | } else if (format == "tdD" ) { |
70 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::DAYS); |
71 | return LogicalType::DATE; |
72 | } else if (format == "tdm" ) { |
73 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS); |
74 | return LogicalType::DATE; |
75 | } else if (format == "tts" ) { |
76 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::SECONDS); |
77 | return LogicalType::TIME; |
78 | } else if (format == "ttm" ) { |
79 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS); |
80 | return LogicalType::TIME; |
81 | } else if (format == "ttu" ) { |
82 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MICROSECONDS); |
83 | return LogicalType::TIME; |
84 | } else if (format == "ttn" ) { |
85 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::NANOSECONDS); |
86 | return LogicalType::TIME; |
87 | } else if (format == "tDs" ) { |
88 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::SECONDS); |
89 | return LogicalType::INTERVAL; |
90 | } else if (format == "tDm" ) { |
91 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS); |
92 | return LogicalType::INTERVAL; |
93 | } else if (format == "tDu" ) { |
94 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MICROSECONDS); |
95 | return LogicalType::INTERVAL; |
96 | } else if (format == "tDn" ) { |
97 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::NANOSECONDS); |
98 | return LogicalType::INTERVAL; |
99 | } else if (format == "tiD" ) { |
100 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::DAYS); |
101 | return LogicalType::INTERVAL; |
102 | } else if (format == "tiM" ) { |
103 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MONTHS); |
104 | return LogicalType::INTERVAL; |
105 | } else if (format == "tin" ) { |
106 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MONTH_DAY_NANO); |
107 | return LogicalType::INTERVAL; |
108 | } else if (format == "+l" ) { |
109 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0); |
110 | auto child_type = GetArrowLogicalType(schema&: *schema.children[0], arrow_convert_data, col_idx); |
111 | return LogicalType::LIST(child: child_type); |
112 | } else if (format == "+L" ) { |
113 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::SUPER_SIZE, args: 0); |
114 | auto child_type = GetArrowLogicalType(schema&: *schema.children[0], arrow_convert_data, col_idx); |
115 | return LogicalType::LIST(child: child_type); |
116 | } else if (format[0] == '+' && format[1] == 'w') { |
117 | std::string parameters = format.substr(pos: format.find(c: ':') + 1); |
118 | idx_t fixed_size = std::stoi(str: parameters); |
119 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::FIXED_SIZE, args&: fixed_size); |
120 | auto child_type = GetArrowLogicalType(schema&: *schema.children[0], arrow_convert_data, col_idx); |
121 | return LogicalType::LIST(child: child_type); |
122 | } else if (format == "+s" ) { |
123 | child_list_t<LogicalType> child_types; |
124 | for (idx_t type_idx = 0; type_idx < (idx_t)schema.n_children; type_idx++) { |
125 | auto child_type = GetArrowLogicalType(schema&: *schema.children[type_idx], arrow_convert_data, col_idx); |
126 | child_types.push_back(x: {schema.children[type_idx]->name, child_type}); |
127 | } |
128 | return LogicalType::STRUCT(children: child_types); |
129 | |
130 | } else if (format == "+m" ) { |
131 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0); |
132 | |
133 | auto &arrow_struct_type = *schema.children[0]; |
134 | D_ASSERT(arrow_struct_type.n_children == 2); |
135 | auto key_type = GetArrowLogicalType(schema&: *arrow_struct_type.children[0], arrow_convert_data, col_idx); |
136 | auto value_type = GetArrowLogicalType(schema&: *arrow_struct_type.children[1], arrow_convert_data, col_idx); |
137 | return LogicalType::MAP(key: key_type, value: value_type); |
138 | } else if (format == "z" ) { |
139 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0); |
140 | return LogicalType::BLOB; |
141 | } else if (format == "Z" ) { |
142 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::SUPER_SIZE, args: 0); |
143 | return LogicalType::BLOB; |
144 | } else if (format[0] == 'w') { |
145 | std::string parameters = format.substr(pos: format.find(c: ':') + 1); |
146 | idx_t fixed_size = std::stoi(str: parameters); |
147 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::FIXED_SIZE, args&: fixed_size); |
148 | return LogicalType::BLOB; |
149 | } else if (format[0] == 't' && format[1] == 's') { |
150 | // Timestamp with Timezone |
151 | if (format[2] == 'n') { |
152 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::NANOSECONDS); |
153 | } else if (format[2] == 'u') { |
154 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MICROSECONDS); |
155 | } else if (format[2] == 'm') { |
156 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS); |
157 | } else if (format[2] == 's') { |
158 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::SECONDS); |
159 | } else { |
160 | throw NotImplementedException(" Timestamptz precision of not accepted" ); |
161 | } |
162 | // TODO right now we just get the UTC value. We probably want to support this properly in the future |
163 | return LogicalType::TIMESTAMP_TZ; |
164 | } else { |
165 | throw NotImplementedException("Unsupported Internal Arrow Type %s" , format); |
166 | } |
167 | } |
168 | |
169 | void ArrowTableFunction::RenameArrowColumns(vector<string> &names) { |
170 | unordered_map<string, idx_t> name_map; |
171 | for (auto &column_name : names) { |
172 | // put it all lower_case |
173 | auto low_column_name = StringUtil::Lower(str: column_name); |
174 | if (name_map.find(x: low_column_name) == name_map.end()) { |
175 | // Name does not exist yet |
176 | name_map[low_column_name]++; |
177 | } else { |
178 | // Name already exists, we add _x where x is the repetition number |
179 | string new_column_name = column_name + "_" + std::to_string(val: name_map[low_column_name]); |
180 | auto new_column_name_low = StringUtil::Lower(str: new_column_name); |
181 | while (name_map.find(x: new_column_name_low) != name_map.end()) { |
182 | // This name is already here due to a previous definition |
183 | name_map[low_column_name]++; |
184 | new_column_name = column_name + "_" + std::to_string(val: name_map[low_column_name]); |
185 | new_column_name_low = StringUtil::Lower(str: new_column_name); |
186 | } |
187 | column_name = new_column_name; |
188 | name_map[new_column_name_low]++; |
189 | } |
190 | } |
191 | } |
192 | |
193 | unique_ptr<FunctionData> ArrowTableFunction::ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, |
194 | vector<LogicalType> &return_types, vector<string> &names) { |
195 | auto stream_factory_ptr = input.inputs[0].GetPointer(); |
196 | auto stream_factory_produce = (stream_factory_produce_t)input.inputs[1].GetPointer(); // NOLINT |
197 | auto stream_factory_get_schema = (stream_factory_get_schema_t)input.inputs[2].GetPointer(); // NOLINT |
198 | |
199 | auto res = make_uniq<ArrowScanFunctionData>(args&: stream_factory_produce, args&: stream_factory_ptr); |
200 | |
201 | auto &data = *res; |
202 | stream_factory_get_schema(stream_factory_ptr, data.schema_root); |
203 | for (idx_t col_idx = 0; col_idx < (idx_t)data.schema_root.arrow_schema.n_children; col_idx++) { |
204 | auto &schema = *data.schema_root.arrow_schema.children[col_idx]; |
205 | if (!schema.release) { |
206 | throw InvalidInputException("arrow_scan: released schema passed" ); |
207 | } |
208 | if (schema.dictionary) { |
209 | auto logical_type = GetArrowLogicalType(schema, arrow_convert_data&: res->arrow_convert_data, col_idx); |
210 | res->arrow_convert_data[col_idx] = make_uniq<ArrowConvertData>(args: std::move(logical_type)); |
211 | return_types.emplace_back(args: GetArrowLogicalType(schema&: *schema.dictionary, arrow_convert_data&: res->arrow_convert_data, col_idx)); |
212 | } else { |
213 | return_types.emplace_back(args: GetArrowLogicalType(schema, arrow_convert_data&: res->arrow_convert_data, col_idx)); |
214 | } |
215 | auto format = string(schema.format); |
216 | auto name = string(schema.name); |
217 | if (name.empty()) { |
218 | name = string("v" ) + to_string(val: col_idx); |
219 | } |
220 | names.push_back(x: name); |
221 | } |
222 | RenameArrowColumns(names); |
223 | res->all_types = return_types; |
224 | return std::move(res); |
225 | } |
226 | |
227 | unique_ptr<ArrowArrayStreamWrapper> ProduceArrowScan(const ArrowScanFunctionData &function, |
228 | const vector<column_t> &column_ids, TableFilterSet *filters) { |
229 | //! Generate Projection Pushdown Vector |
230 | ArrowStreamParameters parameters; |
231 | D_ASSERT(!column_ids.empty()); |
232 | for (idx_t idx = 0; idx < column_ids.size(); idx++) { |
233 | auto col_idx = column_ids[idx]; |
234 | if (col_idx != COLUMN_IDENTIFIER_ROW_ID) { |
235 | auto &schema = *function.schema_root.arrow_schema.children[col_idx]; |
236 | parameters.projected_columns.projection_map[idx] = schema.name; |
237 | parameters.projected_columns.columns.emplace_back(args&: schema.name); |
238 | } |
239 | } |
240 | parameters.filters = filters; |
241 | return function.scanner_producer(function.stream_factory_ptr, parameters); |
242 | } |
243 | |
244 | idx_t ArrowTableFunction::ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data_p) { |
245 | return context.db->NumberOfThreads(); |
246 | } |
247 | |
248 | bool ArrowTableFunction::ArrowScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, |
249 | ArrowScanLocalState &state, ArrowScanGlobalState ¶llel_state) { |
250 | lock_guard<mutex> parallel_lock(parallel_state.main_mutex); |
251 | if (parallel_state.done) { |
252 | return false; |
253 | } |
254 | state.chunk_offset = 0; |
255 | state.batch_index = ++parallel_state.batch_index; |
256 | |
257 | auto current_chunk = parallel_state.stream->GetNextChunk(); |
258 | while (current_chunk->arrow_array.length == 0 && current_chunk->arrow_array.release) { |
259 | current_chunk = parallel_state.stream->GetNextChunk(); |
260 | } |
261 | state.chunk = std::move(current_chunk); |
262 | //! have we run out of chunks? we are done |
263 | if (!state.chunk->arrow_array.release) { |
264 | parallel_state.done = true; |
265 | return false; |
266 | } |
267 | return true; |
268 | } |
269 | |
270 | unique_ptr<GlobalTableFunctionState> ArrowTableFunction::ArrowScanInitGlobal(ClientContext &context, |
271 | TableFunctionInitInput &input) { |
272 | auto &bind_data = input.bind_data->Cast<ArrowScanFunctionData>(); |
273 | auto result = make_uniq<ArrowScanGlobalState>(); |
274 | result->stream = ProduceArrowScan(function: bind_data, column_ids: input.column_ids, filters: input.filters.get()); |
275 | result->max_threads = ArrowScanMaxThreads(context, bind_data_p: input.bind_data.get()); |
276 | if (input.CanRemoveFilterColumns()) { |
277 | result->projection_ids = input.projection_ids; |
278 | for (const auto &col_idx : input.column_ids) { |
279 | if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { |
280 | result->scanned_types.emplace_back(args: LogicalType::ROW_TYPE); |
281 | } else { |
282 | result->scanned_types.push_back(x: bind_data.all_types[col_idx]); |
283 | } |
284 | } |
285 | } |
286 | return std::move(result); |
287 | } |
288 | |
289 | unique_ptr<LocalTableFunctionState> |
290 | ArrowTableFunction::ArrowScanInitLocalInternal(ClientContext &context, TableFunctionInitInput &input, |
291 | GlobalTableFunctionState *global_state_p) { |
292 | auto &global_state = global_state_p->Cast<ArrowScanGlobalState>(); |
293 | auto current_chunk = make_uniq<ArrowArrayWrapper>(); |
294 | auto result = make_uniq<ArrowScanLocalState>(args: std::move(current_chunk)); |
295 | result->column_ids = input.column_ids; |
296 | result->filters = input.filters.get(); |
297 | if (input.CanRemoveFilterColumns()) { |
298 | auto &asgs = global_state_p->Cast<ArrowScanGlobalState>(); |
299 | result->all_columns.Initialize(context, types: asgs.scanned_types); |
300 | } |
301 | if (!ArrowScanParallelStateNext(context, bind_data_p: input.bind_data.get(), state&: *result, parallel_state&: global_state)) { |
302 | return nullptr; |
303 | } |
304 | return std::move(result); |
305 | } |
306 | |
307 | unique_ptr<LocalTableFunctionState> ArrowTableFunction::ArrowScanInitLocal(ExecutionContext &context, |
308 | TableFunctionInitInput &input, |
309 | GlobalTableFunctionState *global_state_p) { |
310 | return ArrowScanInitLocalInternal(context&: context.client, input, global_state_p); |
311 | } |
312 | |
313 | void ArrowTableFunction::ArrowScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { |
314 | if (!data_p.local_state) { |
315 | return; |
316 | } |
317 | auto &data = data_p.bind_data->CastNoConst<ArrowScanFunctionData>(); // FIXME |
318 | auto &state = data_p.local_state->Cast<ArrowScanLocalState>(); |
319 | auto &global_state = data_p.global_state->Cast<ArrowScanGlobalState>(); |
320 | |
321 | //! Out of tuples in this chunk |
322 | if (state.chunk_offset >= (idx_t)state.chunk->arrow_array.length) { |
323 | if (!ArrowScanParallelStateNext(context, bind_data_p: data_p.bind_data.get(), state, parallel_state&: global_state)) { |
324 | return; |
325 | } |
326 | } |
327 | int64_t output_size = MinValue<int64_t>(STANDARD_VECTOR_SIZE, b: state.chunk->arrow_array.length - state.chunk_offset); |
328 | data.lines_read += output_size; |
329 | if (global_state.CanRemoveFilterColumns()) { |
330 | state.all_columns.Reset(); |
331 | state.all_columns.SetCardinality(output_size); |
332 | ArrowToDuckDB(scan_state&: state, arrow_convert_data&: data.arrow_convert_data, output&: state.all_columns, start: data.lines_read - output_size); |
333 | output.ReferenceColumns(other&: state.all_columns, column_ids: global_state.projection_ids); |
334 | } else { |
335 | output.SetCardinality(output_size); |
336 | ArrowToDuckDB(scan_state&: state, arrow_convert_data&: data.arrow_convert_data, output, start: data.lines_read - output_size); |
337 | } |
338 | |
339 | output.Verify(); |
340 | state.chunk_offset += output.size(); |
341 | } |
342 | |
343 | unique_ptr<NodeStatistics> ArrowTableFunction::ArrowScanCardinality(ClientContext &context, const FunctionData *data) { |
344 | return make_uniq<NodeStatistics>(); |
345 | } |
346 | |
347 | idx_t ArrowTableFunction::ArrowGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, |
348 | LocalTableFunctionState *local_state, |
349 | GlobalTableFunctionState *global_state) { |
350 | auto &state = local_state->Cast<ArrowScanLocalState>(); |
351 | return state.batch_index; |
352 | } |
353 | |
354 | void ArrowTableFunction::RegisterFunction(BuiltinFunctions &set) { |
355 | TableFunction arrow("arrow_scan" , {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, |
356 | ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal); |
357 | arrow.cardinality = ArrowScanCardinality; |
358 | arrow.get_batch_index = ArrowGetBatchIndex; |
359 | arrow.projection_pushdown = true; |
360 | arrow.filter_pushdown = true; |
361 | arrow.filter_prune = true; |
362 | set.AddFunction(function: arrow); |
363 | |
364 | TableFunction arrow_dumb("arrow_scan_dumb" , {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, |
365 | ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal); |
366 | arrow_dumb.cardinality = ArrowScanCardinality; |
367 | arrow_dumb.get_batch_index = ArrowGetBatchIndex; |
368 | arrow_dumb.projection_pushdown = false; |
369 | arrow_dumb.filter_pushdown = false; |
370 | arrow_dumb.filter_prune = false; |
371 | set.AddFunction(function: arrow_dumb); |
372 | } |
373 | |
374 | void BuiltinFunctions::RegisterArrowFunctions() { |
375 | ArrowTableFunction::RegisterFunction(set&: *this); |
376 | } |
377 | } // namespace duckdb |
378 | |