1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18// Functions for comparing Arrow data structures
19
20#include "arrow/compare.h"
21
22#include <climits>
23#include <cmath>
24#include <cstdint>
25#include <cstring>
26#include <memory>
27#include <string>
28#include <type_traits>
29#include <vector>
30
31#include "arrow/array.h"
32#include "arrow/buffer.h"
33#include "arrow/sparse_tensor.h"
34#include "arrow/status.h"
35#include "arrow/tensor.h"
36#include "arrow/type.h"
37#include "arrow/util/bit-util.h"
38#include "arrow/util/checked_cast.h"
39#include "arrow/util/logging.h"
40#include "arrow/util/macros.h"
41#include "arrow/visitor_inline.h"
42
43namespace arrow {
44
45using internal::BitmapEquals;
46using internal::checked_cast;
47
48// ----------------------------------------------------------------------
49// Public method implementations
50
51namespace internal {
52
53class RangeEqualsVisitor {
54 public:
55 RangeEqualsVisitor(const Array& right, int64_t left_start_idx, int64_t left_end_idx,
56 int64_t right_start_idx)
57 : right_(right),
58 left_start_idx_(left_start_idx),
59 left_end_idx_(left_end_idx),
60 right_start_idx_(right_start_idx),
61 result_(false) {}
62
63 template <typename ArrayType>
64 inline Status CompareValues(const ArrayType& left) {
65 const auto& right = checked_cast<const ArrayType&>(right_);
66
67 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
68 ++i, ++o_i) {
69 const bool is_null = left.IsNull(i);
70 if (is_null != right.IsNull(o_i) ||
71 (!is_null && left.Value(i) != right.Value(o_i))) {
72 result_ = false;
73 return Status::OK();
74 }
75 }
76 result_ = true;
77 return Status::OK();
78 }
79
80 bool CompareBinaryRange(const BinaryArray& left) const {
81 const auto& right = checked_cast<const BinaryArray&>(right_);
82
83 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
84 ++i, ++o_i) {
85 const bool is_null = left.IsNull(i);
86 if (is_null != right.IsNull(o_i)) {
87 return false;
88 }
89 if (is_null) continue;
90 const int32_t begin_offset = left.value_offset(i);
91 const int32_t end_offset = left.value_offset(i + 1);
92 const int32_t right_begin_offset = right.value_offset(o_i);
93 const int32_t right_end_offset = right.value_offset(o_i + 1);
94 // Underlying can't be equal if the size isn't equal
95 if (end_offset - begin_offset != right_end_offset - right_begin_offset) {
96 return false;
97 }
98
99 if (end_offset - begin_offset > 0 &&
100 std::memcmp(left.value_data()->data() + begin_offset,
101 right.value_data()->data() + right_begin_offset,
102 static_cast<size_t>(end_offset - begin_offset))) {
103 return false;
104 }
105 }
106 return true;
107 }
108
109 bool CompareLists(const ListArray& left) {
110 const auto& right = checked_cast<const ListArray&>(right_);
111
112 const std::shared_ptr<Array>& left_values = left.values();
113 const std::shared_ptr<Array>& right_values = right.values();
114
115 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
116 ++i, ++o_i) {
117 const bool is_null = left.IsNull(i);
118 if (is_null != right.IsNull(o_i)) {
119 return false;
120 }
121 if (is_null) continue;
122 const int32_t begin_offset = left.value_offset(i);
123 const int32_t end_offset = left.value_offset(i + 1);
124 const int32_t right_begin_offset = right.value_offset(o_i);
125 const int32_t right_end_offset = right.value_offset(o_i + 1);
126 // Underlying can't be equal if the size isn't equal
127 if (end_offset - begin_offset != right_end_offset - right_begin_offset) {
128 return false;
129 }
130 if (!left_values->RangeEquals(begin_offset, end_offset, right_begin_offset,
131 right_values)) {
132 return false;
133 }
134 }
135 return true;
136 }
137
138 bool CompareStructs(const StructArray& left) {
139 const auto& right = checked_cast<const StructArray&>(right_);
140 bool equal_fields = true;
141 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
142 ++i, ++o_i) {
143 if (left.IsNull(i) != right.IsNull(o_i)) {
144 return false;
145 }
146 if (left.IsNull(i)) continue;
147 for (int j = 0; j < left.num_fields(); ++j) {
148 // TODO: really we should be comparing stretches of non-null data rather
149 // than looking at one value at a time.
150 equal_fields = left.field(j)->RangeEquals(i, i + 1, o_i, right.field(j));
151 if (!equal_fields) {
152 return false;
153 }
154 }
155 }
156 return true;
157 }
158
159 bool CompareUnions(const UnionArray& left) const {
160 const auto& right = checked_cast<const UnionArray&>(right_);
161
162 const UnionMode::type union_mode = left.mode();
163 if (union_mode != right.mode()) {
164 return false;
165 }
166
167 const auto& left_type = checked_cast<const UnionType&>(*left.type());
168
169 // Define a mapping from the type id to child number
170 uint8_t max_code = 0;
171
172 const std::vector<uint8_t>& type_codes = left_type.type_codes();
173 for (size_t i = 0; i < type_codes.size(); ++i) {
174 const uint8_t code = type_codes[i];
175 if (code > max_code) {
176 max_code = code;
177 }
178 }
179
180 // Store mapping in a vector for constant time lookups
181 std::vector<uint8_t> type_id_to_child_num(max_code + 1);
182 for (uint8_t i = 0; i < static_cast<uint8_t>(type_codes.size()); ++i) {
183 type_id_to_child_num[type_codes[i]] = i;
184 }
185
186 const uint8_t* left_ids = left.raw_type_ids();
187 const uint8_t* right_ids = right.raw_type_ids();
188
189 uint8_t id, child_num;
190 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
191 ++i, ++o_i) {
192 if (left.IsNull(i) != right.IsNull(o_i)) {
193 return false;
194 }
195 if (left.IsNull(i)) continue;
196 if (left_ids[i] != right_ids[o_i]) {
197 return false;
198 }
199
200 id = left_ids[i];
201 child_num = type_id_to_child_num[id];
202
203 // TODO(wesm): really we should be comparing stretches of non-null data
204 // rather than looking at one value at a time.
205 if (union_mode == UnionMode::SPARSE) {
206 if (!left.child(child_num)->RangeEquals(i, i + 1, o_i, right.child(child_num))) {
207 return false;
208 }
209 } else {
210 const int32_t offset = left.raw_value_offsets()[i];
211 const int32_t o_offset = right.raw_value_offsets()[o_i];
212 if (!left.child(child_num)->RangeEquals(offset, offset + 1, o_offset,
213 right.child(child_num))) {
214 return false;
215 }
216 }
217 }
218 return true;
219 }
220
221 Status Visit(const BinaryArray& left) {
222 result_ = CompareBinaryRange(left);
223 return Status::OK();
224 }
225
226 Status Visit(const FixedSizeBinaryArray& left) {
227 const auto& right = checked_cast<const FixedSizeBinaryArray&>(right_);
228
229 int32_t width = left.byte_width();
230
231 const uint8_t* left_data = nullptr;
232 const uint8_t* right_data = nullptr;
233
234 if (left.values()) {
235 left_data = left.raw_values();
236 }
237
238 if (right.values()) {
239 right_data = right.raw_values();
240 }
241
242 for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
243 ++i, ++o_i) {
244 const bool is_null = left.IsNull(i);
245 if (is_null != right.IsNull(o_i)) {
246 result_ = false;
247 return Status::OK();
248 }
249 if (is_null) continue;
250
251 if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) {
252 result_ = false;
253 return Status::OK();
254 }
255 }
256 result_ = true;
257 return Status::OK();
258 }
259
260 Status Visit(const Decimal128Array& left) {
261 return Visit(checked_cast<const FixedSizeBinaryArray&>(left));
262 }
263
264 Status Visit(const NullArray& left) {
265 ARROW_UNUSED(left);
266 result_ = true;
267 return Status::OK();
268 }
269
270 template <typename T>
271 typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value, Status>::type Visit(
272 const T& left) {
273 return CompareValues<T>(left);
274 }
275
276 Status Visit(const ListArray& left) {
277 result_ = CompareLists(left);
278 return Status::OK();
279 }
280
281 Status Visit(const StructArray& left) {
282 result_ = CompareStructs(left);
283 return Status::OK();
284 }
285
286 Status Visit(const UnionArray& left) {
287 result_ = CompareUnions(left);
288 return Status::OK();
289 }
290
291 Status Visit(const DictionaryArray& left) {
292 const auto& right = checked_cast<const DictionaryArray&>(right_);
293 if (!left.dictionary()->Equals(right.dictionary())) {
294 result_ = false;
295 return Status::OK();
296 }
297 result_ = left.indices()->RangeEquals(left_start_idx_, left_end_idx_,
298 right_start_idx_, right.indices());
299 return Status::OK();
300 }
301
302 bool result() const { return result_; }
303
304 protected:
305 const Array& right_;
306 int64_t left_start_idx_;
307 int64_t left_end_idx_;
308 int64_t right_start_idx_;
309
310 bool result_;
311};
312
313static bool IsEqualPrimitive(const PrimitiveArray& left, const PrimitiveArray& right) {
314 const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type());
315 const int byte_width = size_meta.bit_width() / CHAR_BIT;
316
317 const uint8_t* left_data = nullptr;
318 const uint8_t* right_data = nullptr;
319
320 if (left.values()) {
321 left_data = left.values()->data() + left.offset() * byte_width;
322 }
323
324 if (right.values()) {
325 right_data = right.values()->data() + right.offset() * byte_width;
326 }
327
328 if (byte_width == 0) {
329 // Special case 0-width data, as the data pointers may be null
330 for (int64_t i = 0; i < left.length(); ++i) {
331 if (left.IsNull(i) != right.IsNull(i)) {
332 return false;
333 }
334 }
335 return true;
336 } else if (left.null_count() > 0) {
337 for (int64_t i = 0; i < left.length(); ++i) {
338 const bool left_null = left.IsNull(i);
339 const bool right_null = right.IsNull(i);
340 if (left_null != right_null) {
341 return false;
342 }
343 if (!left_null && memcmp(left_data, right_data, byte_width) != 0) {
344 return false;
345 }
346 left_data += byte_width;
347 right_data += byte_width;
348 }
349 return true;
350 } else {
351 auto number_of_bytes_to_compare = static_cast<size_t>(byte_width * left.length());
352 return memcmp(left_data, right_data, number_of_bytes_to_compare) == 0;
353 }
354}
355
356class ArrayEqualsVisitor : public RangeEqualsVisitor {
357 public:
358 explicit ArrayEqualsVisitor(const Array& right)
359 : RangeEqualsVisitor(right, 0, right.length(), 0) {}
360
361 Status Visit(const NullArray& left) {
362 ARROW_UNUSED(left);
363 result_ = true;
364 return Status::OK();
365 }
366
367 Status Visit(const BooleanArray& left) {
368 const auto& right = checked_cast<const BooleanArray&>(right_);
369
370 if (left.null_count() > 0) {
371 const uint8_t* left_data = left.values()->data();
372 const uint8_t* right_data = right.values()->data();
373
374 for (int64_t i = 0; i < left.length(); ++i) {
375 if (left.IsValid(i) && BitUtil::GetBit(left_data, i + left.offset()) !=
376 BitUtil::GetBit(right_data, i + right.offset())) {
377 result_ = false;
378 return Status::OK();
379 }
380 }
381 result_ = true;
382 } else {
383 result_ = BitmapEquals(left.values()->data(), left.offset(), right.values()->data(),
384 right.offset(), left.length());
385 }
386 return Status::OK();
387 }
388
389 template <typename T>
390 typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value &&
391 !std::is_base_of<BooleanArray, T>::value,
392 Status>::type
393 Visit(const T& left) {
394 result_ = IsEqualPrimitive(left, checked_cast<const PrimitiveArray&>(right_));
395 return Status::OK();
396 }
397
398 template <typename ArrayType>
399 bool ValueOffsetsEqual(const ArrayType& left) {
400 const auto& right = checked_cast<const ArrayType&>(right_);
401
402 if (left.offset() == 0 && right.offset() == 0) {
403 return left.value_offsets()->Equals(*right.value_offsets(),
404 (left.length() + 1) * sizeof(int32_t));
405 } else {
406 // One of the arrays is sliced; logic is more complicated because the
407 // value offsets are not both 0-based
408 auto left_offsets =
409 reinterpret_cast<const int32_t*>(left.value_offsets()->data()) + left.offset();
410 auto right_offsets =
411 reinterpret_cast<const int32_t*>(right.value_offsets()->data()) +
412 right.offset();
413
414 for (int64_t i = 0; i < left.length() + 1; ++i) {
415 if (left_offsets[i] - left_offsets[0] != right_offsets[i] - right_offsets[0]) {
416 return false;
417 }
418 }
419 return true;
420 }
421 }
422
423 bool CompareBinary(const BinaryArray& left) {
424 const auto& right = checked_cast<const BinaryArray&>(right_);
425
426 bool equal_offsets = ValueOffsetsEqual<BinaryArray>(left);
427 if (!equal_offsets) {
428 return false;
429 }
430
431 if (!left.value_data() && !(right.value_data())) {
432 return true;
433 }
434 if (left.value_offset(left.length()) == 0) {
435 return true;
436 }
437
438 const uint8_t* left_data = left.value_data()->data();
439 const uint8_t* right_data = right.value_data()->data();
440
441 if (left.null_count() == 0) {
442 // Fast path for null count 0, single memcmp
443 if (left.offset() == 0 && right.offset() == 0) {
444 return std::memcmp(left_data, right_data,
445 left.raw_value_offsets()[left.length()]) == 0;
446 } else {
447 const int64_t total_bytes =
448 left.value_offset(left.length()) - left.value_offset(0);
449 return std::memcmp(left_data + left.value_offset(0),
450 right_data + right.value_offset(0),
451 static_cast<size_t>(total_bytes)) == 0;
452 }
453 } else {
454 // ARROW-537: Only compare data in non-null slots
455 const int32_t* left_offsets = left.raw_value_offsets();
456 const int32_t* right_offsets = right.raw_value_offsets();
457 for (int64_t i = 0; i < left.length(); ++i) {
458 if (left.IsNull(i)) {
459 continue;
460 }
461 if (std::memcmp(left_data + left_offsets[i], right_data + right_offsets[i],
462 left.value_length(i))) {
463 return false;
464 }
465 }
466 return true;
467 }
468 }
469
470 Status Visit(const BinaryArray& left) {
471 result_ = CompareBinary(left);
472 return Status::OK();
473 }
474
475 Status Visit(const ListArray& left) {
476 const auto& right = checked_cast<const ListArray&>(right_);
477 bool equal_offsets = ValueOffsetsEqual<ListArray>(left);
478 if (!equal_offsets) {
479 result_ = false;
480 return Status::OK();
481 }
482
483 result_ = left.values()->RangeEquals(
484 left.value_offset(0), left.value_offset(left.length()) - left.value_offset(0),
485 right.value_offset(0), right.values());
486 return Status::OK();
487 }
488
489 Status Visit(const DictionaryArray& left) {
490 const auto& right = checked_cast<const DictionaryArray&>(right_);
491 if (!left.dictionary()->Equals(right.dictionary())) {
492 result_ = false;
493 } else {
494 result_ = left.indices()->Equals(right.indices());
495 }
496 return Status::OK();
497 }
498
499 template <typename T>
500 typename std::enable_if<std::is_base_of<NestedType, typename T::TypeClass>::value,
501 Status>::type
502 Visit(const T& left) {
503 return RangeEqualsVisitor::Visit(left);
504 }
505};
506
507template <typename TYPE>
508inline bool FloatingApproxEquals(const NumericArray<TYPE>& left,
509 const NumericArray<TYPE>& right) {
510 using T = typename TYPE::c_type;
511
512 const T* left_data = left.raw_values();
513 const T* right_data = right.raw_values();
514
515 static constexpr T EPSILON = static_cast<T>(1E-5);
516
517 if (left.null_count() > 0) {
518 for (int64_t i = 0; i < left.length(); ++i) {
519 if (left.IsNull(i)) continue;
520 if (fabs(left_data[i] - right_data[i]) > EPSILON) {
521 return false;
522 }
523 }
524 } else {
525 for (int64_t i = 0; i < left.length(); ++i) {
526 if (fabs(left_data[i] - right_data[i]) > EPSILON) {
527 return false;
528 }
529 }
530 }
531 return true;
532}
533
534class ApproxEqualsVisitor : public ArrayEqualsVisitor {
535 public:
536 using ArrayEqualsVisitor::ArrayEqualsVisitor;
537 using ArrayEqualsVisitor::Visit;
538
539 Status Visit(const FloatArray& left) {
540 result_ =
541 FloatingApproxEquals<FloatType>(left, checked_cast<const FloatArray&>(right_));
542 return Status::OK();
543 }
544
545 Status Visit(const DoubleArray& left) {
546 result_ =
547 FloatingApproxEquals<DoubleType>(left, checked_cast<const DoubleArray&>(right_));
548 return Status::OK();
549 }
550};
551
552static bool BaseDataEquals(const Array& left, const Array& right) {
553 if (left.length() != right.length() || left.null_count() != right.null_count() ||
554 left.type_id() != right.type_id()) {
555 return false;
556 }
557 // ARROW-2567: Ensure that not only the type id but also the type equality
558 // itself is checked.
559 if (!TypeEquals(*left.type(), *right.type())) {
560 return false;
561 }
562 if (left.null_count() > 0 && left.null_count() < left.length()) {
563 return BitmapEquals(left.null_bitmap()->data(), left.offset(),
564 right.null_bitmap()->data(), right.offset(), left.length());
565 }
566 return true;
567}
568
569template <typename VISITOR>
570inline bool ArrayEqualsImpl(const Array& left, const Array& right) {
571 bool are_equal;
572 // The arrays are the same object
573 if (&left == &right) {
574 are_equal = true;
575 } else if (!BaseDataEquals(left, right)) {
576 are_equal = false;
577 } else if (left.length() == 0) {
578 are_equal = true;
579 } else if (left.null_count() == left.length()) {
580 are_equal = true;
581 } else {
582 VISITOR visitor(right);
583 auto error = VisitArrayInline(left, &visitor);
584 if (!error.ok()) {
585 DCHECK(false) << "Arrays are not comparable: " << error.ToString();
586 }
587 are_equal = visitor.result();
588 }
589 return are_equal;
590}
591
592class TypeEqualsVisitor {
593 public:
594 explicit TypeEqualsVisitor(const DataType& right) : right_(right), result_(false) {}
595
596 Status VisitChildren(const DataType& left) {
597 if (left.num_children() != right_.num_children()) {
598 result_ = false;
599 return Status::OK();
600 }
601
602 for (int i = 0; i < left.num_children(); ++i) {
603 if (!left.child(i)->Equals(right_.child(i))) {
604 result_ = false;
605 return Status::OK();
606 }
607 }
608 result_ = true;
609 return Status::OK();
610 }
611
612 template <typename T>
613 typename std::enable_if<std::is_base_of<NoExtraMeta, T>::value ||
614 std::is_base_of<PrimitiveCType, T>::value,
615 Status>::type
616 Visit(const T&) {
617 result_ = true;
618 return Status::OK();
619 }
620
621 template <typename T>
622 typename std::enable_if<std::is_base_of<TimeType, T>::value ||
623 std::is_base_of<DateType, T>::value,
624 Status>::type
625 Visit(const T& left) {
626 const auto& right = checked_cast<const T&>(right_);
627 result_ = left.unit() == right.unit();
628 return Status::OK();
629 }
630
631 Status Visit(const TimestampType& left) {
632 const auto& right = checked_cast<const TimestampType&>(right_);
633 result_ = left.unit() == right.unit() && left.timezone() == right.timezone();
634 return Status::OK();
635 }
636
637 Status Visit(const FixedSizeBinaryType& left) {
638 const auto& right = checked_cast<const FixedSizeBinaryType&>(right_);
639 result_ = left.byte_width() == right.byte_width();
640 return Status::OK();
641 }
642
643 Status Visit(const Decimal128Type& left) {
644 const auto& right = checked_cast<const Decimal128Type&>(right_);
645 result_ = left.precision() == right.precision() && left.scale() == right.scale();
646 return Status::OK();
647 }
648
649 Status Visit(const ListType& left) { return VisitChildren(left); }
650
651 Status Visit(const StructType& left) { return VisitChildren(left); }
652
653 Status Visit(const UnionType& left) {
654 const auto& right = checked_cast<const UnionType&>(right_);
655
656 if (left.mode() != right.mode() ||
657 left.type_codes().size() != right.type_codes().size()) {
658 result_ = false;
659 return Status::OK();
660 }
661
662 const std::vector<uint8_t>& left_codes = left.type_codes();
663 const std::vector<uint8_t>& right_codes = right.type_codes();
664
665 for (size_t i = 0; i < left_codes.size(); ++i) {
666 if (left_codes[i] != right_codes[i]) {
667 result_ = false;
668 return Status::OK();
669 }
670 }
671
672 for (int i = 0; i < left.num_children(); ++i) {
673 if (!left.child(i)->Equals(right_.child(i))) {
674 result_ = false;
675 return Status::OK();
676 }
677 }
678
679 result_ = true;
680 return Status::OK();
681 }
682
683 Status Visit(const DictionaryType& left) {
684 const auto& right = checked_cast<const DictionaryType&>(right_);
685 result_ = left.index_type()->Equals(right.index_type()) &&
686 left.dictionary()->Equals(right.dictionary()) &&
687 (left.ordered() == right.ordered());
688 return Status::OK();
689 }
690
691 bool result() const { return result_; }
692
693 protected:
694 const DataType& right_;
695 bool result_;
696};
697
698} // namespace internal
699
700bool ArrayEquals(const Array& left, const Array& right) {
701 return internal::ArrayEqualsImpl<internal::ArrayEqualsVisitor>(left, right);
702}
703
704bool ArrayApproxEquals(const Array& left, const Array& right) {
705 return internal::ArrayEqualsImpl<internal::ApproxEqualsVisitor>(left, right);
706}
707
708bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx,
709 int64_t left_end_idx, int64_t right_start_idx) {
710 bool are_equal;
711 if (&left == &right) {
712 are_equal = true;
713 } else if (left.type_id() != right.type_id()) {
714 are_equal = false;
715 } else if (left.length() == 0) {
716 are_equal = true;
717 } else {
718 internal::RangeEqualsVisitor visitor(right, left_start_idx, left_end_idx,
719 right_start_idx);
720 auto error = VisitArrayInline(left, &visitor);
721 if (!error.ok()) {
722 DCHECK(false) << "Arrays are not comparable: " << error.ToString();
723 }
724 are_equal = visitor.result();
725 }
726 return are_equal;
727}
728
729bool StridedTensorContentEquals(int dim_index, int64_t left_offset, int64_t right_offset,
730 int elem_size, const Tensor& left, const Tensor& right) {
731 if (dim_index == left.ndim() - 1) {
732 for (int64_t i = 0; i < left.shape()[dim_index]; ++i) {
733 if (memcmp(left.raw_data() + left_offset + i * left.strides()[dim_index],
734 right.raw_data() + right_offset + i * right.strides()[dim_index],
735 elem_size) != 0) {
736 return false;
737 }
738 }
739 return true;
740 }
741 for (int64_t i = 0; i < left.shape()[dim_index]; ++i) {
742 if (!StridedTensorContentEquals(dim_index + 1, left_offset, right_offset, elem_size,
743 left, right)) {
744 return false;
745 }
746 left_offset += left.strides()[dim_index];
747 right_offset += right.strides()[dim_index];
748 }
749 return true;
750}
751
752bool TensorEquals(const Tensor& left, const Tensor& right) {
753 bool are_equal;
754 // The arrays are the same object
755 if (&left == &right) {
756 are_equal = true;
757 } else if (left.type_id() != right.type_id()) {
758 are_equal = false;
759 } else if (left.size() == 0) {
760 are_equal = true;
761 } else {
762 if (!left.is_contiguous() || !right.is_contiguous()) {
763 const auto& shape = left.shape();
764 if (shape != right.shape()) {
765 are_equal = false;
766 } else {
767 const auto& type = checked_cast<const FixedWidthType&>(*left.type());
768 are_equal =
769 StridedTensorContentEquals(0, 0, 0, type.bit_width() / 8, left, right);
770 }
771 } else {
772 const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type());
773 const int byte_width = size_meta.bit_width() / CHAR_BIT;
774 DCHECK_GT(byte_width, 0);
775
776 const uint8_t* left_data = left.data()->data();
777 const uint8_t* right_data = right.data()->data();
778
779 are_equal = memcmp(left_data, right_data,
780 static_cast<size_t>(byte_width * left.size())) == 0;
781 }
782 }
783 return are_equal;
784}
785
786namespace {
787
788template <typename LeftSparseIndexType, typename RightSparseIndexType>
789struct SparseTensorEqualsImpl {
790 static bool Compare(const SparseTensorImpl<LeftSparseIndexType>& left,
791 const SparseTensorImpl<RightSparseIndexType>& right) {
792 // TODO(mrkn): should we support the equality among different formats?
793 return false;
794 }
795};
796
797template <typename SparseIndexType>
798struct SparseTensorEqualsImpl<SparseIndexType, SparseIndexType> {
799 static bool Compare(const SparseTensorImpl<SparseIndexType>& left,
800 const SparseTensorImpl<SparseIndexType>& right) {
801 DCHECK(left.type()->id() == right.type()->id());
802 DCHECK(left.shape() == right.shape());
803 DCHECK(left.non_zero_length() == right.non_zero_length());
804
805 const auto& left_index = checked_cast<const SparseIndexType&>(*left.sparse_index());
806 const auto& right_index = checked_cast<const SparseIndexType&>(*right.sparse_index());
807
808 if (!left_index.Equals(right_index)) {
809 return false;
810 }
811
812 const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type());
813 const int byte_width = size_meta.bit_width() / CHAR_BIT;
814 DCHECK_GT(byte_width, 0);
815
816 const uint8_t* left_data = left.data()->data();
817 const uint8_t* right_data = right.data()->data();
818
819 return memcmp(left_data, right_data,
820 static_cast<size_t>(byte_width * left.non_zero_length()));
821 }
822};
823
824template <typename SparseIndexType>
825inline bool SparseTensorEqualsImplDispatch(const SparseTensorImpl<SparseIndexType>& left,
826 const SparseTensor& right) {
827 switch (right.format_id()) {
828 case SparseTensorFormat::COO: {
829 const auto& right_coo =
830 checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(right);
831 return SparseTensorEqualsImpl<SparseIndexType, SparseCOOIndex>::Compare(left,
832 right_coo);
833 }
834
835 case SparseTensorFormat::CSR: {
836 const auto& right_csr =
837 checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(right);
838 return SparseTensorEqualsImpl<SparseIndexType, SparseCSRIndex>::Compare(left,
839 right_csr);
840 }
841
842 default:
843 return false;
844 }
845}
846
847} // namespace
848
849bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right) {
850 if (&left == &right) {
851 return true;
852 } else if (left.type()->id() != right.type()->id()) {
853 return false;
854 } else if (left.size() == 0) {
855 return true;
856 } else if (left.shape() != right.shape()) {
857 return false;
858 } else if (left.non_zero_length() != right.non_zero_length()) {
859 return false;
860 }
861
862 switch (left.format_id()) {
863 case SparseTensorFormat::COO: {
864 const auto& left_coo = checked_cast<const SparseTensorImpl<SparseCOOIndex>&>(left);
865 return SparseTensorEqualsImplDispatch(left_coo, right);
866 }
867
868 case SparseTensorFormat::CSR: {
869 const auto& left_csr = checked_cast<const SparseTensorImpl<SparseCSRIndex>&>(left);
870 return SparseTensorEqualsImplDispatch(left_csr, right);
871 }
872
873 default:
874 return false;
875 }
876}
877
878bool TypeEquals(const DataType& left, const DataType& right) {
879 bool are_equal;
880 // The arrays are the same object
881 if (&left == &right) {
882 are_equal = true;
883 } else if (left.id() != right.id()) {
884 are_equal = false;
885 } else {
886 internal::TypeEqualsVisitor visitor(right);
887 auto error = VisitTypeInline(left, &visitor);
888 if (!error.ok()) {
889 DCHECK(false) << "Types are not comparable: " << error.ToString();
890 }
891 are_equal = visitor.result();
892 }
893 return are_equal;
894}
895
896} // namespace arrow
897