1#include "duckdb/common/arrow/arrow.hpp"
2
3#include "duckdb.hpp"
4#include "duckdb/common/arrow/arrow_wrapper.hpp"
5#include "duckdb/common/limits.hpp"
6#include "duckdb/common/to_string.hpp"
7#include "duckdb/common/types/date.hpp"
8#include "duckdb/common/types/vector_buffer.hpp"
9#include "duckdb/function/table/arrow.hpp"
10#include "duckdb/function/table_function.hpp"
11#include "duckdb/parser/parsed_data/create_table_function_info.hpp"
12#include "utf8proc_wrapper.hpp"
13
14namespace duckdb {
15
16LogicalType ArrowTableFunction::GetArrowLogicalType(
17 ArrowSchema &schema, std::unordered_map<idx_t, unique_ptr<ArrowConvertData>> &arrow_convert_data, idx_t col_idx) {
18 auto format = string(schema.format);
19 if (arrow_convert_data.find(x: col_idx) == arrow_convert_data.end()) {
20 arrow_convert_data[col_idx] = make_uniq<ArrowConvertData>();
21 }
22 auto &convert_data = *arrow_convert_data[col_idx];
23 if (format == "n") {
24 return LogicalType::SQLNULL;
25 } else if (format == "b") {
26 return LogicalType::BOOLEAN;
27 } else if (format == "c") {
28 return LogicalType::TINYINT;
29 } else if (format == "s") {
30 return LogicalType::SMALLINT;
31 } else if (format == "i") {
32 return LogicalType::INTEGER;
33 } else if (format == "l") {
34 return LogicalType::BIGINT;
35 } else if (format == "C") {
36 return LogicalType::UTINYINT;
37 } else if (format == "S") {
38 return LogicalType::USMALLINT;
39 } else if (format == "I") {
40 return LogicalType::UINTEGER;
41 } else if (format == "L") {
42 return LogicalType::UBIGINT;
43 } else if (format == "f") {
44 return LogicalType::FLOAT;
45 } else if (format == "g") {
46 return LogicalType::DOUBLE;
47 } else if (format[0] == 'd') { //! this can be either decimal128 or decimal 256 (e.g., d:38,0)
48 std::string parameters = format.substr(pos: format.find(c: ':'));
49 uint8_t width = std::stoi(str: parameters.substr(pos: 1, n: parameters.find(c: ',')));
50 uint8_t scale = std::stoi(str: parameters.substr(pos: parameters.find(c: ',') + 1));
51 if (width > 38) {
52 throw NotImplementedException("Unsupported Internal Arrow Type for Decimal %s", format);
53 }
54 return LogicalType::DECIMAL(width, scale);
55 } else if (format == "u") {
56 convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0);
57 return LogicalType::VARCHAR;
58 } else if (format == "U") {
59 convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::SUPER_SIZE, args: 0);
60 return LogicalType::VARCHAR;
61 } else if (format == "tsn:") {
62 return LogicalTypeId::TIMESTAMP_NS;
63 } else if (format == "tsu:") {
64 return LogicalTypeId::TIMESTAMP;
65 } else if (format == "tsm:") {
66 return LogicalTypeId::TIMESTAMP_MS;
67 } else if (format == "tss:") {
68 return LogicalTypeId::TIMESTAMP_SEC;
69 } else if (format == "tdD") {
70 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::DAYS);
71 return LogicalType::DATE;
72 } else if (format == "tdm") {
73 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS);
74 return LogicalType::DATE;
75 } else if (format == "tts") {
76 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::SECONDS);
77 return LogicalType::TIME;
78 } else if (format == "ttm") {
79 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS);
80 return LogicalType::TIME;
81 } else if (format == "ttu") {
82 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MICROSECONDS);
83 return LogicalType::TIME;
84 } else if (format == "ttn") {
85 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::NANOSECONDS);
86 return LogicalType::TIME;
87 } else if (format == "tDs") {
88 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::SECONDS);
89 return LogicalType::INTERVAL;
90 } else if (format == "tDm") {
91 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS);
92 return LogicalType::INTERVAL;
93 } else if (format == "tDu") {
94 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MICROSECONDS);
95 return LogicalType::INTERVAL;
96 } else if (format == "tDn") {
97 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::NANOSECONDS);
98 return LogicalType::INTERVAL;
99 } else if (format == "tiD") {
100 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::DAYS);
101 return LogicalType::INTERVAL;
102 } else if (format == "tiM") {
103 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MONTHS);
104 return LogicalType::INTERVAL;
105 } else if (format == "tin") {
106 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MONTH_DAY_NANO);
107 return LogicalType::INTERVAL;
108 } else if (format == "+l") {
109 convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0);
110 auto child_type = GetArrowLogicalType(schema&: *schema.children[0], arrow_convert_data, col_idx);
111 return LogicalType::LIST(child: child_type);
112 } else if (format == "+L") {
113 convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::SUPER_SIZE, args: 0);
114 auto child_type = GetArrowLogicalType(schema&: *schema.children[0], arrow_convert_data, col_idx);
115 return LogicalType::LIST(child: child_type);
116 } else if (format[0] == '+' && format[1] == 'w') {
117 std::string parameters = format.substr(pos: format.find(c: ':') + 1);
118 idx_t fixed_size = std::stoi(str: parameters);
119 convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::FIXED_SIZE, args&: fixed_size);
120 auto child_type = GetArrowLogicalType(schema&: *schema.children[0], arrow_convert_data, col_idx);
121 return LogicalType::LIST(child: child_type);
122 } else if (format == "+s") {
123 child_list_t<LogicalType> child_types;
124 for (idx_t type_idx = 0; type_idx < (idx_t)schema.n_children; type_idx++) {
125 auto child_type = GetArrowLogicalType(schema&: *schema.children[type_idx], arrow_convert_data, col_idx);
126 child_types.push_back(x: {schema.children[type_idx]->name, child_type});
127 }
128 return LogicalType::STRUCT(children: child_types);
129
130 } else if (format == "+m") {
131 convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0);
132
133 auto &arrow_struct_type = *schema.children[0];
134 D_ASSERT(arrow_struct_type.n_children == 2);
135 auto key_type = GetArrowLogicalType(schema&: *arrow_struct_type.children[0], arrow_convert_data, col_idx);
136 auto value_type = GetArrowLogicalType(schema&: *arrow_struct_type.children[1], arrow_convert_data, col_idx);
137 return LogicalType::MAP(key: key_type, value: value_type);
138 } else if (format == "z") {
139 convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0);
140 return LogicalType::BLOB;
141 } else if (format == "Z") {
142 convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::SUPER_SIZE, args: 0);
143 return LogicalType::BLOB;
144 } else if (format[0] == 'w') {
145 std::string parameters = format.substr(pos: format.find(c: ':') + 1);
146 idx_t fixed_size = std::stoi(str: parameters);
147 convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::FIXED_SIZE, args&: fixed_size);
148 return LogicalType::BLOB;
149 } else if (format[0] == 't' && format[1] == 's') {
150 // Timestamp with Timezone
151 if (format[2] == 'n') {
152 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::NANOSECONDS);
153 } else if (format[2] == 'u') {
154 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MICROSECONDS);
155 } else if (format[2] == 'm') {
156 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS);
157 } else if (format[2] == 's') {
158 convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::SECONDS);
159 } else {
160 throw NotImplementedException(" Timestamptz precision of not accepted");
161 }
162 // TODO right now we just get the UTC value. We probably want to support this properly in the future
163 return LogicalType::TIMESTAMP_TZ;
164 } else {
165 throw NotImplementedException("Unsupported Internal Arrow Type %s", format);
166 }
167}
168
169void ArrowTableFunction::RenameArrowColumns(vector<string> &names) {
170 unordered_map<string, idx_t> name_map;
171 for (auto &column_name : names) {
172 // put it all lower_case
173 auto low_column_name = StringUtil::Lower(str: column_name);
174 if (name_map.find(x: low_column_name) == name_map.end()) {
175 // Name does not exist yet
176 name_map[low_column_name]++;
177 } else {
178 // Name already exists, we add _x where x is the repetition number
179 string new_column_name = column_name + "_" + std::to_string(val: name_map[low_column_name]);
180 auto new_column_name_low = StringUtil::Lower(str: new_column_name);
181 while (name_map.find(x: new_column_name_low) != name_map.end()) {
182 // This name is already here due to a previous definition
183 name_map[low_column_name]++;
184 new_column_name = column_name + "_" + std::to_string(val: name_map[low_column_name]);
185 new_column_name_low = StringUtil::Lower(str: new_column_name);
186 }
187 column_name = new_column_name;
188 name_map[new_column_name_low]++;
189 }
190 }
191}
192
193unique_ptr<FunctionData> ArrowTableFunction::ArrowScanBind(ClientContext &context, TableFunctionBindInput &input,
194 vector<LogicalType> &return_types, vector<string> &names) {
195 auto stream_factory_ptr = input.inputs[0].GetPointer();
196 auto stream_factory_produce = (stream_factory_produce_t)input.inputs[1].GetPointer(); // NOLINT
197 auto stream_factory_get_schema = (stream_factory_get_schema_t)input.inputs[2].GetPointer(); // NOLINT
198
199 auto res = make_uniq<ArrowScanFunctionData>(args&: stream_factory_produce, args&: stream_factory_ptr);
200
201 auto &data = *res;
202 stream_factory_get_schema(stream_factory_ptr, data.schema_root);
203 for (idx_t col_idx = 0; col_idx < (idx_t)data.schema_root.arrow_schema.n_children; col_idx++) {
204 auto &schema = *data.schema_root.arrow_schema.children[col_idx];
205 if (!schema.release) {
206 throw InvalidInputException("arrow_scan: released schema passed");
207 }
208 if (schema.dictionary) {
209 auto logical_type = GetArrowLogicalType(schema, arrow_convert_data&: res->arrow_convert_data, col_idx);
210 res->arrow_convert_data[col_idx] = make_uniq<ArrowConvertData>(args: std::move(logical_type));
211 return_types.emplace_back(args: GetArrowLogicalType(schema&: *schema.dictionary, arrow_convert_data&: res->arrow_convert_data, col_idx));
212 } else {
213 return_types.emplace_back(args: GetArrowLogicalType(schema, arrow_convert_data&: res->arrow_convert_data, col_idx));
214 }
215 auto format = string(schema.format);
216 auto name = string(schema.name);
217 if (name.empty()) {
218 name = string("v") + to_string(val: col_idx);
219 }
220 names.push_back(x: name);
221 }
222 RenameArrowColumns(names);
223 res->all_types = return_types;
224 return std::move(res);
225}
226
227unique_ptr<ArrowArrayStreamWrapper> ProduceArrowScan(const ArrowScanFunctionData &function,
228 const vector<column_t> &column_ids, TableFilterSet *filters) {
229 //! Generate Projection Pushdown Vector
230 ArrowStreamParameters parameters;
231 D_ASSERT(!column_ids.empty());
232 for (idx_t idx = 0; idx < column_ids.size(); idx++) {
233 auto col_idx = column_ids[idx];
234 if (col_idx != COLUMN_IDENTIFIER_ROW_ID) {
235 auto &schema = *function.schema_root.arrow_schema.children[col_idx];
236 parameters.projected_columns.projection_map[idx] = schema.name;
237 parameters.projected_columns.columns.emplace_back(args&: schema.name);
238 }
239 }
240 parameters.filters = filters;
241 return function.scanner_producer(function.stream_factory_ptr, parameters);
242}
243
244idx_t ArrowTableFunction::ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data_p) {
245 return context.db->NumberOfThreads();
246}
247
248bool ArrowTableFunction::ArrowScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p,
249 ArrowScanLocalState &state, ArrowScanGlobalState &parallel_state) {
250 lock_guard<mutex> parallel_lock(parallel_state.main_mutex);
251 if (parallel_state.done) {
252 return false;
253 }
254 state.chunk_offset = 0;
255 state.batch_index = ++parallel_state.batch_index;
256
257 auto current_chunk = parallel_state.stream->GetNextChunk();
258 while (current_chunk->arrow_array.length == 0 && current_chunk->arrow_array.release) {
259 current_chunk = parallel_state.stream->GetNextChunk();
260 }
261 state.chunk = std::move(current_chunk);
262 //! have we run out of chunks? we are done
263 if (!state.chunk->arrow_array.release) {
264 parallel_state.done = true;
265 return false;
266 }
267 return true;
268}
269
270unique_ptr<GlobalTableFunctionState> ArrowTableFunction::ArrowScanInitGlobal(ClientContext &context,
271 TableFunctionInitInput &input) {
272 auto &bind_data = input.bind_data->Cast<ArrowScanFunctionData>();
273 auto result = make_uniq<ArrowScanGlobalState>();
274 result->stream = ProduceArrowScan(function: bind_data, column_ids: input.column_ids, filters: input.filters.get());
275 result->max_threads = ArrowScanMaxThreads(context, bind_data_p: input.bind_data.get());
276 if (input.CanRemoveFilterColumns()) {
277 result->projection_ids = input.projection_ids;
278 for (const auto &col_idx : input.column_ids) {
279 if (col_idx == COLUMN_IDENTIFIER_ROW_ID) {
280 result->scanned_types.emplace_back(args: LogicalType::ROW_TYPE);
281 } else {
282 result->scanned_types.push_back(x: bind_data.all_types[col_idx]);
283 }
284 }
285 }
286 return std::move(result);
287}
288
289unique_ptr<LocalTableFunctionState>
290ArrowTableFunction::ArrowScanInitLocalInternal(ClientContext &context, TableFunctionInitInput &input,
291 GlobalTableFunctionState *global_state_p) {
292 auto &global_state = global_state_p->Cast<ArrowScanGlobalState>();
293 auto current_chunk = make_uniq<ArrowArrayWrapper>();
294 auto result = make_uniq<ArrowScanLocalState>(args: std::move(current_chunk));
295 result->column_ids = input.column_ids;
296 result->filters = input.filters.get();
297 if (input.CanRemoveFilterColumns()) {
298 auto &asgs = global_state_p->Cast<ArrowScanGlobalState>();
299 result->all_columns.Initialize(context, types: asgs.scanned_types);
300 }
301 if (!ArrowScanParallelStateNext(context, bind_data_p: input.bind_data.get(), state&: *result, parallel_state&: global_state)) {
302 return nullptr;
303 }
304 return std::move(result);
305}
306
307unique_ptr<LocalTableFunctionState> ArrowTableFunction::ArrowScanInitLocal(ExecutionContext &context,
308 TableFunctionInitInput &input,
309 GlobalTableFunctionState *global_state_p) {
310 return ArrowScanInitLocalInternal(context&: context.client, input, global_state_p);
311}
312
313void ArrowTableFunction::ArrowScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
314 if (!data_p.local_state) {
315 return;
316 }
317 auto &data = data_p.bind_data->CastNoConst<ArrowScanFunctionData>(); // FIXME
318 auto &state = data_p.local_state->Cast<ArrowScanLocalState>();
319 auto &global_state = data_p.global_state->Cast<ArrowScanGlobalState>();
320
321 //! Out of tuples in this chunk
322 if (state.chunk_offset >= (idx_t)state.chunk->arrow_array.length) {
323 if (!ArrowScanParallelStateNext(context, bind_data_p: data_p.bind_data.get(), state, parallel_state&: global_state)) {
324 return;
325 }
326 }
327 int64_t output_size = MinValue<int64_t>(STANDARD_VECTOR_SIZE, b: state.chunk->arrow_array.length - state.chunk_offset);
328 data.lines_read += output_size;
329 if (global_state.CanRemoveFilterColumns()) {
330 state.all_columns.Reset();
331 state.all_columns.SetCardinality(output_size);
332 ArrowToDuckDB(scan_state&: state, arrow_convert_data&: data.arrow_convert_data, output&: state.all_columns, start: data.lines_read - output_size);
333 output.ReferenceColumns(other&: state.all_columns, column_ids: global_state.projection_ids);
334 } else {
335 output.SetCardinality(output_size);
336 ArrowToDuckDB(scan_state&: state, arrow_convert_data&: data.arrow_convert_data, output, start: data.lines_read - output_size);
337 }
338
339 output.Verify();
340 state.chunk_offset += output.size();
341}
342
343unique_ptr<NodeStatistics> ArrowTableFunction::ArrowScanCardinality(ClientContext &context, const FunctionData *data) {
344 return make_uniq<NodeStatistics>();
345}
346
347idx_t ArrowTableFunction::ArrowGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p,
348 LocalTableFunctionState *local_state,
349 GlobalTableFunctionState *global_state) {
350 auto &state = local_state->Cast<ArrowScanLocalState>();
351 return state.batch_index;
352}
353
354void ArrowTableFunction::RegisterFunction(BuiltinFunctions &set) {
355 TableFunction arrow("arrow_scan", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER},
356 ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal);
357 arrow.cardinality = ArrowScanCardinality;
358 arrow.get_batch_index = ArrowGetBatchIndex;
359 arrow.projection_pushdown = true;
360 arrow.filter_pushdown = true;
361 arrow.filter_prune = true;
362 set.AddFunction(function: arrow);
363
364 TableFunction arrow_dumb("arrow_scan_dumb", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER},
365 ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal);
366 arrow_dumb.cardinality = ArrowScanCardinality;
367 arrow_dumb.get_batch_index = ArrowGetBatchIndex;
368 arrow_dumb.projection_pushdown = false;
369 arrow_dumb.filter_pushdown = false;
370 arrow_dumb.filter_prune = false;
371 set.AddFunction(function: arrow_dumb);
372}
373
374void BuiltinFunctions::RegisterArrowFunctions() {
375 ArrowTableFunction::RegisterFunction(set&: *this);
376}
377} // namespace duckdb
378