| 1 | //===----------------------------------------------------------------------===// |
| 2 | // DuckDB |
| 3 | // |
| 4 | // duckdb/function/table/arrow.hpp |
| 5 | // |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #pragma once |
| 10 | |
| 11 | #include "duckdb/function/table_function.hpp" |
| 12 | #include "duckdb/common/arrow/arrow_wrapper.hpp" |
| 13 | #include "duckdb/common/atomic.hpp" |
| 14 | #include "duckdb/common/mutex.hpp" |
| 15 | #include "duckdb/common/pair.hpp" |
| 16 | #include "duckdb/common/thread.hpp" |
| 17 | #include "duckdb/common/unordered_map.hpp" |
| 18 | #include "duckdb/function/built_in_functions.hpp" |
| 19 | |
| 20 | namespace duckdb { |
| 21 | //===--------------------------------------------------------------------===// |
| 22 | // Arrow Variable Size Types |
| 23 | //===--------------------------------------------------------------------===// |
| 24 | enum class ArrowVariableSizeType : uint8_t { FIXED_SIZE = 0, NORMAL = 1, SUPER_SIZE = 2 }; |
| 25 | |
| 26 | //===--------------------------------------------------------------------===// |
| 27 | // Arrow Time/Date Types |
| 28 | //===--------------------------------------------------------------------===// |
| 29 | enum class ArrowDateTimeType : uint8_t { |
| 30 | MILLISECONDS = 0, |
| 31 | MICROSECONDS = 1, |
| 32 | NANOSECONDS = 2, |
| 33 | SECONDS = 3, |
| 34 | DAYS = 4, |
| 35 | MONTHS = 5, |
| 36 | MONTH_DAY_NANO = 6 |
| 37 | }; |
| 38 | |
| 39 | struct ArrowInterval { |
| 40 | int32_t months; |
| 41 | int32_t days; |
| 42 | int64_t nanoseconds; |
| 43 | |
| 44 | inline bool operator==(const ArrowInterval &rhs) const { |
| 45 | return this->days == rhs.days && this->months == rhs.months && this->nanoseconds == rhs.nanoseconds; |
| 46 | } |
| 47 | }; |
| 48 | |
| 49 | struct ArrowConvertData { |
| 50 | ArrowConvertData(LogicalType type) : dictionary_type(type) {}; |
| 51 | ArrowConvertData() {}; |
| 52 | |
| 53 | //! Hold type of dictionary |
| 54 | LogicalType dictionary_type; |
| 55 | //! If its a variable size type (e.g., strings, blobs, lists) holds which type it is |
| 56 | vector<pair<ArrowVariableSizeType, idx_t>> variable_sz_type; |
| 57 | //! If this is a date/time holds its precision |
| 58 | vector<ArrowDateTimeType> date_time_precision; |
| 59 | }; |
| 60 | |
| 61 | struct ArrowProjectedColumns { |
| 62 | unordered_map<idx_t, string> projection_map; |
| 63 | vector<string> columns; |
| 64 | }; |
| 65 | |
| 66 | struct ArrowStreamParameters { |
| 67 | ArrowProjectedColumns projected_columns; |
| 68 | TableFilterSet *filters; |
| 69 | }; |
| 70 | |
| 71 | typedef unique_ptr<ArrowArrayStreamWrapper> (*stream_factory_produce_t)(uintptr_t stream_factory_ptr, |
| 72 | ArrowStreamParameters ¶meters); |
| 73 | typedef void (*stream_factory_get_schema_t)(uintptr_t stream_factory_ptr, ArrowSchemaWrapper &schema); |
| 74 | |
| 75 | struct ArrowScanFunctionData : public PyTableFunctionData { |
| 76 | ArrowScanFunctionData(stream_factory_produce_t scanner_producer_p, uintptr_t stream_factory_ptr_p) |
| 77 | : lines_read(0), stream_factory_ptr(stream_factory_ptr_p), scanner_producer(scanner_producer_p) { |
| 78 | } |
| 79 | //! This holds the original list type (col_idx, [ArrowListType,size]) |
| 80 | unordered_map<idx_t, unique_ptr<ArrowConvertData>> arrow_convert_data; |
| 81 | vector<LogicalType> all_types; |
| 82 | atomic<idx_t> lines_read; |
| 83 | ArrowSchemaWrapper schema_root; |
| 84 | idx_t rows_per_thread; |
| 85 | //! Pointer to the scanner factory |
| 86 | uintptr_t stream_factory_ptr; |
| 87 | //! Pointer to the scanner factory produce |
| 88 | stream_factory_produce_t scanner_producer; |
| 89 | }; |
| 90 | |
| 91 | struct ArrowScanLocalState : public LocalTableFunctionState { |
| 92 | explicit ArrowScanLocalState(unique_ptr<ArrowArrayWrapper> current_chunk) : chunk(current_chunk.release()) { |
| 93 | } |
| 94 | |
| 95 | unique_ptr<ArrowArrayStreamWrapper> stream; |
| 96 | shared_ptr<ArrowArrayWrapper> chunk; |
| 97 | idx_t chunk_offset = 0; |
| 98 | idx_t batch_index = 0; |
| 99 | vector<column_t> column_ids; |
| 100 | //! Store child vectors for Arrow Dictionary Vectors (col-idx,vector) |
| 101 | unordered_map<idx_t, unique_ptr<Vector>> arrow_dictionary_vectors; |
| 102 | TableFilterSet *filters = nullptr; |
| 103 | //! The DataChunk containing all read columns (even filter columns that are immediately removed) |
| 104 | DataChunk all_columns; |
| 105 | }; |
| 106 | |
| 107 | struct ArrowScanGlobalState : public GlobalTableFunctionState { |
| 108 | unique_ptr<ArrowArrayStreamWrapper> stream; |
| 109 | mutex main_mutex; |
| 110 | idx_t max_threads = 1; |
| 111 | idx_t batch_index = 0; |
| 112 | bool done = false; |
| 113 | |
| 114 | vector<idx_t> projection_ids; |
| 115 | vector<LogicalType> scanned_types; |
| 116 | |
| 117 | idx_t MaxThreads() const override { |
| 118 | return max_threads; |
| 119 | } |
| 120 | |
| 121 | bool CanRemoveFilterColumns() const { |
| 122 | return !projection_ids.empty(); |
| 123 | } |
| 124 | }; |
| 125 | |
| 126 | struct ArrowTableFunction { |
| 127 | public: |
| 128 | static void RegisterFunction(BuiltinFunctions &set); |
| 129 | |
| 130 | public: |
| 131 | //! Binds an arrow table |
| 132 | static unique_ptr<FunctionData> ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, |
| 133 | vector<LogicalType> &return_types, vector<string> &names); |
| 134 | //! Actual conversion from Arrow to DuckDB |
| 135 | static void ArrowToDuckDB(ArrowScanLocalState &scan_state, |
| 136 | std::unordered_map<idx_t, unique_ptr<ArrowConvertData>> &arrow_convert_data, |
| 137 | DataChunk &output, idx_t start, bool arrow_scan_is_projected = true); |
| 138 | |
| 139 | //! Get next scan state |
| 140 | static bool ArrowScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, |
| 141 | ArrowScanLocalState &state, ArrowScanGlobalState ¶llel_state); |
| 142 | |
| 143 | //! Initialize Global State |
| 144 | static unique_ptr<GlobalTableFunctionState> ArrowScanInitGlobal(ClientContext &context, |
| 145 | TableFunctionInitInput &input); |
| 146 | |
| 147 | //! Initialize Local State |
| 148 | static unique_ptr<LocalTableFunctionState> ArrowScanInitLocalInternal(ClientContext &context, |
| 149 | TableFunctionInitInput &input, |
| 150 | GlobalTableFunctionState *global_state); |
| 151 | static unique_ptr<LocalTableFunctionState> ArrowScanInitLocal(ExecutionContext &context, |
| 152 | TableFunctionInitInput &input, |
| 153 | GlobalTableFunctionState *global_state); |
| 154 | |
| 155 | //! Scan Function |
| 156 | static void ArrowScanFunction(ClientContext &context, TableFunctionInput &data, DataChunk &output); |
| 157 | |
| 158 | protected: |
| 159 | //! Defines Maximum Number of Threads |
| 160 | static idx_t ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data); |
| 161 | |
| 162 | //! Allows parallel Create Table / Insertion |
| 163 | static idx_t ArrowGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, |
| 164 | LocalTableFunctionState *local_state, GlobalTableFunctionState *global_state); |
| 165 | |
| 166 | //! -----Utility Functions:----- |
| 167 | //! Gets Arrow Table's Cardinality |
| 168 | static unique_ptr<NodeStatistics> ArrowScanCardinality(ClientContext &context, const FunctionData *bind_data); |
| 169 | //! Gets the progress on the table scan, used for Progress Bars |
| 170 | static double ArrowProgress(ClientContext &context, const FunctionData *bind_data, |
| 171 | const GlobalTableFunctionState *global_state); |
| 172 | //! Renames repeated columns and case sensitive columns |
| 173 | static void RenameArrowColumns(vector<string> &names); |
| 174 | //! Helper function to get the DuckDB logical type |
| 175 | static LogicalType GetArrowLogicalType(ArrowSchema &schema, |
| 176 | std::unordered_map<idx_t, unique_ptr<ArrowConvertData>> &arrow_convert_data, |
| 177 | idx_t col_idx); |
| 178 | }; |
| 179 | |
| 180 | } // namespace duckdb |
| 181 | |