1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "arrow/array/builder_dict.h" |
19 | |
20 | #include <algorithm> |
21 | #include <cstdint> |
22 | #include <limits> |
23 | #include <sstream> |
24 | #include <type_traits> |
25 | #include <utility> |
26 | #include <vector> |
27 | |
28 | #include "arrow/array.h" |
29 | #include "arrow/buffer.h" |
30 | #include "arrow/status.h" |
31 | #include "arrow/type.h" |
32 | #include "arrow/type_traits.h" |
33 | #include "arrow/util/checked_cast.h" |
34 | #include "arrow/util/hashing.h" |
35 | #include "arrow/util/logging.h" |
36 | #include "arrow/visitor_inline.h" |
37 | |
38 | namespace arrow { |
39 | |
40 | using internal::checked_cast; |
41 | |
42 | // ---------------------------------------------------------------------- |
43 | // DictionaryType unification |
44 | |
45 | struct UnifyDictionaryValues { |
46 | MemoryPool* pool_; |
47 | std::shared_ptr<DataType> value_type_; |
48 | const std::vector<const DictionaryType*>& types_; |
49 | std::shared_ptr<Array>* out_values_; |
50 | std::vector<std::vector<int32_t>>* out_transpose_maps_; |
51 | |
52 | Status Visit(const DataType&, void* = nullptr) { |
53 | // Default implementation for non-dictionary-supported datatypes |
54 | std::stringstream ss; |
55 | ss << "Unification of " << value_type_->ToString() |
56 | << " dictionaries is not implemented" ; |
57 | return Status::NotImplemented(ss.str()); |
58 | } |
59 | |
60 | template <typename T> |
61 | Status Visit(const T&, |
62 | typename internal::DictionaryTraits<T>::MemoTableType* = nullptr) { |
63 | using ArrayType = typename TypeTraits<T>::ArrayType; |
64 | using DictTraits = typename internal::DictionaryTraits<T>; |
65 | using MemoTableType = typename DictTraits::MemoTableType; |
66 | |
67 | MemoTableType memo_table; |
68 | if (out_transpose_maps_ != nullptr) { |
69 | out_transpose_maps_->clear(); |
70 | out_transpose_maps_->reserve(types_.size()); |
71 | } |
72 | // Build up the unified dictionary values and the transpose maps |
73 | for (const auto& type : types_) { |
74 | const ArrayType& values = checked_cast<const ArrayType&>(*type->dictionary()); |
75 | if (out_transpose_maps_ != nullptr) { |
76 | std::vector<int32_t> transpose_map; |
77 | transpose_map.reserve(values.length()); |
78 | for (int64_t i = 0; i < values.length(); ++i) { |
79 | int32_t dict_index = memo_table.GetOrInsert(values.GetView(i)); |
80 | transpose_map.push_back(dict_index); |
81 | } |
82 | out_transpose_maps_->push_back(std::move(transpose_map)); |
83 | } else { |
84 | for (int64_t i = 0; i < values.length(); ++i) { |
85 | memo_table.GetOrInsert(values.GetView(i)); |
86 | } |
87 | } |
88 | } |
89 | // Build unified dictionary array |
90 | std::shared_ptr<ArrayData> data; |
91 | RETURN_NOT_OK(DictTraits::GetDictionaryArrayData(pool_, value_type_, memo_table, |
92 | 0 /* start_offset */, &data)); |
93 | *out_values_ = MakeArray(data); |
94 | return Status::OK(); |
95 | } |
96 | }; |
97 | |
98 | Status DictionaryType::Unify(MemoryPool* pool, const std::vector<const DataType*>& types, |
99 | std::shared_ptr<DataType>* out_type, |
100 | std::vector<std::vector<int32_t>>* out_transpose_maps) { |
101 | if (types.size() == 0) { |
102 | return Status::Invalid("need at least one input type" ); |
103 | } |
104 | std::vector<const DictionaryType*> dict_types; |
105 | dict_types.reserve(types.size()); |
106 | for (const auto& type : types) { |
107 | if (type->id() != Type::DICTIONARY) { |
108 | return Status::TypeError("input types must be dictionary types" ); |
109 | } |
110 | dict_types.push_back(checked_cast<const DictionaryType*>(type)); |
111 | } |
112 | |
113 | // XXX Should we check the ordered flag? |
114 | auto value_type = dict_types[0]->dictionary()->type(); |
115 | for (const auto& type : dict_types) { |
116 | auto values = type->dictionary(); |
117 | if (!values->type()->Equals(value_type)) { |
118 | return Status::TypeError("input types have different value types" ); |
119 | } |
120 | if (values->null_count() != 0) { |
121 | return Status::TypeError("input types have null values" ); |
122 | } |
123 | } |
124 | |
125 | std::shared_ptr<Array> values; |
126 | { |
127 | UnifyDictionaryValues visitor{pool, value_type, dict_types, &values, |
128 | out_transpose_maps}; |
129 | RETURN_NOT_OK(VisitTypeInline(*value_type, &visitor)); |
130 | } |
131 | |
132 | // Build unified dictionary type with the right index type |
133 | std::shared_ptr<DataType> index_type; |
134 | if (values->length() <= std::numeric_limits<int8_t>::max()) { |
135 | index_type = int8(); |
136 | } else if (values->length() <= std::numeric_limits<int16_t>::max()) { |
137 | index_type = int16(); |
138 | } else if (values->length() <= std::numeric_limits<int32_t>::max()) { |
139 | index_type = int32(); |
140 | } else { |
141 | index_type = int64(); |
142 | } |
143 | *out_type = arrow::dictionary(index_type, values); |
144 | return Status::OK(); |
145 | } |
146 | |
147 | // ---------------------------------------------------------------------- |
148 | // DictionaryBuilder |
149 | |
150 | template <typename T> |
151 | class DictionaryBuilder<T>::MemoTableImpl |
152 | : public internal::HashTraits<T>::MemoTableType { |
153 | public: |
154 | using MemoTableType = typename internal::HashTraits<T>::MemoTableType; |
155 | using MemoTableType::MemoTableType; |
156 | }; |
157 | |
158 | template <typename T> |
159 | DictionaryBuilder<T>::~DictionaryBuilder() {} |
160 | |
161 | template <typename T> |
162 | DictionaryBuilder<T>::DictionaryBuilder(const std::shared_ptr<DataType>& type, |
163 | MemoryPool* pool) |
164 | : ArrayBuilder(type, pool), |
165 | memo_table_(new MemoTableImpl(0)), |
166 | delta_offset_(0), |
167 | byte_width_(-1), |
168 | values_builder_(pool) { |
169 | DCHECK_EQ(T::type_id, type->id()) << "inconsistent type passed to DictionaryBuilder" ; |
170 | } |
171 | |
172 | DictionaryBuilder<NullType>::DictionaryBuilder(const std::shared_ptr<DataType>& type, |
173 | MemoryPool* pool) |
174 | : ArrayBuilder(type, pool), values_builder_(pool) { |
175 | DCHECK_EQ(Type::NA, type->id()) << "inconsistent type passed to DictionaryBuilder" ; |
176 | } |
177 | |
178 | template <> |
179 | DictionaryBuilder<FixedSizeBinaryType>::DictionaryBuilder( |
180 | const std::shared_ptr<DataType>& type, MemoryPool* pool) |
181 | : ArrayBuilder(type, pool), |
182 | memo_table_(new MemoTableImpl(0)), |
183 | delta_offset_(0), |
184 | byte_width_(checked_cast<const FixedSizeBinaryType&>(*type).byte_width()) {} |
185 | |
186 | template <typename T> |
187 | void DictionaryBuilder<T>::Reset() { |
188 | ArrayBuilder::Reset(); |
189 | values_builder_.Reset(); |
190 | memo_table_.reset(new MemoTableImpl(0)); |
191 | delta_offset_ = 0; |
192 | } |
193 | |
194 | template <typename T> |
195 | Status DictionaryBuilder<T>::Resize(int64_t capacity) { |
196 | RETURN_NOT_OK(CheckCapacity(capacity, capacity_)); |
197 | capacity = std::max(capacity, kMinBuilderCapacity); |
198 | |
199 | if (capacity_ == 0) { |
200 | // Initialize hash table |
201 | // XXX should we let the user pass additional size heuristics? |
202 | delta_offset_ = 0; |
203 | } |
204 | RETURN_NOT_OK(values_builder_.Resize(capacity)); |
205 | return ArrayBuilder::Resize(capacity); |
206 | } |
207 | |
208 | Status DictionaryBuilder<NullType>::Resize(int64_t capacity) { |
209 | RETURN_NOT_OK(CheckCapacity(capacity, capacity_)); |
210 | capacity = std::max(capacity, kMinBuilderCapacity); |
211 | |
212 | RETURN_NOT_OK(values_builder_.Resize(capacity)); |
213 | return ArrayBuilder::Resize(capacity); |
214 | } |
215 | |
216 | template <typename T> |
217 | Status DictionaryBuilder<T>::Append(const Scalar& value) { |
218 | RETURN_NOT_OK(Reserve(1)); |
219 | |
220 | auto memo_index = memo_table_->GetOrInsert(value); |
221 | RETURN_NOT_OK(values_builder_.Append(memo_index)); |
222 | length_ += 1; |
223 | |
224 | return Status::OK(); |
225 | } |
226 | |
227 | template <typename T> |
228 | Status DictionaryBuilder<T>::AppendNull() { |
229 | length_ += 1; |
230 | null_count_ += 1; |
231 | |
232 | return values_builder_.AppendNull(); |
233 | } |
234 | |
235 | Status DictionaryBuilder<NullType>::AppendNull() { |
236 | length_ += 1; |
237 | null_count_ += 1; |
238 | |
239 | return values_builder_.AppendNull(); |
240 | } |
241 | |
242 | template <typename T> |
243 | Status DictionaryBuilder<T>::AppendArray(const Array& array) { |
244 | using ArrayType = typename TypeTraits<T>::ArrayType; |
245 | |
246 | const auto& concrete_array = checked_cast<const ArrayType&>(array); |
247 | for (int64_t i = 0; i < array.length(); i++) { |
248 | if (array.IsNull(i)) { |
249 | RETURN_NOT_OK(AppendNull()); |
250 | } else { |
251 | RETURN_NOT_OK(Append(concrete_array.GetView(i))); |
252 | } |
253 | } |
254 | return Status::OK(); |
255 | } |
256 | |
257 | template <> |
258 | Status DictionaryBuilder<FixedSizeBinaryType>::AppendArray(const Array& array) { |
259 | if (!type_->Equals(*array.type())) { |
260 | return Status::Invalid("Cannot append FixedSizeBinary array with non-matching type" ); |
261 | } |
262 | |
263 | const auto& typed_array = checked_cast<const FixedSizeBinaryArray&>(array); |
264 | for (int64_t i = 0; i < array.length(); i++) { |
265 | if (array.IsNull(i)) { |
266 | RETURN_NOT_OK(AppendNull()); |
267 | } else { |
268 | RETURN_NOT_OK(Append(typed_array.GetValue(i))); |
269 | } |
270 | } |
271 | return Status::OK(); |
272 | } |
273 | |
274 | Status DictionaryBuilder<NullType>::AppendArray(const Array& array) { |
275 | for (int64_t i = 0; i < array.length(); i++) { |
276 | RETURN_NOT_OK(AppendNull()); |
277 | } |
278 | return Status::OK(); |
279 | } |
280 | |
281 | template <typename T> |
282 | Status DictionaryBuilder<T>::FinishInternal(std::shared_ptr<ArrayData>* out) { |
283 | // Finalize indices array |
284 | RETURN_NOT_OK(values_builder_.FinishInternal(out)); |
285 | |
286 | // Generate dictionary array from hash table contents |
287 | std::shared_ptr<Array> dictionary; |
288 | std::shared_ptr<ArrayData> dictionary_data; |
289 | |
290 | RETURN_NOT_OK(internal::DictionaryTraits<T>::GetDictionaryArrayData( |
291 | pool_, type_, *memo_table_, delta_offset_, &dictionary_data)); |
292 | dictionary = MakeArray(dictionary_data); |
293 | |
294 | // Set type of array data to the right dictionary type |
295 | (*out)->type = std::make_shared<DictionaryType>((*out)->type, dictionary); |
296 | |
297 | // Update internals for further uses of this DictionaryBuilder |
298 | delta_offset_ = memo_table_->size(); |
299 | values_builder_.Reset(); |
300 | |
301 | return Status::OK(); |
302 | } |
303 | |
304 | Status DictionaryBuilder<NullType>::FinishInternal(std::shared_ptr<ArrayData>* out) { |
305 | std::shared_ptr<Array> dictionary = std::make_shared<NullArray>(0); |
306 | |
307 | RETURN_NOT_OK(values_builder_.FinishInternal(out)); |
308 | (*out)->type = std::make_shared<DictionaryType>((*out)->type, dictionary); |
309 | |
310 | return Status::OK(); |
311 | } |
312 | |
313 | template class DictionaryBuilder<UInt8Type>; |
314 | template class DictionaryBuilder<UInt16Type>; |
315 | template class DictionaryBuilder<UInt32Type>; |
316 | template class DictionaryBuilder<UInt64Type>; |
317 | template class DictionaryBuilder<Int8Type>; |
318 | template class DictionaryBuilder<Int16Type>; |
319 | template class DictionaryBuilder<Int32Type>; |
320 | template class DictionaryBuilder<Int64Type>; |
321 | template class DictionaryBuilder<Date32Type>; |
322 | template class DictionaryBuilder<Date64Type>; |
323 | template class DictionaryBuilder<Time32Type>; |
324 | template class DictionaryBuilder<Time64Type>; |
325 | template class DictionaryBuilder<TimestampType>; |
326 | template class DictionaryBuilder<FloatType>; |
327 | template class DictionaryBuilder<DoubleType>; |
328 | template class DictionaryBuilder<FixedSizeBinaryType>; |
329 | template class DictionaryBuilder<BinaryType>; |
330 | template class DictionaryBuilder<StringType>; |
331 | |
332 | } // namespace arrow |
333 | |