1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "arrow/record_batch.h" |
19 | |
20 | #include <algorithm> |
21 | #include <cstdlib> |
22 | #include <memory> |
23 | #include <sstream> |
24 | #include <string> |
25 | #include <utility> |
26 | |
27 | #include "arrow/array.h" |
28 | #include "arrow/status.h" |
29 | #include "arrow/table.h" |
30 | #include "arrow/type.h" |
31 | #include "arrow/util/logging.h" |
32 | #include "arrow/util/stl.h" |
33 | |
34 | namespace arrow { |
35 | |
36 | Status RecordBatch::AddColumn(int i, const std::string& field_name, |
37 | const std::shared_ptr<Array>& column, |
38 | std::shared_ptr<RecordBatch>* out) const { |
39 | auto field = ::arrow::field(field_name, column->type()); |
40 | return AddColumn(i, field, column, out); |
41 | } |
42 | |
43 | /// \class SimpleRecordBatch |
44 | /// \brief A basic, non-lazy in-memory record batch |
45 | class SimpleRecordBatch : public RecordBatch { |
46 | public: |
47 | SimpleRecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows, |
48 | const std::vector<std::shared_ptr<Array>>& columns) |
49 | : RecordBatch(schema, num_rows) { |
50 | columns_.resize(columns.size()); |
51 | boxed_columns_.resize(schema->num_fields()); |
52 | for (size_t i = 0; i < columns.size(); ++i) { |
53 | columns_[i] = columns[i]->data(); |
54 | } |
55 | } |
56 | |
57 | SimpleRecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows, |
58 | std::vector<std::shared_ptr<Array>>&& columns) |
59 | : RecordBatch(schema, num_rows) { |
60 | columns_.resize(columns.size()); |
61 | boxed_columns_.resize(schema->num_fields()); |
62 | for (size_t i = 0; i < columns.size(); ++i) { |
63 | columns_[i] = columns[i]->data(); |
64 | } |
65 | } |
66 | |
67 | SimpleRecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows, |
68 | std::vector<std::shared_ptr<ArrayData>>&& columns) |
69 | : RecordBatch(schema, num_rows) { |
70 | columns_ = std::move(columns); |
71 | boxed_columns_.resize(schema->num_fields()); |
72 | } |
73 | |
74 | SimpleRecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows, |
75 | const std::vector<std::shared_ptr<ArrayData>>& columns) |
76 | : RecordBatch(schema, num_rows) { |
77 | columns_ = columns; |
78 | boxed_columns_.resize(schema->num_fields()); |
79 | } |
80 | |
81 | std::shared_ptr<Array> column(int i) const override { |
82 | if (!boxed_columns_[i]) { |
83 | boxed_columns_[i] = MakeArray(columns_[i]); |
84 | } |
85 | DCHECK(boxed_columns_[i]); |
86 | return boxed_columns_[i]; |
87 | } |
88 | |
89 | std::shared_ptr<ArrayData> column_data(int i) const override { return columns_[i]; } |
90 | |
91 | Status AddColumn(int i, const std::shared_ptr<Field>& field, |
92 | const std::shared_ptr<Array>& column, |
93 | std::shared_ptr<RecordBatch>* out) const override { |
94 | DCHECK(field != nullptr); |
95 | DCHECK(column != nullptr); |
96 | |
97 | if (!field->type()->Equals(column->type())) { |
98 | return Status::Invalid("Column data type " , field->type()->name(), |
99 | " does not match field data type " , column->type()->name()); |
100 | } |
101 | if (column->length() != num_rows_) { |
102 | return Status::Invalid( |
103 | "Added column's length must match record batch's length. Expected length " , |
104 | num_rows_, " but got length " , column->length()); |
105 | } |
106 | |
107 | std::shared_ptr<Schema> new_schema; |
108 | RETURN_NOT_OK(schema_->AddField(i, field, &new_schema)); |
109 | |
110 | *out = RecordBatch::Make(new_schema, num_rows_, |
111 | internal::AddVectorElement(columns_, i, column->data())); |
112 | return Status::OK(); |
113 | } |
114 | |
115 | Status RemoveColumn(int i, std::shared_ptr<RecordBatch>* out) const override { |
116 | std::shared_ptr<Schema> new_schema; |
117 | RETURN_NOT_OK(schema_->RemoveField(i, &new_schema)); |
118 | |
119 | *out = RecordBatch::Make(new_schema, num_rows_, |
120 | internal::DeleteVectorElement(columns_, i)); |
121 | return Status::OK(); |
122 | } |
123 | |
124 | std::shared_ptr<RecordBatch> ReplaceSchemaMetadata( |
125 | const std::shared_ptr<const KeyValueMetadata>& metadata) const override { |
126 | auto new_schema = schema_->AddMetadata(metadata); |
127 | return RecordBatch::Make(new_schema, num_rows_, columns_); |
128 | } |
129 | |
130 | std::shared_ptr<RecordBatch> Slice(int64_t offset, int64_t length) const override { |
131 | std::vector<std::shared_ptr<ArrayData>> arrays; |
132 | arrays.reserve(num_columns()); |
133 | for (const auto& field : columns_) { |
134 | int64_t col_length = std::min(field->length - offset, length); |
135 | int64_t col_offset = field->offset + offset; |
136 | |
137 | auto new_data = std::make_shared<ArrayData>(*field); |
138 | new_data->length = col_length; |
139 | new_data->offset = col_offset; |
140 | new_data->null_count = kUnknownNullCount; |
141 | arrays.emplace_back(new_data); |
142 | } |
143 | int64_t num_rows = std::min(num_rows_ - offset, length); |
144 | return std::make_shared<SimpleRecordBatch>(schema_, num_rows, std::move(arrays)); |
145 | } |
146 | |
147 | Status Validate() const override { |
148 | if (static_cast<int>(columns_.size()) != schema_->num_fields()) { |
149 | return Status::Invalid("Number of columns did not match schema" ); |
150 | } |
151 | return RecordBatch::Validate(); |
152 | } |
153 | |
154 | private: |
155 | std::vector<std::shared_ptr<ArrayData>> columns_; |
156 | |
157 | // Caching boxed array data |
158 | mutable std::vector<std::shared_ptr<Array>> boxed_columns_; |
159 | }; |
160 | |
161 | RecordBatch::RecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows) |
162 | : schema_(schema), num_rows_(num_rows) {} |
163 | |
164 | std::shared_ptr<RecordBatch> RecordBatch::Make( |
165 | const std::shared_ptr<Schema>& schema, int64_t num_rows, |
166 | const std::vector<std::shared_ptr<Array>>& columns) { |
167 | return std::make_shared<SimpleRecordBatch>(schema, num_rows, columns); |
168 | } |
169 | |
170 | std::shared_ptr<RecordBatch> RecordBatch::Make( |
171 | const std::shared_ptr<Schema>& schema, int64_t num_rows, |
172 | std::vector<std::shared_ptr<Array>>&& columns) { |
173 | return std::make_shared<SimpleRecordBatch>(schema, num_rows, std::move(columns)); |
174 | } |
175 | |
176 | std::shared_ptr<RecordBatch> RecordBatch::Make( |
177 | const std::shared_ptr<Schema>& schema, int64_t num_rows, |
178 | std::vector<std::shared_ptr<ArrayData>>&& columns) { |
179 | return std::make_shared<SimpleRecordBatch>(schema, num_rows, std::move(columns)); |
180 | } |
181 | |
182 | std::shared_ptr<RecordBatch> RecordBatch::Make( |
183 | const std::shared_ptr<Schema>& schema, int64_t num_rows, |
184 | const std::vector<std::shared_ptr<ArrayData>>& columns) { |
185 | return std::make_shared<SimpleRecordBatch>(schema, num_rows, columns); |
186 | } |
187 | |
188 | const std::string& RecordBatch::column_name(int i) const { |
189 | return schema_->field(i)->name(); |
190 | } |
191 | |
192 | bool RecordBatch::Equals(const RecordBatch& other) const { |
193 | if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) { |
194 | return false; |
195 | } |
196 | |
197 | for (int i = 0; i < num_columns(); ++i) { |
198 | if (!column(i)->Equals(other.column(i))) { |
199 | return false; |
200 | } |
201 | } |
202 | |
203 | return true; |
204 | } |
205 | |
206 | bool RecordBatch::ApproxEquals(const RecordBatch& other) const { |
207 | if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) { |
208 | return false; |
209 | } |
210 | |
211 | for (int i = 0; i < num_columns(); ++i) { |
212 | if (!column(i)->ApproxEquals(other.column(i))) { |
213 | return false; |
214 | } |
215 | } |
216 | |
217 | return true; |
218 | } |
219 | |
220 | std::shared_ptr<RecordBatch> RecordBatch::Slice(int64_t offset) const { |
221 | return Slice(offset, this->num_rows() - offset); |
222 | } |
223 | |
224 | Status RecordBatch::Validate() const { |
225 | for (int i = 0; i < num_columns(); ++i) { |
226 | auto arr_shared = this->column_data(i); |
227 | const ArrayData& arr = *arr_shared; |
228 | if (arr.length != num_rows_) { |
229 | return Status::Invalid("Number of rows in column " , i, |
230 | " did not match batch: " , arr.length, " vs " , num_rows_); |
231 | } |
232 | const auto& schema_type = *schema_->field(i)->type(); |
233 | if (!arr.type->Equals(schema_type)) { |
234 | return Status::Invalid("Column " , i, |
235 | " type not match schema: " , arr.type->ToString(), " vs " , |
236 | schema_type.ToString()); |
237 | } |
238 | } |
239 | return Status::OK(); |
240 | } |
241 | |
242 | // ---------------------------------------------------------------------- |
243 | // Base record batch reader |
244 | |
245 | RecordBatchReader::~RecordBatchReader() {} |
246 | |
247 | Status RecordBatchReader::ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches) { |
248 | while (true) { |
249 | std::shared_ptr<RecordBatch> batch; |
250 | RETURN_NOT_OK(ReadNext(&batch)); |
251 | if (!batch) { |
252 | break; |
253 | } |
254 | batches->emplace_back(std::move(batch)); |
255 | } |
256 | return Status::OK(); |
257 | } |
258 | |
259 | Status RecordBatchReader::ReadAll(std::shared_ptr<Table>* table) { |
260 | std::vector<std::shared_ptr<RecordBatch>> batches; |
261 | RETURN_NOT_OK(ReadAll(&batches)); |
262 | return Table::FromRecordBatches(schema(), batches, table); |
263 | } |
264 | |
265 | } // namespace arrow |
266 | |