1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/tensor.h"
19
20#include <cstddef>
21#include <cstdint>
22#include <functional>
23#include <memory>
24#include <numeric>
25#include <string>
26#include <vector>
27
28#include "arrow/compare.h"
29#include "arrow/type.h"
30#include "arrow/type_traits.h"
31#include "arrow/util/checked_cast.h"
32#include "arrow/util/logging.h"
33
34namespace arrow {
35
36using internal::checked_cast;
37
38static void ComputeRowMajorStrides(const FixedWidthType& type,
39 const std::vector<int64_t>& shape,
40 std::vector<int64_t>* strides) {
41 int64_t remaining = type.bit_width() / 8;
42 for (int64_t dimsize : shape) {
43 remaining *= dimsize;
44 }
45
46 if (remaining == 0) {
47 strides->assign(shape.size(), type.bit_width() / 8);
48 return;
49 }
50
51 for (int64_t dimsize : shape) {
52 remaining /= dimsize;
53 strides->push_back(remaining);
54 }
55}
56
57static void ComputeColumnMajorStrides(const FixedWidthType& type,
58 const std::vector<int64_t>& shape,
59 std::vector<int64_t>* strides) {
60 int64_t total = type.bit_width() / 8;
61 for (int64_t dimsize : shape) {
62 if (dimsize == 0) {
63 strides->assign(shape.size(), type.bit_width() / 8);
64 return;
65 }
66 }
67 for (int64_t dimsize : shape) {
68 strides->push_back(total);
69 total *= dimsize;
70 }
71}
72
73/// Constructor with strides and dimension names
74Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
75 const std::vector<int64_t>& shape, const std::vector<int64_t>& strides,
76 const std::vector<std::string>& dim_names)
77 : type_(type), data_(data), shape_(shape), strides_(strides), dim_names_(dim_names) {
78 DCHECK(is_tensor_supported(type->id()));
79 if (shape.size() > 0 && strides.size() == 0) {
80 ComputeRowMajorStrides(checked_cast<const FixedWidthType&>(*type_), shape, &strides_);
81 }
82}
83
84Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
85 const std::vector<int64_t>& shape, const std::vector<int64_t>& strides)
86 : Tensor(type, data, shape, strides, {}) {}
87
88Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
89 const std::vector<int64_t>& shape)
90 : Tensor(type, data, shape, {}, {}) {}
91
92const std::string& Tensor::dim_name(int i) const {
93 static const std::string kEmpty = "";
94 if (dim_names_.size() == 0) {
95 return kEmpty;
96 } else {
97 DCHECK_LT(i, static_cast<int>(dim_names_.size()));
98 return dim_names_[i];
99 }
100}
101
102int64_t Tensor::size() const {
103 return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int64_t>());
104}
105
106bool Tensor::is_contiguous() const { return is_row_major() || is_column_major(); }
107
108bool Tensor::is_row_major() const {
109 std::vector<int64_t> c_strides;
110 const auto& fw_type = checked_cast<const FixedWidthType&>(*type_);
111 ComputeRowMajorStrides(fw_type, shape_, &c_strides);
112 return strides_ == c_strides;
113}
114
115bool Tensor::is_column_major() const {
116 std::vector<int64_t> f_strides;
117 const auto& fw_type = checked_cast<const FixedWidthType&>(*type_);
118 ComputeColumnMajorStrides(fw_type, shape_, &f_strides);
119 return strides_ == f_strides;
120}
121
122Type::type Tensor::type_id() const { return type_->id(); }
123
124bool Tensor::Equals(const Tensor& other) const { return TensorEquals(*this, other); }
125
126} // namespace arrow
127