1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | // Functions for comparing Arrow data structures |
19 | |
20 | #include "arrow/compare.h" |
21 | |
22 | #include <climits> |
23 | #include <cmath> |
24 | #include <cstdint> |
25 | #include <cstring> |
26 | #include <memory> |
27 | #include <string> |
28 | #include <type_traits> |
29 | #include <vector> |
30 | |
31 | #include "arrow/array.h" |
32 | #include "arrow/buffer.h" |
33 | #include "arrow/sparse_tensor.h" |
34 | #include "arrow/status.h" |
35 | #include "arrow/tensor.h" |
36 | #include "arrow/type.h" |
37 | #include "arrow/util/bit-util.h" |
38 | #include "arrow/util/checked_cast.h" |
39 | #include "arrow/util/logging.h" |
40 | #include "arrow/util/macros.h" |
41 | #include "arrow/visitor_inline.h" |
42 | |
43 | namespace arrow { |
44 | |
45 | using internal::BitmapEquals; |
46 | using internal::checked_cast; |
47 | |
48 | // ---------------------------------------------------------------------- |
49 | // Public method implementations |
50 | |
51 | namespace internal { |
52 | |
53 | class RangeEqualsVisitor { |
54 | public: |
55 | RangeEqualsVisitor(const Array& right, int64_t left_start_idx, int64_t left_end_idx, |
56 | int64_t right_start_idx) |
57 | : right_(right), |
58 | left_start_idx_(left_start_idx), |
59 | left_end_idx_(left_end_idx), |
60 | right_start_idx_(right_start_idx), |
61 | result_(false) {} |
62 | |
63 | template <typename ArrayType> |
64 | inline Status CompareValues(const ArrayType& left) { |
65 | const auto& right = checked_cast<const ArrayType&>(right_); |
66 | |
67 | for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_; |
68 | ++i, ++o_i) { |
69 | const bool is_null = left.IsNull(i); |
70 | if (is_null != right.IsNull(o_i) || |
71 | (!is_null && left.Value(i) != right.Value(o_i))) { |
72 | result_ = false; |
73 | return Status::OK(); |
74 | } |
75 | } |
76 | result_ = true; |
77 | return Status::OK(); |
78 | } |
79 | |
80 | bool CompareBinaryRange(const BinaryArray& left) const { |
81 | const auto& right = checked_cast<const BinaryArray&>(right_); |
82 | |
83 | for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_; |
84 | ++i, ++o_i) { |
85 | const bool is_null = left.IsNull(i); |
86 | if (is_null != right.IsNull(o_i)) { |
87 | return false; |
88 | } |
89 | if (is_null) continue; |
90 | const int32_t begin_offset = left.value_offset(i); |
91 | const int32_t end_offset = left.value_offset(i + 1); |
92 | const int32_t right_begin_offset = right.value_offset(o_i); |
93 | const int32_t right_end_offset = right.value_offset(o_i + 1); |
94 | // Underlying can't be equal if the size isn't equal |
95 | if (end_offset - begin_offset != right_end_offset - right_begin_offset) { |
96 | return false; |
97 | } |
98 | |
99 | if (end_offset - begin_offset > 0 && |
100 | std::memcmp(left.value_data()->data() + begin_offset, |
101 | right.value_data()->data() + right_begin_offset, |
102 | static_cast<size_t>(end_offset - begin_offset))) { |
103 | return false; |
104 | } |
105 | } |
106 | return true; |
107 | } |
108 | |
109 | bool CompareLists(const ListArray& left) { |
110 | const auto& right = checked_cast<const ListArray&>(right_); |
111 | |
112 | const std::shared_ptr<Array>& left_values = left.values(); |
113 | const std::shared_ptr<Array>& right_values = right.values(); |
114 | |
115 | for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_; |
116 | ++i, ++o_i) { |
117 | const bool is_null = left.IsNull(i); |
118 | if (is_null != right.IsNull(o_i)) { |
119 | return false; |
120 | } |
121 | if (is_null) continue; |
122 | const int32_t begin_offset = left.value_offset(i); |
123 | const int32_t end_offset = left.value_offset(i + 1); |
124 | const int32_t right_begin_offset = right.value_offset(o_i); |
125 | const int32_t right_end_offset = right.value_offset(o_i + 1); |
126 | // Underlying can't be equal if the size isn't equal |
127 | if (end_offset - begin_offset != right_end_offset - right_begin_offset) { |
128 | return false; |
129 | } |
130 | if (!left_values->RangeEquals(begin_offset, end_offset, right_begin_offset, |
131 | right_values)) { |
132 | return false; |
133 | } |
134 | } |
135 | return true; |
136 | } |
137 | |
138 | bool CompareStructs(const StructArray& left) { |
139 | const auto& right = checked_cast<const StructArray&>(right_); |
140 | bool equal_fields = true; |
141 | for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_; |
142 | ++i, ++o_i) { |
143 | if (left.IsNull(i) != right.IsNull(o_i)) { |
144 | return false; |
145 | } |
146 | if (left.IsNull(i)) continue; |
147 | for (int j = 0; j < left.num_fields(); ++j) { |
148 | // TODO: really we should be comparing stretches of non-null data rather |
149 | // than looking at one value at a time. |
150 | equal_fields = left.field(j)->RangeEquals(i, i + 1, o_i, right.field(j)); |
151 | if (!equal_fields) { |
152 | return false; |
153 | } |
154 | } |
155 | } |
156 | return true; |
157 | } |
158 | |
159 | bool CompareUnions(const UnionArray& left) const { |
160 | const auto& right = checked_cast<const UnionArray&>(right_); |
161 | |
162 | const UnionMode::type union_mode = left.mode(); |
163 | if (union_mode != right.mode()) { |
164 | return false; |
165 | } |
166 | |
167 | const auto& left_type = checked_cast<const UnionType&>(*left.type()); |
168 | |
169 | // Define a mapping from the type id to child number |
170 | uint8_t max_code = 0; |
171 | |
172 | const std::vector<uint8_t>& type_codes = left_type.type_codes(); |
173 | for (size_t i = 0; i < type_codes.size(); ++i) { |
174 | const uint8_t code = type_codes[i]; |
175 | if (code > max_code) { |
176 | max_code = code; |
177 | } |
178 | } |
179 | |
180 | // Store mapping in a vector for constant time lookups |
181 | std::vector<uint8_t> type_id_to_child_num(max_code + 1); |
182 | for (uint8_t i = 0; i < static_cast<uint8_t>(type_codes.size()); ++i) { |
183 | type_id_to_child_num[type_codes[i]] = i; |
184 | } |
185 | |
186 | const uint8_t* left_ids = left.raw_type_ids(); |
187 | const uint8_t* right_ids = right.raw_type_ids(); |
188 | |
189 | uint8_t id, child_num; |
190 | for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_; |
191 | ++i, ++o_i) { |
192 | if (left.IsNull(i) != right.IsNull(o_i)) { |
193 | return false; |
194 | } |
195 | if (left.IsNull(i)) continue; |
196 | if (left_ids[i] != right_ids[o_i]) { |
197 | return false; |
198 | } |
199 | |
200 | id = left_ids[i]; |
201 | child_num = type_id_to_child_num[id]; |
202 | |
203 | // TODO(wesm): really we should be comparing stretches of non-null data |
204 | // rather than looking at one value at a time. |
205 | if (union_mode == UnionMode::SPARSE) { |
206 | if (!left.child(child_num)->RangeEquals(i, i + 1, o_i, right.child(child_num))) { |
207 | return false; |
208 | } |
209 | } else { |
210 | const int32_t offset = left.raw_value_offsets()[i]; |
211 | const int32_t o_offset = right.raw_value_offsets()[o_i]; |
212 | if (!left.child(child_num)->RangeEquals(offset, offset + 1, o_offset, |
213 | right.child(child_num))) { |
214 | return false; |
215 | } |
216 | } |
217 | } |
218 | return true; |
219 | } |
220 | |
221 | Status Visit(const BinaryArray& left) { |
222 | result_ = CompareBinaryRange(left); |
223 | return Status::OK(); |
224 | } |
225 | |
226 | Status Visit(const FixedSizeBinaryArray& left) { |
227 | const auto& right = checked_cast<const FixedSizeBinaryArray&>(right_); |
228 | |
229 | int32_t width = left.byte_width(); |
230 | |
231 | const uint8_t* left_data = nullptr; |
232 | const uint8_t* right_data = nullptr; |
233 | |
234 | if (left.values()) { |
235 | left_data = left.raw_values(); |
236 | } |
237 | |
238 | if (right.values()) { |
239 | right_data = right.raw_values(); |
240 | } |
241 | |
242 | for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_; |
243 | ++i, ++o_i) { |
244 | const bool is_null = left.IsNull(i); |
245 | if (is_null != right.IsNull(o_i)) { |
246 | result_ = false; |
247 | return Status::OK(); |
248 | } |
249 | if (is_null) continue; |
250 | |
251 | if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) { |
252 | result_ = false; |
253 | return Status::OK(); |
254 | } |
255 | } |
256 | result_ = true; |
257 | return Status::OK(); |
258 | } |
259 | |
260 | Status Visit(const Decimal128Array& left) { |
261 | return Visit(checked_cast<const FixedSizeBinaryArray&>(left)); |
262 | } |
263 | |
264 | Status Visit(const NullArray& left) { |
265 | ARROW_UNUSED(left); |
266 | result_ = true; |
267 | return Status::OK(); |
268 | } |
269 | |
270 | template <typename T> |
271 | typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value, Status>::type Visit( |
272 | const T& left) { |
273 | return CompareValues<T>(left); |
274 | } |
275 | |
276 | Status Visit(const ListArray& left) { |
277 | result_ = CompareLists(left); |
278 | return Status::OK(); |
279 | } |
280 | |
281 | Status Visit(const StructArray& left) { |
282 | result_ = CompareStructs(left); |
283 | return Status::OK(); |
284 | } |
285 | |
286 | Status Visit(const UnionArray& left) { |
287 | result_ = CompareUnions(left); |
288 | return Status::OK(); |
289 | } |
290 | |
291 | Status Visit(const DictionaryArray& left) { |
292 | const auto& right = checked_cast<const DictionaryArray&>(right_); |
293 | if (!left.dictionary()->Equals(right.dictionary())) { |
294 | result_ = false; |
295 | return Status::OK(); |
296 | } |
297 | result_ = left.indices()->RangeEquals(left_start_idx_, left_end_idx_, |
298 | right_start_idx_, right.indices()); |
299 | return Status::OK(); |
300 | } |
301 | |
302 | bool result() const { return result_; } |
303 | |
304 | protected: |
305 | const Array& right_; |
306 | int64_t left_start_idx_; |
307 | int64_t left_end_idx_; |
308 | int64_t right_start_idx_; |
309 | |
310 | bool result_; |
311 | }; |
312 | |
313 | static bool IsEqualPrimitive(const PrimitiveArray& left, const PrimitiveArray& right) { |
314 | const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type()); |
315 | const int byte_width = size_meta.bit_width() / CHAR_BIT; |
316 | |
317 | const uint8_t* left_data = nullptr; |
318 | const uint8_t* right_data = nullptr; |
319 | |
320 | if (left.values()) { |
321 | left_data = left.values()->data() + left.offset() * byte_width; |
322 | } |
323 | |
324 | if (right.values()) { |
325 | right_data = right.values()->data() + right.offset() * byte_width; |
326 | } |
327 | |
328 | if (byte_width == 0) { |
329 | // Special case 0-width data, as the data pointers may be null |
330 | for (int64_t i = 0; i < left.length(); ++i) { |
331 | if (left.IsNull(i) != right.IsNull(i)) { |
332 | return false; |
333 | } |
334 | } |
335 | return true; |
336 | } else if (left.null_count() > 0) { |
337 | for (int64_t i = 0; i < left.length(); ++i) { |
338 | const bool left_null = left.IsNull(i); |
339 | const bool right_null = right.IsNull(i); |
340 | if (left_null != right_null) { |
341 | return false; |
342 | } |
343 | if (!left_null && memcmp(left_data, right_data, byte_width) != 0) { |
344 | return false; |
345 | } |
346 | left_data += byte_width; |
347 | right_data += byte_width; |
348 | } |
349 | return true; |
350 | } else { |
351 | auto number_of_bytes_to_compare = static_cast<size_t>(byte_width * left.length()); |
352 | return memcmp(left_data, right_data, number_of_bytes_to_compare) == 0; |
353 | } |
354 | } |
355 | |
356 | class ArrayEqualsVisitor : public RangeEqualsVisitor { |
357 | public: |
358 | explicit ArrayEqualsVisitor(const Array& right) |
359 | : RangeEqualsVisitor(right, 0, right.length(), 0) {} |
360 | |
361 | Status Visit(const NullArray& left) { |
362 | ARROW_UNUSED(left); |
363 | result_ = true; |
364 | return Status::OK(); |
365 | } |
366 | |
367 | Status Visit(const BooleanArray& left) { |
368 | const auto& right = checked_cast<const BooleanArray&>(right_); |
369 | |
370 | if (left.null_count() > 0) { |
371 | const uint8_t* left_data = left.values()->data(); |
372 | const uint8_t* right_data = right.values()->data(); |
373 | |
374 | for (int64_t i = 0; i < left.length(); ++i) { |
375 | if (left.IsValid(i) && BitUtil::GetBit(left_data, i + left.offset()) != |
376 | BitUtil::GetBit(right_data, i + right.offset())) { |
377 | result_ = false; |
378 | return Status::OK(); |
379 | } |
380 | } |
381 | result_ = true; |
382 | } else { |
383 | result_ = BitmapEquals(left.values()->data(), left.offset(), right.values()->data(), |
384 | right.offset(), left.length()); |
385 | } |
386 | return Status::OK(); |
387 | } |
388 | |
389 | template <typename T> |
390 | typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value && |
391 | !std::is_base_of<BooleanArray, T>::value, |
392 | Status>::type |
393 | Visit(const T& left) { |
394 | result_ = IsEqualPrimitive(left, checked_cast<const PrimitiveArray&>(right_)); |
395 | return Status::OK(); |
396 | } |
397 | |
398 | template <typename ArrayType> |
399 | bool ValueOffsetsEqual(const ArrayType& left) { |
400 | const auto& right = checked_cast<const ArrayType&>(right_); |
401 | |
402 | if (left.offset() == 0 && right.offset() == 0) { |
403 | return left.value_offsets()->Equals(*right.value_offsets(), |
404 | (left.length() + 1) * sizeof(int32_t)); |
405 | } else { |
406 | // One of the arrays is sliced; logic is more complicated because the |
407 | // value offsets are not both 0-based |
408 | auto left_offsets = |
409 | reinterpret_cast<const int32_t*>(left.value_offsets()->data()) + left.offset(); |
410 | auto right_offsets = |
411 | reinterpret_cast<const int32_t*>(right.value_offsets()->data()) + |
412 | right.offset(); |
413 | |
414 | for (int64_t i = 0; i < left.length() + 1; ++i) { |
415 | if (left_offsets[i] - left_offsets[0] != right_offsets[i] - right_offsets[0]) { |
416 | return false; |
417 | } |
418 | } |
419 | return true; |
420 | } |
421 | } |
422 | |
423 | bool CompareBinary(const BinaryArray& left) { |
424 | const auto& right = checked_cast<const BinaryArray&>(right_); |
425 | |
426 | bool equal_offsets = ValueOffsetsEqual<BinaryArray>(left); |
427 | if (!equal_offsets) { |
428 | return false; |
429 | } |
430 | |
431 | if (!left.value_data() && !(right.value_data())) { |
432 | return true; |
433 | } |
434 | if (left.value_offset(left.length()) == 0) { |
435 | return true; |
436 | } |
437 | |
438 | const uint8_t* left_data = left.value_data()->data(); |
439 | const uint8_t* right_data = right.value_data()->data(); |
440 | |
441 | if (left.null_count() == 0) { |
442 | // Fast path for null count 0, single memcmp |
443 | if (left.offset() == 0 && right.offset() == 0) { |
444 | return std::memcmp(left_data, right_data, |
445 | left.raw_value_offsets()[left.length()]) == 0; |
446 | } else { |
447 | const int64_t total_bytes = |
448 | left.value_offset(left.length()) - left.value_offset(0); |
449 | return std::memcmp(left_data + left.value_offset(0), |
450 | right_data + right.value_offset(0), |
451 | static_cast<size_t>(total_bytes)) == 0; |
452 | } |
453 | } else { |
454 | // ARROW-537: Only compare data in non-null slots |
455 | const int32_t* left_offsets = left.raw_value_offsets(); |
456 | const int32_t* right_offsets = right.raw_value_offsets(); |
457 | for (int64_t i = 0; i < left.length(); ++i) { |
458 | if (left.IsNull(i)) { |
459 | continue; |
460 | } |
461 | if (std::memcmp(left_data + left_offsets[i], right_data + right_offsets[i], |
462 | left.value_length(i))) { |
463 | return false; |
464 | } |
465 | } |
466 | return true; |
467 | } |
468 | } |
469 | |
470 | Status Visit(const BinaryArray& left) { |
471 | result_ = CompareBinary(left); |
472 | return Status::OK(); |
473 | } |
474 | |
475 | Status Visit(const ListArray& left) { |
476 | const auto& right = checked_cast<const ListArray&>(right_); |
477 | bool equal_offsets = ValueOffsetsEqual<ListArray>(left); |
478 | if (!equal_offsets) { |
479 | result_ = false; |
480 | return Status::OK(); |
481 | } |
482 | |
483 | result_ = left.values()->RangeEquals( |
484 | left.value_offset(0), left.value_offset(left.length()) - left.value_offset(0), |
485 | right.value_offset(0), right.values()); |
486 | return Status::OK(); |
487 | } |
488 | |
489 | Status Visit(const DictionaryArray& left) { |
490 | const auto& right = checked_cast<const DictionaryArray&>(right_); |
491 | if (!left.dictionary()->Equals(right.dictionary())) { |
492 | result_ = false; |
493 | } else { |
494 | result_ = left.indices()->Equals(right.indices()); |
495 | } |
496 | return Status::OK(); |
497 | } |
498 | |
499 | template <typename T> |
500 | typename std::enable_if<std::is_base_of<NestedType, typename T::TypeClass>::value, |
501 | Status>::type |
502 | Visit(const T& left) { |
503 | return RangeEqualsVisitor::Visit(left); |
504 | } |
505 | }; |
506 | |
507 | template <typename TYPE> |
508 | inline bool FloatingApproxEquals(const NumericArray<TYPE>& left, |
509 | const NumericArray<TYPE>& right) { |
510 | using T = typename TYPE::c_type; |
511 | |
512 | const T* left_data = left.raw_values(); |
513 | const T* right_data = right.raw_values(); |
514 | |
515 | static constexpr T EPSILON = static_cast<T>(1E-5); |
516 | |
517 | if (left.null_count() > 0) { |
518 | for (int64_t i = 0; i < left.length(); ++i) { |
519 | if (left.IsNull(i)) continue; |
520 | if (fabs(left_data[i] - right_data[i]) > EPSILON) { |
521 | return false; |
522 | } |
523 | } |
524 | } else { |
525 | for (int64_t i = 0; i < left.length(); ++i) { |
526 | if (fabs(left_data[i] - right_data[i]) > EPSILON) { |
527 | return false; |
528 | } |
529 | } |
530 | } |
531 | return true; |
532 | } |
533 | |
534 | class ApproxEqualsVisitor : public ArrayEqualsVisitor { |
535 | public: |
536 | using ArrayEqualsVisitor::ArrayEqualsVisitor; |
537 | using ArrayEqualsVisitor::Visit; |
538 | |
539 | Status Visit(const FloatArray& left) { |
540 | result_ = |
541 | FloatingApproxEquals<FloatType>(left, checked_cast<const FloatArray&>(right_)); |
542 | return Status::OK(); |
543 | } |
544 | |
545 | Status Visit(const DoubleArray& left) { |
546 | result_ = |
547 | FloatingApproxEquals<DoubleType>(left, checked_cast<const DoubleArray&>(right_)); |
548 | return Status::OK(); |
549 | } |
550 | }; |
551 | |
552 | static bool BaseDataEquals(const Array& left, const Array& right) { |
553 | if (left.length() != right.length() || left.null_count() != right.null_count() || |
554 | left.type_id() != right.type_id()) { |
555 | return false; |
556 | } |
557 | // ARROW-2567: Ensure that not only the type id but also the type equality |
558 | // itself is checked. |
559 | if (!TypeEquals(*left.type(), *right.type())) { |
560 | return false; |
561 | } |
562 | if (left.null_count() > 0 && left.null_count() < left.length()) { |
563 | return BitmapEquals(left.null_bitmap()->data(), left.offset(), |
564 | right.null_bitmap()->data(), right.offset(), left.length()); |
565 | } |
566 | return true; |
567 | } |
568 | |
569 | template <typename VISITOR> |
570 | inline bool ArrayEqualsImpl(const Array& left, const Array& right) { |
571 | bool are_equal; |
572 | // The arrays are the same object |
573 | if (&left == &right) { |
574 | are_equal = true; |
575 | } else if (!BaseDataEquals(left, right)) { |
576 | are_equal = false; |
577 | } else if (left.length() == 0) { |
578 | are_equal = true; |
579 | } else if (left.null_count() == left.length()) { |
580 | are_equal = true; |
581 | } else { |
582 | VISITOR visitor(right); |
583 | auto error = VisitArrayInline(left, &visitor); |
584 | if (!error.ok()) { |
585 | DCHECK(false) << "Arrays are not comparable: " << error.ToString(); |
586 | } |
587 | are_equal = visitor.result(); |
588 | } |
589 | return are_equal; |
590 | } |
591 | |
592 | class TypeEqualsVisitor { |
593 | public: |
594 | explicit TypeEqualsVisitor(const DataType& right) : right_(right), result_(false) {} |
595 | |
596 | Status VisitChildren(const DataType& left) { |
597 | if (left.num_children() != right_.num_children()) { |
598 | result_ = false; |
599 | return Status::OK(); |
600 | } |
601 | |
602 | for (int i = 0; i < left.num_children(); ++i) { |
603 | if (!left.child(i)->Equals(right_.child(i))) { |
604 | result_ = false; |
605 | return Status::OK(); |
606 | } |
607 | } |
608 | result_ = true; |
609 | return Status::OK(); |
610 | } |
611 | |
612 | template <typename T> |
613 | typename std::enable_if<std::is_base_of<NoExtraMeta, T>::value || |
614 | std::is_base_of<PrimitiveCType, T>::value, |
615 | Status>::type |
616 | Visit(const T&) { |
617 | result_ = true; |
618 | return Status::OK(); |
619 | } |
620 | |
621 | template <typename T> |
622 | typename std::enable_if<std::is_base_of<TimeType, T>::value || |
623 | std::is_base_of<DateType, T>::value, |
624 | Status>::type |
625 | Visit(const T& left) { |
626 | const auto& right = checked_cast<const T&>(right_); |
627 | result_ = left.unit() == right.unit(); |
628 | return Status::OK(); |
629 | } |
630 | |
631 | Status Visit(const TimestampType& left) { |
632 | const auto& right = checked_cast<const TimestampType&>(right_); |
633 | result_ = left.unit() == right.unit() && left.timezone() == right.timezone(); |
634 | return Status::OK(); |
635 | } |
636 | |
637 | Status Visit(const FixedSizeBinaryType& left) { |
638 | const auto& right = checked_cast<const FixedSizeBinaryType&>(right_); |
639 | result_ = left.byte_width() == right.byte_width(); |
640 | return Status::OK(); |
641 | } |
642 | |
643 | Status Visit(const Decimal128Type& left) { |
644 | const auto& right = checked_cast<const Decimal128Type&>(right_); |
645 | result_ = left.precision() == right.precision() && left.scale() == right.scale(); |
646 | return Status::OK(); |
647 | } |
648 | |
649 | Status Visit(const ListType& left) { return VisitChildren(left); } |
650 | |
651 | Status Visit(const StructType& left) { return VisitChildren(left); } |
652 | |
653 | Status Visit(const UnionType& left) { |
654 | const auto& right = checked_cast<const UnionType&>(right_); |
655 | |
656 | if (left.mode() != right.mode() || |
657 | left.type_codes().size() != right.type_codes().size()) { |
658 | result_ = false; |
659 | return Status::OK(); |
660 | } |
661 | |
662 | const std::vector<uint8_t>& left_codes = left.type_codes(); |
663 | const std::vector<uint8_t>& right_codes = right.type_codes(); |
664 | |
665 | for (size_t i = 0; i < left_codes.size(); ++i) { |
666 | if (left_codes[i] != right_codes[i]) { |
667 | result_ = false; |
668 | return Status::OK(); |
669 | } |
670 | } |
671 | |
672 | for (int i = 0; i < left.num_children(); ++i) { |
673 | if (!left.child(i)->Equals(right_.child(i))) { |
674 | result_ = false; |
675 | return Status::OK(); |
676 | } |
677 | } |
678 | |
679 | result_ = true; |
680 | return Status::OK(); |
681 | } |
682 | |
683 | Status Visit(const DictionaryType& left) { |
684 | const auto& right = checked_cast<const DictionaryType&>(right_); |
685 | result_ = left.index_type()->Equals(right.index_type()) && |
686 | left.dictionary()->Equals(right.dictionary()) && |
687 | (left.ordered() == right.ordered()); |
688 | return Status::OK(); |
689 | } |
690 | |
691 | bool result() const { return result_; } |
692 | |
693 | protected: |
694 | const DataType& right_; |
695 | bool result_; |
696 | }; |
697 | |
698 | } // namespace internal |
699 | |
700 | bool ArrayEquals(const Array& left, const Array& right) { |
701 | return internal::ArrayEqualsImpl<internal::ArrayEqualsVisitor>(left, right); |
702 | } |
703 | |
704 | bool ArrayApproxEquals(const Array& left, const Array& right) { |
705 | return internal::ArrayEqualsImpl<internal::ApproxEqualsVisitor>(left, right); |
706 | } |
707 | |
708 | bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx, |
709 | int64_t left_end_idx, int64_t right_start_idx) { |
710 | bool are_equal; |
711 | if (&left == &right) { |
712 | are_equal = true; |
713 | } else if (left.type_id() != right.type_id()) { |
714 | are_equal = false; |
715 | } else if (left.length() == 0) { |
716 | are_equal = true; |
717 | } else { |
718 | internal::RangeEqualsVisitor visitor(right, left_start_idx, left_end_idx, |
719 | right_start_idx); |
720 | auto error = VisitArrayInline(left, &visitor); |
721 | if (!error.ok()) { |
722 | DCHECK(false) << "Arrays are not comparable: " << error.ToString(); |
723 | } |
724 | are_equal = visitor.result(); |
725 | } |
726 | return are_equal; |
727 | } |
728 | |
729 | bool StridedTensorContentEquals(int dim_index, int64_t left_offset, int64_t right_offset, |
730 | int elem_size, const Tensor& left, const Tensor& right) { |
731 | if (dim_index == left.ndim() - 1) { |
732 | for (int64_t i = 0; i < left.shape()[dim_index]; ++i) { |
733 | if (memcmp(left.raw_data() + left_offset + i * left.strides()[dim_index], |
734 | right.raw_data() + right_offset + i * right.strides()[dim_index], |
735 | elem_size) != 0) { |
736 | return false; |
737 | } |
738 | } |
739 | return true; |
740 | } |
741 | for (int64_t i = 0; i < left.shape()[dim_index]; ++i) { |
742 | if (!StridedTensorContentEquals(dim_index + 1, left_offset, right_offset, elem_size, |
743 | left, right)) { |
744 | return false; |
745 | } |
746 | left_offset += left.strides()[dim_index]; |
747 | right_offset += right.strides()[dim_index]; |
748 | } |
749 | return true; |
750 | } |
751 | |
752 | bool TensorEquals(const Tensor& left, const Tensor& right) { |
753 | bool are_equal; |
754 | // The arrays are the same object |
755 | if (&left == &right) { |
756 | are_equal = true; |
757 | } else if (left.type_id() != right.type_id()) { |
758 | are_equal = false; |
759 | } else if (left.size() == 0) { |
760 | are_equal = true; |
761 | } else { |
762 | if (!left.is_contiguous() || !right.is_contiguous()) { |
763 | const auto& shape = left.shape(); |
764 | if (shape != right.shape()) { |
765 | are_equal = false; |
766 | } else { |
767 | const auto& type = checked_cast<const FixedWidthType&>(*left.type()); |
768 | are_equal = |
769 | StridedTensorContentEquals(0, 0, 0, type.bit_width() / 8, left, right); |
770 | } |
771 | } else { |
772 | const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type()); |
773 | const int byte_width = size_meta.bit_width() / CHAR_BIT; |
774 | DCHECK_GT(byte_width, 0); |
775 | |
776 | const uint8_t* left_data = left.data()->data(); |
777 | const uint8_t* right_data = right.data()->data(); |
778 | |
779 | are_equal = memcmp(left_data, right_data, |
780 | static_cast<size_t>(byte_width * left.size())) == 0; |
781 | } |
782 | } |
783 | return are_equal; |
784 | } |
785 | |
786 | namespace { |
787 | |
788 | template <typename LeftSparseIndexType, typename RightSparseIndexType> |
789 | struct SparseTensorEqualsImpl { |
790 | static bool Compare(const SparseTensorImpl<LeftSparseIndexType>& left, |
791 | const SparseTensorImpl<RightSparseIndexType>& right) { |
792 | // TODO(mrkn): should we support the equality among different formats? |
793 | return false; |
794 | } |
795 | }; |
796 | |
797 | template <typename SparseIndexType> |
798 | struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> { |
799 | static bool Compare(const SparseTensorImpl<SparseIndexType>& left, |
800 | const SparseTensorImpl<SparseIndexType>& right) { |
801 | DCHECK(left.type()->id() == right.type()->id()); |
802 | DCHECK(left.shape() == right.shape()); |
803 | DCHECK(left.non_zero_length() == right.non_zero_length()); |
804 | |
805 | const auto& left_index = checked_cast<const SparseIndexType&>(*left.sparse_index()); |
806 | const auto& right_index = checked_cast<const SparseIndexType&>(*right.sparse_index()); |
807 | |
808 | if (!left_index.Equals(right_index)) { |
809 | return false; |
810 | } |
811 | |
812 | const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type()); |
813 | const int byte_width = size_meta.bit_width() / CHAR_BIT; |
814 | DCHECK_GT(byte_width, 0); |
815 | |
816 | const uint8_t* left_data = left.data()->data(); |
817 | const uint8_t* right_data = right.data()->data(); |
818 | |
819 | return memcmp(left_data, right_data, |
820 | static_cast<size_t>(byte_width * left.non_zero_length())); |
821 | } |
822 | }; |
823 | |
824 | template <typename SparseIndexType> |
825 | inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType>& left, |
826 | const SparseTensor& right) { |
827 | switch (right.format_id()) { |
828 | case SparseTensorFormat::COO: { |
829 | const auto& right_coo = |
830 | checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(right); |
831 | return SparseTensorEqualsImpl<SparseIndexType, SparseCOOIndex>::Compare(left, |
832 | right_coo); |
833 | } |
834 | |
835 | case SparseTensorFormat::CSR: { |
836 | const auto& right_csr = |
837 | checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(right); |
838 | return SparseTensorEqualsImpl<SparseIndexType, SparseCSRIndex>::Compare(left, |
839 | right_csr); |
840 | } |
841 | |
842 | default: |
843 | return false; |
844 | } |
845 | } |
846 | |
847 | } // namespace |
848 | |
849 | bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right) { |
850 | if (&left == &right) { |
851 | return true; |
852 | } else if (left.type()->id() != right.type()->id()) { |
853 | return false; |
854 | } else if (left.size() == 0) { |
855 | return true; |
856 | } else if (left.shape() != right.shape()) { |
857 | return false; |
858 | } else if (left.non_zero_length() != right.non_zero_length()) { |
859 | return false; |
860 | } |
861 | |
862 | switch (left.format_id()) { |
863 | case SparseTensorFormat::COO: { |
864 | const auto& left_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(left); |
865 | return SparseTensorEqualsImplDispatch(left_coo, right); |
866 | } |
867 | |
868 | case SparseTensorFormat::CSR: { |
869 | const auto& left_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(left); |
870 | return SparseTensorEqualsImplDispatch(left_csr, right); |
871 | } |
872 | |
873 | default: |
874 | return false; |
875 | } |
876 | } |
877 | |
878 | bool TypeEquals(const DataType& left, const DataType& right) { |
879 | bool are_equal; |
880 | // The arrays are the same object |
881 | if (&left == &right) { |
882 | are_equal = true; |
883 | } else if (left.id() != right.id()) { |
884 | are_equal = false; |
885 | } else { |
886 | internal::TypeEqualsVisitor visitor(right); |
887 | auto error = VisitTypeInline(left, &visitor); |
888 | if (!error.ok()) { |
889 | DCHECK(false) << "Types are not comparable: " << error.ToString(); |
890 | } |
891 | are_equal = visitor.result(); |
892 | } |
893 | return are_equal; |
894 | } |
895 | |
896 | } // namespace arrow |
897 | |