| 1 | #include "duckdb/common/arrow/arrow.hpp" |
| 2 | |
| 3 | #include "duckdb.hpp" |
| 4 | #include "duckdb/common/arrow/arrow_wrapper.hpp" |
| 5 | #include "duckdb/common/limits.hpp" |
| 6 | #include "duckdb/common/to_string.hpp" |
| 7 | #include "duckdb/common/types/date.hpp" |
| 8 | #include "duckdb/common/types/vector_buffer.hpp" |
| 9 | #include "duckdb/function/table/arrow.hpp" |
| 10 | #include "duckdb/function/table_function.hpp" |
| 11 | #include "duckdb/parser/parsed_data/create_table_function_info.hpp" |
| 12 | #include "utf8proc_wrapper.hpp" |
| 13 | |
| 14 | namespace duckdb { |
| 15 | |
| 16 | LogicalType ArrowTableFunction::GetArrowLogicalType( |
| 17 | ArrowSchema &schema, std::unordered_map<idx_t, unique_ptr<ArrowConvertData>> &arrow_convert_data, idx_t col_idx) { |
| 18 | auto format = string(schema.format); |
| 19 | if (arrow_convert_data.find(x: col_idx) == arrow_convert_data.end()) { |
| 20 | arrow_convert_data[col_idx] = make_uniq<ArrowConvertData>(); |
| 21 | } |
| 22 | auto &convert_data = *arrow_convert_data[col_idx]; |
| 23 | if (format == "n" ) { |
| 24 | return LogicalType::SQLNULL; |
| 25 | } else if (format == "b" ) { |
| 26 | return LogicalType::BOOLEAN; |
| 27 | } else if (format == "c" ) { |
| 28 | return LogicalType::TINYINT; |
| 29 | } else if (format == "s" ) { |
| 30 | return LogicalType::SMALLINT; |
| 31 | } else if (format == "i" ) { |
| 32 | return LogicalType::INTEGER; |
| 33 | } else if (format == "l" ) { |
| 34 | return LogicalType::BIGINT; |
| 35 | } else if (format == "C" ) { |
| 36 | return LogicalType::UTINYINT; |
| 37 | } else if (format == "S" ) { |
| 38 | return LogicalType::USMALLINT; |
| 39 | } else if (format == "I" ) { |
| 40 | return LogicalType::UINTEGER; |
| 41 | } else if (format == "L" ) { |
| 42 | return LogicalType::UBIGINT; |
| 43 | } else if (format == "f" ) { |
| 44 | return LogicalType::FLOAT; |
| 45 | } else if (format == "g" ) { |
| 46 | return LogicalType::DOUBLE; |
| 47 | } else if (format[0] == 'd') { //! this can be either decimal128 or decimal 256 (e.g., d:38,0) |
| 48 | std::string parameters = format.substr(pos: format.find(c: ':')); |
| 49 | uint8_t width = std::stoi(str: parameters.substr(pos: 1, n: parameters.find(c: ','))); |
| 50 | uint8_t scale = std::stoi(str: parameters.substr(pos: parameters.find(c: ',') + 1)); |
| 51 | if (width > 38) { |
| 52 | throw NotImplementedException("Unsupported Internal Arrow Type for Decimal %s" , format); |
| 53 | } |
| 54 | return LogicalType::DECIMAL(width, scale); |
| 55 | } else if (format == "u" ) { |
| 56 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0); |
| 57 | return LogicalType::VARCHAR; |
| 58 | } else if (format == "U" ) { |
| 59 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::SUPER_SIZE, args: 0); |
| 60 | return LogicalType::VARCHAR; |
| 61 | } else if (format == "tsn:" ) { |
| 62 | return LogicalTypeId::TIMESTAMP_NS; |
| 63 | } else if (format == "tsu:" ) { |
| 64 | return LogicalTypeId::TIMESTAMP; |
| 65 | } else if (format == "tsm:" ) { |
| 66 | return LogicalTypeId::TIMESTAMP_MS; |
| 67 | } else if (format == "tss:" ) { |
| 68 | return LogicalTypeId::TIMESTAMP_SEC; |
| 69 | } else if (format == "tdD" ) { |
| 70 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::DAYS); |
| 71 | return LogicalType::DATE; |
| 72 | } else if (format == "tdm" ) { |
| 73 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS); |
| 74 | return LogicalType::DATE; |
| 75 | } else if (format == "tts" ) { |
| 76 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::SECONDS); |
| 77 | return LogicalType::TIME; |
| 78 | } else if (format == "ttm" ) { |
| 79 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS); |
| 80 | return LogicalType::TIME; |
| 81 | } else if (format == "ttu" ) { |
| 82 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MICROSECONDS); |
| 83 | return LogicalType::TIME; |
| 84 | } else if (format == "ttn" ) { |
| 85 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::NANOSECONDS); |
| 86 | return LogicalType::TIME; |
| 87 | } else if (format == "tDs" ) { |
| 88 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::SECONDS); |
| 89 | return LogicalType::INTERVAL; |
| 90 | } else if (format == "tDm" ) { |
| 91 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS); |
| 92 | return LogicalType::INTERVAL; |
| 93 | } else if (format == "tDu" ) { |
| 94 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MICROSECONDS); |
| 95 | return LogicalType::INTERVAL; |
| 96 | } else if (format == "tDn" ) { |
| 97 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::NANOSECONDS); |
| 98 | return LogicalType::INTERVAL; |
| 99 | } else if (format == "tiD" ) { |
| 100 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::DAYS); |
| 101 | return LogicalType::INTERVAL; |
| 102 | } else if (format == "tiM" ) { |
| 103 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MONTHS); |
| 104 | return LogicalType::INTERVAL; |
| 105 | } else if (format == "tin" ) { |
| 106 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MONTH_DAY_NANO); |
| 107 | return LogicalType::INTERVAL; |
| 108 | } else if (format == "+l" ) { |
| 109 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0); |
| 110 | auto child_type = GetArrowLogicalType(schema&: *schema.children[0], arrow_convert_data, col_idx); |
| 111 | return LogicalType::LIST(child: child_type); |
| 112 | } else if (format == "+L" ) { |
| 113 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::SUPER_SIZE, args: 0); |
| 114 | auto child_type = GetArrowLogicalType(schema&: *schema.children[0], arrow_convert_data, col_idx); |
| 115 | return LogicalType::LIST(child: child_type); |
| 116 | } else if (format[0] == '+' && format[1] == 'w') { |
| 117 | std::string parameters = format.substr(pos: format.find(c: ':') + 1); |
| 118 | idx_t fixed_size = std::stoi(str: parameters); |
| 119 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::FIXED_SIZE, args&: fixed_size); |
| 120 | auto child_type = GetArrowLogicalType(schema&: *schema.children[0], arrow_convert_data, col_idx); |
| 121 | return LogicalType::LIST(child: child_type); |
| 122 | } else if (format == "+s" ) { |
| 123 | child_list_t<LogicalType> child_types; |
| 124 | for (idx_t type_idx = 0; type_idx < (idx_t)schema.n_children; type_idx++) { |
| 125 | auto child_type = GetArrowLogicalType(schema&: *schema.children[type_idx], arrow_convert_data, col_idx); |
| 126 | child_types.push_back(x: {schema.children[type_idx]->name, child_type}); |
| 127 | } |
| 128 | return LogicalType::STRUCT(children: child_types); |
| 129 | |
| 130 | } else if (format == "+m" ) { |
| 131 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0); |
| 132 | |
| 133 | auto &arrow_struct_type = *schema.children[0]; |
| 134 | D_ASSERT(arrow_struct_type.n_children == 2); |
| 135 | auto key_type = GetArrowLogicalType(schema&: *arrow_struct_type.children[0], arrow_convert_data, col_idx); |
| 136 | auto value_type = GetArrowLogicalType(schema&: *arrow_struct_type.children[1], arrow_convert_data, col_idx); |
| 137 | return LogicalType::MAP(key: key_type, value: value_type); |
| 138 | } else if (format == "z" ) { |
| 139 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::NORMAL, args: 0); |
| 140 | return LogicalType::BLOB; |
| 141 | } else if (format == "Z" ) { |
| 142 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::SUPER_SIZE, args: 0); |
| 143 | return LogicalType::BLOB; |
| 144 | } else if (format[0] == 'w') { |
| 145 | std::string parameters = format.substr(pos: format.find(c: ':') + 1); |
| 146 | idx_t fixed_size = std::stoi(str: parameters); |
| 147 | convert_data.variable_sz_type.emplace_back(args: ArrowVariableSizeType::FIXED_SIZE, args&: fixed_size); |
| 148 | return LogicalType::BLOB; |
| 149 | } else if (format[0] == 't' && format[1] == 's') { |
| 150 | // Timestamp with Timezone |
| 151 | if (format[2] == 'n') { |
| 152 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::NANOSECONDS); |
| 153 | } else if (format[2] == 'u') { |
| 154 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MICROSECONDS); |
| 155 | } else if (format[2] == 'm') { |
| 156 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::MILLISECONDS); |
| 157 | } else if (format[2] == 's') { |
| 158 | convert_data.date_time_precision.emplace_back(args: ArrowDateTimeType::SECONDS); |
| 159 | } else { |
| 160 | throw NotImplementedException(" Timestamptz precision of not accepted" ); |
| 161 | } |
| 162 | // TODO right now we just get the UTC value. We probably want to support this properly in the future |
| 163 | return LogicalType::TIMESTAMP_TZ; |
| 164 | } else { |
| 165 | throw NotImplementedException("Unsupported Internal Arrow Type %s" , format); |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | void ArrowTableFunction::RenameArrowColumns(vector<string> &names) { |
| 170 | unordered_map<string, idx_t> name_map; |
| 171 | for (auto &column_name : names) { |
| 172 | // put it all lower_case |
| 173 | auto low_column_name = StringUtil::Lower(str: column_name); |
| 174 | if (name_map.find(x: low_column_name) == name_map.end()) { |
| 175 | // Name does not exist yet |
| 176 | name_map[low_column_name]++; |
| 177 | } else { |
| 178 | // Name already exists, we add _x where x is the repetition number |
| 179 | string new_column_name = column_name + "_" + std::to_string(val: name_map[low_column_name]); |
| 180 | auto new_column_name_low = StringUtil::Lower(str: new_column_name); |
| 181 | while (name_map.find(x: new_column_name_low) != name_map.end()) { |
| 182 | // This name is already here due to a previous definition |
| 183 | name_map[low_column_name]++; |
| 184 | new_column_name = column_name + "_" + std::to_string(val: name_map[low_column_name]); |
| 185 | new_column_name_low = StringUtil::Lower(str: new_column_name); |
| 186 | } |
| 187 | column_name = new_column_name; |
| 188 | name_map[new_column_name_low]++; |
| 189 | } |
| 190 | } |
| 191 | } |
| 192 | |
| 193 | unique_ptr<FunctionData> ArrowTableFunction::ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, |
| 194 | vector<LogicalType> &return_types, vector<string> &names) { |
| 195 | auto stream_factory_ptr = input.inputs[0].GetPointer(); |
| 196 | auto stream_factory_produce = (stream_factory_produce_t)input.inputs[1].GetPointer(); // NOLINT |
| 197 | auto stream_factory_get_schema = (stream_factory_get_schema_t)input.inputs[2].GetPointer(); // NOLINT |
| 198 | |
| 199 | auto res = make_uniq<ArrowScanFunctionData>(args&: stream_factory_produce, args&: stream_factory_ptr); |
| 200 | |
| 201 | auto &data = *res; |
| 202 | stream_factory_get_schema(stream_factory_ptr, data.schema_root); |
| 203 | for (idx_t col_idx = 0; col_idx < (idx_t)data.schema_root.arrow_schema.n_children; col_idx++) { |
| 204 | auto &schema = *data.schema_root.arrow_schema.children[col_idx]; |
| 205 | if (!schema.release) { |
| 206 | throw InvalidInputException("arrow_scan: released schema passed" ); |
| 207 | } |
| 208 | if (schema.dictionary) { |
| 209 | auto logical_type = GetArrowLogicalType(schema, arrow_convert_data&: res->arrow_convert_data, col_idx); |
| 210 | res->arrow_convert_data[col_idx] = make_uniq<ArrowConvertData>(args: std::move(logical_type)); |
| 211 | return_types.emplace_back(args: GetArrowLogicalType(schema&: *schema.dictionary, arrow_convert_data&: res->arrow_convert_data, col_idx)); |
| 212 | } else { |
| 213 | return_types.emplace_back(args: GetArrowLogicalType(schema, arrow_convert_data&: res->arrow_convert_data, col_idx)); |
| 214 | } |
| 215 | auto format = string(schema.format); |
| 216 | auto name = string(schema.name); |
| 217 | if (name.empty()) { |
| 218 | name = string("v" ) + to_string(val: col_idx); |
| 219 | } |
| 220 | names.push_back(x: name); |
| 221 | } |
| 222 | RenameArrowColumns(names); |
| 223 | res->all_types = return_types; |
| 224 | return std::move(res); |
| 225 | } |
| 226 | |
| 227 | unique_ptr<ArrowArrayStreamWrapper> ProduceArrowScan(const ArrowScanFunctionData &function, |
| 228 | const vector<column_t> &column_ids, TableFilterSet *filters) { |
| 229 | //! Generate Projection Pushdown Vector |
| 230 | ArrowStreamParameters parameters; |
| 231 | D_ASSERT(!column_ids.empty()); |
| 232 | for (idx_t idx = 0; idx < column_ids.size(); idx++) { |
| 233 | auto col_idx = column_ids[idx]; |
| 234 | if (col_idx != COLUMN_IDENTIFIER_ROW_ID) { |
| 235 | auto &schema = *function.schema_root.arrow_schema.children[col_idx]; |
| 236 | parameters.projected_columns.projection_map[idx] = schema.name; |
| 237 | parameters.projected_columns.columns.emplace_back(args&: schema.name); |
| 238 | } |
| 239 | } |
| 240 | parameters.filters = filters; |
| 241 | return function.scanner_producer(function.stream_factory_ptr, parameters); |
| 242 | } |
| 243 | |
| 244 | idx_t ArrowTableFunction::ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data_p) { |
| 245 | return context.db->NumberOfThreads(); |
| 246 | } |
| 247 | |
| 248 | bool ArrowTableFunction::ArrowScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, |
| 249 | ArrowScanLocalState &state, ArrowScanGlobalState ¶llel_state) { |
| 250 | lock_guard<mutex> parallel_lock(parallel_state.main_mutex); |
| 251 | if (parallel_state.done) { |
| 252 | return false; |
| 253 | } |
| 254 | state.chunk_offset = 0; |
| 255 | state.batch_index = ++parallel_state.batch_index; |
| 256 | |
| 257 | auto current_chunk = parallel_state.stream->GetNextChunk(); |
| 258 | while (current_chunk->arrow_array.length == 0 && current_chunk->arrow_array.release) { |
| 259 | current_chunk = parallel_state.stream->GetNextChunk(); |
| 260 | } |
| 261 | state.chunk = std::move(current_chunk); |
| 262 | //! have we run out of chunks? we are done |
| 263 | if (!state.chunk->arrow_array.release) { |
| 264 | parallel_state.done = true; |
| 265 | return false; |
| 266 | } |
| 267 | return true; |
| 268 | } |
| 269 | |
| 270 | unique_ptr<GlobalTableFunctionState> ArrowTableFunction::ArrowScanInitGlobal(ClientContext &context, |
| 271 | TableFunctionInitInput &input) { |
| 272 | auto &bind_data = input.bind_data->Cast<ArrowScanFunctionData>(); |
| 273 | auto result = make_uniq<ArrowScanGlobalState>(); |
| 274 | result->stream = ProduceArrowScan(function: bind_data, column_ids: input.column_ids, filters: input.filters.get()); |
| 275 | result->max_threads = ArrowScanMaxThreads(context, bind_data_p: input.bind_data.get()); |
| 276 | if (input.CanRemoveFilterColumns()) { |
| 277 | result->projection_ids = input.projection_ids; |
| 278 | for (const auto &col_idx : input.column_ids) { |
| 279 | if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { |
| 280 | result->scanned_types.emplace_back(args: LogicalType::ROW_TYPE); |
| 281 | } else { |
| 282 | result->scanned_types.push_back(x: bind_data.all_types[col_idx]); |
| 283 | } |
| 284 | } |
| 285 | } |
| 286 | return std::move(result); |
| 287 | } |
| 288 | |
| 289 | unique_ptr<LocalTableFunctionState> |
| 290 | ArrowTableFunction::ArrowScanInitLocalInternal(ClientContext &context, TableFunctionInitInput &input, |
| 291 | GlobalTableFunctionState *global_state_p) { |
| 292 | auto &global_state = global_state_p->Cast<ArrowScanGlobalState>(); |
| 293 | auto current_chunk = make_uniq<ArrowArrayWrapper>(); |
| 294 | auto result = make_uniq<ArrowScanLocalState>(args: std::move(current_chunk)); |
| 295 | result->column_ids = input.column_ids; |
| 296 | result->filters = input.filters.get(); |
| 297 | if (input.CanRemoveFilterColumns()) { |
| 298 | auto &asgs = global_state_p->Cast<ArrowScanGlobalState>(); |
| 299 | result->all_columns.Initialize(context, types: asgs.scanned_types); |
| 300 | } |
| 301 | if (!ArrowScanParallelStateNext(context, bind_data_p: input.bind_data.get(), state&: *result, parallel_state&: global_state)) { |
| 302 | return nullptr; |
| 303 | } |
| 304 | return std::move(result); |
| 305 | } |
| 306 | |
| 307 | unique_ptr<LocalTableFunctionState> ArrowTableFunction::ArrowScanInitLocal(ExecutionContext &context, |
| 308 | TableFunctionInitInput &input, |
| 309 | GlobalTableFunctionState *global_state_p) { |
| 310 | return ArrowScanInitLocalInternal(context&: context.client, input, global_state_p); |
| 311 | } |
| 312 | |
| 313 | void ArrowTableFunction::ArrowScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { |
| 314 | if (!data_p.local_state) { |
| 315 | return; |
| 316 | } |
| 317 | auto &data = data_p.bind_data->CastNoConst<ArrowScanFunctionData>(); // FIXME |
| 318 | auto &state = data_p.local_state->Cast<ArrowScanLocalState>(); |
| 319 | auto &global_state = data_p.global_state->Cast<ArrowScanGlobalState>(); |
| 320 | |
| 321 | //! Out of tuples in this chunk |
| 322 | if (state.chunk_offset >= (idx_t)state.chunk->arrow_array.length) { |
| 323 | if (!ArrowScanParallelStateNext(context, bind_data_p: data_p.bind_data.get(), state, parallel_state&: global_state)) { |
| 324 | return; |
| 325 | } |
| 326 | } |
| 327 | int64_t output_size = MinValue<int64_t>(STANDARD_VECTOR_SIZE, b: state.chunk->arrow_array.length - state.chunk_offset); |
| 328 | data.lines_read += output_size; |
| 329 | if (global_state.CanRemoveFilterColumns()) { |
| 330 | state.all_columns.Reset(); |
| 331 | state.all_columns.SetCardinality(output_size); |
| 332 | ArrowToDuckDB(scan_state&: state, arrow_convert_data&: data.arrow_convert_data, output&: state.all_columns, start: data.lines_read - output_size); |
| 333 | output.ReferenceColumns(other&: state.all_columns, column_ids: global_state.projection_ids); |
| 334 | } else { |
| 335 | output.SetCardinality(output_size); |
| 336 | ArrowToDuckDB(scan_state&: state, arrow_convert_data&: data.arrow_convert_data, output, start: data.lines_read - output_size); |
| 337 | } |
| 338 | |
| 339 | output.Verify(); |
| 340 | state.chunk_offset += output.size(); |
| 341 | } |
| 342 | |
| 343 | unique_ptr<NodeStatistics> ArrowTableFunction::ArrowScanCardinality(ClientContext &context, const FunctionData *data) { |
| 344 | return make_uniq<NodeStatistics>(); |
| 345 | } |
| 346 | |
| 347 | idx_t ArrowTableFunction::ArrowGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, |
| 348 | LocalTableFunctionState *local_state, |
| 349 | GlobalTableFunctionState *global_state) { |
| 350 | auto &state = local_state->Cast<ArrowScanLocalState>(); |
| 351 | return state.batch_index; |
| 352 | } |
| 353 | |
| 354 | void ArrowTableFunction::RegisterFunction(BuiltinFunctions &set) { |
| 355 | TableFunction arrow("arrow_scan" , {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, |
| 356 | ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal); |
| 357 | arrow.cardinality = ArrowScanCardinality; |
| 358 | arrow.get_batch_index = ArrowGetBatchIndex; |
| 359 | arrow.projection_pushdown = true; |
| 360 | arrow.filter_pushdown = true; |
| 361 | arrow.filter_prune = true; |
| 362 | set.AddFunction(function: arrow); |
| 363 | |
| 364 | TableFunction arrow_dumb("arrow_scan_dumb" , {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, |
| 365 | ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal); |
| 366 | arrow_dumb.cardinality = ArrowScanCardinality; |
| 367 | arrow_dumb.get_batch_index = ArrowGetBatchIndex; |
| 368 | arrow_dumb.projection_pushdown = false; |
| 369 | arrow_dumb.filter_pushdown = false; |
| 370 | arrow_dumb.filter_prune = false; |
| 371 | set.AddFunction(function: arrow_dumb); |
| 372 | } |
| 373 | |
| 374 | void BuiltinFunctions::RegisterArrowFunctions() { |
| 375 | ArrowTableFunction::RegisterFunction(set&: *this); |
| 376 | } |
| 377 | } // namespace duckdb |
| 378 | |