1 | //===----------------------------------------------------------------------===// |
2 | // DuckDB |
3 | // |
4 | // duckdb/function/table/arrow.hpp |
5 | // |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #pragma once |
10 | |
11 | #include "duckdb/function/table_function.hpp" |
12 | #include "duckdb/common/arrow/arrow_wrapper.hpp" |
13 | #include "duckdb/common/atomic.hpp" |
14 | #include "duckdb/common/mutex.hpp" |
15 | #include "duckdb/common/pair.hpp" |
16 | #include "duckdb/common/thread.hpp" |
17 | #include "duckdb/common/unordered_map.hpp" |
18 | #include "duckdb/function/built_in_functions.hpp" |
19 | |
20 | namespace duckdb { |
21 | //===--------------------------------------------------------------------===// |
22 | // Arrow Variable Size Types |
23 | //===--------------------------------------------------------------------===// |
24 | enum class ArrowVariableSizeType : uint8_t { FIXED_SIZE = 0, NORMAL = 1, SUPER_SIZE = 2 }; |
25 | |
26 | //===--------------------------------------------------------------------===// |
27 | // Arrow Time/Date Types |
28 | //===--------------------------------------------------------------------===// |
29 | enum class ArrowDateTimeType : uint8_t { |
30 | MILLISECONDS = 0, |
31 | MICROSECONDS = 1, |
32 | NANOSECONDS = 2, |
33 | SECONDS = 3, |
34 | DAYS = 4, |
35 | MONTHS = 5, |
36 | MONTH_DAY_NANO = 6 |
37 | }; |
38 | |
39 | struct ArrowInterval { |
40 | int32_t months; |
41 | int32_t days; |
42 | int64_t nanoseconds; |
43 | |
44 | inline bool operator==(const ArrowInterval &rhs) const { |
45 | return this->days == rhs.days && this->months == rhs.months && this->nanoseconds == rhs.nanoseconds; |
46 | } |
47 | }; |
48 | |
49 | struct ArrowConvertData { |
50 | ArrowConvertData(LogicalType type) : dictionary_type(type) {}; |
51 | ArrowConvertData() {}; |
52 | |
53 | //! Hold type of dictionary |
54 | LogicalType dictionary_type; |
55 | //! If its a variable size type (e.g., strings, blobs, lists) holds which type it is |
56 | vector<pair<ArrowVariableSizeType, idx_t>> variable_sz_type; |
57 | //! If this is a date/time holds its precision |
58 | vector<ArrowDateTimeType> date_time_precision; |
59 | }; |
60 | |
61 | struct ArrowProjectedColumns { |
62 | unordered_map<idx_t, string> projection_map; |
63 | vector<string> columns; |
64 | }; |
65 | |
66 | struct ArrowStreamParameters { |
67 | ArrowProjectedColumns projected_columns; |
68 | TableFilterSet *filters; |
69 | }; |
70 | |
71 | typedef unique_ptr<ArrowArrayStreamWrapper> (*stream_factory_produce_t)(uintptr_t stream_factory_ptr, |
72 | ArrowStreamParameters ¶meters); |
73 | typedef void (*stream_factory_get_schema_t)(uintptr_t stream_factory_ptr, ArrowSchemaWrapper &schema); |
74 | |
75 | struct ArrowScanFunctionData : public PyTableFunctionData { |
76 | ArrowScanFunctionData(stream_factory_produce_t scanner_producer_p, uintptr_t stream_factory_ptr_p) |
77 | : lines_read(0), stream_factory_ptr(stream_factory_ptr_p), scanner_producer(scanner_producer_p) { |
78 | } |
79 | //! This holds the original list type (col_idx, [ArrowListType,size]) |
80 | unordered_map<idx_t, unique_ptr<ArrowConvertData>> arrow_convert_data; |
81 | vector<LogicalType> all_types; |
82 | atomic<idx_t> lines_read; |
83 | ArrowSchemaWrapper schema_root; |
84 | idx_t rows_per_thread; |
85 | //! Pointer to the scanner factory |
86 | uintptr_t stream_factory_ptr; |
87 | //! Pointer to the scanner factory produce |
88 | stream_factory_produce_t scanner_producer; |
89 | }; |
90 | |
91 | struct ArrowScanLocalState : public LocalTableFunctionState { |
92 | explicit ArrowScanLocalState(unique_ptr<ArrowArrayWrapper> current_chunk) : chunk(current_chunk.release()) { |
93 | } |
94 | |
95 | unique_ptr<ArrowArrayStreamWrapper> stream; |
96 | shared_ptr<ArrowArrayWrapper> chunk; |
97 | idx_t chunk_offset = 0; |
98 | idx_t batch_index = 0; |
99 | vector<column_t> column_ids; |
100 | //! Store child vectors for Arrow Dictionary Vectors (col-idx,vector) |
101 | unordered_map<idx_t, unique_ptr<Vector>> arrow_dictionary_vectors; |
102 | TableFilterSet *filters = nullptr; |
103 | //! The DataChunk containing all read columns (even filter columns that are immediately removed) |
104 | DataChunk all_columns; |
105 | }; |
106 | |
107 | struct ArrowScanGlobalState : public GlobalTableFunctionState { |
108 | unique_ptr<ArrowArrayStreamWrapper> stream; |
109 | mutex main_mutex; |
110 | idx_t max_threads = 1; |
111 | idx_t batch_index = 0; |
112 | bool done = false; |
113 | |
114 | vector<idx_t> projection_ids; |
115 | vector<LogicalType> scanned_types; |
116 | |
117 | idx_t MaxThreads() const override { |
118 | return max_threads; |
119 | } |
120 | |
121 | bool CanRemoveFilterColumns() const { |
122 | return !projection_ids.empty(); |
123 | } |
124 | }; |
125 | |
126 | struct ArrowTableFunction { |
127 | public: |
128 | static void RegisterFunction(BuiltinFunctions &set); |
129 | |
130 | public: |
131 | //! Binds an arrow table |
132 | static unique_ptr<FunctionData> ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, |
133 | vector<LogicalType> &return_types, vector<string> &names); |
134 | //! Actual conversion from Arrow to DuckDB |
135 | static void ArrowToDuckDB(ArrowScanLocalState &scan_state, |
136 | std::unordered_map<idx_t, unique_ptr<ArrowConvertData>> &arrow_convert_data, |
137 | DataChunk &output, idx_t start, bool arrow_scan_is_projected = true); |
138 | |
139 | //! Get next scan state |
140 | static bool ArrowScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, |
141 | ArrowScanLocalState &state, ArrowScanGlobalState ¶llel_state); |
142 | |
143 | //! Initialize Global State |
144 | static unique_ptr<GlobalTableFunctionState> ArrowScanInitGlobal(ClientContext &context, |
145 | TableFunctionInitInput &input); |
146 | |
147 | //! Initialize Local State |
148 | static unique_ptr<LocalTableFunctionState> ArrowScanInitLocalInternal(ClientContext &context, |
149 | TableFunctionInitInput &input, |
150 | GlobalTableFunctionState *global_state); |
151 | static unique_ptr<LocalTableFunctionState> ArrowScanInitLocal(ExecutionContext &context, |
152 | TableFunctionInitInput &input, |
153 | GlobalTableFunctionState *global_state); |
154 | |
155 | //! Scan Function |
156 | static void ArrowScanFunction(ClientContext &context, TableFunctionInput &data, DataChunk &output); |
157 | |
158 | protected: |
159 | //! Defines Maximum Number of Threads |
160 | static idx_t ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data); |
161 | |
162 | //! Allows parallel Create Table / Insertion |
163 | static idx_t ArrowGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p, |
164 | LocalTableFunctionState *local_state, GlobalTableFunctionState *global_state); |
165 | |
166 | //! -----Utility Functions:----- |
167 | //! Gets Arrow Table's Cardinality |
168 | static unique_ptr<NodeStatistics> ArrowScanCardinality(ClientContext &context, const FunctionData *bind_data); |
169 | //! Gets the progress on the table scan, used for Progress Bars |
170 | static double ArrowProgress(ClientContext &context, const FunctionData *bind_data, |
171 | const GlobalTableFunctionState *global_state); |
172 | //! Renames repeated columns and case sensitive columns |
173 | static void RenameArrowColumns(vector<string> &names); |
174 | //! Helper function to get the DuckDB logical type |
175 | static LogicalType GetArrowLogicalType(ArrowSchema &schema, |
176 | std::unordered_map<idx_t, unique_ptr<ArrowConvertData>> &arrow_convert_data, |
177 | idx_t col_idx); |
178 | }; |
179 | |
180 | } // namespace duckdb |
181 | |